albhu commited on
Commit
cf4a404
·
verified ·
1 Parent(s): 0ce91d9
Files changed (1) hide show
  1. search.py +13 -3
search.py CHANGED
@@ -27,12 +27,22 @@ def read_docx(file) -> list[Paragraph]:
27
  doc = Document(file)
28
  return [Paragraph(1, i, para.text.strip()) for i, para in enumerate(doc.paragraphs, 1) if para.text.strip()]
29
 
30
- def generate_context_with_rag(question: str) -> str:
31
- inputs = rag_tokenizer(question, return_tensors="pt")
32
- output_ids = rag_model.generate(**inputs)
 
 
 
 
 
 
 
 
 
33
  context = rag_tokenizer.decode(output_ids[0], skip_special_tokens=True)
34
  return context
35
 
 
36
  def generate_answer_with_phi(question: str, context: str) -> str:
37
  enhanced_question = f"Question: {question}\nContext: {context}\nAnswer:"
38
  inputs = phi_tokenizer.encode(enhanced_question, return_tensors="pt", max_length=512, truncation=True)
 
27
  doc = Document(file)
28
  return [Paragraph(1, i, para.text.strip()) for i, para in enumerate(doc.paragraphs, 1) if para.text.strip()]
29
 
30
+ def generate_context_with_rag(question: str, documents: List[str]) -> str:
31
+ combined_text = " ".join(documents)
32
+ if not combined_text.strip(): # Ensure combined_text is not empty
33
+ return "No context available."
34
+
35
+ inputs = rag_tokenizer(question + " " + combined_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
36
+
37
+ # Ensure inputs are correctly prepared
38
+ if "input_ids" not in inputs or "attention_mask" not in inputs:
39
+ return "Invalid input for model."
40
+
41
+ output_ids = rag_model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
42
  context = rag_tokenizer.decode(output_ids[0], skip_special_tokens=True)
43
  return context
44
 
45
+
46
  def generate_answer_with_phi(question: str, context: str) -> str:
47
  enhanced_question = f"Question: {question}\nContext: {context}\nAnswer:"
48
  inputs = phi_tokenizer.encode(enhanced_question, return_tensors="pt", max_length=512, truncation=True)