from typing import Dict, List, TypedDict, Sequence
from langgraph.graph import StateGraph, END
from langchain.schema import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain_community.tools.tavily_search import TavilySearchResults
import models
import prompts
import json
from operator import itemgetter
from langgraph.errors import GraphRecursionError


#######################################
###     Research Team Components    ###
#######################################
class ResearchState(TypedDict):
    workflow: List[str]
    topic: str
    research_data: Dict[str, str]
    next: str
    message_to_manager: str
    message_from_manager: str

#
#   Reserach Chains and Tools
#
qdrant_research_chain = (
        {"context": itemgetter("topic") | models.compression_retriever, "topic": itemgetter("topic")}
        | RunnablePassthrough.assign(context=itemgetter("context"))
        | {"response": prompts.research_query_prompt  | models.gpt4o_mini | StrOutputParser(), "context": itemgetter("context")}
    )

tavily_tool = TavilySearchResults(max_results=3)
query_chain = ( prompts.search_query_prompt | models.gpt4o_mini | StrOutputParser() )
tavily_simple = ({"tav_results": tavily_tool} | prompts.tavily_prompt | models.gpt4o_mini | StrOutputParser())
tavily_chain = (
    {"query": query_chain} | tavily_simple
)

research_supervisor_chain = (
    prompts.research_supervisor_prompt | models.gpt4o | StrOutputParser()
)

#
#   Reserach Node Defs
#
def query_qdrant(state: ResearchState) -> ResearchState:
    topic = state["topic"]
    result = qdrant_research_chain.invoke({"topic": topic})
    print(result)
    state["research_data"]["qdrant_results"] = result["response"]
    state['workflow'].append("query_qdrant")
    print(state['workflow'])

    return state

def web_search(state: ResearchState) -> ResearchState:
    topic = state["topic"]
    qdrant_results = state["research_data"].get("qdrant_results", "No previous results available.")
    result = tavily_chain.invoke({"topic": topic,"qdrant_results": qdrant_results })
    print(result)
    state["research_data"]["web_search_results"] = result
    state['workflow'].append("web_search")
    print(state['workflow'])
    return state

def research_supervisor(state):
    message_from_manager = state["message_from_manager"]
    collected_data = state["research_data"]
    topic = state['topic']
    supervisor_result = research_supervisor_chain.invoke({"message_from_manager": message_from_manager, "collected_data": collected_data, "topic": topic})
    lines = supervisor_result.split('\n')
    print(supervisor_result)
    for line in lines:
        if line.startswith('Next Action: '):
            state['next'] = line[len('Next Action: '):].strip()  # Extract the next action content
        elif line.startswith('Message to project manager: '):
            state['message_to_manager'] = line[len('Message to project manager: '):].strip()
    state['workflow'].append("research_supervisor")
    print(state['workflow'])
    return state

def research_end(state):
    state['workflow'].append("research_end")
    print(state['workflow'])
    return state

#######################################
###     Writing Team Components     ###
#######################################
class WritingState(TypedDict):
    workflow: List[str]
    topic: str
    research_data: Dict[str, str]
    draft_posts: Sequence[str]
    final_post: str
    next: str
    message_to_manager: str
    message_from_manager: str
    review_comments: str
    style_checked: bool

#
#   Writing Chains
#
writing_supervisor_chain = (
    prompts.writing_supervisor_prompt | models.gpt4o | StrOutputParser()
)

post_creation_chain = (
    prompts.post_creation_prompt | models.gpt4o_mini | StrOutputParser()
)

post_editor_chain = (
	prompts.post_editor_prompt | models.gpt4o | StrOutputParser()
)

post_review_chain = (
	prompts.post_review_prompt | models.gpt4o | StrOutputParser()
)

#
#   Writing Node Defs
#
def post_creation(state):
    topic = state['topic']
    drafts = state['draft_posts']
    collected_data = state["research_data"]
    review_comments = state['review_comments']
    results = post_creation_chain.invoke({"topic": topic, "collected_data": collected_data, "drafts": drafts, "review_comments": review_comments})
    print(results)
    state['draft_posts'].append(results)
    state['workflow'].append("post_creation")
    print(state['workflow'])
    return state

def post_editor(state):
    current_draft = state['draft_posts'][-1]
    styleguide = prompts.style_guide_text
    review_comments = state['review_comments']
    results = post_editor_chain.invoke({"current_draft": current_draft, "styleguide": styleguide, "review_comments": review_comments})
    print(results)
    state['draft_posts'].append(results)
    state['workflow'].append("post_editor")
    print(state['workflow'])
    return state

def post_review(state):
    print("post_review node")
    current_draft = state['draft_posts'][-1]
    styleguide = prompts.style_guide_text
    results = post_review_chain.invoke({"current_draft": current_draft, "styleguide": styleguide})
    print(results)
    data = json.loads(results.strip())
    state['review_comments'] = data["Comments on current draft"]
    if data["Draft Acceptable"] == 'Yes':
        state['final_post'] = state['draft_posts'][-1]
    state['workflow'].append("post_review")
    print(state['workflow'])
    return state

def writing_end(state):
    print("writing_end node")
    state['workflow'].append("writing_end")
    print(state['workflow'])
    return state

def writing_supervisor(state):
    print("writing_supervisor node")
    message_from_manager = state['message_from_manager']
    topic = state['topic']
    drafts = state['draft_posts']
    final_draft = state['final_post']
    review_comments = state['review_comments']
    supervisor_result = writing_supervisor_chain.invoke({"review_comments": review_comments, "message_from_manager": message_from_manager, "topic": topic, "drafts": drafts, "final_draft": final_draft})
    print(supervisor_result)
    lines = supervisor_result.split('\n')
    for line in lines:
        if line.startswith('Next Action: '):
            state['next'] = line[len('Next Action: '):].strip()  # Extract the next action content
        elif line.startswith('Message to project manager: '):
            state['message_to_manager'] = line[len('Message to project manager: '):].strip()
    state['workflow'].append("writing_supervisor")
    print(state['workflow'])
    return state

#######################################
###  Overarching Graph Components   ###
#######################################
class State(TypedDict):
    workflow: List[str]
    topic: str
    research_data: Dict[str, str]
    draft_posts: Sequence[str]
    final_post: str
    next: str
    user_input: str
    message_to_manager: str
    message_from_manager: str
    last_active_team :str
    next_team: str
    review_comments: str

#
#   Complete Graph Chains
#
overall_supervisor_chain = (
    prompts.overall_supervisor_prompt | models.gpt4o | StrOutputParser()
)

#
#   Complete Graph Node defs
#
def overall_supervisor(state):
    init_user_query = state["user_input"]
    message_to_manager = state['message_to_manager']
    last_active_team = state['last_active_team']
    final_post = state['final_post']
    supervisor_result = overall_supervisor_chain.invoke({"query": init_user_query, "message_to_manager": message_to_manager, "last_active_team": last_active_team, "final_post": final_post})
    print(supervisor_result)
    lines = supervisor_result.split('\n')
    for line in lines:
        if line.startswith('Next Action: '):
            state['next_team'] = line[len('Next Action: '):].strip()  # Extract the next action content
        elif line.startswith('Extracted Topic: '):
            state['topic'] = line[len('Extracted Topic: '):].strip()  # Extract the next action content
        elif line.startswith('Message to supervisor: '):
            state['message_from_manager'] = line[len('Message to supervisor: '):].strip()  # Extract the next action content
    state['workflow'].append("overall_supervisor")
    print(state['workflow'])
    return state

#######################################
###         Graph structures        ###
#######################################

#
#   Reserach Graph Nodes
#
research_graph = StateGraph(ResearchState)
research_graph.add_node("query_qdrant", query_qdrant)
research_graph.add_node("web_search", web_search)
research_graph.add_node("research_supervisor", research_supervisor)
research_graph.add_node("research_end", research_end)
#
#   Reserach Graph Edges
#
research_graph.set_entry_point("research_supervisor")
research_graph.add_edge("query_qdrant", "research_supervisor")
research_graph.add_edge("web_search", "research_supervisor")
research_graph.add_conditional_edges(
    "research_supervisor",
    lambda x: x["next"],
    {"query_qdrant": "query_qdrant", "web_search": "web_search", "FINISH": "research_end"},
)
research_graph_comp = research_graph.compile()

#
#   Writing Graph Nodes
#
writing_graph = StateGraph(WritingState)
writing_graph.add_node("post_creation", post_creation)
writing_graph.add_node("post_editor", post_editor)
writing_graph.add_node("post_review", post_review)
writing_graph.add_node("writing_supervisor", writing_supervisor)
writing_graph.add_node("writing_end", writing_end)
#
#   Writing Graph Edges
#
writing_graph.set_entry_point("writing_supervisor")
writing_graph.add_edge("post_creation", "post_editor")
writing_graph.add_edge("post_editor", "post_review")
writing_graph.add_edge("post_review", "writing_supervisor")
writing_graph.add_conditional_edges(
    "writing_supervisor",
    lambda x: x["next"],
    {"NEW DRAFT": "post_creation", 
     "FINISH": "writing_end"},
)
writing_graph_comp = writing_graph.compile()

#
#   Complete Graph Nodes
#
overall_graph = StateGraph(State)
overall_graph.add_node("overall_supervisor", overall_supervisor)
overall_graph.add_node("research_team_graph", research_graph_comp)
overall_graph.add_node("writing_team_graph", writing_graph_comp)
#
#   Complete Graph Edges
#
overall_graph.set_entry_point("overall_supervisor")
overall_graph.add_edge("research_team_graph", "overall_supervisor")
overall_graph.add_edge("writing_team_graph", "overall_supervisor")
overall_graph.add_conditional_edges(
    "overall_supervisor",
    lambda x: x["next_team"],
    {"research_team": "research_team_graph",
     "writing_team": "writing_team_graph", 
     "FINISH": END},
)
app = overall_graph.compile()


#######################################
###         Run method              ###
#######################################

def getSocialMediaPost(userInput: str) -> str:
    finalPost = ""
    initial_state = State(
        workflow = [],
        topic= "",
        research_data = {},
        draft_posts = [],
        final_post = [],
        next = [],
        next_team = [],
        user_input=userInput,
        message_to_manager="",
        message_from_manager="",
        last_active_team="",
        review_comments=""
    )
    results = app.invoke(initial_state, {"recursion_limit": 40})
    try:
        results = app.invoke(initial_state, {"recursion_limit": 40})
    except GraphRecursionError:
        return "Recursion Error"
    finalPost = results['final_post']
    return finalPost