Update search.py
Browse files
search.py
CHANGED
@@ -1,16 +1,12 @@
|
|
1 |
-
from transformers import RagTokenizer, RagTokenForGeneration,
|
|
|
2 |
from pdfminer.high_level import extract_text
|
3 |
from docx import Document
|
4 |
from dataclasses import dataclass
|
5 |
-
import pandas as pd
|
6 |
|
7 |
-
#
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
# Initialize Phi-2
|
12 |
-
phi_tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
|
13 |
-
phi_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True)
|
14 |
|
15 |
@dataclass
|
16 |
class Paragraph:
|
@@ -18,41 +14,19 @@ class Paragraph:
|
|
18 |
paragraph_num: int
|
19 |
content: str
|
20 |
|
21 |
-
def read_pdf_pdfminer(file_path) -> list
|
22 |
text = extract_text(file_path).replace('\n', ' ').strip()
|
23 |
-
paragraphs = text.split(
|
24 |
-
return
|
25 |
|
26 |
-
def read_docx(file) -> list
|
27 |
doc = Document(file)
|
28 |
-
return [Paragraph(1, i, para.text.strip()) for i, para in enumerate(doc.paragraphs, 1) if para.text.strip()]
|
29 |
-
|
30 |
-
def generate_context_with_rag(question: str, documents: List[str]) -> str:
|
31 |
-
combined_text = " ".join(documents)
|
32 |
-
if not combined_text.strip(): # Ensure combined_text is not empty
|
33 |
-
return "No context available."
|
34 |
-
|
35 |
-
inputs = rag_tokenizer(question + " " + combined_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
36 |
-
|
37 |
-
# Ensure inputs are correctly prepared
|
38 |
-
if "input_ids" not in inputs or "attention_mask" not in inputs:
|
39 |
-
return "Invalid input for model."
|
40 |
-
|
41 |
-
output_ids = rag_model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
|
42 |
-
context = rag_tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
43 |
-
return context
|
44 |
-
|
45 |
-
|
46 |
-
def generate_answer_with_phi(question: str, context: str) -> str:
|
47 |
-
enhanced_question = f"Question: {question}\nContext: {context}\nAnswer:"
|
48 |
-
inputs = phi_tokenizer.encode(enhanced_question, return_tensors="pt", max_length=512, truncation=True)
|
49 |
-
outputs = phi_model.generate(inputs, max_length=600)
|
50 |
-
answer = phi_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
51 |
-
return answer
|
52 |
|
53 |
def answer_question(question: str, documents_df: pd.DataFrame) -> str:
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
58 |
return answer
|
|
|
1 |
+
from transformers import RagTokenizer, RagTokenForGeneration, pipeline
|
2 |
+
import pandas as pd
|
3 |
from pdfminer.high_level import extract_text
|
4 |
from docx import Document
|
5 |
from dataclasses import dataclass
|
|
|
6 |
|
7 |
+
# RAG setup
|
8 |
+
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
|
9 |
+
model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq")
|
|
|
|
|
|
|
|
|
10 |
|
11 |
@dataclass
|
12 |
class Paragraph:
|
|
|
14 |
paragraph_num: int
|
15 |
content: str
|
16 |
|
17 |
+
def read_pdf_pdfminer(file_path) -> list:
|
18 |
text = extract_text(file_path).replace('\n', ' ').strip()
|
19 |
+
paragraphs = [Paragraph(0, i, para) for i, para in enumerate(text.split('. '), start=1)]
|
20 |
+
return paragraphs
|
21 |
|
22 |
+
def read_docx(file) -> list:
|
23 |
doc = Document(file)
|
24 |
+
return [Paragraph(1, i, para.text.strip()) for i, para in enumerate(doc.paragraphs, start=1) if para.text.strip()]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
def answer_question(question: str, documents_df: pd.DataFrame) -> str:
|
27 |
+
document_texts = " ".join(documents_df['content'].tolist())
|
28 |
+
context = f"{question} {document_texts}"
|
29 |
+
inputs = tokenizer(context, return_tensors="pt", truncation=True, max_length=512, padding="max_length")
|
30 |
+
output_ids = model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
|
31 |
+
answer = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
32 |
return answer
|