albhu commited on
Commit
daa1093
·
verified ·
1 Parent(s): 95fca27
Files changed (1) hide show
  1. search.py +31 -21
search.py CHANGED
@@ -1,32 +1,24 @@
1
- from transformers import AutoTokenizer, AutoModelForCausalLM
2
  from docx import Document
3
  from pdfminer.high_level import extract_text
4
- from transformers import GPT2Tokenizer
5
- from dataclasses import dataclass
6
  from typing import List
7
- from tqdm import tqdm
8
- import os
9
  import pandas as pd
10
  import re
11
- from sklearn.feature_extraction.text import TfidfVectorizer
12
- import numpy as np
13
 
14
- tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
15
- model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True)
16
-
17
- EMBEDDING_SEG_LEN = 1500
18
- EMBEDDING_MODEL = "gpt-4"
19
-
20
- EMBEDDING_CTX_LENGTH = 8191
21
- EMBEDDING_ENCODING = "cl100k_base"
22
- ENCODING = "gpt2"
23
 
 
24
  @dataclass
25
  class Paragraph:
26
  page_num: int
27
  paragraph_num: int
28
  content: str
29
 
 
30
  def read_pdf_pdfminer(file_path) -> List[Paragraph]:
31
  text = extract_text(file_path).replace('\n', ' ').strip()
32
  paragraphs = batched(text, EMBEDDING_SEG_LEN)
@@ -38,6 +30,7 @@ def read_pdf_pdfminer(file_path) -> List[Paragraph]:
38
  paragraph_num += 1
39
  return paragraphs_objs
40
 
 
41
  def read_docx(file) -> List[Paragraph]:
42
  doc = Document(file)
43
  paragraphs = []
@@ -48,14 +41,17 @@ def read_docx(file) -> List[Paragraph]:
48
  paragraphs.append(para)
49
  return paragraphs
50
 
 
51
  def count_tokens(text, tokenizer):
52
  return len(tokenizer.encode(text))
53
 
 
54
  def batched(iterable, n):
55
  l = len(iterable)
56
  for ndx in range(0, l, n):
57
  yield iterable[ndx : min(ndx + n, l)]
58
 
 
59
  def compute_doc_embeddings(df, tokenizer):
60
  embeddings = {}
61
  for index, row in tqdm(df.iterrows(), total=df.shape[0]):
@@ -64,6 +60,7 @@ def compute_doc_embeddings(df, tokenizer):
64
  embeddings[index] = doc_embedding
65
  return embeddings
66
 
 
67
  def enhanced_context_extraction(document, keywords, vectorizer, tfidf_scores, top_n=5):
68
  paragraphs = [para for para in document.split("\n") if para]
69
  scores = [sum([para.lower().count(keyword) * tfidf_scores[vectorizer.vocabulary_[keyword]] for keyword in keywords if keyword in para.lower()]) for para in paragraphs]
@@ -73,6 +70,7 @@ def enhanced_context_extraction(document, keywords, vectorizer, tfidf_scores, to
73
 
74
  return " ".join(relevant_paragraphs)
75
 
 
76
  def targeted_context_extraction(document, keywords, vectorizer, tfidf_scores, top_n=5):
77
  paragraphs = [para for para in document.split("\n") if para]
78
  scores = [sum([para.lower().count(keyword) * tfidf_scores[vectorizer.vocabulary_[keyword]] for keyword in keywords]) for para in paragraphs]
@@ -82,7 +80,7 @@ def targeted_context_extraction(document, keywords, vectorizer, tfidf_scores, to
82
 
83
  return " ".join(relevant_paragraphs)
84
 
85
-
86
  def extract_page_and_clause_references(paragraph: str) -> str:
87
  page_matches = re.findall(r'Page (\d+)', paragraph)
88
  clause_matches = re.findall(r'Clause (\d+\.\d+)', paragraph)
@@ -92,6 +90,7 @@ def extract_page_and_clause_references(paragraph: str) -> str:
92
 
93
  return f"({page_ref}, {clause_ref})".strip(", ")
94
 
 
95
  def refine_answer_based_on_question(question: str, answer: str) -> str:
96
  if "Does the agreement contain" in question:
97
  if "not" in answer or "No" in answer:
@@ -103,7 +102,8 @@ def refine_answer_based_on_question(question: str, answer: str) -> str:
103
 
104
  return refined_answer
105
 
106
- def answer_query_with_context(question: str, df: pd.DataFrame, tokenizer, model, top_n_paragraphs: int = 5) -> str:
 
107
  question_words = set(question.split())
108
 
109
  priority_keywords = ["duration", "term", "period", "month", "year", "day", "week", "agreement", "obligation", "effective date"]
@@ -113,10 +113,13 @@ def answer_query_with_context(question: str, df: pd.DataFrame, tokenizer, model,
113
  most_relevant_paragraphs = df.sort_values(by='relevance_score', ascending=False).iloc[:top_n_paragraphs]['content'].tolist()
114
 
115
  context = "\n\n".join(most_relevant_paragraphs)
116
- prompt = f"Question: {question}\n\nContext: {context}\n\nAnswer:"
117
 
118
- inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True)
119
- outputs = model.generate(inputs, max_length=200)
 
 
 
 
120
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
121
 
122
  references = extract_page_and_clause_references(context)
@@ -124,6 +127,7 @@ def answer_query_with_context(question: str, df: pd.DataFrame, tokenizer, model,
124
 
125
  return answer
126
 
 
127
  def get_embedding(text, tokenizer):
128
  try:
129
  inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
@@ -133,3 +137,9 @@ def get_embedding(text, tokenizer):
133
  print("Error obtaining embedding:", e)
134
  embedding = []
135
  return embedding
 
 
 
 
 
 
 
1
+ from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration, RagConfig
2
  from docx import Document
3
  from pdfminer.high_level import extract_text
 
 
4
  from typing import List
 
 
5
  import pandas as pd
6
  import re
 
 
7
 
8
+ # Initialize RAG components
9
+ rag_tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
10
+ rag_retriever = RagRetriever.from_pretrained("facebook/rag-token-base")
11
+ rag_token_for_generation = RagTokenForGeneration.from_pretrained("facebook/rag-token-base")
12
+ rag_config = RagConfig.from_pretrained("facebook/rag-token-base")
 
 
 
 
13
 
14
+ # Dataclass for paragraph
15
  @dataclass
16
  class Paragraph:
17
  page_num: int
18
  paragraph_num: int
19
  content: str
20
 
21
+ # Read PDF using pdfminer
22
  def read_pdf_pdfminer(file_path) -> List[Paragraph]:
23
  text = extract_text(file_path).replace('\n', ' ').strip()
24
  paragraphs = batched(text, EMBEDDING_SEG_LEN)
 
30
  paragraph_num += 1
31
  return paragraphs_objs
32
 
33
+ # Read DOCX file
34
  def read_docx(file) -> List[Paragraph]:
35
  doc = Document(file)
36
  paragraphs = []
 
41
  paragraphs.append(para)
42
  return paragraphs
43
 
44
+ # Count tokens
45
  def count_tokens(text, tokenizer):
46
  return len(tokenizer.encode(text))
47
 
48
+ # Batched processing
49
  def batched(iterable, n):
50
  l = len(iterable)
51
  for ndx in range(0, l, n):
52
  yield iterable[ndx : min(ndx + n, l)]
53
 
54
+ # Compute document embeddings
55
  def compute_doc_embeddings(df, tokenizer):
56
  embeddings = {}
57
  for index, row in tqdm(df.iterrows(), total=df.shape[0]):
 
60
  embeddings[index] = doc_embedding
61
  return embeddings
62
 
63
+ # Enhanced context extraction
64
  def enhanced_context_extraction(document, keywords, vectorizer, tfidf_scores, top_n=5):
65
  paragraphs = [para for para in document.split("\n") if para]
66
  scores = [sum([para.lower().count(keyword) * tfidf_scores[vectorizer.vocabulary_[keyword]] for keyword in keywords if keyword in para.lower()]) for para in paragraphs]
 
70
 
71
  return " ".join(relevant_paragraphs)
72
 
73
+ # Targeted context extraction
74
  def targeted_context_extraction(document, keywords, vectorizer, tfidf_scores, top_n=5):
75
  paragraphs = [para for para in document.split("\n") if para]
76
  scores = [sum([para.lower().count(keyword) * tfidf_scores[vectorizer.vocabulary_[keyword]] for keyword in keywords]) for para in paragraphs]
 
80
 
81
  return " ".join(relevant_paragraphs)
82
 
83
+ # Extract page and clause references
84
  def extract_page_and_clause_references(paragraph: str) -> str:
85
  page_matches = re.findall(r'Page (\d+)', paragraph)
86
  clause_matches = re.findall(r'Clause (\d+\.\d+)', paragraph)
 
90
 
91
  return f"({page_ref}, {clause_ref})".strip(", ")
92
 
93
+ # Refine answer based on question
94
  def refine_answer_based_on_question(question: str, answer: str) -> str:
95
  if "Does the agreement contain" in question:
96
  if "not" in answer or "No" in answer:
 
102
 
103
  return refined_answer
104
 
105
+ # Answer query with context using RAG
106
+ def answer_query_with_context(question: str, df: pd.DataFrame, tokenizer, retriever, generator, top_n_paragraphs: int = 5) -> str:
107
  question_words = set(question.split())
108
 
109
  priority_keywords = ["duration", "term", "period", "month", "year", "day", "week", "agreement", "obligation", "effective date"]
 
113
  most_relevant_paragraphs = df.sort_values(by='relevance_score', ascending=False).iloc[:top_n_paragraphs]['content'].tolist()
114
 
115
  context = "\n\n".join(most_relevant_paragraphs)
 
116
 
117
+ # Retrieve documents relevant to the question
118
+ documents = retriever.retrieve(question)
119
+
120
+ # Generate answer using RAG
121
+ inputs = tokenizer(question, context, return_tensors="pt", max_length=512, truncation=True)
122
+ outputs = generator.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=200, num_return_sequences=1)
123
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
124
 
125
  references = extract_page_and_clause_references(context)
 
127
 
128
  return answer
129
 
130
+ # Get embedding
131
  def get_embedding(text, tokenizer):
132
  try:
133
  inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
 
137
  print("Error obtaining embedding:", e)
138
  embedding = []
139
  return embedding
140
+
141
+ # Example usage
142
+ question = "What is the duration of the agreement?"
143
+ df = pd.DataFrame(...) # Assuming you have a DataFrame with content
144
+ answer = answer_query_with_context(question, df, rag_tokenizer, rag_retriever, rag_token_for_generation)
145
+ print("Answer:", answer)