|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""This script provides a simple web interface that allows users to interact with""" |
|
|
|
|
|
import argparse |
|
|
import base64 |
|
|
from collections import namedtuple |
|
|
from functools import partial |
|
|
import hashlib |
|
|
import json |
|
|
import logging |
|
|
import faiss |
|
|
import os |
|
|
from argparse import ArgumentParser |
|
|
import textwrap |
|
|
|
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
|
|
|
from bot_requests import BotClient |
|
|
|
|
|
|
|
|
os.environ["NO_PROXY"] = "localhost,127.0.0.1" |
|
|
|
|
|
logging.root.setLevel(logging.INFO) |
|
|
|
|
|
FILE_URL_DEFAULT = "data/coffee.txt" |
|
|
RELEVANT_PASSAGE_DEFAULT = textwrap.dedent("""\ |
|
|
1675年时,英格兰就有3000多家咖啡馆;启蒙运动时期,咖啡馆成为民众深入讨论宗教和政治的聚集地, |
|
|
1670年代的英国国王查理二世就曾试图取缔咖啡馆。这一时期的英国人认为咖啡具有药用价值, |
|
|
甚至名医也会推荐将咖啡用于医疗。""" |
|
|
) |
|
|
|
|
|
QUERY_REWRITE_PROMPT = textwrap.dedent("""\ |
|
|
你是一个擅长问答系统和信息检索的大模型助手。 |
|
|
|
|
|
请根据用户提出的问题,判断是否需要调用文档检索系统来获取答案: |
|
|
- 若问题属于常识性、定义性或答案明确,不依赖外部资料,请标记为 "is_search": false; |
|
|
- 若问题涉及事实查证、具体数据、文档内容等,必须依赖资料检索,请标记为 "is_search": true,并将问题拆解成多个可用于检索的子问题。 |
|
|
|
|
|
要求: |
|
|
1. 子问题应语义清晰、独立,适合用于检索; |
|
|
2. 只在**确有必要**的情况下拆解,最多不超过 5 个,不要为了凑满数量而输出冗余子问题; |
|
|
3. 输出为严格的 JSON 格式,无多余注释。 |
|
|
|
|
|
【用户当前问题】: |
|
|
{query} |
|
|
|
|
|
【输出格式】: |
|
|
请仅输出如下格式的内容(符合 JSON 规范,无多余注释): |
|
|
``` |
|
|
{{ |
|
|
"is_search": true 或 false, |
|
|
"sub_query_list": ["子问题1","子问题2","..."] |
|
|
}} |
|
|
```""" |
|
|
) |
|
|
|
|
|
ANSWER_PROMPT = textwrap.dedent( |
|
|
"""\ |
|
|
你是一个乐于助人且信息丰富的机器人,使用下面提供的参考段落中的文本来回答问题。 |
|
|
请务必用完整的句子回答,内容要全面,包括所有相关的背景信息。 |
|
|
|
|
|
然而,你的对话对象是非技术人员,所以请务必分解复杂的概念,并使用友好和对话式的语气。 |
|
|
如果段落与答案无关,你可以忽略它。 |
|
|
|
|
|
问题:'{query}' |
|
|
段落:'{relevant_passage}' |
|
|
|
|
|
答案:""" |
|
|
) |
|
|
QUERY_DEFAULT = "1675 年时,英格兰有多少家咖啡馆?" |
|
|
|
|
|
|
|
|
def get_args() -> argparse.Namespace: |
|
|
""" |
|
|
Parse and return command line arguments for the ERNIE models web chat demo. |
|
|
Configures server settings, model endpoints, and document processing parameters. |
|
|
|
|
|
Returns: |
|
|
argparse.Namespace: Parsed command line arguments containing: |
|
|
- server_port: Demo server port (default: 8333) |
|
|
- server_name: Demo server host (default: "0.0.0.0") |
|
|
- model_urls: Endpoints for ERNIE and Qianfan models |
|
|
- document_processing: Chunk size, FAISS index and text DB paths |
|
|
""" |
|
|
parser = ArgumentParser(description="ERNIE models web chat demo.") |
|
|
|
|
|
parser.add_argument( |
|
|
"--server-port", type=int, default=7860, help="Demo server port." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--server-name", type=str, default="0.0.0.0", help="Demo server name." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max_char", type=int, default=8000, help="Maximum character limit for messages." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max_retry_num", type=int, default=3, help="Maximum retry number for request." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--eb45t_model_url", |
|
|
type=str, |
|
|
default="https://qianfan.baidubce.com/v2", |
|
|
help="Model URL for multimodal model." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--qianfan_url", |
|
|
type=str, |
|
|
default="https://qianfan.baidubce.com/v2", |
|
|
help="Qianfan URL." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--qianfan_api_key", |
|
|
type=str, |
|
|
default=os.environ.get("API_KEY"), |
|
|
help="Qianfan API key." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--embedding_model", |
|
|
type=str, |
|
|
default="embedding-v1", |
|
|
help="Embedding model name." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--chunk_size", |
|
|
type=int, |
|
|
default=512, |
|
|
help="Chunk size for splitting long documents." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--faiss_index_path", |
|
|
type=str, |
|
|
default="data/faiss_index", |
|
|
help="Faiss index path." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--text_db_path", |
|
|
type=str, |
|
|
default="data/text_db.jsonl", |
|
|
help="Text database path." |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
return args |
|
|
|
|
|
|
|
|
class FaissTextDatabase: |
|
|
""" |
|
|
A vector database for text retrieval using FAISS (Facebook AI Similarity Search). |
|
|
Provides efficient similarity search and document management capabilities. |
|
|
""" |
|
|
def __init__(self, args, bot_client: BotClient, embedding_dim: int=384): |
|
|
""" |
|
|
Initialize the FaissTextDatabase. |
|
|
|
|
|
Args: |
|
|
args: arguments for initialization |
|
|
bot_client: instance of BotClient |
|
|
embedding_dim: dimension of the embedding vector |
|
|
""" |
|
|
self.logger = logging.getLogger(__name__) |
|
|
|
|
|
self.bot_client = bot_client |
|
|
self.faiss_index_path = getattr(args, "faiss_index_path", "data/faiss_index") |
|
|
self.text_db_path = getattr(args, "text_db_path", "data/text_db.jsonl") |
|
|
self.embedding_dim = embedding_dim |
|
|
|
|
|
|
|
|
if os.path.exists(self.faiss_index_path) and os.path.exists(self.text_db_path): |
|
|
self.index = faiss.read_index(self.faiss_index_path) |
|
|
with open(self.text_db_path, 'r', encoding='utf-8') as f: |
|
|
self.text_db = json.load(f) |
|
|
else: |
|
|
self.index = faiss.IndexFlatIP(self.embedding_dim) |
|
|
self.text_db = { |
|
|
"file_md5s": [], |
|
|
"chunks": [] |
|
|
} |
|
|
|
|
|
def calculate_md5(self, file_path: str) -> str: |
|
|
""" |
|
|
Calculate the MD5 hash of a file |
|
|
|
|
|
Args: |
|
|
file_path: the path of the source file |
|
|
|
|
|
Returns: |
|
|
str: the MD5 hash |
|
|
""" |
|
|
with open(file_path, "rb") as f: |
|
|
return hashlib.md5(f.read()).hexdigest() |
|
|
|
|
|
def is_file_processed(self, file_path: str) -> bool: |
|
|
""" |
|
|
Check if the file has been processed before |
|
|
|
|
|
Args: |
|
|
file_path: the path of the source file |
|
|
|
|
|
Returns: |
|
|
bool: whether the file has been processed |
|
|
""" |
|
|
file_md5 = self.calculate_md5(file_path) |
|
|
return file_md5 in self.text_db["file_md5s"] |
|
|
|
|
|
def add_embeddings(self, file_path: str, segments: list[str], progress_bar: gr.Progress=None) -> bool: |
|
|
""" |
|
|
Stores document embeddings in FAISS database after checking for duplicates. |
|
|
Generates embeddings for each text segment, updates the FAISS index and metadata database, |
|
|
and persists changes to disk. Includes optional progress tracking for Gradio interfaces. |
|
|
|
|
|
Args: |
|
|
file_path: the path of the source file |
|
|
segments: the list of segments |
|
|
progress_bar: the progress bar object |
|
|
|
|
|
Returns: |
|
|
bool: whether the operation was successful |
|
|
""" |
|
|
file_md5 = self.calculate_md5(file_path) |
|
|
if file_md5 in self.text_db["file_md5s"]: |
|
|
self.logger.info("File already processed: {file_path} (MD5: {file_md5})".format( |
|
|
file_path=file_path, |
|
|
file_md5=file_md5 |
|
|
)) |
|
|
return False |
|
|
|
|
|
|
|
|
vectors = [] |
|
|
file_name = os.path.basename(file_path) |
|
|
for i, segment in enumerate(segments): |
|
|
vectors.append(self.bot_client.embed_fn(segment)) |
|
|
if progress_bar is not None: |
|
|
progress_bar((i + 1) / len(segments), desc=file_name + " Processing...") |
|
|
vectors = np.array(vectors) |
|
|
self.index.add(vectors.astype('float32')) |
|
|
|
|
|
start_id = len(self.text_db["chunks"]) |
|
|
for i, text in enumerate(segments): |
|
|
self.text_db["chunks"].append({ |
|
|
"file_md5": file_md5, |
|
|
"text": text, |
|
|
"vector_id": start_id + i |
|
|
}) |
|
|
|
|
|
self.text_db["file_md5s"].append(file_md5) |
|
|
self.save() |
|
|
return True |
|
|
|
|
|
def search_with_context(self, query: str, context_size: int=2) -> str: |
|
|
""" |
|
|
Finds the most relevant text chunk for a query and includes surrounding context. |
|
|
Uses FAISS to find the closest matching embedding, then retrieves adjacent chunks |
|
|
from the same source document to provide better context understanding. |
|
|
|
|
|
Args: |
|
|
query: the input query string |
|
|
context_size: the number of surrounding chunks to include |
|
|
|
|
|
Returns: |
|
|
str: the relevant chunk with context |
|
|
""" |
|
|
query_vector = np.array([self.bot_client.embed_fn(query)]).astype('float32') |
|
|
distances, indices = self.index.search(query_vector, 1) |
|
|
|
|
|
target_idx = indices[0][0] |
|
|
target_chunk = self.text_db["chunks"][target_idx] |
|
|
target_file_md5 = target_chunk["file_md5"] |
|
|
self.logger.info("Similarity: {}".format(distances[0][0])) |
|
|
self.logger.info("Target Chunk: {}".format(self.text_db["chunks"][target_idx]["text"])) |
|
|
|
|
|
|
|
|
start = max(0, target_idx - context_size) |
|
|
end = min(len(self.text_db["chunks"]) - 1, target_idx + context_size) |
|
|
result = "" |
|
|
for pos in range(start, end + 1): |
|
|
if self.text_db["chunks"][pos]["file_md5"] == target_file_md5: |
|
|
result += self.text_db["chunks"][pos]["text"] + "\n" |
|
|
|
|
|
return result |
|
|
|
|
|
def save(self) -> None: |
|
|
"""Save the database to disk.""" |
|
|
faiss.write_index(self.index, self.faiss_index_path) |
|
|
|
|
|
with open(self.text_db_path, 'w', encoding='utf-8') as f: |
|
|
json.dump(self.text_db, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
|
|
|
class GradioEvents(object): |
|
|
""" |
|
|
Manages event handling and UI interactions for Gradio applications. |
|
|
Provides methods to process user inputs, trigger callbacks, and update interface components. |
|
|
""" |
|
|
@staticmethod |
|
|
def chat_stream( |
|
|
query: str, |
|
|
task_history: list, |
|
|
model: str, |
|
|
bot_client: BotClient, |
|
|
faiss_db: FaissTextDatabase, |
|
|
) -> dict: |
|
|
""" |
|
|
Streams chatbot responses by processing queries with context from history and FAISS database. |
|
|
Integrates language model generation with knowledge retrieval to produce dynamic responses. |
|
|
Yields response events in real-time for interactive conversation experiences. |
|
|
|
|
|
Args: |
|
|
query (str): The query string. |
|
|
task_history (list): The task history record list. |
|
|
model (Model): The model used to generate responses. |
|
|
bot_client (BotClient): The chatbot client object. |
|
|
faiss_db (FaissTextDatabase): The FAISS database object. |
|
|
|
|
|
Yields: |
|
|
dict: A dictionary containing the event type and its corresponding content. |
|
|
""" |
|
|
search_info_result = GradioEvents.get_sub_query(query, model, bot_client) |
|
|
if search_info_result.get("is_search", False) and search_info_result.get("sub_query_list", []): |
|
|
relevant_passage = GradioEvents.get_relevant_passage( |
|
|
search_info_result["sub_query_list"], |
|
|
faiss_db |
|
|
) |
|
|
yield {"type": "relevant_passage", "content": relevant_passage} |
|
|
input = ANSWER_PROMPT.format(query=query, relevant_passage=relevant_passage) |
|
|
else: |
|
|
input = query |
|
|
|
|
|
conversation = [] |
|
|
for query_h, response_h in task_history: |
|
|
conversation.append({"role": "user", "content": query_h}) |
|
|
conversation.append({"role": "assistant", "content": response_h}) |
|
|
conversation.append({"role": "user", "content": input}) |
|
|
|
|
|
try: |
|
|
req_data = {"messages": conversation} |
|
|
for chunk in bot_client.process_stream(model, req_data): |
|
|
if "error" in chunk: |
|
|
raise Exception(chunk["error"]) |
|
|
|
|
|
message = chunk.get("choices", [{}])[0].get("delta", {}) |
|
|
content = message.get("content", "") |
|
|
reasoning_content = message.get("reasoning_content", "") |
|
|
|
|
|
if reasoning_content: |
|
|
yield {"type": "thinking", "content": reasoning_content} |
|
|
if content: |
|
|
yield {"type": "answer", "content": content} |
|
|
|
|
|
except Exception as e: |
|
|
raise gr.Error("Exception: " + repr(e)) |
|
|
|
|
|
@staticmethod |
|
|
def predict_stream( |
|
|
query: str, |
|
|
chatbot: list, |
|
|
task_history: list, |
|
|
model: str, |
|
|
bot_client: BotClient, |
|
|
faiss_db: FaissTextDatabase |
|
|
) -> tuple: |
|
|
""" |
|
|
Generates streaming responses by combining model predictions with knowledge retrieval. |
|
|
Processes user queries using conversation history and FAISS database context, |
|
|
yielding updated chat messages and relevant passages in real-time. |
|
|
|
|
|
Args: |
|
|
query (str): The content of the user's input query. |
|
|
chatbot (list): The chatbot's historical message list. |
|
|
task_history (list): The task history record list. |
|
|
model (Model): The model used to generate responses. |
|
|
bot_client (object): The chatbot client object. |
|
|
faiss_db (FaissTextDatabase): The FAISS database instance. |
|
|
|
|
|
Yields: |
|
|
tuple: A tuple containing the updated chatbot's message list and the relevant passage. |
|
|
""" |
|
|
query = query if query else QUERY_DEFAULT |
|
|
|
|
|
logging.info("User: {}".format(query)) |
|
|
chatbot.append({"role": "user", "content": query}) |
|
|
|
|
|
|
|
|
yield chatbot, None |
|
|
|
|
|
new_texts = GradioEvents.chat_stream( |
|
|
query, |
|
|
task_history, |
|
|
model, |
|
|
bot_client, |
|
|
faiss_db, |
|
|
) |
|
|
reasoning_content = "" |
|
|
response = "" |
|
|
has_thinking = False |
|
|
current_relevant_passage = None |
|
|
for new_text in new_texts: |
|
|
if not isinstance(new_text, dict): |
|
|
continue |
|
|
|
|
|
if new_text.get("type") == "embedding": |
|
|
current_relevant_passage = new_text["content"] |
|
|
yield chatbot, current_relevant_passage |
|
|
continue |
|
|
elif new_text.get("type") == "relevant_passage": |
|
|
current_relevant_passage = new_text["content"] |
|
|
yield chatbot, current_relevant_passage |
|
|
continue |
|
|
elif new_text.get("type") == "thinking": |
|
|
has_thinking = True |
|
|
reasoning_content += new_text["content"] |
|
|
elif new_text.get("type") == "answer": |
|
|
response += new_text["content"] |
|
|
|
|
|
|
|
|
if chatbot[-1].get("role") == "assistant": |
|
|
chatbot.pop(-1) |
|
|
|
|
|
content = "" |
|
|
if has_thinking: |
|
|
content = "**思考过程:**<br>{}<br>".format(reasoning_content) |
|
|
if response: |
|
|
if has_thinking: |
|
|
content += "<br><br>**最终回答:**<br>{}".format(response) |
|
|
else: |
|
|
content = response |
|
|
|
|
|
if content: |
|
|
chatbot.append({"role": "assistant", "content": content}) |
|
|
yield chatbot, current_relevant_passage |
|
|
|
|
|
logging.info("History: {}".format(task_history)) |
|
|
task_history.append((query, response)) |
|
|
logging.info("ERNIE models: {}".format(response)) |
|
|
|
|
|
@staticmethod |
|
|
def regenerate( |
|
|
chatbot: list, |
|
|
task_history: list, |
|
|
model: str, |
|
|
bot_client: BotClient, |
|
|
faiss_db: FaissTextDatabase |
|
|
) -> tuple: |
|
|
""" |
|
|
Regenerate the chatbot's response based on the latest user query |
|
|
|
|
|
Args: |
|
|
chatbot (list): Chat history list |
|
|
task_history (list): Task history |
|
|
model (str): Model name to use |
|
|
bot_client (BotClient): Bot request client instance |
|
|
faiss_db (FaissTextDatabase): Faiss database instance |
|
|
|
|
|
Yields: |
|
|
tuple: Updated chatbot and relevant_passage |
|
|
""" |
|
|
if not task_history: |
|
|
yield chatbot, None |
|
|
return |
|
|
|
|
|
item = task_history.pop(-1) |
|
|
while len(chatbot) != 0 and chatbot[-1].get("role") == "assistant": |
|
|
chatbot.pop(-1) |
|
|
chatbot.pop(-1) |
|
|
|
|
|
for chunk, relevant_passage in GradioEvents.predict_stream( |
|
|
item[0], |
|
|
chatbot, |
|
|
task_history, |
|
|
model, |
|
|
bot_client, |
|
|
faiss_db |
|
|
): |
|
|
yield chunk, relevant_passage |
|
|
|
|
|
@staticmethod |
|
|
def reset_user_input() -> gr.update: |
|
|
""" |
|
|
Reset user input box content. |
|
|
|
|
|
Returns: |
|
|
gr.update: An update object representing the cleared value |
|
|
""" |
|
|
return gr.update(value="") |
|
|
|
|
|
@staticmethod |
|
|
def reset_state() -> namedtuple: |
|
|
""" |
|
|
Reset chat state and clear all history. |
|
|
|
|
|
Returns: |
|
|
tuple: A named tuple containing the updated values for chatbot, task_history, file_btn, and relevant_passage |
|
|
""" |
|
|
GradioEvents.gc() |
|
|
|
|
|
reset_result = namedtuple("reset_result", |
|
|
["chatbot", |
|
|
"task_history", |
|
|
"file_btn", |
|
|
"relevant_passage"]) |
|
|
return reset_result( |
|
|
[], |
|
|
[], |
|
|
gr.update(value=None), |
|
|
gr.update(value=None) |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def gc(): |
|
|
""" |
|
|
Force garbage collection to free memory. |
|
|
""" |
|
|
import gc |
|
|
|
|
|
gc.collect() |
|
|
|
|
|
@staticmethod |
|
|
def get_image_url(image_path: str) -> str: |
|
|
""" |
|
|
Encode image file to Base64 format and generate data URL. |
|
|
Reads an image file from disk, encodes it as Base64, and formats it |
|
|
as a data URL that can be used directly in HTML or API requests. |
|
|
|
|
|
Args: |
|
|
image_path (str): Path to the image file. Must be a valid file path. |
|
|
|
|
|
Returns: |
|
|
str: Data URL string in format "data:image/{ext};base64,{encoded_data}" |
|
|
""" |
|
|
base64_image = "" |
|
|
extension = image_path.split(".")[-1] |
|
|
with open(image_path, "rb") as image_file: |
|
|
base64_image = base64.b64encode(image_file.read()).decode("utf-8") |
|
|
url = "data:image/{ext};base64,{img}".format(ext=extension, img=base64_image) |
|
|
return url |
|
|
|
|
|
@staticmethod |
|
|
def get_relevant_passage( |
|
|
sub_query_list: list, |
|
|
faiss_db: FaissTextDatabase |
|
|
) -> str: |
|
|
""" |
|
|
Retrieve the relevant passage from the database based on the query. |
|
|
|
|
|
Args: |
|
|
sub_query_list (list): List of sub-queries. |
|
|
faiss_db (FaissTextDatabase): The FAISS database instance. |
|
|
|
|
|
Returns: |
|
|
str: The relevant passage. |
|
|
""" |
|
|
relevant_passages = "" |
|
|
for idx, query_item in enumerate(sub_query_list): |
|
|
relevant_passage = faiss_db.search_with_context(query_item) |
|
|
relevant_passages += "\n段落{idx}:\n{relevant_passage}".format(idx=idx + 1, relevant_passage=relevant_passage) |
|
|
|
|
|
return relevant_passages |
|
|
|
|
|
@staticmethod |
|
|
def get_sub_query(query: str, model_name: str, bot_client: BotClient) -> dict: |
|
|
""" |
|
|
Enhances user queries by generating alternative phrasings using language models. |
|
|
Creates semantically similar variations of the original query to improve retrieval accuracy. |
|
|
Returns structured dictionary containing both original and rephrased queries. |
|
|
|
|
|
Args: |
|
|
query (str): The query to rephrase. |
|
|
model_name (str): The name of the model to use for rephrasing. |
|
|
bot_client (BotClient): The bot client instance. |
|
|
|
|
|
Returns: |
|
|
dict: The rephrased query. |
|
|
""" |
|
|
query = QUERY_REWRITE_PROMPT.format(query=query) |
|
|
conversation = [{"role": "user", "content": query}] |
|
|
req_data = {"messages": conversation} |
|
|
try: |
|
|
response = bot_client.process(model_name, req_data) |
|
|
search_info_res = response["choices"][0]["message"]["content"] |
|
|
start = search_info_res.find("{") |
|
|
end = search_info_res.rfind("}") + 1 |
|
|
if start >= 0 and end > start: |
|
|
search_info_res = search_info_res[start:end] |
|
|
search_info_res = json.loads(search_info_res) |
|
|
if search_info_res.get("sub_query_list", []): |
|
|
unique_list = list(set(search_info_res["sub_query_list"])) |
|
|
search_info_res["sub_query_list"] = unique_list |
|
|
return search_info_res |
|
|
except Exception: |
|
|
raise gr.Error("Error: Model output is not a valid JSON") |
|
|
|
|
|
@staticmethod |
|
|
def split_oversized_line(line: str, chunk_size: int) -> tuple: |
|
|
""" |
|
|
Split a line into two parts based on punctuation marks or whitespace while preserving |
|
|
natural language boundaries and maintaining the original content structure. |
|
|
|
|
|
Args: |
|
|
line (str): The line to split. |
|
|
chunk_size (int): The maximum length of each chunk. |
|
|
|
|
|
Returns: |
|
|
tuple: Two strings, the first part of the original line and the rest of the line. |
|
|
""" |
|
|
PUNCTUATIONS = [".", "。", "!", "!", "?", "?", ",", ",", ";", ";", ":", ":"] |
|
|
|
|
|
if len(line) <= chunk_size: |
|
|
return line, "" |
|
|
|
|
|
|
|
|
split_pos = chunk_size |
|
|
for i in range(chunk_size, 0, -1): |
|
|
if line[i] in PUNCTUATIONS: |
|
|
split_pos = i + 1 |
|
|
break |
|
|
|
|
|
|
|
|
if split_pos == chunk_size: |
|
|
split_pos = line.rfind(" ", 0, chunk_size) |
|
|
if split_pos == -1: |
|
|
split_pos = chunk_size |
|
|
|
|
|
return line[:split_pos], line[split_pos:] |
|
|
|
|
|
@staticmethod |
|
|
def split_text_into_chunks(text: str, chunk_size: int) -> list: |
|
|
""" |
|
|
Split text into chunks of a specified size while respecting natural language boundaries |
|
|
and avoiding mid-word splits whenever possible. |
|
|
|
|
|
Args: |
|
|
text (str): The text to split. |
|
|
chunk_size (int): The maximum length of each chunk. |
|
|
|
|
|
Returns: |
|
|
list: A list of strings, where each element represents a chunk of the original text. |
|
|
""" |
|
|
lines = [line.strip() for line in text.split('\n') if line.strip()] |
|
|
chunks = [] |
|
|
current_chunk = [] |
|
|
current_length = 0 |
|
|
|
|
|
for line in lines: |
|
|
|
|
|
|
|
|
if current_length + len(line) > chunk_size and current_chunk: |
|
|
chunks.append(" ".join(current_chunk)) |
|
|
current_chunk = [] |
|
|
current_length = 0 |
|
|
|
|
|
|
|
|
while len(line) > chunk_size: |
|
|
head, line = GradioEvents.split_oversized_line(line, chunk_size) |
|
|
chunks.append(head) |
|
|
|
|
|
|
|
|
if line: |
|
|
current_chunk.append(line) |
|
|
current_length += len(line) + 1 |
|
|
|
|
|
if current_chunk: |
|
|
chunks.append(" ".join(current_chunk)) |
|
|
return chunks |
|
|
|
|
|
@staticmethod |
|
|
def file_upload( |
|
|
files_url: list, |
|
|
chunk_size: int, |
|
|
faiss_db: FaissTextDatabase, |
|
|
progress_bar: gr.Progress = gr.Progress() |
|
|
) -> str: |
|
|
""" |
|
|
Uploads and processes multiple files by splitting them into semantically meaningful chunks, |
|
|
then indexes them in the FAISS database with progress tracking. |
|
|
|
|
|
Args: |
|
|
files_url (list): List of file URLs. |
|
|
chunk_size (int): Maximum chunk size. |
|
|
faiss_db (FaissTextDatabase): FAISS database instance. |
|
|
progress_bar (gr.Progress): Progress bar instance. |
|
|
|
|
|
Returns: |
|
|
str: Message indicating successful completion. |
|
|
""" |
|
|
if not files_url: |
|
|
return |
|
|
yield gr.update(visible=True) |
|
|
for file_url in files_url: |
|
|
if not GradioEvents.save_file_to_db(file_url, chunk_size, faiss_db, progress_bar): |
|
|
file_name = os.path.basename(file_url) |
|
|
gr.Info("{} already processed.".format(file_name)) |
|
|
|
|
|
yield gr.update(visible=False) |
|
|
|
|
|
@staticmethod |
|
|
def save_file_to_db(file_url: str, chunk_size: int, faiss_db: FaissTextDatabase, progress_bar: gr.Progress=None): |
|
|
""" |
|
|
Processes and indexes document content into FAISS database with semantic-aware chunking. |
|
|
Handles file validation, text segmentation, embedding generation and storage operations. |
|
|
|
|
|
Args: |
|
|
file_url (str): File URL. |
|
|
chunk_size (int): Chunk size. |
|
|
faiss_db (FaissTextDatabase): FAISS database instance. |
|
|
progress_bar (gr.Progress): Progress bar instance. |
|
|
|
|
|
Returns: |
|
|
bool: True if the file was saved successfully, otherwise False. |
|
|
""" |
|
|
file_name = os.path.basename(file_url) |
|
|
if not faiss_db.is_file_processed(file_url): |
|
|
logging.info("{} not processed yet, processing now...".format(file_url)) |
|
|
try: |
|
|
with open(file_url, "r", encoding="utf-8") as f: |
|
|
text = f.read() |
|
|
segments = GradioEvents.split_text_into_chunks(text, chunk_size) |
|
|
faiss_db.add_embeddings(file_url, segments, progress_bar) |
|
|
|
|
|
logging.info("{} processed successfully.".format(file_url)) |
|
|
return True |
|
|
except Exception as e: |
|
|
logging.error("Error processing {}: {}".format(file_url, str(e))) |
|
|
gr.Error("Error processing file: {}".format(file_name)) |
|
|
raise |
|
|
else: |
|
|
logging.info("{} already processed.".format(file_url)) |
|
|
return False |
|
|
|
|
|
|
|
|
def launch_demo(args: argparse.Namespace, bot_client: BotClient, faiss_db: FaissTextDatabase): |
|
|
""" |
|
|
Launch demo program |
|
|
|
|
|
Args: |
|
|
args (argparse.Namespace): argparse Namespace object containing parsed command line arguments |
|
|
bot_client (BotClient): Bot client instance |
|
|
faiss_db (FaissTextDatabase): FAISS database instance |
|
|
""" |
|
|
css = """ |
|
|
/* Hide original Chinese text */ |
|
|
#file-upload .wrap { |
|
|
font-size: 0 !important; |
|
|
position: relative; |
|
|
display: flex; |
|
|
flex-direction: column; |
|
|
align-items: center; |
|
|
justify-content: center; |
|
|
} |
|
|
|
|
|
/* Insert English prompt text below the SVG icon */ |
|
|
#file-upload .wrap::after { |
|
|
content: "Drag and drop files here or click to upload"; |
|
|
font-size: 18px; |
|
|
color: #555; |
|
|
margin-top: 8px; |
|
|
white-space: nowrap; |
|
|
} |
|
|
""" |
|
|
with gr.Blocks(css=css) as demo: |
|
|
model_name = gr.State("eb-45t") |
|
|
|
|
|
logo_url = GradioEvents.get_image_url("assets/logo.png") |
|
|
gr.Markdown("""\ |
|
|
<p align="center"><img src="{}" \ |
|
|
style="height: 60px"/><p>""".format(logo_url)) |
|
|
gr.Markdown( |
|
|
"""\ |
|
|
<center><font size=3>This demo is based on ERNIE models. \ |
|
|
(本演示基于文心大模型实现。)</center>""" |
|
|
) |
|
|
|
|
|
chatbot = gr.Chatbot( |
|
|
label="ERNIE", |
|
|
type="messages" |
|
|
) |
|
|
|
|
|
with gr.Row(equal_height=True): |
|
|
file_btn = gr.File( |
|
|
label="Knowledge Base Upload (System default will be used if none provided. Accepted formats: TXT, MD)", |
|
|
height="150px", |
|
|
file_types=[".txt", ".md"], |
|
|
elem_id="file-upload", |
|
|
file_count="multiple" |
|
|
) |
|
|
relevant_passage = gr.Textbox( |
|
|
label="Relevant Passage", |
|
|
lines=5, |
|
|
max_lines=5, |
|
|
placeholder=RELEVANT_PASSAGE_DEFAULT, |
|
|
interactive=False |
|
|
) |
|
|
with gr.Row(): |
|
|
progress_bar = gr.Textbox(label="Progress", visible=False) |
|
|
|
|
|
query = gr.Textbox(label="Query", elem_id="text_input", value=QUERY_DEFAULT) |
|
|
|
|
|
with gr.Row(): |
|
|
empty_btn = gr.Button("🧹 Clear History(清除历史)") |
|
|
submit_btn = gr.Button("🚀 Submit(发送)", elem_id="submit-button") |
|
|
regen_btn = gr.Button("🤔️ Regenerate(重试)") |
|
|
|
|
|
task_history = gr.State([]) |
|
|
|
|
|
predict_with_clients = partial( |
|
|
GradioEvents.predict_stream, |
|
|
bot_client=bot_client, |
|
|
faiss_db=faiss_db |
|
|
) |
|
|
regenerate_with_clients = partial( |
|
|
GradioEvents.regenerate, |
|
|
bot_client=bot_client, |
|
|
faiss_db=faiss_db |
|
|
) |
|
|
file_upload_with_clients = partial( |
|
|
GradioEvents.file_upload, |
|
|
faiss_db=faiss_db |
|
|
) |
|
|
|
|
|
chunk_size = gr.State(args.chunk_size) |
|
|
file_btn.change( |
|
|
fn=file_upload_with_clients, |
|
|
inputs=[file_btn, chunk_size], |
|
|
outputs=[progress_bar], |
|
|
) |
|
|
query.submit( |
|
|
predict_with_clients, |
|
|
inputs=[query, chatbot, task_history, model_name], |
|
|
outputs=[chatbot, relevant_passage], |
|
|
show_progress=True |
|
|
) |
|
|
query.submit(GradioEvents.reset_user_input, [], [query]) |
|
|
submit_btn.click( |
|
|
predict_with_clients, |
|
|
inputs=[query, chatbot, task_history, model_name], |
|
|
outputs=[chatbot, relevant_passage], |
|
|
show_progress=True, |
|
|
) |
|
|
submit_btn.click(GradioEvents.reset_user_input, [], [query]) |
|
|
empty_btn.click( |
|
|
GradioEvents.reset_state, |
|
|
outputs=[chatbot, task_history, file_btn, relevant_passage], show_progress=True |
|
|
) |
|
|
regen_btn.click( |
|
|
regenerate_with_clients, |
|
|
inputs=[chatbot, task_history, model_name], |
|
|
outputs=[chatbot, relevant_passage], |
|
|
show_progress=True |
|
|
) |
|
|
|
|
|
demo.queue().launch( |
|
|
server_port=args.server_port, |
|
|
server_name=args.server_name |
|
|
) |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main function that runs when this script is executed.""" |
|
|
args = get_args() |
|
|
bot_client = BotClient(args) |
|
|
faiss_db = FaissTextDatabase(args, bot_client) |
|
|
|
|
|
|
|
|
GradioEvents.save_file_to_db(FILE_URL_DEFAULT, args.chunk_size, faiss_db) |
|
|
|
|
|
launch_demo(args, bot_client, faiss_db) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|