chatbot_demo / rag_pipeline.py
deddoggo's picture
update
0002c1d
raw
history blame
6.1 kB
# file: rag_pipeline.py
import torch
import json
import faiss
import numpy as np
import re
from unsloth import FastLanguageModel
from sentence_transformers import SentenceTransformer
from rank_bm25 import BM25Okapi
from transformers import TextStreamer
# Import các hàm từ file khác
from data_processor import process_law_data_to_chunks
from retriever import search_relevant_laws, tokenize_vi_for_bm25_setup
def initialize_components(data_path):
"""
Khởi tạo và tải tất cả các thành phần cần thiết cho RAG pipeline.
Hàm này chỉ nên được gọi một lần khi ứng dụng khởi động.
"""
print("--- Bắt đầu khởi tạo các thành phần ---")
# 1. Tải LLM và Tokenizer (Processor) từ Unsloth
print("1. Tải mô hình LLM (Unsloth)...")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit",
max_seq_length=4096, # Có thể tăng cho các mô hình mới
dtype=None,
load_in_4bit=True,
)
print("✅ Tải LLM và Tokenizer thành công.")
# 2. Tải mô hình Embedding
print("2. Tải mô hình Embedding...")
embedding_model = SentenceTransformer(
"bkai-foundation-models/vietnamese-bi-encoder",
device="cuda" if torch.cuda.is_available() else "cpu"
)
print("✅ Tải mô hình Embedding thành công.")
# 3. Tải và xử lý dữ liệu JSON
print(f"3. Tải và xử lý dữ liệu từ {data_path}...")
with open(data_path, 'r', encoding='utf-8') as f:
raw_data = json.load(f)
chunks_data = process_law_data_to_chunks(raw_data)
print(f"✅ Xử lý dữ liệu thành công, có {len(chunks_data)} chunks.")
# 4. Tạo Embeddings và FAISS Index
print("4. Tạo embeddings và FAISS index...")
texts_to_encode = [chunk.get('text', '') for chunk in chunks_data]
chunk_embeddings_tensor = embedding_model.encode(
texts_to_encode,
convert_to_tensor=True,
device=embedding_model.device
)
chunk_embeddings_np = chunk_embeddings_tensor.cpu().numpy().astype('float32')
faiss.normalize_L2(chunk_embeddings_np)
dimension = chunk_embeddings_np.shape[1]
faiss_index = faiss.IndexFlatIP(dimension)
faiss_index.add(chunk_embeddings_np)
print(f"✅ Tạo FAISS index thành công với {faiss_index.ntotal} vector.")
# 5. Tạo BM25 Model
print("5. Tạo mô hình BM25...")
corpus_texts_for_bm25 = [chunk.get('text', '') for chunk in chunks_data]
tokenized_corpus_bm25 = [tokenize_vi_for_bm25_setup(text) for text in corpus_texts_for_bm25]
bm25_model = BM25Okapi(tokenized_corpus_bm25)
print("✅ Tạo mô hình BM25 thành công.")
print("--- ✅ Khởi tạo tất cả thành phần hoàn tất ---")
return {
"llm_model": model,
"tokenizer": tokenizer,
"embedding_model": embedding_model,
"chunks_data": chunks_data,
"faiss_index": faiss_index,
"bm25_model": bm25_model
}
def generate_response(query, components):
"""
Tạo câu trả lời cho một query (single-turn) bằng cách sử dụng các thành phần đã được khởi tạo.
"""
print("--- Bắt đầu quy trình RAG (Single-turn) cho query mới ---")
# Unpack các thành phần
llm_model = components["llm_model"]
tokenizer = components["tokenizer"]
# 1. Truy xuất ngữ cảnh trực tiếp từ câu hỏi của người dùng
retrieved_results = search_relevant_laws(
query_text=query,
embedding_model=components["embedding_model"],
faiss_index=components["faiss_index"],
chunks_data=components["chunks_data"],
bm25_model=components["bm25_model"],
k=5,
initial_k_multiplier=18
)
# 2. Định dạng Context
if not retrieved_results:
context = "Không tìm thấy thông tin luật liên quan trong cơ sở dữ liệu."
else:
context_parts = []
for i, res in enumerate(retrieved_results):
metadata = res.get('metadata', {})
header = f"Trích dẫn {i+1}: Điều {metadata.get('article', 'N/A')}, Khoản {metadata.get('clause_number', 'N/A')} (Nguồn: {metadata.get('source_document', 'N/A')})"
text = res.get('text', '*Nội dung không có*')
context_parts.append(f"{header}\n{text}")
context = "\n\n---\n\n".join(context_parts)
# 3. Xây dựng Prompt đơn giản (không có lịch sử trò chuyện)
prompt = f"""Bạn là trợ lý pháp luật chuyên trả lời các câu hỏi liên quan đến luật giao thông đường bộ Việt Nam.
Dựa trên các đoạn luật dưới đây:
{context}
Hãy trả lời câu hỏi của người dùng bằng tiếng Việt, chính xác và dễ hiểu. Nếu cần, hãy trích dẫn điều, khoản hoặc điểm tương ứng trong văn bản luật. Nếu không đủ thông tin trong các đoạn trên, hãy trả lời “Tôi không chắc, cần kiểm tra thêm văn bản luật liên quan.”
Câu hỏi: {query}
"""
print("--- Bắt đầu tạo câu trả lời từ LLM ---")
# SỬA LỖI CHO VISION MODEL: Sử dụng API tường minh
inputs = tokenizer(
text=prompt,
images=None,
return_tensors="pt"
).to("cuda" if torch.cuda.is_available() else "cpu")
generation_config = dict(
max_new_tokens=256,
temperature=0.5,
top_p=0.7,
top_k=50,
repetition_penalty=1.1,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
output_ids = llm_model.generate(**inputs, **generation_config)
input_length = inputs.input_ids.shape[1]
generated_ids = output_ids[0][input_length:]
response_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
print("--- Tạo câu trả lời hoàn tất ---")
return response_text