maxiaolong03
		
	commited on
		
		
					Commit 
							
							·
						
						3a5faf4
	
1
								Parent(s):
							
							bd55f1d
								
add files
Browse files- .gitignore +1 -0
- app.py +881 -0
- assets/logo.png +0 -0
- bot_requests.py +388 -0
- data/coffee.txt +63 -0
- requirements.txt +14 -0
    	
        .gitignore
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            .idea
         | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,881 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            """This script provides a simple web interface that allows users to interact with"""
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            import argparse
         | 
| 18 | 
            +
            import base64
         | 
| 19 | 
            +
            from collections import namedtuple
         | 
| 20 | 
            +
            from functools import partial
         | 
| 21 | 
            +
            import hashlib
         | 
| 22 | 
            +
            import json
         | 
| 23 | 
            +
            import logging
         | 
| 24 | 
            +
            import faiss
         | 
| 25 | 
            +
            import os
         | 
| 26 | 
            +
            from argparse import ArgumentParser
         | 
| 27 | 
            +
            import textwrap
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            import gradio as gr
         | 
| 30 | 
            +
            import numpy as np
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            from bot_requests import BotClient
         | 
| 33 | 
            +
            # from faiss_text_database import FaissTextDatabase
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            os.environ["NO_PROXY"] = "localhost,127.0.0.1"  # Disable proxy
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            logging.root.setLevel(logging.INFO)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            FILE_URL_DEFAULT = "data/coffee.txt"
         | 
| 40 | 
            +
            RELEVANT_PASSAGE_DEFAULT = textwrap.dedent("""\
         | 
| 41 | 
            +
                1675年时,英格兰就有3000多家咖啡馆;启蒙运动时期,咖啡馆成为民众深入讨论宗教和政治的聚集地,
         | 
| 42 | 
            +
                1670年代的英国国王查理二世就曾试图取缔咖啡馆。这一时期的英国人认为咖啡具有药用价值,
         | 
| 43 | 
            +
                甚至名医也会推荐将咖啡用于医疗。"""
         | 
| 44 | 
            +
            )
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            QUERY_REWRITE_PROMPT = textwrap.dedent("""\
         | 
| 47 | 
            +
                你是一个擅长问答系统和信息检索的大模型助手。
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                请根据用户提出的问题,判断是否需要调用文档检索系统来获取答案:
         | 
| 50 | 
            +
                - 若问题属于常识性、定义性或答案明确,不依赖外部资料,请标记为 "is_search": false;
         | 
| 51 | 
            +
                - 若问题涉及事实查证、具体数据、文档内容等,必须依赖资料检索,请标记为 "is_search": true,并将问题拆解成多个可用于检索的子问题。
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                要求:
         | 
| 54 | 
            +
                1. 子问题应语义清晰、独立,适合用于检索;
         | 
| 55 | 
            +
                2. 只在**确有必要**的情况下拆解,最多不超过 5 个,不要为了凑满数量而输出冗余子问题;
         | 
| 56 | 
            +
                3. 输出为严格的 JSON 格式,无多余注释。
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                【用户当前问题】:
         | 
| 59 | 
            +
                {query}
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                【输出格式】:
         | 
| 62 | 
            +
                请仅输出如下格式的内容(符合 JSON 规范,无多余注释):
         | 
| 63 | 
            +
                ```
         | 
| 64 | 
            +
                {{
         | 
| 65 | 
            +
                    "is_search": true 或 false,
         | 
| 66 | 
            +
                    "sub_query_list": ["子问题1","子问题2","..."]
         | 
| 67 | 
            +
                }}
         | 
| 68 | 
            +
                ```"""
         | 
| 69 | 
            +
            )
         | 
| 70 | 
            +
             | 
| 71 | 
            +
            ANSWER_PROMPT = textwrap.dedent(
         | 
| 72 | 
            +
                """\
         | 
| 73 | 
            +
                你是一个乐于助人且信息丰富的机器人,使用下面提供的参考段落中的文本来回答问题。
         | 
| 74 | 
            +
                请务必用完整的句子回答,内容要全面,包括所有相关的背景信息。
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                然而,你的对话对象是非技术人员,所以请务必分解复杂的概念,并使用友好和对话式的语气。
         | 
| 77 | 
            +
                如果段落与答案无关,你可以忽略它。
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                问题:'{query}'
         | 
| 80 | 
            +
                段落:'{relevant_passage}'
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                答案:"""
         | 
| 83 | 
            +
            )
         | 
| 84 | 
            +
            QUERY_DEFAULT = "1675 年时,英格兰有多少家咖啡馆?"
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
            def get_args() -> argparse.Namespace:
         | 
| 88 | 
            +
                """
         | 
| 89 | 
            +
                Parse and return command line arguments for the ERNIE models web chat demo.
         | 
| 90 | 
            +
                Configures server settings, model endpoints, and document processing parameters.
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                Returns:
         | 
| 93 | 
            +
                    argparse.Namespace: Parsed command line arguments containing:
         | 
| 94 | 
            +
                    - server_port: Demo server port (default: 8333)
         | 
| 95 | 
            +
                    - server_name: Demo server host (default: "0.0.0.0")
         | 
| 96 | 
            +
                    - model_urls: Endpoints for ERNIE and Qianfan models
         | 
| 97 | 
            +
                    - document_processing: Chunk size, FAISS index and text DB paths
         | 
| 98 | 
            +
                """
         | 
| 99 | 
            +
                parser = ArgumentParser(description="ERNIE models web chat demo.")
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                parser.add_argument(
         | 
| 102 | 
            +
                    "--server-port", type=int, default=7860, help="Demo server port."
         | 
| 103 | 
            +
                )
         | 
| 104 | 
            +
                parser.add_argument(
         | 
| 105 | 
            +
                    "--server-name", type=str, default="0.0.0.0", help="Demo server name."
         | 
| 106 | 
            +
                )
         | 
| 107 | 
            +
                parser.add_argument(
         | 
| 108 | 
            +
                    "--max_char", type=int, default=8000, help="Maximum character limit for messages."
         | 
| 109 | 
            +
                )
         | 
| 110 | 
            +
                parser.add_argument(
         | 
| 111 | 
            +
                    "--max_retry_num", type=int, default=3, help="Maximum retry number for request."
         | 
| 112 | 
            +
                )
         | 
| 113 | 
            +
                parser.add_argument(
         | 
| 114 | 
            +
                    "--eb45t_model_url", 
         | 
| 115 | 
            +
                    type=str, 
         | 
| 116 | 
            +
                    default="https://qianfan.baidubce.com/v2",
         | 
| 117 | 
            +
                    help="Model URL for multimodal model."
         | 
| 118 | 
            +
                )
         | 
| 119 | 
            +
                parser.add_argument(
         | 
| 120 | 
            +
                    "--qianfan_url", 
         | 
| 121 | 
            +
                    type=str, 
         | 
| 122 | 
            +
                    default="https://qianfan.baidubce.com/v2", 
         | 
| 123 | 
            +
                    help="Qianfan URL."
         | 
| 124 | 
            +
                )
         | 
| 125 | 
            +
                parser.add_argument(
         | 
| 126 | 
            +
                    "--qianfan_api_key", 
         | 
| 127 | 
            +
                    type=str, 
         | 
| 128 | 
            +
                    default=os.environ.get("API_KEY"),
         | 
| 129 | 
            +
                    help="Qianfan API key."
         | 
| 130 | 
            +
                )
         | 
| 131 | 
            +
                parser.add_argument(
         | 
| 132 | 
            +
                    "--embedding_model", 
         | 
| 133 | 
            +
                    type=str, 
         | 
| 134 | 
            +
                    default="embedding-v1", 
         | 
| 135 | 
            +
                    help="Embedding model name."
         | 
| 136 | 
            +
                )
         | 
| 137 | 
            +
                parser.add_argument(
         | 
| 138 | 
            +
                    "--chunk_size", 
         | 
| 139 | 
            +
                    type=int, 
         | 
| 140 | 
            +
                    default=512, 
         | 
| 141 | 
            +
                    help="Chunk size for splitting long documents."
         | 
| 142 | 
            +
                )
         | 
| 143 | 
            +
                parser.add_argument(
         | 
| 144 | 
            +
                    "--faiss_index_path", 
         | 
| 145 | 
            +
                    type=str, 
         | 
| 146 | 
            +
                    default="data/faiss_index", 
         | 
| 147 | 
            +
                    help="Faiss index path."
         | 
| 148 | 
            +
                )
         | 
| 149 | 
            +
                parser.add_argument(
         | 
| 150 | 
            +
                    "--text_db_path", 
         | 
| 151 | 
            +
                    type=str, 
         | 
| 152 | 
            +
                    default="data/text_db.jsonl", 
         | 
| 153 | 
            +
                    help="Text database path."
         | 
| 154 | 
            +
                )
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                args = parser.parse_args()
         | 
| 157 | 
            +
                return args
         | 
| 158 | 
            +
             | 
| 159 | 
            +
             | 
| 160 | 
            +
            class FaissTextDatabase:
         | 
| 161 | 
            +
                """
         | 
| 162 | 
            +
                A vector database for text retrieval using FAISS (Facebook AI Similarity Search).
         | 
| 163 | 
            +
                Provides efficient similarity search and document management capabilities.
         | 
| 164 | 
            +
                """
         | 
| 165 | 
            +
                def __init__(self, args, bot_client: BotClient, embedding_dim: int=384):
         | 
| 166 | 
            +
                    """
         | 
| 167 | 
            +
                    Initialize the FaissTextDatabase.
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                    Args:
         | 
| 170 | 
            +
                        args: arguments for initialization
         | 
| 171 | 
            +
                        bot_client: instance of BotClient
         | 
| 172 | 
            +
                        embedding_dim: dimension of the embedding vector
         | 
| 173 | 
            +
                    """
         | 
| 174 | 
            +
                    self.logger = logging.getLogger(__name__)
         | 
| 175 | 
            +
                    
         | 
| 176 | 
            +
                    self.bot_client = bot_client
         | 
| 177 | 
            +
                    self.faiss_index_path = getattr(args, "faiss_index_path", "data/faiss_index")
         | 
| 178 | 
            +
                    self.text_db_path = getattr(args, "text_db_path", "data/text_db.jsonl")
         | 
| 179 | 
            +
                    self.embedding_dim = embedding_dim
         | 
| 180 | 
            +
                    
         | 
| 181 | 
            +
                    # If faiss_index_path exists, load it and text_db_path
         | 
| 182 | 
            +
                    if os.path.exists(self.faiss_index_path) and os.path.exists(self.text_db_path):
         | 
| 183 | 
            +
                        self.index = faiss.read_index(self.faiss_index_path)
         | 
| 184 | 
            +
                        with open(self.text_db_path, 'r', encoding='utf-8') as f:
         | 
| 185 | 
            +
                            self.text_db = json.load(f)
         | 
| 186 | 
            +
                    else:
         | 
| 187 | 
            +
                        self.index = faiss.IndexFlatIP(self.embedding_dim)
         | 
| 188 | 
            +
                        self.text_db = {
         | 
| 189 | 
            +
                            "file_md5s": [],  # Save file_md5s to avoid duplicates
         | 
| 190 | 
            +
                            "chunks": []      # Save chunks
         | 
| 191 | 
            +
                        }
         | 
| 192 | 
            +
                
         | 
| 193 | 
            +
                def calculate_md5(self, file_path: str) -> str:
         | 
| 194 | 
            +
                    """
         | 
| 195 | 
            +
                    Calculate the MD5 hash of a file
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    Args:
         | 
| 198 | 
            +
                        file_path: the path of the source file
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                    Returns:
         | 
| 201 | 
            +
                        str: the MD5 hash
         | 
| 202 | 
            +
                    """
         | 
| 203 | 
            +
                    with open(file_path, "rb") as f:
         | 
| 204 | 
            +
                        return hashlib.md5(f.read()).hexdigest()
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                def is_file_processed(self, file_path: str) -> bool:
         | 
| 207 | 
            +
                    """
         | 
| 208 | 
            +
                    Check if the file has been processed before
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    Args:
         | 
| 211 | 
            +
                        file_path: the path of the source file
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    Returns:
         | 
| 214 | 
            +
                        bool: whether the file has been processed
         | 
| 215 | 
            +
                    """
         | 
| 216 | 
            +
                    file_md5 = self.calculate_md5(file_path)
         | 
| 217 | 
            +
                    return file_md5 in self.text_db["file_md5s"]
         | 
| 218 | 
            +
                
         | 
| 219 | 
            +
                def add_embeddings(self, file_path: str, segments: list[str], progress_bar: gr.Progress=None) -> bool:
         | 
| 220 | 
            +
                    """
         | 
| 221 | 
            +
                    Stores document embeddings in FAISS database after checking for duplicates.
         | 
| 222 | 
            +
                    Generates embeddings for each text segment, updates the FAISS index and metadata database,
         | 
| 223 | 
            +
                    and persists changes to disk. Includes optional progress tracking for Gradio interfaces.
         | 
| 224 | 
            +
                    
         | 
| 225 | 
            +
                    Args:
         | 
| 226 | 
            +
                        file_path: the path of the source file
         | 
| 227 | 
            +
                        segments: the list of segments
         | 
| 228 | 
            +
                        progress_bar: the progress bar object
         | 
| 229 | 
            +
                    
         | 
| 230 | 
            +
                    Returns:
         | 
| 231 | 
            +
                        bool: whether the operation was successful
         | 
| 232 | 
            +
                    """
         | 
| 233 | 
            +
                    file_md5 = self.calculate_md5(file_path)
         | 
| 234 | 
            +
                    if file_md5 in self.text_db["file_md5s"]:
         | 
| 235 | 
            +
                        self.logger.info("File already processed: {file_path} (MD5: {file_md5})".format(
         | 
| 236 | 
            +
                            file_path=file_path, 
         | 
| 237 | 
            +
                            file_md5=file_md5
         | 
| 238 | 
            +
                        ))
         | 
| 239 | 
            +
                        return False
         | 
| 240 | 
            +
                    
         | 
| 241 | 
            +
                    # Generate embeddings
         | 
| 242 | 
            +
                    vectors = []
         | 
| 243 | 
            +
                    file_name = os.path.basename(file_path)
         | 
| 244 | 
            +
                    for i, segment in  enumerate(segments):
         | 
| 245 | 
            +
                        vectors.append(self.bot_client.embed_fn(segment))
         | 
| 246 | 
            +
                        if progress_bar is not None:
         | 
| 247 | 
            +
                            progress_bar((i + 1) / len(segments), desc=file_name + " Processing...")
         | 
| 248 | 
            +
                    vectors = np.array(vectors)
         | 
| 249 | 
            +
                    self.index.add(vectors.astype('float32'))
         | 
| 250 | 
            +
                    
         | 
| 251 | 
            +
                    start_id = len(self.text_db["chunks"])
         | 
| 252 | 
            +
                    for i, text in enumerate(segments):
         | 
| 253 | 
            +
                        self.text_db["chunks"].append({
         | 
| 254 | 
            +
                            "file_md5": file_md5,
         | 
| 255 | 
            +
                            "text": text,
         | 
| 256 | 
            +
                            "vector_id": start_id + i
         | 
| 257 | 
            +
                        })
         | 
| 258 | 
            +
                    
         | 
| 259 | 
            +
                    self.text_db["file_md5s"].append(file_md5)
         | 
| 260 | 
            +
                    self.save()
         | 
| 261 | 
            +
                    return True
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                def search_with_context(self, query: str, context_size: int=2) -> str:
         | 
| 264 | 
            +
                    """
         | 
| 265 | 
            +
                    Finds the most relevant text chunk for a query and includes surrounding context.
         | 
| 266 | 
            +
                    Uses FAISS to find the closest matching embedding, then retrieves adjacent chunks
         | 
| 267 | 
            +
                    from the same source document to provide better context understanding.
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                    Args:
         | 
| 270 | 
            +
                        query: the input query string
         | 
| 271 | 
            +
                        context_size: the number of surrounding chunks to include
         | 
| 272 | 
            +
                    
         | 
| 273 | 
            +
                    Returns:
         | 
| 274 | 
            +
                        str: the relevant chunk with context
         | 
| 275 | 
            +
                    """
         | 
| 276 | 
            +
                    query_vector = np.array([self.bot_client.embed_fn(query)]).astype('float32')
         | 
| 277 | 
            +
                    distances, indices = self.index.search(query_vector, 1)
         | 
| 278 | 
            +
                    
         | 
| 279 | 
            +
                    target_idx = indices[0][0]
         | 
| 280 | 
            +
                    target_chunk = self.text_db["chunks"][target_idx]
         | 
| 281 | 
            +
                    target_file_md5 = target_chunk["file_md5"]
         | 
| 282 | 
            +
                    self.logger.info("Similarity: {}".format(distances[0][0]))
         | 
| 283 | 
            +
                    self.logger.info("Target Chunk: {}".format(self.text_db["chunks"][target_idx]["text"]))
         | 
| 284 | 
            +
                    
         | 
| 285 | 
            +
                    # Get the context 
         | 
| 286 | 
            +
                    start = max(0, target_idx - context_size)
         | 
| 287 | 
            +
                    end = min(len(self.text_db["chunks"]) - 1, target_idx + context_size)
         | 
| 288 | 
            +
                    result = ""
         | 
| 289 | 
            +
                    for pos in range(start, end + 1):
         | 
| 290 | 
            +
                        if self.text_db["chunks"][pos]["file_md5"] == target_file_md5:
         | 
| 291 | 
            +
                            result += self.text_db["chunks"][pos]["text"] + "\n"
         | 
| 292 | 
            +
                    
         | 
| 293 | 
            +
                    return result
         | 
| 294 | 
            +
                
         | 
| 295 | 
            +
                def save(self) -> None:
         | 
| 296 | 
            +
                    """Save the database to disk."""
         | 
| 297 | 
            +
                    faiss.write_index(self.index, self.faiss_index_path)
         | 
| 298 | 
            +
                    
         | 
| 299 | 
            +
                    with open(self.text_db_path, 'w', encoding='utf-8') as f:
         | 
| 300 | 
            +
                        json.dump(self.text_db, f, ensure_ascii=False, indent=2)
         | 
| 301 | 
            +
             | 
| 302 | 
            +
             | 
| 303 | 
            +
            class GradioEvents(object):
         | 
| 304 | 
            +
                """
         | 
| 305 | 
            +
                Manages event handling and UI interactions for Gradio applications.
         | 
| 306 | 
            +
                Provides methods to process user inputs, trigger callbacks, and update interface components.
         | 
| 307 | 
            +
                """
         | 
| 308 | 
            +
                @staticmethod
         | 
| 309 | 
            +
                def chat_stream(
         | 
| 310 | 
            +
                    query: str, 
         | 
| 311 | 
            +
                    task_history: list, 
         | 
| 312 | 
            +
                    model: str, 
         | 
| 313 | 
            +
                    bot_client: BotClient, 
         | 
| 314 | 
            +
                    faiss_db: FaissTextDatabase,
         | 
| 315 | 
            +
                ) -> dict:
         | 
| 316 | 
            +
                    """
         | 
| 317 | 
            +
                    Streams chatbot responses by processing queries with context from history and FAISS database.
         | 
| 318 | 
            +
                    Integrates language model generation with knowledge retrieval to produce dynamic responses.
         | 
| 319 | 
            +
                    Yields response events in real-time for interactive conversation experiences.
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                    Args:
         | 
| 322 | 
            +
                        query (str): The query string.
         | 
| 323 | 
            +
                        task_history (list): The task history record list.
         | 
| 324 | 
            +
                        model (Model): The model used to generate responses.
         | 
| 325 | 
            +
                        bot_client (BotClient): The chatbot client object.
         | 
| 326 | 
            +
                        faiss_db (FaissTextDatabase): The FAISS database object.
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                    Yields:
         | 
| 329 | 
            +
                        dict: A dictionary containing the event type and its corresponding content.
         | 
| 330 | 
            +
                    """
         | 
| 331 | 
            +
                    search_info_result = GradioEvents.get_sub_query(query, model, bot_client)
         | 
| 332 | 
            +
                    if search_info_result.get("is_search", False) and search_info_result.get("sub_query_list", []):
         | 
| 333 | 
            +
                        relevant_passage = GradioEvents.get_relevant_passage(
         | 
| 334 | 
            +
                            search_info_result["sub_query_list"], 
         | 
| 335 | 
            +
                            faiss_db
         | 
| 336 | 
            +
                        )
         | 
| 337 | 
            +
                        yield {"type": "relevant_passage", "content": relevant_passage}
         | 
| 338 | 
            +
                        input = ANSWER_PROMPT.format(query=query, relevant_passage=relevant_passage)
         | 
| 339 | 
            +
                    else:
         | 
| 340 | 
            +
                        input = query
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                    conversation = []
         | 
| 343 | 
            +
                    for query_h, response_h in task_history:
         | 
| 344 | 
            +
                        conversation.append({"role": "user", "content": query_h})
         | 
| 345 | 
            +
                        conversation.append({"role": "assistant", "content": response_h})
         | 
| 346 | 
            +
                    conversation.append({"role": "user", "content": input})
         | 
| 347 | 
            +
                    
         | 
| 348 | 
            +
                    try:
         | 
| 349 | 
            +
                        req_data = {"messages": conversation}
         | 
| 350 | 
            +
                        for chunk in bot_client.process_stream(model, req_data):
         | 
| 351 | 
            +
                            if "error" in chunk:
         | 
| 352 | 
            +
                                raise Exception(chunk["error"])
         | 
| 353 | 
            +
                            
         | 
| 354 | 
            +
                            message = chunk.get("choices", [{}])[0].get("delta", {})
         | 
| 355 | 
            +
                            content = message.get("content", "")
         | 
| 356 | 
            +
                            reasoning_content = message.get("reasoning_content", "")
         | 
| 357 | 
            +
                            
         | 
| 358 | 
            +
                            if reasoning_content:
         | 
| 359 | 
            +
                                yield {"type": "thinking", "content": reasoning_content}
         | 
| 360 | 
            +
                            if content:
         | 
| 361 | 
            +
                                yield {"type": "answer", "content": content}
         | 
| 362 | 
            +
                                
         | 
| 363 | 
            +
                    except Exception as e:
         | 
| 364 | 
            +
                        raise gr.Error("Exception: " + repr(e))
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                @staticmethod
         | 
| 367 | 
            +
                def predict_stream(
         | 
| 368 | 
            +
                    query: str, 
         | 
| 369 | 
            +
                    chatbot: list, 
         | 
| 370 | 
            +
                    task_history: list,
         | 
| 371 | 
            +
                    model: str, 
         | 
| 372 | 
            +
                    bot_client: BotClient,
         | 
| 373 | 
            +
                    faiss_db: FaissTextDatabase
         | 
| 374 | 
            +
                ) -> tuple:
         | 
| 375 | 
            +
                    """
         | 
| 376 | 
            +
                    Generates streaming responses by combining model predictions with knowledge retrieval.
         | 
| 377 | 
            +
                    Processes user queries using conversation history and FAISS database context,
         | 
| 378 | 
            +
                    yielding updated chat messages and relevant passages in real-time.
         | 
| 379 | 
            +
             | 
| 380 | 
            +
                    Args:
         | 
| 381 | 
            +
                        query (str): The content of the user's input query.
         | 
| 382 | 
            +
                        chatbot (list): The chatbot's historical message list.
         | 
| 383 | 
            +
                        task_history (list): The task history record list. 
         | 
| 384 | 
            +
                        model (Model): The model used to generate responses.
         | 
| 385 | 
            +
                        bot_client (object): The chatbot client object.
         | 
| 386 | 
            +
                        faiss_db (FaissTextDatabase): The FAISS database instance.
         | 
| 387 | 
            +
             | 
| 388 | 
            +
                    Yields:
         | 
| 389 | 
            +
                        tuple: A tuple containing the updated chatbot's message list and the relevant passage.
         | 
| 390 | 
            +
                    """
         | 
| 391 | 
            +
                    query = query if query else QUERY_DEFAULT
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                    logging.info("User: {}".format(query))
         | 
| 394 | 
            +
                    chatbot.append({"role": "user", "content": query})  
         | 
| 395 | 
            +
                    
         | 
| 396 | 
            +
                    # First yield the chatbot with user message
         | 
| 397 | 
            +
                    yield chatbot, None
         | 
| 398 | 
            +
             | 
| 399 | 
            +
                    new_texts = GradioEvents.chat_stream(
         | 
| 400 | 
            +
                        query, 
         | 
| 401 | 
            +
                        task_history, 
         | 
| 402 | 
            +
                        model, 
         | 
| 403 | 
            +
                        bot_client,
         | 
| 404 | 
            +
                        faiss_db,
         | 
| 405 | 
            +
                    )
         | 
| 406 | 
            +
                    reasoning_content = ""
         | 
| 407 | 
            +
                    response = ""
         | 
| 408 | 
            +
                    has_thinking = False
         | 
| 409 | 
            +
                    current_relevant_passage = None
         | 
| 410 | 
            +
                    for new_text in new_texts:
         | 
| 411 | 
            +
                        if not isinstance(new_text, dict):
         | 
| 412 | 
            +
                            continue
         | 
| 413 | 
            +
                            
         | 
| 414 | 
            +
                        if new_text.get("type") == "embedding":
         | 
| 415 | 
            +
                            current_relevant_passage = new_text["content"]
         | 
| 416 | 
            +
                            yield chatbot, current_relevant_passage
         | 
| 417 | 
            +
                            continue
         | 
| 418 | 
            +
                        elif new_text.get("type") == "relevant_passage":
         | 
| 419 | 
            +
                            current_relevant_passage = new_text["content"]
         | 
| 420 | 
            +
                            yield chatbot, current_relevant_passage
         | 
| 421 | 
            +
                            continue
         | 
| 422 | 
            +
                        elif new_text.get("type") == "thinking":
         | 
| 423 | 
            +
                            has_thinking = True
         | 
| 424 | 
            +
                            reasoning_content += new_text["content"]
         | 
| 425 | 
            +
                        elif new_text.get("type") == "answer":
         | 
| 426 | 
            +
                            response += new_text["content"]
         | 
| 427 | 
            +
                        
         | 
| 428 | 
            +
                        # Remove previous thinking message if exists
         | 
| 429 | 
            +
                        if chatbot[-1].get("role") == "assistant":
         | 
| 430 | 
            +
                            chatbot.pop(-1)
         | 
| 431 | 
            +
                        
         | 
| 432 | 
            +
                        content = ""
         | 
| 433 | 
            +
                        if has_thinking:
         | 
| 434 | 
            +
                            content = "**思考过程:**<br>{}<br>".format(reasoning_content)
         | 
| 435 | 
            +
                        if response:
         | 
| 436 | 
            +
                            if has_thinking:
         | 
| 437 | 
            +
                                content += "<br><br>**最终回答:**<br>{}".format(response)
         | 
| 438 | 
            +
                            else:
         | 
| 439 | 
            +
                                content = response
         | 
| 440 | 
            +
                        
         | 
| 441 | 
            +
                        if content:
         | 
| 442 | 
            +
                            chatbot.append({"role": "assistant", "content": content})
         | 
| 443 | 
            +
                            yield chatbot, current_relevant_passage
         | 
| 444 | 
            +
             | 
| 445 | 
            +
                    logging.info("History: {}".format(task_history))
         | 
| 446 | 
            +
                    task_history.append((query, response))    
         | 
| 447 | 
            +
                    logging.info("ERNIE models: {}".format(response))
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                @staticmethod
         | 
| 450 | 
            +
                def regenerate(
         | 
| 451 | 
            +
                    chatbot: list, 
         | 
| 452 | 
            +
                    task_history: list, 
         | 
| 453 | 
            +
                    model: str, 
         | 
| 454 | 
            +
                    bot_client: BotClient,
         | 
| 455 | 
            +
                    faiss_db: FaissTextDatabase
         | 
| 456 | 
            +
                ) -> tuple:
         | 
| 457 | 
            +
                    """
         | 
| 458 | 
            +
                    Regenerate the chatbot's response based on the latest user query
         | 
| 459 | 
            +
             | 
| 460 | 
            +
                    Args:
         | 
| 461 | 
            +
                        chatbot (list): Chat history list
         | 
| 462 | 
            +
                        task_history (list): Task history
         | 
| 463 | 
            +
                        model (str): Model name to use
         | 
| 464 | 
            +
                        bot_client (BotClient): Bot request client instance
         | 
| 465 | 
            +
                        faiss_db (FaissTextDatabase): Faiss database instance
         | 
| 466 | 
            +
             | 
| 467 | 
            +
                    Yields:
         | 
| 468 | 
            +
                        tuple: Updated chatbot and relevant_passage
         | 
| 469 | 
            +
                    """
         | 
| 470 | 
            +
                    if not task_history:
         | 
| 471 | 
            +
                        yield chatbot, None
         | 
| 472 | 
            +
                        return
         | 
| 473 | 
            +
                    # Pop the last user query and bot response from task_history
         | 
| 474 | 
            +
                    item = task_history.pop(-1)
         | 
| 475 | 
            +
                    while len(chatbot) != 0 and chatbot[-1].get("role") == "assistant":
         | 
| 476 | 
            +
                        chatbot.pop(-1)
         | 
| 477 | 
            +
                    chatbot.pop(-1)
         | 
| 478 | 
            +
             | 
| 479 | 
            +
                    for chunk, relevant_passage in GradioEvents.predict_stream(
         | 
| 480 | 
            +
                        item[0], 
         | 
| 481 | 
            +
                        chatbot, 
         | 
| 482 | 
            +
                        task_history, 
         | 
| 483 | 
            +
                        model, 
         | 
| 484 | 
            +
                        bot_client,
         | 
| 485 | 
            +
                        faiss_db
         | 
| 486 | 
            +
                    ):
         | 
| 487 | 
            +
                        yield chunk, relevant_passage
         | 
| 488 | 
            +
             | 
| 489 | 
            +
                @staticmethod
         | 
| 490 | 
            +
                def reset_user_input() -> gr.update:
         | 
| 491 | 
            +
                    """
         | 
| 492 | 
            +
                    Reset user input box content.
         | 
| 493 | 
            +
             | 
| 494 | 
            +
                    Returns:
         | 
| 495 | 
            +
                        gr.update: An update object representing the cleared value
         | 
| 496 | 
            +
                    """
         | 
| 497 | 
            +
                    return gr.update(value="")
         | 
| 498 | 
            +
             | 
| 499 | 
            +
                @staticmethod
         | 
| 500 | 
            +
                def reset_state() -> namedtuple:
         | 
| 501 | 
            +
                    """
         | 
| 502 | 
            +
                    Reset chat state and clear all history.
         | 
| 503 | 
            +
             | 
| 504 | 
            +
                    Returns:
         | 
| 505 | 
            +
                        tuple: A named tuple containing the updated values for chatbot, task_history, file_btn, and relevant_passage
         | 
| 506 | 
            +
                    """
         | 
| 507 | 
            +
                    GradioEvents.gc()
         | 
| 508 | 
            +
                    
         | 
| 509 | 
            +
                    reset_result = namedtuple("reset_result", 
         | 
| 510 | 
            +
                                       ["chatbot", 
         | 
| 511 | 
            +
                                        "task_history", 
         | 
| 512 | 
            +
                                        "file_btn", 
         | 
| 513 | 
            +
                                        "relevant_passage"])
         | 
| 514 | 
            +
                    return reset_result(
         | 
| 515 | 
            +
                        [],  # clear chatbot
         | 
| 516 | 
            +
                        [],  # clear task_history
         | 
| 517 | 
            +
                        gr.update(value=None),  # clear file_btn
         | 
| 518 | 
            +
                        gr.update(value=None)  # reset relevant_passage
         | 
| 519 | 
            +
                    )
         | 
| 520 | 
            +
                
         | 
| 521 | 
            +
                @staticmethod
         | 
| 522 | 
            +
                def gc():
         | 
| 523 | 
            +
                    """
         | 
| 524 | 
            +
                    Force garbage collection to free memory.
         | 
| 525 | 
            +
                    """
         | 
| 526 | 
            +
                    import gc
         | 
| 527 | 
            +
             | 
| 528 | 
            +
                    gc.collect()
         | 
| 529 | 
            +
             | 
| 530 | 
            +
                @staticmethod
         | 
| 531 | 
            +
                def get_image_url(image_path: str) -> str:
         | 
| 532 | 
            +
                    """
         | 
| 533 | 
            +
                    Encode image file to Base64 format and generate data URL.
         | 
| 534 | 
            +
                    Reads an image file from disk, encodes it as Base64, and formats it
         | 
| 535 | 
            +
                    as a data URL that can be used directly in HTML or API requests.
         | 
| 536 | 
            +
             | 
| 537 | 
            +
                    Args:
         | 
| 538 | 
            +
                        image_path (str): Path to the image file. Must be a valid file path.
         | 
| 539 | 
            +
             | 
| 540 | 
            +
                    Returns:
         | 
| 541 | 
            +
                        str: Data URL string in format "data:image/{ext};base64,{encoded_data}"
         | 
| 542 | 
            +
                    """
         | 
| 543 | 
            +
                    base64_image = ""
         | 
| 544 | 
            +
                    extension = image_path.split(".")[-1]
         | 
| 545 | 
            +
                    with open(image_path, "rb") as image_file:
         | 
| 546 | 
            +
                        base64_image = base64.b64encode(image_file.read()).decode("utf-8")
         | 
| 547 | 
            +
                    url = "data:image/{ext};base64,{img}".format(ext=extension, img=base64_image)
         | 
| 548 | 
            +
                    return url
         | 
| 549 | 
            +
             | 
| 550 | 
            +
                @staticmethod
         | 
| 551 | 
            +
                def get_relevant_passage(
         | 
| 552 | 
            +
                    sub_query_list: list,
         | 
| 553 | 
            +
                    faiss_db: FaissTextDatabase
         | 
| 554 | 
            +
                ) -> str:
         | 
| 555 | 
            +
                    """
         | 
| 556 | 
            +
                    Retrieve the relevant passage from the database based on the query.
         | 
| 557 | 
            +
             | 
| 558 | 
            +
                    Args:
         | 
| 559 | 
            +
                        sub_query_list (list): List of sub-queries.
         | 
| 560 | 
            +
                        faiss_db (FaissTextDatabase): The FAISS database instance.
         | 
| 561 | 
            +
             | 
| 562 | 
            +
                    Returns:
         | 
| 563 | 
            +
                        str: The relevant passage.
         | 
| 564 | 
            +
                    """
         | 
| 565 | 
            +
                    relevant_passages = ""
         | 
| 566 | 
            +
                    for idx, query_item in enumerate(sub_query_list):
         | 
| 567 | 
            +
                        relevant_passage = faiss_db.search_with_context(query_item)
         | 
| 568 | 
            +
                        relevant_passages += "\n段落{idx}:\n{relevant_passage}".format(idx=idx + 1, relevant_passage=relevant_passage)
         | 
| 569 | 
            +
             | 
| 570 | 
            +
                    return relevant_passages
         | 
| 571 | 
            +
             | 
| 572 | 
            +
                @staticmethod
         | 
| 573 | 
            +
                def get_sub_query(query: str, model_name: str, bot_client: BotClient) -> dict:
         | 
| 574 | 
            +
                    """
         | 
| 575 | 
            +
                    Enhances user queries by generating alternative phrasings using language models.
         | 
| 576 | 
            +
                    Creates semantically similar variations of the original query to improve retrieval accuracy.
         | 
| 577 | 
            +
                    Returns structured dictionary containing both original and rephrased queries.
         | 
| 578 | 
            +
             | 
| 579 | 
            +
                    Args:
         | 
| 580 | 
            +
                        query (str): The query to rephrase.
         | 
| 581 | 
            +
                        model_name (str): The name of the model to use for rephrasing.
         | 
| 582 | 
            +
                        bot_client (BotClient): The bot client instance.
         | 
| 583 | 
            +
             | 
| 584 | 
            +
                    Returns:
         | 
| 585 | 
            +
                        dict: The rephrased query.
         | 
| 586 | 
            +
                    """
         | 
| 587 | 
            +
                    query = QUERY_REWRITE_PROMPT.format(query=query)
         | 
| 588 | 
            +
                    conversation = [{"role": "user", "content": query}]
         | 
| 589 | 
            +
                    req_data = {"messages": conversation}
         | 
| 590 | 
            +
                    try:
         | 
| 591 | 
            +
                        response = bot_client.process(model_name, req_data)
         | 
| 592 | 
            +
                        search_info_res = response["choices"][0]["message"]["content"]
         | 
| 593 | 
            +
                        start = search_info_res.find("{")
         | 
| 594 | 
            +
                        end = search_info_res.rfind("}") + 1
         | 
| 595 | 
            +
                        if start >= 0 and end > start:
         | 
| 596 | 
            +
                            search_info_res = search_info_res[start:end]
         | 
| 597 | 
            +
                        search_info_res = json.loads(search_info_res)
         | 
| 598 | 
            +
                        if search_info_res.get("sub_query_list", []):
         | 
| 599 | 
            +
                            unique_list = list(set(search_info_res["sub_query_list"]))
         | 
| 600 | 
            +
                            search_info_res["sub_query_list"] = unique_list
         | 
| 601 | 
            +
                        return search_info_res
         | 
| 602 | 
            +
                    except Exception:
         | 
| 603 | 
            +
                        raise gr.Error("Error: Model output is not a valid JSON")
         | 
| 604 | 
            +
             | 
| 605 | 
            +
                @staticmethod
         | 
| 606 | 
            +
                def split_oversized_line(line: str, chunk_size: int) -> tuple:
         | 
| 607 | 
            +
                    """
         | 
| 608 | 
            +
                    Split a line into two parts based on punctuation marks or whitespace while preserving
         | 
| 609 | 
            +
                    natural language boundaries and maintaining the original content structure.
         | 
| 610 | 
            +
             | 
| 611 | 
            +
                    Args:
         | 
| 612 | 
            +
                        line (str): The line to split.
         | 
| 613 | 
            +
                        chunk_size (int): The maximum length of each chunk.
         | 
| 614 | 
            +
             | 
| 615 | 
            +
                    Returns:
         | 
| 616 | 
            +
                        tuple: Two strings, the first part of the original line and the rest of the line.
         | 
| 617 | 
            +
                    """
         | 
| 618 | 
            +
                    PUNCTUATIONS = [".", "。", "!", "!", "?", "?", ",", ",", ";", ";", ":", ":"]
         | 
| 619 | 
            +
             | 
| 620 | 
            +
                    if len(line) <= chunk_size:
         | 
| 621 | 
            +
                        return line, ""
         | 
| 622 | 
            +
                    
         | 
| 623 | 
            +
                    # Search from chunk_size position backwards
         | 
| 624 | 
            +
                    split_pos = chunk_size
         | 
| 625 | 
            +
                    for i in range(chunk_size, 0, -1):
         | 
| 626 | 
            +
                        if line[i] in PUNCTUATIONS:
         | 
| 627 | 
            +
                            split_pos = i + 1  # Include punctuation
         | 
| 628 | 
            +
                            break
         | 
| 629 | 
            +
                    
         | 
| 630 | 
            +
                    # Fallback to whitespace if no punctuation found
         | 
| 631 | 
            +
                    if split_pos == chunk_size:
         | 
| 632 | 
            +
                        split_pos = line.rfind(" ", 0, chunk_size)
         | 
| 633 | 
            +
                        if split_pos == -1:
         | 
| 634 | 
            +
                            split_pos = chunk_size  # Hard split
         | 
| 635 | 
            +
                    
         | 
| 636 | 
            +
                    return line[:split_pos], line[split_pos:]
         | 
| 637 | 
            +
             | 
| 638 | 
            +
                @staticmethod
         | 
| 639 | 
            +
                def split_text_into_chunks(text: str, chunk_size: int) -> list:
         | 
| 640 | 
            +
                    """
         | 
| 641 | 
            +
                    Split text into chunks of a specified size while respecting natural language boundaries
         | 
| 642 | 
            +
                    and avoiding mid-word splits whenever possible.
         | 
| 643 | 
            +
             | 
| 644 | 
            +
                    Args:
         | 
| 645 | 
            +
                        text (str): The text to split.
         | 
| 646 | 
            +
                        chunk_size (int): The maximum length of each chunk.
         | 
| 647 | 
            +
             | 
| 648 | 
            +
                    Returns:
         | 
| 649 | 
            +
                        list: A list of strings, where each element represents a chunk of the original text.
         | 
| 650 | 
            +
                    """
         | 
| 651 | 
            +
                    lines = [line.strip() for line in text.split('\n') if line.strip()]
         | 
| 652 | 
            +
                    chunks = []
         | 
| 653 | 
            +
                    current_chunk = []
         | 
| 654 | 
            +
                    current_length = 0
         | 
| 655 | 
            +
                    
         | 
| 656 | 
            +
                    for line in lines:
         | 
| 657 | 
            +
                        
         | 
| 658 | 
            +
                        # If adding this line would exceed chunk size (and we have content)
         | 
| 659 | 
            +
                        if current_length + len(line) > chunk_size and current_chunk:
         | 
| 660 | 
            +
                            chunks.append(" ".join(current_chunk))
         | 
| 661 | 
            +
                            current_chunk = []
         | 
| 662 | 
            +
                            current_length = 0
         | 
| 663 | 
            +
             | 
| 664 | 
            +
                        # Process oversized lines first
         | 
| 665 | 
            +
                        while len(line) > chunk_size:
         | 
| 666 | 
            +
                            head, line = GradioEvents.split_oversized_line(line, chunk_size)
         | 
| 667 | 
            +
                            chunks.append(head)
         | 
| 668 | 
            +
                        
         | 
| 669 | 
            +
                        # Add remaining line content
         | 
| 670 | 
            +
                        if line:
         | 
| 671 | 
            +
                            current_chunk.append(line)
         | 
| 672 | 
            +
                            current_length += len(line) + 1
         | 
| 673 | 
            +
                    
         | 
| 674 | 
            +
                    if current_chunk:
         | 
| 675 | 
            +
                        chunks.append(" ".join(current_chunk))
         | 
| 676 | 
            +
                    return chunks
         | 
| 677 | 
            +
             | 
| 678 | 
            +
                @staticmethod
         | 
| 679 | 
            +
                def file_upload(
         | 
| 680 | 
            +
                    files_url: list, 
         | 
| 681 | 
            +
                    chunk_size: int, 
         | 
| 682 | 
            +
                    faiss_db: FaissTextDatabase, 
         | 
| 683 | 
            +
                    progress_bar: gr.Progress = gr.Progress()
         | 
| 684 | 
            +
                ) -> str:
         | 
| 685 | 
            +
                    """
         | 
| 686 | 
            +
                    Uploads and processes multiple files by splitting them into semantically meaningful chunks,
         | 
| 687 | 
            +
                    then indexes them in the FAISS database with progress tracking.
         | 
| 688 | 
            +
             | 
| 689 | 
            +
                    Args:
         | 
| 690 | 
            +
                        files_url (list): List of file URLs.
         | 
| 691 | 
            +
                        chunk_size (int): Maximum chunk size.
         | 
| 692 | 
            +
                        faiss_db (FaissTextDatabase): FAISS database instance.
         | 
| 693 | 
            +
                        progress_bar (gr.Progress): Progress bar instance.
         | 
| 694 | 
            +
             | 
| 695 | 
            +
                    Returns:
         | 
| 696 | 
            +
                        str: Message indicating successful completion.
         | 
| 697 | 
            +
                    """
         | 
| 698 | 
            +
                    if not files_url:
         | 
| 699 | 
            +
                        return
         | 
| 700 | 
            +
                    yield gr.update(visible=True)
         | 
| 701 | 
            +
                    for file_url in files_url:
         | 
| 702 | 
            +
                        if not GradioEvents.save_file_to_db(file_url, chunk_size, faiss_db, progress_bar):
         | 
| 703 | 
            +
                            file_name = os.path.basename(file_url)
         | 
| 704 | 
            +
                            gr.Info("{} already processed.".format(file_name))
         | 
| 705 | 
            +
             | 
| 706 | 
            +
                    yield gr.update(visible=False)
         | 
| 707 | 
            +
             | 
| 708 | 
            +
                @staticmethod
         | 
| 709 | 
            +
                def save_file_to_db(file_url: str, chunk_size: int, faiss_db: FaissTextDatabase, progress_bar: gr.Progress=None):
         | 
| 710 | 
            +
                    """
         | 
| 711 | 
            +
                    Processes and indexes document content into FAISS database with semantic-aware chunking.
         | 
| 712 | 
            +
                    Handles file validation, text segmentation, embedding generation and storage operations.
         | 
| 713 | 
            +
             | 
| 714 | 
            +
                    Args:
         | 
| 715 | 
            +
                        file_url (str): File URL.
         | 
| 716 | 
            +
                        chunk_size (int): Chunk size.
         | 
| 717 | 
            +
                        faiss_db (FaissTextDatabase): FAISS database instance.
         | 
| 718 | 
            +
                        progress_bar (gr.Progress): Progress bar instance.
         | 
| 719 | 
            +
             | 
| 720 | 
            +
                    Returns:
         | 
| 721 | 
            +
                        bool: True if the file was saved successfully, otherwise False.
         | 
| 722 | 
            +
                    """
         | 
| 723 | 
            +
                    file_name = os.path.basename(file_url)
         | 
| 724 | 
            +
                    if not faiss_db.is_file_processed(file_url):
         | 
| 725 | 
            +
                        logging.info("{} not processed yet, processing now...".format(file_url))
         | 
| 726 | 
            +
                        try:
         | 
| 727 | 
            +
                            with open(file_url, "r", encoding="utf-8") as f:
         | 
| 728 | 
            +
                                text = f.read()
         | 
| 729 | 
            +
                            segments = GradioEvents.split_text_into_chunks(text, chunk_size)
         | 
| 730 | 
            +
                            faiss_db.add_embeddings(file_url, segments, progress_bar)
         | 
| 731 | 
            +
             | 
| 732 | 
            +
                            logging.info("{} processed successfully.".format(file_url))
         | 
| 733 | 
            +
                            return True
         | 
| 734 | 
            +
                        except Exception as e:
         | 
| 735 | 
            +
                            logging.error("Error processing {}: {}".format(file_url, str(e)))
         | 
| 736 | 
            +
                            gr.Error("Error processing file: {}".format(file_name))
         | 
| 737 | 
            +
                            raise
         | 
| 738 | 
            +
                    else:
         | 
| 739 | 
            +
                        logging.info("{} already processed.".format(file_url))          
         | 
| 740 | 
            +
                        return False
         | 
| 741 | 
            +
             | 
| 742 | 
            +
             | 
| 743 | 
            +
            def launch_demo(args: argparse.Namespace, bot_client: BotClient, faiss_db: FaissTextDatabase):
         | 
| 744 | 
            +
                """
         | 
| 745 | 
            +
                Launch demo program
         | 
| 746 | 
            +
                
         | 
| 747 | 
            +
                Args:
         | 
| 748 | 
            +
                    args (argparse.Namespace): argparse Namespace object containing parsed command line arguments
         | 
| 749 | 
            +
                    bot_client (BotClient): Bot client instance
         | 
| 750 | 
            +
                    faiss_db (FaissTextDatabase): FAISS database instance
         | 
| 751 | 
            +
                """
         | 
| 752 | 
            +
                css = """
         | 
| 753 | 
            +
                /* Hide original Chinese text */
         | 
| 754 | 
            +
                #file-upload .wrap {
         | 
| 755 | 
            +
                    font-size: 0 !important;
         | 
| 756 | 
            +
                    position: relative;
         | 
| 757 | 
            +
                    display: flex;
         | 
| 758 | 
            +
                    flex-direction: column;
         | 
| 759 | 
            +
                    align-items: center;
         | 
| 760 | 
            +
                    justify-content: center;
         | 
| 761 | 
            +
                }
         | 
| 762 | 
            +
             | 
| 763 | 
            +
                /* Insert English prompt text below the SVG icon */
         | 
| 764 | 
            +
                #file-upload .wrap::after {
         | 
| 765 | 
            +
                    content: "Drag and drop files here or click to upload";
         | 
| 766 | 
            +
                    font-size: 18px;
         | 
| 767 | 
            +
                    color: #555;
         | 
| 768 | 
            +
                    margin-top: 8px;
         | 
| 769 | 
            +
                    white-space: nowrap;
         | 
| 770 | 
            +
                }
         | 
| 771 | 
            +
                """
         | 
| 772 | 
            +
                with gr.Blocks(css=css) as demo:
         | 
| 773 | 
            +
                    model_name = gr.State("eb-45t")
         | 
| 774 | 
            +
             | 
| 775 | 
            +
                    logo_url = GradioEvents.get_image_url("assets/logo.png")
         | 
| 776 | 
            +
                    gr.Markdown("""\
         | 
| 777 | 
            +
                            <p align="center"><img src="{}" \
         | 
| 778 | 
            +
                            style="height: 60px"/><p>""".format(logo_url))
         | 
| 779 | 
            +
                    gr.Markdown(
         | 
| 780 | 
            +
                        """\
         | 
| 781 | 
            +
            <center><font size=3>This demo is based on ERNIE models. \
         | 
| 782 | 
            +
            (本演示基于文心大模型实现。)</center>"""
         | 
| 783 | 
            +
                    )
         | 
| 784 | 
            +
             | 
| 785 | 
            +
                    chatbot = gr.Chatbot(
         | 
| 786 | 
            +
                        label="ERNIE", 
         | 
| 787 | 
            +
                        type="messages"
         | 
| 788 | 
            +
                    )
         | 
| 789 | 
            +
             | 
| 790 | 
            +
                    with gr.Row(equal_height=True):
         | 
| 791 | 
            +
                        file_btn = gr.File(
         | 
| 792 | 
            +
                            label="Knowledge Base Upload (System default will be used if none provided. Accepted formats: TXT, MD)", 
         | 
| 793 | 
            +
                            height="150px", 
         | 
| 794 | 
            +
                            file_types=[".txt", ".md"],
         | 
| 795 | 
            +
                            elem_id="file-upload",
         | 
| 796 | 
            +
                            file_count="multiple"
         | 
| 797 | 
            +
                        )
         | 
| 798 | 
            +
                        relevant_passage = gr.Textbox(
         | 
| 799 | 
            +
                            label="Relevant Passage",
         | 
| 800 | 
            +
                            lines=5,
         | 
| 801 | 
            +
                            max_lines=5,
         | 
| 802 | 
            +
                            placeholder=RELEVANT_PASSAGE_DEFAULT,
         | 
| 803 | 
            +
                            interactive=False
         | 
| 804 | 
            +
                        )
         | 
| 805 | 
            +
                    with gr.Row():
         | 
| 806 | 
            +
                        progress_bar = gr.Textbox(label="Progress", visible=False)
         | 
| 807 | 
            +
             | 
| 808 | 
            +
                    query = gr.Textbox(label="Query", elem_id="text_input", value=QUERY_DEFAULT)
         | 
| 809 | 
            +
             | 
| 810 | 
            +
                    with gr.Row():
         | 
| 811 | 
            +
                        empty_btn = gr.Button("🧹 Clear History(清除历史)")
         | 
| 812 | 
            +
                        submit_btn = gr.Button("🚀 Submit(发送)", elem_id="submit-button")
         | 
| 813 | 
            +
                        regen_btn = gr.Button("🤔️ Regenerate(重试)")
         | 
| 814 | 
            +
                    
         | 
| 815 | 
            +
                    task_history = gr.State([])
         | 
| 816 | 
            +
                    
         | 
| 817 | 
            +
                    predict_with_clients = partial(
         | 
| 818 | 
            +
                        GradioEvents.predict_stream,
         | 
| 819 | 
            +
                        bot_client=bot_client,
         | 
| 820 | 
            +
                        faiss_db=faiss_db
         | 
| 821 | 
            +
                    )
         | 
| 822 | 
            +
                    regenerate_with_clients = partial(
         | 
| 823 | 
            +
                        GradioEvents.regenerate,
         | 
| 824 | 
            +
                        bot_client=bot_client,
         | 
| 825 | 
            +
                        faiss_db=faiss_db
         | 
| 826 | 
            +
                    )
         | 
| 827 | 
            +
                    file_upload_with_clients = partial(
         | 
| 828 | 
            +
                        GradioEvents.file_upload,
         | 
| 829 | 
            +
                        faiss_db=faiss_db
         | 
| 830 | 
            +
                    )
         | 
| 831 | 
            +
                    
         | 
| 832 | 
            +
                    chunk_size = gr.State(args.chunk_size)
         | 
| 833 | 
            +
                    file_btn.change(
         | 
| 834 | 
            +
                        fn=file_upload_with_clients,
         | 
| 835 | 
            +
                        inputs=[file_btn, chunk_size],
         | 
| 836 | 
            +
                        outputs=[progress_bar],
         | 
| 837 | 
            +
                    )
         | 
| 838 | 
            +
                    query.submit(
         | 
| 839 | 
            +
                        predict_with_clients, 
         | 
| 840 | 
            +
                        inputs=[query, chatbot, task_history, model_name], 
         | 
| 841 | 
            +
                        outputs=[chatbot, relevant_passage],
         | 
| 842 | 
            +
                        show_progress=True
         | 
| 843 | 
            +
                    )
         | 
| 844 | 
            +
                    query.submit(GradioEvents.reset_user_input, [], [query])
         | 
| 845 | 
            +
                    submit_btn.click(
         | 
| 846 | 
            +
                        predict_with_clients, 
         | 
| 847 | 
            +
                        inputs=[query, chatbot, task_history, model_name],
         | 
| 848 | 
            +
                        outputs=[chatbot, relevant_passage],
         | 
| 849 | 
            +
                        show_progress=True,
         | 
| 850 | 
            +
                    )
         | 
| 851 | 
            +
                    submit_btn.click(GradioEvents.reset_user_input, [], [query])
         | 
| 852 | 
            +
                    empty_btn.click(
         | 
| 853 | 
            +
                        GradioEvents.reset_state,
         | 
| 854 | 
            +
                        outputs=[chatbot, task_history, file_btn, relevant_passage], show_progress=True
         | 
| 855 | 
            +
                    )
         | 
| 856 | 
            +
                    regen_btn.click(
         | 
| 857 | 
            +
                        regenerate_with_clients, 
         | 
| 858 | 
            +
                        inputs=[chatbot, task_history, model_name],
         | 
| 859 | 
            +
                        outputs=[chatbot, relevant_passage],
         | 
| 860 | 
            +
                        show_progress=True
         | 
| 861 | 
            +
                    )
         | 
| 862 | 
            +
             | 
| 863 | 
            +
                demo.queue().launch(
         | 
| 864 | 
            +
                    server_port=args.server_port,
         | 
| 865 | 
            +
                    server_name=args.server_name
         | 
| 866 | 
            +
                )
         | 
| 867 | 
            +
             | 
| 868 | 
            +
             | 
| 869 | 
            +
            def main():
         | 
| 870 | 
            +
                """Main function that runs when this script is executed."""
         | 
| 871 | 
            +
                args = get_args()
         | 
| 872 | 
            +
                bot_client = BotClient(args)
         | 
| 873 | 
            +
                faiss_db = FaissTextDatabase(args, bot_client)
         | 
| 874 | 
            +
             | 
| 875 | 
            +
                # Run file upload function to save default knowledge base.
         | 
| 876 | 
            +
                GradioEvents.save_file_to_db(FILE_URL_DEFAULT, args.chunk_size, faiss_db)
         | 
| 877 | 
            +
             | 
| 878 | 
            +
                launch_demo(args, bot_client, faiss_db)
         | 
| 879 | 
            +
             | 
| 880 | 
            +
            if __name__ == "__main__":
         | 
| 881 | 
            +
                main()
         | 
    	
        assets/logo.png
    ADDED
    
    |   | 
    	
        bot_requests.py
    ADDED
    
    | @@ -0,0 +1,388 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            """BotClient class for interacting with bot models."""
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            import os
         | 
| 18 | 
            +
            import argparse
         | 
| 19 | 
            +
            import logging
         | 
| 20 | 
            +
            import traceback
         | 
| 21 | 
            +
            import json
         | 
| 22 | 
            +
            import jieba
         | 
| 23 | 
            +
            from openai import OpenAI
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            from appbuilder.mcp_server.client import MCPClient
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            class BotClient(object):
         | 
| 28 | 
            +
                """Client for interacting with various AI models."""
         | 
| 29 | 
            +
                def __init__(self, args: argparse.Namespace):
         | 
| 30 | 
            +
                    """
         | 
| 31 | 
            +
                    Initializes the BotClient instance by configuring essential parameters from command line arguments 
         | 
| 32 | 
            +
                    including retry limits, character constraints, model endpoints and API credentials while setting up 
         | 
| 33 | 
            +
                    default values for missing arguments to ensure robust operation.
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    Args:
         | 
| 36 | 
            +
                        args (argparse.Namespace): Command line arguments containing configuration parameters.
         | 
| 37 | 
            +
                                                  Uses getattr() to safely retrieve values with fallback defaults.
         | 
| 38 | 
            +
                    """
         | 
| 39 | 
            +
                    self.logger = logging.getLogger(__name__)
         | 
| 40 | 
            +
                    
         | 
| 41 | 
            +
                    self.max_retry_num = getattr(args, 'max_retry_num', 3)
         | 
| 42 | 
            +
                    self.max_char = getattr(args, 'max_char', 8000)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    self.eb45t_model_url = getattr(args, 'eb45t_model_url', 'eb45t_model_url')
         | 
| 45 | 
            +
                    self.x1_model_url = getattr(args, 'x1_model_url', 'x1_model_url')
         | 
| 46 | 
            +
                    self.api_key = os.environ.get("API_KEY")
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    self.qianfan_url = getattr(args, 'qianfan_url', 'qianfan_url')
         | 
| 49 | 
            +
                    self.qianfan_api_key = getattr(args, 'qianfan_api_key', 'qianfan_api_key')
         | 
| 50 | 
            +
                    self.embedding_model = getattr(args, 'embedding_model', 'embedding_model')
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    self.ai_search_service_url = getattr(args, 'ai_search_service_url', 'ai_search_service_url')
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                def call_back(self, host_url: str, req_data: dict) -> dict:
         | 
| 55 | 
            +
                    """
         | 
| 56 | 
            +
                    Executes an HTTP request to the specified endpoint using the OpenAI client, handles the response 
         | 
| 57 | 
            +
                    conversion to a compatible dictionary format, and manages any exceptions that may occur during 
         | 
| 58 | 
            +
                    the request process while logging errors appropriately.
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    Args:
         | 
| 61 | 
            +
                        host_url (str): The URL to send the request to.
         | 
| 62 | 
            +
                        req_data (dict): The data to send in the request body.
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    Returns:
         | 
| 65 | 
            +
                        dict: Parsed JSON response from the server. Returns empty dict
         | 
| 66 | 
            +
                            if request fails or response is invalid.
         | 
| 67 | 
            +
                    """
         | 
| 68 | 
            +
                    try:
         | 
| 69 | 
            +
                        client = OpenAI(base_url=host_url, api_key=self.api_key)
         | 
| 70 | 
            +
                        response = client.chat.completions.create(
         | 
| 71 | 
            +
                            **req_data
         | 
| 72 | 
            +
                        )
         | 
| 73 | 
            +
                            
         | 
| 74 | 
            +
                        # Convert OpenAI response to compatible format
         | 
| 75 | 
            +
                        return response.model_dump()
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    except Exception as e:
         | 
| 78 | 
            +
                        self.logger.error("Stream request failed: {}".format(e))
         | 
| 79 | 
            +
                        raise
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                def call_back_stream(self, host_url: str, req_data: dict) -> dict:
         | 
| 82 | 
            +
                    """
         | 
| 83 | 
            +
                    Makes a streaming HTTP request to the specified host URL using the OpenAI client and yields response chunks 
         | 
| 84 | 
            +
                    in real-time while handling any exceptions that may occur during the streaming process.
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    Args:
         | 
| 87 | 
            +
                        host_url (str): The URL to send the request to.
         | 
| 88 | 
            +
                        req_data (dict): The data to send in the request body.
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    Returns:
         | 
| 91 | 
            +
                        generator: Generator that yields parsed JSON responses from the server.
         | 
| 92 | 
            +
                    """
         | 
| 93 | 
            +
                    try:
         | 
| 94 | 
            +
                        client = OpenAI(base_url=host_url, api_key=self.api_key)
         | 
| 95 | 
            +
                        response = client.chat.completions.create(
         | 
| 96 | 
            +
                            **req_data,
         | 
| 97 | 
            +
                            stream=True,
         | 
| 98 | 
            +
                        )
         | 
| 99 | 
            +
                        for chunk in response:
         | 
| 100 | 
            +
                            if not chunk.choices:
         | 
| 101 | 
            +
                                continue
         | 
| 102 | 
            +
                            
         | 
| 103 | 
            +
                            # Convert OpenAI response to compatible format
         | 
| 104 | 
            +
                            yield chunk.model_dump()
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    except Exception as e:
         | 
| 107 | 
            +
                        self.logger.error("Stream request failed: {}".format(e))
         | 
| 108 | 
            +
                        raise
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                def process(
         | 
| 111 | 
            +
                    self, 
         | 
| 112 | 
            +
                    model_name: str, 
         | 
| 113 | 
            +
                    req_data: dict, 
         | 
| 114 | 
            +
                    max_tokens: int=2048, 
         | 
| 115 | 
            +
                    temperature: float=1.0, 
         | 
| 116 | 
            +
                    top_p: float=0.7
         | 
| 117 | 
            +
                ) -> dict:
         | 
| 118 | 
            +
                    """
         | 
| 119 | 
            +
                    Handles chat completion requests by mapping the model name to its endpoint, preparing request parameters 
         | 
| 120 | 
            +
                    including token limits and sampling settings, truncating messages to fit character limits, making API calls 
         | 
| 121 | 
            +
                    with built-in retry mechanism, and logging the full request/response cycle for debugging purposes.
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    Args:
         | 
| 124 | 
            +
                        model_name (str): Name of the model, used to look up the model URL from model_map.
         | 
| 125 | 
            +
                        req_data (dict): Dictionary containing request data, including information to be processed.
         | 
| 126 | 
            +
                        max_tokens (int): Maximum number of tokens to generate.
         | 
| 127 | 
            +
                        temperature (float): Sampling temperature to control the diversity of generated text.
         | 
| 128 | 
            +
                        top_p (float): Cumulative probability threshold to control the diversity of generated text.
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    Returns:
         | 
| 131 | 
            +
                        dict: Dictionary containing the model's processing results.
         | 
| 132 | 
            +
                    """
         | 
| 133 | 
            +
                    model_map = {
         | 
| 134 | 
            +
                        "eb-45t": self.eb45t_model_url,
         | 
| 135 | 
            +
                        "eb-x1": self.x1_model_url
         | 
| 136 | 
            +
                    }
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    model_url = model_map[model_name]
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    req_data["model"] = "ernie-4.5-turbo-32k" if "eb-45t" == model_name else "ernie-x1-turbo-32k"
         | 
| 141 | 
            +
                    req_data["max_tokens"] = max_tokens
         | 
| 142 | 
            +
                    req_data["temperature"] = temperature
         | 
| 143 | 
            +
                    req_data["top_p"] = top_p
         | 
| 144 | 
            +
                    req_data["messages"] = self.truncate_messages(req_data["messages"])
         | 
| 145 | 
            +
                    for _ in range(self.max_retry_num):
         | 
| 146 | 
            +
                        try:
         | 
| 147 | 
            +
                            self.logger.info("[MODEL] {}".format(model_url))
         | 
| 148 | 
            +
                            self.logger.info("[req_data]====>")
         | 
| 149 | 
            +
                            self.logger.info(json.dumps(req_data, ensure_ascii=False))
         | 
| 150 | 
            +
                            res = self.call_back(model_url, req_data)
         | 
| 151 | 
            +
                            self.logger.info("model response")
         | 
| 152 | 
            +
                            self.logger.info(res)
         | 
| 153 | 
            +
                            self.logger.info("-" * 30)
         | 
| 154 | 
            +
                        except Exception as e:
         | 
| 155 | 
            +
                            self.logger.info(e)
         | 
| 156 | 
            +
                            self.logger.info(traceback.format_exc())
         | 
| 157 | 
            +
                            res = {}
         | 
| 158 | 
            +
                        if len(res) != 0 and "error" not in res:
         | 
| 159 | 
            +
                            break
         | 
| 160 | 
            +
                    self.logger.info(json.dumps(res, ensure_ascii=False))
         | 
| 161 | 
            +
                    
         | 
| 162 | 
            +
                    return res
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                def process_stream(
         | 
| 165 | 
            +
                    self, model_name: str, 
         | 
| 166 | 
            +
                    req_data: dict, 
         | 
| 167 | 
            +
                    max_tokens: int=2048, 
         | 
| 168 | 
            +
                    temperature: float=1.0, 
         | 
| 169 | 
            +
                    top_p: float=0.7
         | 
| 170 | 
            +
                ) -> dict:
         | 
| 171 | 
            +
                    """
         | 
| 172 | 
            +
                    Processes streaming requests by mapping the model name to its endpoint, configuring request parameters,
         | 
| 173 | 
            +
                    implementing a retry mechanism with logging, and streaming back response chunks in real-time while
         | 
| 174 | 
            +
                    handling any errors that may occur during the streaming session.
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    Args:
         | 
| 177 | 
            +
                        model_name (str): Name of the model, used to look up the model URL from model_map.
         | 
| 178 | 
            +
                        req_data (dict): Dictionary containing request data, including information to be processed.
         | 
| 179 | 
            +
                        max_tokens (int): Maximum number of tokens to generate.
         | 
| 180 | 
            +
                        temperature (float): Sampling temperature to control the diversity of generated text.
         | 
| 181 | 
            +
                        top_p (float): Cumulative probability threshold to control the diversity of generated text.
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    Yields:
         | 
| 184 | 
            +
                        dict: Dictionary containing the model's processing results.
         | 
| 185 | 
            +
                    """
         | 
| 186 | 
            +
                    model_map = {
         | 
| 187 | 
            +
                        "eb-45t": self.eb45t_model_url,
         | 
| 188 | 
            +
                        "eb-x1": self.x1_model_url
         | 
| 189 | 
            +
                    }
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    model_url = model_map[model_name]
         | 
| 192 | 
            +
                    req_data["model"] = "ernie-4.5-turbo-32k" if "eb-45t" == model_name else "ernie-x1-turbo-32k"
         | 
| 193 | 
            +
                    req_data["max_tokens"] = max_tokens
         | 
| 194 | 
            +
                    req_data["temperature"] = temperature
         | 
| 195 | 
            +
                    req_data["top_p"] = top_p
         | 
| 196 | 
            +
                    req_data["messages"] = self.truncate_messages(req_data["messages"])
         | 
| 197 | 
            +
                        
         | 
| 198 | 
            +
                    last_error = None
         | 
| 199 | 
            +
                    for _ in range(self.max_retry_num):
         | 
| 200 | 
            +
                        try:
         | 
| 201 | 
            +
                            self.logger.info("[MODEL] {}".format(model_url))
         | 
| 202 | 
            +
                            self.logger.info("[req_data]====>")
         | 
| 203 | 
            +
                            self.logger.info(json.dumps(req_data, ensure_ascii=False))
         | 
| 204 | 
            +
                            
         | 
| 205 | 
            +
                            for chunk in self.call_back_stream(model_url, req_data):
         | 
| 206 | 
            +
                                yield chunk
         | 
| 207 | 
            +
                            return
         | 
| 208 | 
            +
                            
         | 
| 209 | 
            +
                        except Exception as e:
         | 
| 210 | 
            +
                            last_error = e
         | 
| 211 | 
            +
                            self.logger.error("Stream request failed (attempt {}/{}): {}".format(_ + 1, self.max_retry_num, e))
         | 
| 212 | 
            +
                    
         | 
| 213 | 
            +
                    self.logger.error("All retry attempts failed for stream request")
         | 
| 214 | 
            +
                    yield {"error": str(last_error)}
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                def cut_chinese_english(self, text: str) -> list:
         | 
| 217 | 
            +
                    """
         | 
| 218 | 
            +
                    Segments mixed Chinese and English text into individual components using Jieba for Chinese words 
         | 
| 219 | 
            +
                    while preserving English words as whole units, with special handling for Unicode character ranges 
         | 
| 220 | 
            +
                    to distinguish between the two languages.
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                    Args:
         | 
| 223 | 
            +
                        text (str): Input string to be segmented.
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                    Returns:
         | 
| 226 | 
            +
                        list: A list of segments, where each segment is either a letter or a word.
         | 
| 227 | 
            +
                    """
         | 
| 228 | 
            +
                    words = jieba.lcut(text)
         | 
| 229 | 
            +
                    en_ch_words = []
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                    for word in words:
         | 
| 232 | 
            +
                        if word.isalpha() and not any("\u4e00" <= char <= "\u9fff" for char in word):
         | 
| 233 | 
            +
                            en_ch_words.append(word)
         | 
| 234 | 
            +
                        else:
         | 
| 235 | 
            +
                            en_ch_words.extend(list(word))
         | 
| 236 | 
            +
                    return en_ch_words
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                def truncate_messages(self, messages: list[dict]) -> list:
         | 
| 239 | 
            +
                    """
         | 
| 240 | 
            +
                    Truncates conversation messages to fit within the maximum character limit (self.max_char)
         | 
| 241 | 
            +
                    by intelligently removing content while preserving message structure. The truncation follows
         | 
| 242 | 
            +
                    a prioritized order: historical messages first, then system message, and finally the last message.
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                    Args:
         | 
| 245 | 
            +
                        messages (list[dict]): List of messages to be truncated.
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                    Returns:
         | 
| 248 | 
            +
                        list[dict]: Modified list of messages after truncation.
         | 
| 249 | 
            +
                    """
         | 
| 250 | 
            +
                    if not messages:
         | 
| 251 | 
            +
                        return messages
         | 
| 252 | 
            +
                
         | 
| 253 | 
            +
                    processed = []
         | 
| 254 | 
            +
                    total_units = 0
         | 
| 255 | 
            +
                    
         | 
| 256 | 
            +
                    for msg in messages:
         | 
| 257 | 
            +
                        # Handle two different content formats
         | 
| 258 | 
            +
                        if isinstance(msg["content"], str):
         | 
| 259 | 
            +
                            text_content = msg["content"]
         | 
| 260 | 
            +
                        elif isinstance(msg["content"], list):
         | 
| 261 | 
            +
                            text_content = msg["content"][1]["text"]
         | 
| 262 | 
            +
                        else:
         | 
| 263 | 
            +
                            text_content = ""
         | 
| 264 | 
            +
                        
         | 
| 265 | 
            +
                        # Calculate unit count after tokenization
         | 
| 266 | 
            +
                        units = self.cut_chinese_english(text_content)
         | 
| 267 | 
            +
                        unit_count = len(units)
         | 
| 268 | 
            +
                        
         | 
| 269 | 
            +
                        processed.append({
         | 
| 270 | 
            +
                            "role": msg["role"],
         | 
| 271 | 
            +
                            "original_content": msg["content"],  # Preserve original content
         | 
| 272 | 
            +
                            "text_content": text_content,        # Extracted plain text
         | 
| 273 | 
            +
                            "units": units,
         | 
| 274 | 
            +
                            "unit_count": unit_count
         | 
| 275 | 
            +
                        })
         | 
| 276 | 
            +
                        total_units += unit_count
         | 
| 277 | 
            +
                    
         | 
| 278 | 
            +
                    if total_units <= self.max_char:
         | 
| 279 | 
            +
                        return messages
         | 
| 280 | 
            +
                    
         | 
| 281 | 
            +
                    # Number of units to remove
         | 
| 282 | 
            +
                    to_remove = total_units - self.max_char
         | 
| 283 | 
            +
                    
         | 
| 284 | 
            +
                    # 1. Truncate historical messages
         | 
| 285 | 
            +
                    for i in range(1, len(processed) - 1):
         | 
| 286 | 
            +
                        if to_remove <= 0:
         | 
| 287 | 
            +
                            break
         | 
| 288 | 
            +
                        
         | 
| 289 | 
            +
                        # current = processed[i]
         | 
| 290 | 
            +
                        if processed[i]["unit_count"] <= to_remove:
         | 
| 291 | 
            +
                            processed[i]["text_content"] = ""
         | 
| 292 | 
            +
                            to_remove -= processed[i]["unit_count"]
         | 
| 293 | 
            +
                            if isinstance(processed[i]["original_content"], str):
         | 
| 294 | 
            +
                                processed[i]["original_content"] = ""
         | 
| 295 | 
            +
                            elif isinstance(processed[i]["original_content"], list):
         | 
| 296 | 
            +
                                processed[i]["original_content"][1]["text"] = ""
         | 
| 297 | 
            +
                        else:
         | 
| 298 | 
            +
                            kept_units = processed[i]["units"][:-to_remove]
         | 
| 299 | 
            +
                            new_text = "".join(kept_units)
         | 
| 300 | 
            +
                            processed[i]["text_content"] = new_text
         | 
| 301 | 
            +
                            if isinstance(processed[i]["original_content"], str):
         | 
| 302 | 
            +
                                processed[i]["original_content"] = new_text
         | 
| 303 | 
            +
                            elif isinstance(processed[i]["original_content"], list):
         | 
| 304 | 
            +
                                processed[i]["original_content"][1]["text"] = new_text
         | 
| 305 | 
            +
                            to_remove = 0
         | 
| 306 | 
            +
                    
         | 
| 307 | 
            +
                    # 2. Truncate system message
         | 
| 308 | 
            +
                    if to_remove > 0:
         | 
| 309 | 
            +
                        system_msg = processed[0]
         | 
| 310 | 
            +
                        if system_msg["unit_count"] <= to_remove:
         | 
| 311 | 
            +
                            processed[0]["text_content"] = ""
         | 
| 312 | 
            +
                            to_remove -= system_msg["unit_count"]
         | 
| 313 | 
            +
                            if isinstance(processed[0]["original_content"], str):
         | 
| 314 | 
            +
                                processed[0]["original_content"] = ""
         | 
| 315 | 
            +
                            elif isinstance(processed[0]["original_content"], list):
         | 
| 316 | 
            +
                                processed[0]["original_content"][1]["text"] = ""
         | 
| 317 | 
            +
                        else:
         | 
| 318 | 
            +
                            kept_units = system_msg["units"][:-to_remove]
         | 
| 319 | 
            +
                            new_text = "".join(kept_units)
         | 
| 320 | 
            +
                            processed[0]["text_content"] = new_text
         | 
| 321 | 
            +
                            if isinstance(processed[0]["original_content"], str):
         | 
| 322 | 
            +
                                processed[0]["original_content"] = new_text
         | 
| 323 | 
            +
                            elif isinstance(processed[0]["original_content"], list):
         | 
| 324 | 
            +
                                processed[0]["original_content"][1]["text"] = new_text
         | 
| 325 | 
            +
                            to_remove = 0
         | 
| 326 | 
            +
                    
         | 
| 327 | 
            +
                    # 3. Truncate last message
         | 
| 328 | 
            +
                    if to_remove > 0 and len(processed) > 1:
         | 
| 329 | 
            +
                        last_msg = processed[-1]
         | 
| 330 | 
            +
                        if last_msg["unit_count"] > to_remove:
         | 
| 331 | 
            +
                            kept_units = last_msg["units"][:-to_remove]
         | 
| 332 | 
            +
                            new_text = "".join(kept_units)
         | 
| 333 | 
            +
                            last_msg["text_content"] = new_text
         | 
| 334 | 
            +
                            if isinstance(last_msg["original_content"], str):
         | 
| 335 | 
            +
                                last_msg["original_content"] = new_text
         | 
| 336 | 
            +
                            elif isinstance(last_msg["original_content"], list):
         | 
| 337 | 
            +
                                last_msg["original_content"][1]["text"] = new_text
         | 
| 338 | 
            +
                        else:
         | 
| 339 | 
            +
                            last_msg["text_content"] = ""
         | 
| 340 | 
            +
                            if isinstance(last_msg["original_content"], str):
         | 
| 341 | 
            +
                                last_msg["original_content"] = ""
         | 
| 342 | 
            +
                            elif isinstance(last_msg["original_content"], list):
         | 
| 343 | 
            +
                                last_msg["original_content"][1]["text"] = ""
         | 
| 344 | 
            +
                    
         | 
| 345 | 
            +
                    result = []
         | 
| 346 | 
            +
                    for msg in processed:
         | 
| 347 | 
            +
                        if msg["text_content"]:
         | 
| 348 | 
            +
                            result.append({
         | 
| 349 | 
            +
                                "role": msg["role"],
         | 
| 350 | 
            +
                                "content": msg["original_content"]
         | 
| 351 | 
            +
                            })
         | 
| 352 | 
            +
                    
         | 
| 353 | 
            +
                    return result
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                def embed_fn(self, text: str) -> list:
         | 
| 356 | 
            +
                    """
         | 
| 357 | 
            +
                    Generate an embedding for the given text using the QianFan API.
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                    Args:
         | 
| 360 | 
            +
                        text (str): The input text to be embedded.
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                    Returns:
         | 
| 363 | 
            +
                        list: A list of floats representing the embedding.
         | 
| 364 | 
            +
                    """
         | 
| 365 | 
            +
                    client = OpenAI(base_url=self.qianfan_url, api_key=self.qianfan_api_key)
         | 
| 366 | 
            +
                    response = client.embeddings.create(input=[text], model=self.embedding_model)
         | 
| 367 | 
            +
                    return response.data[0].embedding
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                async def get_ai_search_res(self, query_list: list) -> list:
         | 
| 370 | 
            +
                    """
         | 
| 371 | 
            +
                    Get AI search results for the given queries using the MCPClient.
         | 
| 372 | 
            +
                    
         | 
| 373 | 
            +
                    Args:
         | 
| 374 | 
            +
                        query_list (list): List of queries to search for.
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                    Returns:
         | 
| 377 | 
            +
                        list: List of search results as strings.
         | 
| 378 | 
            +
                    """
         | 
| 379 | 
            +
                    try:
         | 
| 380 | 
            +
                        client = MCPClient()
         | 
| 381 | 
            +
                        await client.connect_to_server(service_url=self.ai_search_service_url)
         | 
| 382 | 
            +
                        result = []
         | 
| 383 | 
            +
                        for query in query_list:
         | 
| 384 | 
            +
                            response = await client.call_tool("AIsearch", {"query": query})
         | 
| 385 | 
            +
                            result.append(response.content[0].text)
         | 
| 386 | 
            +
                    finally:
         | 
| 387 | 
            +
                        await client.cleanup()
         | 
| 388 | 
            +
                    return result
         | 
    	
        data/coffee.txt
    ADDED
    
    | @@ -0,0 +1,63 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            咖啡(英语:coffee)是指咖啡植物的种子即咖啡豆在经过烘焙磨粉后通过冲泡制成的饮料,是世界上流行范围最为广泛的饮料之一。咖啡在人类饮食中一般为日常的饮品,人们通常会为了提振精神,或在用餐和社交、阅读时饮用。咖啡原产于非洲东岸的埃塞俄比亚,咖啡起源于15-16世纪,从也门被传播至穆斯林世界,16世纪的威尼斯商人将咖啡引入意大利,随后17-18世纪由于欧洲对咖啡的需求,促使殖民者将咖啡树传播并栽种到美洲、东南亚和印度等热带地区,现今有超过70个国家种植咖啡树。未经烘焙的 咖啡生豆作为世界上最大的出口农产品,以及世界上交易量为广泛的热带农产品之一,也是发展中国家出口中最有价值的商品之一。采收的成熟咖啡果会经过剥离果肉的初步加工,再经过烘焙的工序,而成为能制作咖啡的咖啡豆。透过不同的冲泡方式与成分比例,咖啡有浓缩咖啡、卡布奇诺和拿铁咖啡等变化。咖啡豆的品种可大致分为两种:最为普遍的小果咖啡(阿拉比卡),以及颗粒较粗且酸味较低而苦味较浓的中果咖啡(罗布斯塔)。一些争议指咖啡的种植与它环境影响有关,例如肯亚咖啡豆在移植种植后失去了独有的肯亚酸,而肯亚的原种地土壤含有较高浓度的磷酸。因此,公平贸易咖啡与有机咖啡是一个不断扩大的市场。
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            传说9世纪的埃塞俄比亚的牧羊人发现并咀嚼了咖啡果实,随后将咖啡果实带给了附近修道院的僧侣,但僧侣起初不愿食用果实,并把果实扔进火里,经过火烤的咖啡果中冒出香气引来僧侣前来查看,僧侣从余烬中捞出咖啡豆,并将其磨碎溶解在热水中,这才制成了世界上第一杯咖啡。但此故事截至1671年并没有得到任何记载,因此可能是杜撰的。亦有研究认为最初栽培的咖啡源自埃塞俄比亚的哈勒尔。埃塞俄比亚的阿克苏姆王国兴盛时曾一度占据也门南部,6世纪中期,萨珊帝国攻占也门后将阿克苏姆赶出南阿拉伯半岛,可以肯定的是咖啡是从埃塞俄比亚传播到也门的。
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            咖啡传播到穆斯林世界后伊斯兰医学认可了咖啡的好处,认为其可以提振精神并防止酒和大麻对穆斯林的诱惑,15世纪的也门苏菲派修道院在祈祷时使用咖啡来帮助集中注意力。 16世纪初咖啡从也门的摩卡港传播到埃及,随后咖啡馆还出现在叙利亚阿勒颇,并于1554年在奥斯曼帝国首都伊斯坦布尔开业。1511年,由于也门麦加的宗教领袖认为咖啡具有刺激作用,便开始禁止穆斯林饮用咖啡,造成其余阿拉伯世界的苏丹和宗教领袖也相继效仿;其中两位奥斯曼帝国苏丹更是同样出于政治考量,而在1517年和1623年两度禁止咖啡。
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            同样在16世纪,与阿拉伯世界的贸易令威尼斯获得了包括咖啡在内的非洲商品,威尼斯商人则向威尼斯的上流阶级高价推销咖啡。起初意大利的宗教人士对咖啡这种穆斯林饮料持怀疑态度,并称咖啡为“撒旦的苦涩发明(bitter invention of Satan)”或是“阿拉伯酒(wine of Araby)”,1600年,教宗克莱孟八世对咖啡的争议作出裁决,在教宗品尝咖啡后认为可以饮用,并祝福了咖啡。 1616年,荷兰商人彼得·范登布罗克从也门摩卡获得了一些阿拉比卡咖啡树苗并带回了阿姆斯特丹,还在当地植物园种植成功。1658年,荷兰人首先在其殖民地锡兰和印度南部开始种植咖啡,但出于避免供应过剩而降低价格的考量,最终放弃了在锡兰种植,专注于爪哇和苏里南的种植园。
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            1675年时,英格兰就有3000多家咖啡馆;启蒙运动时期,咖啡馆成为民众深入讨论宗教和政治的聚集地,1670年代的英国国王查理二世就曾试图取缔咖啡馆。这一时期的英国人认为咖啡具有药用价值,甚至名医也会推荐将咖啡用于医疗。
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            1773年,波士顿倾茶事件后约翰·亚当斯和许多美国人认为喝茶是不爱国的,令大量美国人在美国独立战争期间改喝咖啡。
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            18世纪,葡萄牙人首先在巴西里约热内卢附近,后来则是圣保罗种植咖啡并建设种植园。1852-1950年,巴西主导了世界咖啡生产,其出口的咖啡比世界其他地区的总和还多。1950年以来,由于哥伦比亚和越南等主要生产国相继出现,而越南在1999年超过哥伦比亚成为世界第二大咖啡生产国,并在2011年达到15%的市场份额,而同年巴西的市场份额仅占33%。
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            在咖啡的原产地埃塞俄比亚,18世纪前咖啡曾被埃塞俄比亚正教会所禁止,直至19世纪后期叶埃塞俄比亚皇帝孟尼利克二世的统治时期才有所开放。
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            咖啡在19世纪中已经引入中国上海,1843年—44年上海对外贸易文献就有记载“枷榧豆5��,每包70斤”,表明当时上海已经从外国进口咖啡豆。
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            香港英文报章在1866年刊登关于coffee shop的报导。1885年香港中文报章以“咖啡”为中文名,此后逐渐成为华语地区普及使用的中文译名。香港在1840年代起有英国人聚居,由于饮食文化的差异,最初被输入到香港的咖啡豆是主要供应西方人饮用,而一般本地华人则不喜欢咖啡苦涩的味道,在早年的香港常有大量从事搬运工作的苦力在码头聚集,为来港的货轮搬运货物,从事体力劳动的苦力比一般华人更容易接触到刚循海路进口的咖啡豆,所以在华人社会中最早有饮用咖啡习惯的群体,却是社会地位低下的码头搬运工。
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            不同地区和民族之间的口味偏好,令咖啡冲泡方式以及调味品的使用多种多样,通常热咖啡添加砂糖、牛奶、奶油、奶精等调味,冷饮咖啡则有更多选择,如酒、薄荷、丁香、柠檬汁等。而不同冲泡和调味方式亦产生出了许多咖啡品类:
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            土耳其咖啡:是种具有古老历史的咖啡饮品和冲泡方式,而土耳其以外的中东国家以及东南欧皆有流行过此种冲泡方式。土耳其咖啡冲泡好后未经过滤即可直接饮用,土耳其传统上会将土耳其咖啡倒入小瓷杯中慢慢啜饮,而处于悬浊状的咖啡残留有少量咖啡渣亦成为土耳其咖啡独特风味与口感的来源。冲泡土耳其咖啡的方法为将咖啡豆研磨成粉末后装入土耳其壶中,倒入热水并与咖啡末搅拌均匀,再加人豆蔻粉充分搅拌,对土耳其壶加热并充分搅拌。咖啡煮至冒泡后停止加热,待泡沫消失,此时可短暂重复加热2次;或是将三分之一的咖啡先倒入到各个杯子中,壶中剩余的咖啡则再度加热,直到沸腾后倒入之前的杯子里。
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            浓缩咖啡:是一种通过迫使接近沸腾的高压水流通过咖啡末制作而成的咖啡,拿铁咖啡和卡布奇诺、玛琪雅朵等皆是以浓缩咖啡为基本制成的。
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            拿铁咖啡:拿铁咖啡是由浓缩咖啡和热牛奶以1:2的比例冲泡,并加入些许奶泡制成的。也可依需求加上两份浓缩咖啡,意大利语称之为“Double”。
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            卡布奇诺:卡布奇诺是一种意大利咖啡,是由在浓缩咖啡上倒入奶泡制成,由于咖啡的颜色就像方济嘉布遣会修士深褐色外衣上覆的头巾一样,卡布奇诺也因此得名。其与拿铁咖啡类似,区别仅是卡布奇诺在咖啡、牛奶、奶泡的比例为1:1:1。卡布奇诺咖啡奶泡多,而拿铁咖啡的奶泡少。口味上卡布奇诺咖啡的咖啡味重,而拿铁较为清淡一些,这是因为拿铁的牛奶更多。
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            摩卡咖啡:通常是由三分之一的意式浓缩咖啡和三分之二的奶泡配成,并加入少量巧克力糖浆或速溶巧克力粉。拉夫咖啡是在单杯浓缩咖啡中添加带有少量泡沫(0.5 厘米)的奶油而制成的咖啡。通常与香草糖一起喝用但通常使用糖浆代替香草糖。
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            玛琪雅朵咖啡:在冲泡好的浓缩咖啡上加入鲜奶并倒入一层较薄的奶泡的意大利咖啡。
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            焦糖玛琪雅朵:是一种在浓缩咖啡加入热牛奶和香草,最后淋上焦糖制成的玛琪雅朵咖啡。
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            欧蕾咖啡:是一种咖啡和牛奶的比例为1:1的牛奶咖啡,在冲泡时,需要牛奶壶和咖啡壶从两旁同时注入到咖啡杯。在星巴克则被称为Caffè Misto,以1:1比例的法式压滤咖啡搭配奶泡而成。
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            美式咖啡:是一种浓缩咖啡以1:5比例加入热水稀释制成的咖啡饮料。冲泡美式咖啡亦可使用意式咖啡机萃取浓缩咖啡,而在咖啡萃取完成后,继续使用咖啡机向浓缩咖啡加入热水稀释到合适比例即可。其浓度随浓缩咖啡的冲泡次数和添加的水量而变化,美式咖啡具有浓缩咖啡风味但却更为柔和。
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            长黑咖啡:是澳大利亚和新西兰常见的一种咖啡,是将双份浓缩咖啡倒入热水中制成的,其恰好与美式咖啡截然相反。长黑咖啡通常使用约100–120毫升的水,但水量可根据个人口味灵活调整。
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            维也纳咖啡:其制作方式为将糖或粗砂糖放入杯内再倒入热咖啡,杯上挤入鲜奶油以及巧克力膏,最终撒上彩色糖粒装饰即可。此种制法可追溯至1683年,当时乌克兰裔波兰军官耶日·弗朗西泽克·库奇茨基开设了奥地利首家咖啡馆并在维也纳开业,其普及了在咖啡中加糖和牛奶的制作和饮用方式。而维也纳咖啡传说是由奥地利马车夫爱因·舒伯纳发明。
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            爱尔兰咖啡:在咖啡中加入威士忌后在其顶部放上奶油。而加入威士忌的爱尔兰咖啡能将咖啡的酸甜味衬托出来。
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            调味咖啡:依据口味的不同在咖啡中加入巧克力、糖浆、果汁、肉桂、肉豆蔻、橘子花等不同调味料。
         | 
| 48 | 
            +
             | 
| 49 | 
            +
            康宝蓝:康宝蓝是一种在意大利浓缩咖啡上倒入适量奶油的咖啡,并用玻璃咖啡杯盛装,由于鲜奶油具有甜味因此通常无需加糖。
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            白咖啡:起源于马来西亚怡保,其使用经过人造黄油烘培的咖啡豆,冲泡好后加入甜炼乳的饮品。19世纪和20世纪初英国锡矿公司在怡保设立锡矿,而中国移民则在怡保锡矿工作,白咖啡是19世纪中后期移民马来亚的海南人出于华人不习惯咖啡味道而发明。从本质上是一种拿铁咖啡。在美国,白咖啡也指轻度烘培的咖啡豆,使用意式冲煮,具有较强酸味的咖啡。
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            越南咖啡:是一种滴漏咖啡,冲泡时先在盛装咖啡的杯子中倒入炼乳,将滴漏壶置于盛装的杯上,并向滴漏壶加入咖啡末,再以压板压住咖啡末,倒入热水后等待滴漏。越南常用的咖啡豆品种为罗布斯塔,因其带有较重的酸味与苦味以及烘焙时间较长,使得风味较重,因此需要加入炼乳饮用。
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            印度滴漏咖啡:其通常是阿拉比卡咖啡或咖啡公豆制作的;咖啡豆经过深度烘焙、研磨并与菊苣混合,咖啡占混合物的80-90%,其余的为菊苣。菊苣的轻微苦味有助于产生印度滴漏咖啡的风味,传统上使用粗糖或蜂蜜作为甜味剂,但自1900年代中期改为白糖。
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            皇家咖啡:据说拿破仑在俄法战争时,因遭遇俄国酷寒的冬天,于是命令下属在咖啡里倒入白兰地取暖而发明。其制作方式为,在预热好的咖啡杯中倒入热咖啡,将咖啡匙架在杯缘上,在咖啡匙上放置方糖后淋上白兰地并点火燃烧,火焰熄灭后将咖啡匙放入咖啡搅拌至方糖溶解即可饮用。
         | 
| 58 | 
            +
             | 
| 59 | 
            +
            黑咖啡:是使用滴滤法、渗滤法、虹吸法或加法冲泡的咖啡,在饮用时不添加牛奶、糖等调味品。速溶咖啡是不属于黑咖啡的范围的。
         | 
| 60 | 
            +
             | 
| 61 | 
            +
            希腊法拉沛咖啡:通常由速溶咖啡、糖和牛奶制成的冰咖啡,咖啡中也会倒入奶泡;其口感微甜凉爽,适宜在夏季饮用。
         | 
| 62 | 
            +
             | 
| 63 | 
            +
            阿芙佳朵:是种近乎甜点的冰咖啡,由冰淇淋上加入意大利浓缩咖啡制成。会加入焦糖来增加甜味和促进口感,或加入巧克力酱、可可粉、肉桂粉等。
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,14 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Requires Python 3.10-3.12
         | 
| 2 | 
            +
            appbuilder_sdk==1.0.6
         | 
| 3 | 
            +
            crawl4ai==0.6.3
         | 
| 4 | 
            +
            docx==0.2.4
         | 
| 5 | 
            +
            faiss-cpu==1.9.0
         | 
| 6 | 
            +
            gradio==5.27.1
         | 
| 7 | 
            +
            jieba==0.42.1
         | 
| 8 | 
            +
            mcp==1.9.4
         | 
| 9 | 
            +
            numpy==2.2.6
         | 
| 10 | 
            +
            openai==1.88.0
         | 
| 11 | 
            +
            pdfplumber==0.11.7
         | 
| 12 | 
            +
            python_docx==1.1.2
         | 
| 13 | 
            +
            Requests==2.32.4
         | 
| 14 | 
            +
            sse-starlette==2.3.6
         | 
