Update search.py
Browse files
search.py
CHANGED
@@ -4,13 +4,21 @@ from pdfminer.high_level import extract_text
|
|
4 |
from typing import List
|
5 |
import pandas as pd
|
6 |
import re
|
|
|
7 |
|
8 |
# Initialize RAG components
|
9 |
rag_tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
|
10 |
-
rag_retriever = RagRetriever.from_pretrained("facebook/rag-token-base", trust_remote_code=True)
|
11 |
rag_token_for_generation = RagTokenForGeneration.from_pretrained("facebook/rag-token-base")
|
12 |
rag_config = RagConfig.from_pretrained("facebook/rag-token-base")
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
# Dataclass for paragraph
|
15 |
@dataclass
|
16 |
class Paragraph:
|
@@ -114,13 +122,10 @@ def answer_query_with_context(question: str, df: pd.DataFrame, tokenizer, retrie
|
|
114 |
|
115 |
context = "\n\n".join(most_relevant_paragraphs)
|
116 |
|
117 |
-
# Retrieve documents relevant to the question
|
118 |
-
documents = retriever.retrieve(question)
|
119 |
-
|
120 |
# Generate answer using RAG
|
121 |
-
inputs =
|
122 |
-
outputs =
|
123 |
-
answer =
|
124 |
|
125 |
references = extract_page_and_clause_references(context)
|
126 |
answer = refine_answer_based_on_question(question, answer) + " " + references
|
|
|
4 |
from typing import List
|
5 |
import pandas as pd
|
6 |
import re
|
7 |
+
from datasets import load_dataset
|
8 |
|
9 |
# Initialize RAG components
|
10 |
rag_tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
|
|
|
11 |
rag_token_for_generation = RagTokenForGeneration.from_pretrained("facebook/rag-token-base")
|
12 |
rag_config = RagConfig.from_pretrained("facebook/rag-token-base")
|
13 |
|
14 |
+
# Download and prepare the wiki_dpr dataset
|
15 |
+
dpr_dataset = load_dataset("wiki_dpr")
|
16 |
+
passages = dpr_dataset["train"]["passage"]
|
17 |
+
titles = dpr_dataset["train"]["title"]
|
18 |
+
|
19 |
+
# Initialize the RagRetriever
|
20 |
+
rag_retriever = RagRetriever(passages=passages, titles=titles, config=rag_config)
|
21 |
+
|
22 |
# Dataclass for paragraph
|
23 |
@dataclass
|
24 |
class Paragraph:
|
|
|
122 |
|
123 |
context = "\n\n".join(most_relevant_paragraphs)
|
124 |
|
|
|
|
|
|
|
125 |
# Generate answer using RAG
|
126 |
+
inputs = rag_tokenizer(question, context, return_tensors="pt", max_length=512, truncation=True)
|
127 |
+
outputs = rag_token_for_generation.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=200, num_return_sequences=1)
|
128 |
+
answer = rag_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
129 |
|
130 |
references = extract_page_and_clause_references(context)
|
131 |
answer = refine_answer_based_on_question(question, answer) + " " + references
|