Spaces:
Paused
Paused
# file: rag_pipeline.py | |
import torch | |
import json | |
import faiss | |
import numpy as np | |
import re | |
from unsloth import FastLanguageModel | |
from sentence_transformers import SentenceTransformer | |
from rank_bm25 import BM25Okapi | |
from transformers import TextStreamer | |
# Import các hàm từ file khác | |
from data_processor import process_law_data_to_chunks | |
from retriever import search_relevant_laws, tokenize_vi_for_bm25_setup | |
def initialize_components(data_path): | |
""" | |
Khởi tạo và tải tất cả các thành phần cần thiết cho RAG pipeline. | |
Hàm này chỉ nên được gọi một lần khi ứng dụng khởi động. | |
""" | |
print("--- Bắt đầu khởi tạo các thành phần ---") | |
# 1. Tải LLM và Tokenizer (Processor) từ Unsloth | |
print("1. Tải mô hình LLM (Unsloth)...") | |
model, tokenizer = FastLanguageModel.from_pretrained( | |
model_name="unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit", | |
max_seq_length=4096, # Có thể tăng cho các mô hình mới | |
dtype=None, | |
load_in_4bit=True, | |
) | |
print("✅ Tải LLM và Tokenizer thành công.") | |
# 2. Tải mô hình Embedding | |
print("2. Tải mô hình Embedding...") | |
embedding_model = SentenceTransformer( | |
"bkai-foundation-models/vietnamese-bi-encoder", | |
device="cuda" if torch.cuda.is_available() else "cpu" | |
) | |
print("✅ Tải mô hình Embedding thành công.") | |
# 3. Tải và xử lý dữ liệu JSON | |
print(f"3. Tải và xử lý dữ liệu từ {data_path}...") | |
with open(data_path, 'r', encoding='utf-8') as f: | |
raw_data = json.load(f) | |
chunks_data = process_law_data_to_chunks(raw_data) | |
print(f"✅ Xử lý dữ liệu thành công, có {len(chunks_data)} chunks.") | |
# 4. Tạo Embeddings và FAISS Index | |
print("4. Tạo embeddings và FAISS index...") | |
texts_to_encode = [chunk.get('text', '') for chunk in chunks_data] | |
chunk_embeddings_tensor = embedding_model.encode( | |
texts_to_encode, | |
convert_to_tensor=True, | |
device=embedding_model.device | |
) | |
chunk_embeddings_np = chunk_embeddings_tensor.cpu().numpy().astype('float32') | |
faiss.normalize_L2(chunk_embeddings_np) | |
dimension = chunk_embeddings_np.shape[1] | |
faiss_index = faiss.IndexFlatIP(dimension) | |
faiss_index.add(chunk_embeddings_np) | |
print(f"✅ Tạo FAISS index thành công với {faiss_index.ntotal} vector.") | |
# 5. Tạo BM25 Model | |
print("5. Tạo mô hình BM25...") | |
corpus_texts_for_bm25 = [chunk.get('text', '') for chunk in chunks_data] | |
tokenized_corpus_bm25 = [tokenize_vi_for_bm25_setup(text) for text in corpus_texts_for_bm25] | |
bm25_model = BM25Okapi(tokenized_corpus_bm25) | |
print("✅ Tạo mô hình BM25 thành công.") | |
print("--- ✅ Khởi tạo tất cả thành phần hoàn tất ---") | |
return { | |
"llm_model": model, | |
"tokenizer": tokenizer, | |
"embedding_model": embedding_model, | |
"chunks_data": chunks_data, | |
"faiss_index": faiss_index, | |
"bm25_model": bm25_model | |
} | |
def _format_context_with_summary(retrieved_results: list[dict]) -> str: | |
""" | |
Hàm phụ trợ: Định dạng ngữ cảnh từ kết quả truy xuất, bổ sung tóm tắt từ metadata. | |
Hàm này được tách ra để làm cho code sạch sẽ và dễ quản lý hơn. | |
""" | |
if not retrieved_results: | |
return "Không tìm thấy thông tin luật liên quan trong cơ sở dữ liệu." | |
context_parts = [] | |
for i, res in enumerate(retrieved_results): | |
metadata = res.get('metadata', {}) | |
text = res.get('text', '*Nội dung không có*') | |
# Tạo header rõ ràng | |
header = f"Trích dẫn {i+1}: Điều {metadata.get('article', 'N/A')}, Điểm {metadata.get('point_id', '')} Khoản {metadata.get('clause_number', 'N/A')} (Nguồn: {metadata.get('source_document', 'N/A')})" | |
# --- LOGIC TÓM TẮT THÔNG MINH TỪ METADATA --- | |
metadata_summary = "" | |
penalty_details_list = metadata.get("penalties_detail", []) | |
if penalty_details_list: | |
summary_parts = [] | |
# Chỉ lấy thông tin từ mục hình phạt đầu tiên trong danh sách | |
details = penalty_details_list[0].get('details', {}) | |
# Tóm tắt mức phạt tiền cho cá nhân (phổ biến nhất) | |
i_min = details.get("individual_fine_min") | |
i_max = details.get("individual_fine_max") | |
if i_min is not None and i_max is not None: | |
summary_parts.append(f"Phạt tiền cá nhân từ {i_min:,} - {i_max:,} đồng.") | |
# Tóm tắt mức trừ điểm | |
points = details.get("points_deducted") | |
if points is not None: | |
summary_parts.append(f"Trừ {points} điểm GPLX.") | |
if summary_parts: | |
# Chèn dòng tóm tắt vào giữa header và text | |
metadata_summary = f"\n[Tóm tắt từ metadata: {' '.join(summary_parts)}]" | |
context_parts.append(f"{header}{metadata_summary}\n{text}") | |
return "\n\n---\n\n".join(context_parts) | |
def generate_response(query: str, components: dict) -> str: | |
""" | |
Tạo câu trả lời (single-turn) bằng cách sử dụng các thành phần đã được khởi tạo. | |
Phiên bản đã được tối ưu và tái cấu trúc. | |
""" | |
print("--- Bắt đầu quy trình RAG cho query mới ---") | |
# 1. Truy xuất ngữ cảnh bằng retriever đã được nâng cấp | |
# (Giả định search_relevant_laws đã được sửa để ưu tiên loại xe) | |
retrieved_results = search_relevant_laws( | |
query_text=query, | |
embedding_model=components["embedding_model"], | |
faiss_index=components["faiss_index"], | |
chunks_data=components["chunks_data"], | |
bm25_model=components["bm25_model"], | |
k=5, | |
initial_k_multiplier=15 | |
) | |
# 2. Định dạng Context một cách thông minh bằng hàm phụ trợ | |
context = _format_context_with_summary(retrieved_results) | |
# 3. Xây dựng Prompt | |
prompt = f"""Bạn là một trợ lý pháp luật chuyên trả lời các câu hỏi về luật giao thông Việt Nam. Dựa vào các trích dẫn luật dưới đây để trả lời câu hỏi của người dùng một cách chính xác. | |
### Thông tin luật: | |
{context} | |
### Câu hỏi: | |
{query} | |
### Trả lời:""" | |
# 4. Tạo câu trả lời từ LLM | |
llm_model = components["llm_model"] | |
tokenizer = components["tokenizer"] | |
# Chuyển input lên cùng device với model | |
inputs = tokenizer(prompt, return_tensors="pt").to(llm_model.device) | |
# Cấu hình generation tối ưu cho việc trả lời câu hỏi dựa trên facts | |
generation_config = dict( | |
max_new_tokens=256, | |
temperature=0.1, # Rất thấp để câu trả lời bám sát ngữ cảnh | |
repetition_penalty=1.1, # Phạt nhẹ việc lặp từ | |
do_sample=True, # Vẫn cần bật để temperature và các tham số khác có hiệu lực | |
pad_token_id=tokenizer.eos_token_id | |
) | |
output_ids = llm_model.generate(**inputs, **generation_config) | |
# Chỉ decode phần văn bản được sinh ra mới, bỏ qua phần prompt | |
response_text = tokenizer.decode(output_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
print("--- Tạo câu trả lời hoàn tất ---") | |
return response_text |