albhu commited on
Commit
e22de9e
·
verified ·
1 Parent(s): c5ed69f

Update search.py

Browse files
Files changed (1) hide show
  1. search.py +12 -7
search.py CHANGED
@@ -4,13 +4,21 @@ from pdfminer.high_level import extract_text
4
  from typing import List
5
  import pandas as pd
6
  import re
 
7
 
8
  # Initialize RAG components
9
  rag_tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
10
- rag_retriever = RagRetriever.from_pretrained("facebook/rag-token-base", trust_remote_code=True)
11
  rag_token_for_generation = RagTokenForGeneration.from_pretrained("facebook/rag-token-base")
12
  rag_config = RagConfig.from_pretrained("facebook/rag-token-base")
13
 
 
 
 
 
 
 
 
 
14
  # Dataclass for paragraph
15
  @dataclass
16
  class Paragraph:
@@ -114,13 +122,10 @@ def answer_query_with_context(question: str, df: pd.DataFrame, tokenizer, retrie
114
 
115
  context = "\n\n".join(most_relevant_paragraphs)
116
 
117
- # Retrieve documents relevant to the question
118
- documents = retriever.retrieve(question)
119
-
120
  # Generate answer using RAG
121
- inputs = tokenizer(question, context, return_tensors="pt", max_length=512, truncation=True)
122
- outputs = generator.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=200, num_return_sequences=1)
123
- answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
124
 
125
  references = extract_page_and_clause_references(context)
126
  answer = refine_answer_based_on_question(question, answer) + " " + references
 
4
  from typing import List
5
  import pandas as pd
6
  import re
7
+ from datasets import load_dataset
8
 
9
  # Initialize RAG components
10
  rag_tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
 
11
  rag_token_for_generation = RagTokenForGeneration.from_pretrained("facebook/rag-token-base")
12
  rag_config = RagConfig.from_pretrained("facebook/rag-token-base")
13
 
14
+ # Download and prepare the wiki_dpr dataset
15
+ dpr_dataset = load_dataset("wiki_dpr")
16
+ passages = dpr_dataset["train"]["passage"]
17
+ titles = dpr_dataset["train"]["title"]
18
+
19
+ # Initialize the RagRetriever
20
+ rag_retriever = RagRetriever(passages=passages, titles=titles, config=rag_config)
21
+
22
  # Dataclass for paragraph
23
  @dataclass
24
  class Paragraph:
 
122
 
123
  context = "\n\n".join(most_relevant_paragraphs)
124
 
 
 
 
125
  # Generate answer using RAG
126
+ inputs = rag_tokenizer(question, context, return_tensors="pt", max_length=512, truncation=True)
127
+ outputs = rag_token_for_generation.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=200, num_return_sequences=1)
128
+ answer = rag_tokenizer.decode(outputs[0], skip_special_tokens=True)
129
 
130
  references = extract_page_and_clause_references(context)
131
  answer = refine_answer_based_on_question(question, answer) + " " + references