vansin's picture
feat: update
dc9e27a
import os
from copy import deepcopy
from datetime import datetime
from lagent.actions import AsyncWebBrowser, WebBrowser
from lagent.agents.stream import get_plugin_prompt
from lagent.prompts import InterpreterParser, PluginParser
from lagent.utils import create_object
from . import models as llm_factory
from .mindsearch_agent import AsyncMindSearchAgent, MindSearchAgent
from .mindsearch_prompt import (
FINAL_RESPONSE_CN,
FINAL_RESPONSE_EN,
GRAPH_PROMPT_CN,
GRAPH_PROMPT_EN,
searcher_context_template_cn,
searcher_context_template_en,
searcher_input_template_cn,
searcher_input_template_en,
searcher_system_prompt_cn,
searcher_system_prompt_en,
)
LLM = {}
def init_agent(lang="cn",
model_format="internlm_server",
search_engine="BingSearch",
use_async=False):
mode = "async" if use_async else "sync"
llm = LLM.get(model_format, {}).get(mode)
if llm is None:
llm_cfg = deepcopy(getattr(llm_factory, model_format))
if llm_cfg is None:
raise NotImplementedError
if use_async:
cls_name = (
llm_cfg["type"].split(".")[-1] if isinstance(
llm_cfg["type"], str) else llm_cfg["type"].__name__)
llm_cfg["type"] = f"lagent.llms.Async{cls_name}"
llm = create_object(llm_cfg)
LLM.setdefault(model_format, {}).setdefault(mode, llm)
date = datetime.now().strftime("The current date is %Y-%m-%d.")
plugins = [(dict(
type=AsyncWebBrowser if use_async else WebBrowser,
searcher_type=search_engine,
topk=6,
secret_id=os.getenv("TENCENT_SEARCH_SECRET_ID"),
secret_key=os.getenv("TENCENT_SEARCH_SECRET_KEY"),
) if search_engine == "TencentSearch" else dict(
type=AsyncWebBrowser if use_async else WebBrowser,
searcher_type=search_engine,
topk=6,
api_key=os.getenv("WEB_SEARCH_API_KEY"),
))]
agent = (AsyncMindSearchAgent if use_async else MindSearchAgent)(
llm=llm,
template=date,
output_format=InterpreterParser(
template=GRAPH_PROMPT_CN if lang == "cn" else GRAPH_PROMPT_EN),
searcher_cfg=dict(
llm=llm,
plugins=plugins,
template=date,
output_format=PluginParser(
template=searcher_system_prompt_cn
if lang == "cn" else searcher_system_prompt_en,
tool_info=get_plugin_prompt(plugins),
),
user_input_template=(searcher_input_template_cn if lang == "cn"
else searcher_input_template_en),
user_context_template=(searcher_context_template_cn if lang == "cn"
else searcher_context_template_en),
),
summary_prompt=FINAL_RESPONSE_CN
if lang == "cn" else FINAL_RESPONSE_EN,
max_turn=10,
)
return agent