RAG fix
Browse files
search.py
CHANGED
@@ -27,12 +27,22 @@ def read_docx(file) -> list[Paragraph]:
|
|
27 |
doc = Document(file)
|
28 |
return [Paragraph(1, i, para.text.strip()) for i, para in enumerate(doc.paragraphs, 1) if para.text.strip()]
|
29 |
|
30 |
-
def generate_context_with_rag(question: str) -> str:
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
context = rag_tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
34 |
return context
|
35 |
|
|
|
36 |
def generate_answer_with_phi(question: str, context: str) -> str:
|
37 |
enhanced_question = f"Question: {question}\nContext: {context}\nAnswer:"
|
38 |
inputs = phi_tokenizer.encode(enhanced_question, return_tensors="pt", max_length=512, truncation=True)
|
|
|
27 |
doc = Document(file)
|
28 |
return [Paragraph(1, i, para.text.strip()) for i, para in enumerate(doc.paragraphs, 1) if para.text.strip()]
|
29 |
|
30 |
+
def generate_context_with_rag(question: str, documents: List[str]) -> str:
|
31 |
+
combined_text = " ".join(documents)
|
32 |
+
if not combined_text.strip(): # Ensure combined_text is not empty
|
33 |
+
return "No context available."
|
34 |
+
|
35 |
+
inputs = rag_tokenizer(question + " " + combined_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
36 |
+
|
37 |
+
# Ensure inputs are correctly prepared
|
38 |
+
if "input_ids" not in inputs or "attention_mask" not in inputs:
|
39 |
+
return "Invalid input for model."
|
40 |
+
|
41 |
+
output_ids = rag_model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
|
42 |
context = rag_tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
43 |
return context
|
44 |
|
45 |
+
|
46 |
def generate_answer_with_phi(question: str, context: str) -> str:
|
47 |
enhanced_question = f"Question: {question}\nContext: {context}\nAnswer:"
|
48 |
inputs = phi_tokenizer.encode(enhanced_question, return_tensors="pt", max_length=512, truncation=True)
|