Spaces:
Runtime error
Runtime error
| from langchain_huggingface import HuggingFaceEmbeddings | |
| import config | |
| import prompt_templete | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.runnables import RunnableLambda, RunnablePassthrough | |
| from langchain_core.documents import Document | |
| import logging | |
| from langchain_core.output_parsers import StrOutputParser, JsonOutputParser | |
| from typing import List,Any,Dict | |
| from langchain_weaviate.vectorstores import WeaviateVectorStore | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from utils.process_data import filter_and_serialize_complex_metadata | |
| import weaviate | |
| import weaviate.classes.config as wvc_config | |
| from weaviate.exceptions import WeaviateQueryException | |
| import time | |
| from operator import itemgetter | |
| logger = logging.getLogger(__name__) | |
| WEAVIATE_SCHEMA_CONFIG: List[Dict[str, Any]] = [ | |
| # Tên trường, Kiểu dữ liệu trong Weaviate, Có nên vector hóa trường này không? | |
| {"name": "source", "dataType": wvc_config.DataType.TEXT,"index_searchable": False, "vectorize": False}, | |
| {"name": "title", "dataType": wvc_config.DataType.TEXT, "index_searchable": True, "tokenization": wvc_config.Tokenization.WORD, "vectorize": True}, | |
| {"name": "field", "dataType": wvc_config.DataType.TEXT,"index_searchable": True, "vectorize": True}, | |
| {"name": "so_hieu", "dataType": wvc_config.DataType.TEXT, "index_searchable": False,"vectorize": False}, | |
| {"name": "loai_van_ban", "dataType": wvc_config.DataType.TEXT, "index_searchable": True,"vectorize": True}, | |
| {"name": "ten_van_ban", "dataType": wvc_config.DataType.TEXT,"index_searchable": True, "tokenization": wvc_config.Tokenization.WORD, "vectorize": True}, | |
| {"name": "co_quan_ban_hanh", "dataType": wvc_config.DataType.TEXT, "index_searchable": False,"vectorize": False}, | |
| {"name": "ngay_ban_hanh_str", "dataType": wvc_config.DataType.TEXT,"index_searchable": False, "vectorize": False}, | |
| {"name": "nam_ban_hanh", "dataType": wvc_config.DataType.INT,"index_searchable": True, "vectorize": False}, | |
| {"name": "phan_code", "dataType": wvc_config.DataType.TEXT,"index_searchable": False, "vectorize": False}, | |
| {"name": "chuong_code", "dataType": wvc_config.DataType.TEXT, "index_searchable": False,"vectorize": False}, | |
| {"name": "muc_code", "dataType": wvc_config.DataType.TEXT,"index_searchable": False, "vectorize": False}, | |
| {"name": "dieu_code", "dataType": wvc_config.DataType.TEXT,"index_searchable": False, "vectorize": False}, | |
| {"name": "entity_type", "dataType": wvc_config.DataType.TEXT,"index_searchable": True, "vectorize": False}, | |
| {"name": "penalties", "dataType": wvc_config.DataType.TEXT,"index_searchable": False, "vectorize": False}, | |
| {"name": "cross_references", "dataType": wvc_config.DataType.TEXT, "index_searchable": False, "vectorize": False}, | |
| ] | |
| # Hàm get_huggingface_embeddings giữ nguyên | |
| def get_huggingface_embeddings(model_name: str, device: str = 'cpu'): | |
| logger.info(f"🔸Đang khởi tạo model embedding: {model_name} trên thiết bị {device}...") | |
| model_kwargs = { | |
| 'device': device, | |
| 'trust_remote_code': True # thêm để đảm bảo load được những model custom | |
| } | |
| encode_kwargs = { | |
| 'batch_size': 32, # kích thước batch cho embedding | |
| 'normalize_embeddings': True # normalize để cosine similarity chuẩn | |
| } | |
| try: | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name=model_name, | |
| model_kwargs=model_kwargs, | |
| encode_kwargs=encode_kwargs | |
| ) | |
| logger.info("🔸Khởi tạo model embedding thành công.") | |
| return embeddings | |
| except Exception as e: | |
| logger.error(f"🔸Lỗi khi khởi tạo model embedding: {e}") | |
| raise Exception(f"Khởi tạo model embedding thất bại: {str(e)}") | |
| # Begin New | |
| def create_weaviate_schema_if_not_exists(client: weaviate.WeaviateClient, collection_name: str): | |
| """ | |
| CẢI TIẾN: Tạo schema với cấu hình chi tiết cho filtering và hybrid search. | |
| """ | |
| if client.collections.exists(collection_name): | |
| logger.info(f"✅ Schema for collection '{collection_name}' already exists.") | |
| return | |
| logger.info(f"🔸 Schema for collection '{collection_name}' not found. Creating...") | |
| try: | |
| properties = [] | |
| for prop_config in WEAVIATE_SCHEMA_CONFIG: | |
| properties.append( | |
| wvc_config.Property( | |
| name=prop_config["name"], | |
| data_type=prop_config["dataType"], | |
| # Bỏ qua vector hóa nếu vectorize=False hoặc không được định nghĩa | |
| skip_vectorization=not prop_config.get("vectorize", False), | |
| # Kích hoạt tokenization cho các trường cần tìm kiếm từ khóa | |
| tokenization=prop_config.get("tokenization") | |
| ) | |
| ) | |
| # Thêm trường 'text' chính, tối ưu cho cả vector và keyword search | |
| properties.append( | |
| wvc_config.Property( | |
| name="text", | |
| data_type=wvc_config.DataType.TEXT, | |
| skip_vectorization=False, # Luôn vector hóa nội dung chính | |
| tokenization=wvc_config.Tokenization.WORD # Cho phép tìm kiếm BM25 trên nội dung | |
| ) | |
| ) | |
| client.collections.create( | |
| name=collection_name, | |
| properties=properties, | |
| # Kích hoạt inverted index (bắt buộc cho filtering và BM25) | |
| inverted_index_config=wvc_config.Configure.inverted_index( | |
| index_null_state=True, | |
| index_property_length=True, | |
| index_timestamps=True, | |
| bm25_b=0.75, # Tham số BM25, có thể điều chỉnh | |
| bm25_k1=1.2 # Tham số K1 cho BM25 | |
| ), | |
| vectorizer_config=wvc_config.Configure.Vectorizer.none(), | |
| vector_index_config=wvc_config.Configure.VectorIndex.hnsw( | |
| distance_metric=wvc_config.VectorDistances.COSINE | |
| ) | |
| ) | |
| logger.info(f"✅ Successfully created schema for collection '{collection_name}'.") | |
| except WeaviateQueryException as e: | |
| logger.error(f"❌ Error creating schema: {e}", exc_info=True) | |
| raise | |
| def ingest_chunks_with_native_batching(client: weaviate.WeaviateClient, collection_name: str, chunks: List[Document], embeddings_model): | |
| """Sử dụng API batch gốc của Weaviate, an toàn và hiệu suất cao.""" | |
| logger.info(f"🚀 Bắt đầu quá trình ingestion cho {len(chunks)} chunks...") | |
| texts_to_embed = [chunk.page_content for chunk in chunks] | |
| logger.info(f"🧠 Đang tạo embeddings cho {len(texts_to_embed)} chunks...") | |
| start_embed_time = time.time() | |
| chunk_vectors = embeddings_model.embed_documents(texts_to_embed) | |
| logger.info(f"⏱️ Thời gian tạo embedding: {time.time() - start_embed_time:.2f} giây.") | |
| # 3. CẢI TIẾN: Đảm bảo chỉ ingest các thuộc tính hợp lệ | |
| valid_property_names = {prop["name"] for prop in WEAVIATE_SCHEMA_CONFIG} | |
| valid_property_names.add("text") # Thêm trường 'text' | |
| with client.batch.dynamic() as batch: | |
| for i, chunk in enumerate(chunks): | |
| if not isinstance(chunk,Document) or not hasattr(chunk, 'id') or not chunk.id: | |
| logger.warning(f"Bỏ qua chunk ở vị trí {i} do không hợp lệ (sai type hoặc thiếu ID).") | |
| continue | |
| properties = {"text": chunk.page_content} | |
| # Lọc metadata để chỉ giữ lại các key hợp lệ đã định nghĩa trong schema | |
| filtered_metadata = { | |
| k: v for k, v in chunk.metadata.items() if k in valid_property_names | |
| } | |
| properties.update(filtered_metadata) | |
| batch.add_object( | |
| collection=collection_name, | |
| properties=properties, | |
| uuid=chunk.id, | |
| vector=chunk_vectors[i] | |
| ) | |
| logger.info(f"✅ Batching hoàn tất. Đã gửi {len(chunks)} objects.") | |
| if batch.number_errors > 0: | |
| logger.error(f"❌ Có {batch.number_errors} lỗi xảy ra trong quá trình batching.") | |
| # Log ra 5 lỗi đầu tiên để dễ gỡ lỗi | |
| for i, error_msg in enumerate(batch.errors): | |
| if i >= 5: break | |
| logger.error(f" - Lỗi {i+1}: {error_msg}") | |
| # End new | |
| def create_or_load_vectorstore(embeddings, weaviate_url, collection_name, weaviate_client, chunks=None): | |
| vectorstore = None | |
| if not embeddings: | |
| logger.error("🔸Không có model embedding để tạo/tải vector store.") | |
| return None | |
| logger.info(f"🔸Truy cập Weaviate tại: {weaviate_url} với collection: {collection_name}") | |
| try: | |
| # Kết nối tới Weaviate | |
| client = weaviate_client | |
| if not client: | |
| logger.error("🔸Không thể kết nối tới Weaviate.") | |
| return None | |
| # Tên collection cần kiểm tra | |
| collection_name = config.WEAVIATE_COLLECTION_NAME | |
| # Kiểm tra xem collection có tồn tại không | |
| collection_exists = client.collections.exists(collection_name) | |
| logger.info(f"Collection {collection_name} exists: {collection_exists}") | |
| if chunks is not None and not collection_exists: | |
| logger.info(f"🔸Tạo Weaviate collection mới từ {len(chunks)} chunks...") | |
| # Kiểm tra mẫu dữ liệu đầu tiên | |
| logger.info(f"🔸Chunk đầu tiên:\n{chunks[0].metadata}") | |
| logger.info(f"🔸Nội dung:\n{chunks[0].page_content[:500]}...") | |
| # Lọc metadata để đảm bảo tương thích với Weaviate | |
| chunks = filter_and_serialize_complex_metadata(chunks) | |
| logger.info(f"🔸Metadata chunk đầu tiên sau khi lọc/serialize:\n{chunks[0].metadata}") | |
| # KIỂM TRA TYPE | |
| if chunks: | |
| logger.info(f"Type của chunk đầu tiên: {type(chunks[0])}") | |
| # Kiểm tra xem có phải là langchain Document không | |
| from langchain_core.documents import Document as LangchainDocument | |
| is_langchain_doc = isinstance(chunks[0], LangchainDocument) | |
| logger.info(f"Chunk đầu tiên có phải là langchain_core.documents.Document không? {is_langchain_doc}") | |
| if not is_langchain_doc: | |
| logger.error("!!! LỖI NGHIÊM TRỌNG: Chunks không phải là instance của langchain_core.documents.Document") | |
| # In ra các attribute của object để xem nó là gì | |
| try: | |
| logger.error(f"Attributes của chunk[0]: {dir(chunks[0])}") | |
| if hasattr(chunks[0], "metadata"): | |
| logger.error(f"Metadata của chunk[0] (nếu có): {chunks[0].metadata}") | |
| if hasattr(chunks[0], "page_content"): | |
| logger.error(f"Page_content của chunk[0] (nếu có): {chunks[0].page_content[:100]}") | |
| except: | |
| pass # Bỏ qua nếu không thể dir() | |
| return None # Dừng ở đây nếu type sai | |
| # Tạo vectorstore | |
| max_batch_size = 1000 # Kích thước batch an toàn | |
| total_chunks = len(chunks) | |
| logger.info("🔸Đang nhúng dữ liệu...") | |
| # Tạo collection mới | |
| vectorstore = WeaviateVectorStore.from_documents( | |
| documents=chunks[:1], # Khởi tạo với 1 tài liệu để tạo schema | |
| embedding=embeddings, | |
| client=client, | |
| index_name=collection_name, | |
| text_key="text", # Tên trường văn bản trong tài liệu | |
| # by_texts=False # Nếu dùng ids thì không cần by_texts, nhưng để rõ ràng | |
| ) | |
| # Thêm tài liệu theo batch | |
| for i in range(1, total_chunks, max_batch_size): | |
| end_idx = min(i + max_batch_size, total_chunks) | |
| current_batch = chunks[i:end_idx] | |
| logger.info(f"🔸Đang xử lý batch {i//max_batch_size + 1}/{(total_chunks-1)//max_batch_size + 1}: từ {i} đến {end_idx-1}") | |
| try: | |
| vectorstore.add_documents(current_batch) | |
| logger.info(f"🔸Đã thêm batch {i//max_batch_size + 1} thành công") | |
| except Exception as batch_error: | |
| logger.error(f"🔸Lỗi khi xử lý batch từ {i} đến {end_idx-1}: {str(batch_error)}") | |
| # Thử với batch nhỏ hơn | |
| smaller_batch_size = max_batch_size // 2 | |
| if smaller_batch_size >= 10: | |
| logger.info(f"🔸Thử lại với batch size nhỏ hơn: {smaller_batch_size}") | |
| for j in range(i, end_idx, smaller_batch_size): | |
| end_j = min(j + smaller_batch_size, end_idx) | |
| smaller_batch = chunks[j:end_j] | |
| try: | |
| vectorstore.add_documents(smaller_batch) | |
| logger.info(f"🔸Đã thêm batch nhỏ từ {j} đến {end_j-1} thành công") | |
| except Exception as small_batch_error: | |
| logger.error(f"🔸Vẫn lỗi với batch nhỏ hơn từ {j} đến {end_j-1}: {str(small_batch_error)}") | |
| else: | |
| logger.error(f"🔸Batch size đã quá nhỏ, không thể giảm thêm. Bỏ qua batch này.") | |
| logger.info(f"🔸Tạo Weaviate collection thành công: {collection_name}") | |
| elif collection_exists: | |
| logger.info(f"🔸Tải Weaviate collection đã tồn tại: {collection_name}") | |
| vectorstore = WeaviateVectorStore( | |
| client=client, | |
| index_name=collection_name, | |
| embedding=embeddings, | |
| text_key="text", | |
| attributes=[ # Liệt kê TẤT CẢ các metadata bạn cần để retriever hoạt động | |
| "nam_ban_hanh", "title", "source", "field", "loai_van_ban", "so_hieu", | |
| "ten_van_ban", "ngay_ban_hanh_str", "co_quan_ban_hanh", "entity_type", | |
| # Các trường serialize thành JSON cũng cần được liệt kê nếu muốn lấy về | |
| "cross_references", "penalties" | |
| ] | |
| ) | |
| logger.info("🔸Tải Weaviate collection thành công.") | |
| else: | |
| logger.error(f"🔸Collection '{collection_name}' không tồn tại và không có dữ liệu chunks để tạo mới.") | |
| return None | |
| logger.info("🔸Vectorstore sẵn sàng.") | |
| return vectorstore | |
| except Exception as e: | |
| if client: | |
| client.close() | |
| logger.info("🔸Đã đóng kết nối tới Weaviate.") | |
| logger.error(f"🔸Lỗi khi tạo/tải Weaviate vector store: {e}") | |
| return None | |
| def get_google_llm(google_api_key): | |
| logger.info("🔸Đang khởi tạo LLM từ Google Generative AI...") | |
| if not google_api_key: | |
| logger.error("🔸Google API Key không được cung cấp.") | |
| return None | |
| try: | |
| def create_chat_google(): | |
| return ChatGoogleGenerativeAI( | |
| model="gemini-2.5-flash-preview-05-20", | |
| google_api_key=google_api_key, | |
| temperature=0.0, # Điều chỉnh nhiệt độ nếu cần, 0.1-0.3 thường tốt cho RAG | |
| safety_settings={ }, | |
| ) | |
| llm = create_chat_google() | |
| logger.info("🔸Khởi tạo Google Generative AI LLM thành công.") | |
| return llm | |
| except Exception as e: | |
| logger.error(f"🔸Lỗi khi khởi tạo Google Generative AI LLM: {e}") | |
| return None | |
| def create_qa_chain( | |
| llm: Any, | |
| retriever: Any, # Nhận retriever nâng cao đã được khởi tạo | |
| process_input_llm: Any = None | |
| ): | |
| """ | |
| PHIÊN BẢN CUỐI CÙNG: Tạo ra một RAG chain hoàn chỉnh, tối ưu hóa với: | |
| 1. Unified Pre-processing: Một lệnh gọi LLM để hiểu lịch sử, "dịch" thuật ngữ, và phân loại. | |
| 2. Multi-route: Định tuyến thông minh đến các nhánh xử lý chuyên biệt. | |
| 3. Advanced Retriever: Sử dụng retriever tùy chỉnh cho nhánh pháp luật. | |
| """ | |
| if not all([llm, retriever]): | |
| logger.error("🔸 Thiếu LLM hoặc Retriever chính để tạo QA Chain.") | |
| return None | |
| try: | |
| logger.info("🔸 Bắt đầu tạo QA Chain Tối ưu (phiên bản cuối cùng)...") | |
| # LLM cho bước tiền xử lý (thường là model mạnh nhất) | |
| preprocessing_llm = process_input_llm or llm | |
| # ----- PROMPTS (Sử dụng các phiên bản đã cải tiến) ----- | |
| # 1. Prompt tiền xử lý hợp nhất | |
| # Sử dụng phiên bản V5 mạnh mẽ nhất để "dịch" thuật ngữ hiệu quả | |
| unified_preprocessing_prompt = ChatPromptTemplate.from_template( | |
| prompt_templete.UNIFIED_PREPROCESSING_PROMPT | |
| ) | |
| # 2. Prompt để tạo câu trả lời RAG từ context | |
| # Sử dụng phiên bản V4 để "dạy" LLM cách phân tích và ưu tiên thông tin | |
| qa_prompt = ChatPromptTemplate.from_template( | |
| prompt_templete.QA_PROMPT_TEMPLATE | |
| ) | |
| # 3. Các prompt cho các nhánh khác | |
| persona_prompt = ChatPromptTemplate.from_messages([ | |
| ("system", prompt_templete.GENERAL_PROMPT), | |
| ("human", "{input}") | |
| ]) | |
| # ----- STEP 1: UNIFIED PREPROCESSING CHAIN ----- | |
| # Đây là bộ não xử lý đầu vào, thay thế cho 3 lệnh gọi LLM cũ | |
| unified_preprocessing_chain = ( | |
| unified_preprocessing_prompt | |
| | preprocessing_llm | |
| | JsonOutputParser() | |
| ).with_config({"run_name": "UnifiedQuestionPreprocessor"}) | |
| # ----- STEP 2: DEFINE BRANCHES (CÁC NHÁNH XỬ LÝ) ----- | |
| # --- Nhánh 1: LEGAL (RAG) --- | |
| # Sử dụng retriever nâng cao đã được truyền vào | |
| legal_chain = ( | |
| # `retriever` nhận `rewritten_question` từ dict đầu vào | |
| RunnablePassthrough.assign(context=itemgetter("rewritten_question") | retriever) | |
| # Chuẩn bị input cho qa_prompt cuối cùng | |
| .assign(input=itemgetter("rewritten_question")) | |
| | { | |
| "answer": qa_prompt | llm | StrOutputParser(), | |
| "context": itemgetter("context") # Giữ lại context để có thể hiển thị nguồn | |
| } | |
| ).with_config({"run_name": "AdvancedLegalRAGChain"}) | |
| # --- Nhánh 3: GENERAL CHAT --- | |
| general_chat_chain = ( | |
| {"input": itemgetter("rewritten_question")} | |
| | persona_prompt | |
| | llm | |
| | StrOutputParser() | |
| | (lambda answer: {"answer": answer, "context": []}) | |
| ).with_config({"run_name": "GeneralChatChain"}) | |
| # ----- STEP 3: ROUTER ----- | |
| # Định nghĩa các nhánh mà router có thể chọn | |
| branches = { | |
| "legal_rag": legal_chain, | |
| "general_chat": general_chat_chain, | |
| # Thêm nhánh legal_term_explanation ở đây nếu bạn triển khai nó | |
| } | |
| def route_branches(info: dict): | |
| """Hàm định tuyến, chọn chain phù hợp dựa trên kết quả phân loại.""" | |
| classification = info.get("classification", "general_chat") | |
| logger.info(f"Routing to branch: '{classification}'") | |
| # Chọn chain, mặc định là general_chat nếu có lỗi | |
| return branches.get(classification, general_chat_chain) | |
| # ----- STEP 4: FULL CHAIN ----- | |
| # Kết hợp thành một chuỗi xử lý duy nhất và liền mạch | |
| # Luồng: Input -> Tiền xử lý (Viết lại + Phân loại) -> Router -> Chạy nhánh được chọn | |
| full_chain = unified_preprocessing_chain | RunnableLambda(route_branches) | |
| logger.info("✅ Successfully created Final Optimized QA Chain.") | |
| return full_chain | |
| except Exception as e: | |
| logger.error(f"❌ Error creating QA Chain: {e}", exc_info=True) | |
| return None |