albhu commited on
Commit
229e240
·
verified ·
1 Parent(s): cf4a404

Update search.py

Browse files
Files changed (1) hide show
  1. search.py +15 -41
search.py CHANGED
@@ -1,16 +1,12 @@
1
- from transformers import RagTokenizer, RagTokenForGeneration, AutoTokenizer, AutoModelForCausalLM, pipeline
 
2
  from pdfminer.high_level import extract_text
3
  from docx import Document
4
  from dataclasses import dataclass
5
- import pandas as pd
6
 
7
- # Initialize RAG
8
- rag_tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
9
- rag_model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq")
10
-
11
- # Initialize Phi-2
12
- phi_tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
13
- phi_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True)
14
 
15
  @dataclass
16
  class Paragraph:
@@ -18,41 +14,19 @@ class Paragraph:
18
  paragraph_num: int
19
  content: str
20
 
21
- def read_pdf_pdfminer(file_path) -> list[Paragraph]:
22
  text = extract_text(file_path).replace('\n', ' ').strip()
23
- paragraphs = text.split(". ")
24
- return [Paragraph(0, i, para) for i, para in enumerate(paragraphs, 1)]
25
 
26
- def read_docx(file) -> list[Paragraph]:
27
  doc = Document(file)
28
- return [Paragraph(1, i, para.text.strip()) for i, para in enumerate(doc.paragraphs, 1) if para.text.strip()]
29
-
30
- def generate_context_with_rag(question: str, documents: List[str]) -> str:
31
- combined_text = " ".join(documents)
32
- if not combined_text.strip(): # Ensure combined_text is not empty
33
- return "No context available."
34
-
35
- inputs = rag_tokenizer(question + " " + combined_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
36
-
37
- # Ensure inputs are correctly prepared
38
- if "input_ids" not in inputs or "attention_mask" not in inputs:
39
- return "Invalid input for model."
40
-
41
- output_ids = rag_model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
42
- context = rag_tokenizer.decode(output_ids[0], skip_special_tokens=True)
43
- return context
44
-
45
-
46
- def generate_answer_with_phi(question: str, context: str) -> str:
47
- enhanced_question = f"Question: {question}\nContext: {context}\nAnswer:"
48
- inputs = phi_tokenizer.encode(enhanced_question, return_tensors="pt", max_length=512, truncation=True)
49
- outputs = phi_model.generate(inputs, max_length=600)
50
- answer = phi_tokenizer.decode(outputs[0], skip_special_tokens=True)
51
- return answer
52
 
53
  def answer_question(question: str, documents_df: pd.DataFrame) -> str:
54
- # Assuming documents_df contains the text from uploaded files
55
- combined_text = " ".join(documents_df['content'].tolist())
56
- context = generate_context_with_rag(combined_text + " " + question)
57
- answer = generate_answer_with_phi(question, context)
 
58
  return answer
 
1
+ from transformers import RagTokenizer, RagTokenForGeneration, pipeline
2
+ import pandas as pd
3
  from pdfminer.high_level import extract_text
4
  from docx import Document
5
  from dataclasses import dataclass
 
6
 
7
+ # RAG setup
8
+ tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
9
+ model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq")
 
 
 
 
10
 
11
  @dataclass
12
  class Paragraph:
 
14
  paragraph_num: int
15
  content: str
16
 
17
+ def read_pdf_pdfminer(file_path) -> list:
18
  text = extract_text(file_path).replace('\n', ' ').strip()
19
+ paragraphs = [Paragraph(0, i, para) for i, para in enumerate(text.split('. '), start=1)]
20
+ return paragraphs
21
 
22
+ def read_docx(file) -> list:
23
  doc = Document(file)
24
+ return [Paragraph(1, i, para.text.strip()) for i, para in enumerate(doc.paragraphs, start=1) if para.text.strip()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def answer_question(question: str, documents_df: pd.DataFrame) -> str:
27
+ document_texts = " ".join(documents_df['content'].tolist())
28
+ context = f"{question} {document_texts}"
29
+ inputs = tokenizer(context, return_tensors="pt", truncation=True, max_length=512, padding="max_length")
30
+ output_ids = model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
31
+ answer = tokenizer.decode(output_ids[0], skip_special_tokens=True)
32
  return answer