albhu commited on
Commit
d0eefa8
·
verified ·
1 Parent(s): e22de9e

Update search.py

Browse files
Files changed (1) hide show
  1. search.py +20 -35
search.py CHANGED
@@ -1,32 +1,32 @@
1
- from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration, RagConfig
2
  from docx import Document
3
  from pdfminer.high_level import extract_text
 
 
4
  from typing import List
 
 
5
  import pandas as pd
6
  import re
7
- from datasets import load_dataset
 
8
 
9
- # Initialize RAG components
10
- rag_tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
11
- rag_token_for_generation = RagTokenForGeneration.from_pretrained("facebook/rag-token-base")
12
- rag_config = RagConfig.from_pretrained("facebook/rag-token-base")
13
 
14
- # Download and prepare the wiki_dpr dataset
15
- dpr_dataset = load_dataset("wiki_dpr")
16
- passages = dpr_dataset["train"]["passage"]
17
- titles = dpr_dataset["train"]["title"]
18
 
19
- # Initialize the RagRetriever
20
- rag_retriever = RagRetriever(passages=passages, titles=titles, config=rag_config)
 
21
 
22
- # Dataclass for paragraph
23
  @dataclass
24
  class Paragraph:
25
  page_num: int
26
  paragraph_num: int
27
  content: str
28
 
29
- # Read PDF using pdfminer
30
  def read_pdf_pdfminer(file_path) -> List[Paragraph]:
31
  text = extract_text(file_path).replace('\n', ' ').strip()
32
  paragraphs = batched(text, EMBEDDING_SEG_LEN)
@@ -38,7 +38,6 @@ def read_pdf_pdfminer(file_path) -> List[Paragraph]:
38
  paragraph_num += 1
39
  return paragraphs_objs
40
 
41
- # Read DOCX file
42
  def read_docx(file) -> List[Paragraph]:
43
  doc = Document(file)
44
  paragraphs = []
@@ -49,17 +48,14 @@ def read_docx(file) -> List[Paragraph]:
49
  paragraphs.append(para)
50
  return paragraphs
51
 
52
- # Count tokens
53
  def count_tokens(text, tokenizer):
54
  return len(tokenizer.encode(text))
55
 
56
- # Batched processing
57
  def batched(iterable, n):
58
  l = len(iterable)
59
  for ndx in range(0, l, n):
60
  yield iterable[ndx : min(ndx + n, l)]
61
 
62
- # Compute document embeddings
63
  def compute_doc_embeddings(df, tokenizer):
64
  embeddings = {}
65
  for index, row in tqdm(df.iterrows(), total=df.shape[0]):
@@ -68,7 +64,6 @@ def compute_doc_embeddings(df, tokenizer):
68
  embeddings[index] = doc_embedding
69
  return embeddings
70
 
71
- # Enhanced context extraction
72
  def enhanced_context_extraction(document, keywords, vectorizer, tfidf_scores, top_n=5):
73
  paragraphs = [para for para in document.split("\n") if para]
74
  scores = [sum([para.lower().count(keyword) * tfidf_scores[vectorizer.vocabulary_[keyword]] for keyword in keywords if keyword in para.lower()]) for para in paragraphs]
@@ -78,7 +73,6 @@ def enhanced_context_extraction(document, keywords, vectorizer, tfidf_scores, to
78
 
79
  return " ".join(relevant_paragraphs)
80
 
81
- # Targeted context extraction
82
  def targeted_context_extraction(document, keywords, vectorizer, tfidf_scores, top_n=5):
83
  paragraphs = [para for para in document.split("\n") if para]
84
  scores = [sum([para.lower().count(keyword) * tfidf_scores[vectorizer.vocabulary_[keyword]] for keyword in keywords]) for para in paragraphs]
@@ -88,7 +82,7 @@ def targeted_context_extraction(document, keywords, vectorizer, tfidf_scores, to
88
 
89
  return " ".join(relevant_paragraphs)
90
 
91
- # Extract page and clause references
92
  def extract_page_and_clause_references(paragraph: str) -> str:
93
  page_matches = re.findall(r'Page (\d+)', paragraph)
94
  clause_matches = re.findall(r'Clause (\d+\.\d+)', paragraph)
@@ -98,7 +92,6 @@ def extract_page_and_clause_references(paragraph: str) -> str:
98
 
99
  return f"({page_ref}, {clause_ref})".strip(", ")
100
 
101
- # Refine answer based on question
102
  def refine_answer_based_on_question(question: str, answer: str) -> str:
103
  if "Does the agreement contain" in question:
104
  if "not" in answer or "No" in answer:
@@ -110,8 +103,7 @@ def refine_answer_based_on_question(question: str, answer: str) -> str:
110
 
111
  return refined_answer
112
 
113
- # Answer query with context using RAG
114
- def answer_query_with_context(question: str, df: pd.DataFrame, tokenizer, retriever, generator, top_n_paragraphs: int = 5) -> str:
115
  question_words = set(question.split())
116
 
117
  priority_keywords = ["duration", "term", "period", "month", "year", "day", "week", "agreement", "obligation", "effective date"]
@@ -121,18 +113,17 @@ def answer_query_with_context(question: str, df: pd.DataFrame, tokenizer, retrie
121
  most_relevant_paragraphs = df.sort_values(by='relevance_score', ascending=False).iloc[:top_n_paragraphs]['content'].tolist()
122
 
123
  context = "\n\n".join(most_relevant_paragraphs)
 
124
 
125
- # Generate answer using RAG
126
- inputs = rag_tokenizer(question, context, return_tensors="pt", max_length=512, truncation=True)
127
- outputs = rag_token_for_generation.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=200, num_return_sequences=1)
128
- answer = rag_tokenizer.decode(outputs[0], skip_special_tokens=True)
129
 
130
  references = extract_page_and_clause_references(context)
131
  answer = refine_answer_based_on_question(question, answer) + " " + references
132
 
133
  return answer
134
 
135
- # Get embedding
136
  def get_embedding(text, tokenizer):
137
  try:
138
  inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
@@ -142,9 +133,3 @@ def get_embedding(text, tokenizer):
142
  print("Error obtaining embedding:", e)
143
  embedding = []
144
  return embedding
145
-
146
- # Example usage
147
- question = "What is the duration of the agreement?"
148
- df = pd.DataFrame(...) # Assuming you have a DataFrame with content
149
- answer = answer_query_with_context(question, df, rag_tokenizer, rag_retriever, rag_token_for_generation)
150
- print("Answer:", answer)
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
  from docx import Document
3
  from pdfminer.high_level import extract_text
4
+ from transformers import GPT2Tokenizer
5
+ from dataclasses import dataclass
6
  from typing import List
7
+ from tqdm import tqdm
8
+ import os
9
  import pandas as pd
10
  import re
11
+ from sklearn.feature_extraction.text import TfidfVectorizer
12
+ import numpy as np
13
 
14
+ tokenizer = AutoTokenizer.from_pretrained("impira/layoutlm-document-qa", trust_remote_code=True)
15
+ model = AutoModelForCausalLM.from_pretrained("impira/layoutlm-document-qa", trust_remote_code=True)
 
 
16
 
17
+ EMBEDDING_SEG_LEN = 1500
18
+ EMBEDDING_MODEL = "gpt-4"
 
 
19
 
20
+ EMBEDDING_CTX_LENGTH = 8191
21
+ EMBEDDING_ENCODING = "cl100k_base"
22
+ ENCODING = "gpt2"
23
 
 
24
  @dataclass
25
  class Paragraph:
26
  page_num: int
27
  paragraph_num: int
28
  content: str
29
 
 
30
  def read_pdf_pdfminer(file_path) -> List[Paragraph]:
31
  text = extract_text(file_path).replace('\n', ' ').strip()
32
  paragraphs = batched(text, EMBEDDING_SEG_LEN)
 
38
  paragraph_num += 1
39
  return paragraphs_objs
40
 
 
41
  def read_docx(file) -> List[Paragraph]:
42
  doc = Document(file)
43
  paragraphs = []
 
48
  paragraphs.append(para)
49
  return paragraphs
50
 
 
51
  def count_tokens(text, tokenizer):
52
  return len(tokenizer.encode(text))
53
 
 
54
  def batched(iterable, n):
55
  l = len(iterable)
56
  for ndx in range(0, l, n):
57
  yield iterable[ndx : min(ndx + n, l)]
58
 
 
59
  def compute_doc_embeddings(df, tokenizer):
60
  embeddings = {}
61
  for index, row in tqdm(df.iterrows(), total=df.shape[0]):
 
64
  embeddings[index] = doc_embedding
65
  return embeddings
66
 
 
67
  def enhanced_context_extraction(document, keywords, vectorizer, tfidf_scores, top_n=5):
68
  paragraphs = [para for para in document.split("\n") if para]
69
  scores = [sum([para.lower().count(keyword) * tfidf_scores[vectorizer.vocabulary_[keyword]] for keyword in keywords if keyword in para.lower()]) for para in paragraphs]
 
73
 
74
  return " ".join(relevant_paragraphs)
75
 
 
76
  def targeted_context_extraction(document, keywords, vectorizer, tfidf_scores, top_n=5):
77
  paragraphs = [para for para in document.split("\n") if para]
78
  scores = [sum([para.lower().count(keyword) * tfidf_scores[vectorizer.vocabulary_[keyword]] for keyword in keywords]) for para in paragraphs]
 
82
 
83
  return " ".join(relevant_paragraphs)
84
 
85
+
86
  def extract_page_and_clause_references(paragraph: str) -> str:
87
  page_matches = re.findall(r'Page (\d+)', paragraph)
88
  clause_matches = re.findall(r'Clause (\d+\.\d+)', paragraph)
 
92
 
93
  return f"({page_ref}, {clause_ref})".strip(", ")
94
 
 
95
  def refine_answer_based_on_question(question: str, answer: str) -> str:
96
  if "Does the agreement contain" in question:
97
  if "not" in answer or "No" in answer:
 
103
 
104
  return refined_answer
105
 
106
+ def answer_query_with_context(question: str, df: pd.DataFrame, tokenizer, model, top_n_paragraphs: int = 5) -> str:
 
107
  question_words = set(question.split())
108
 
109
  priority_keywords = ["duration", "term", "period", "month", "year", "day", "week", "agreement", "obligation", "effective date"]
 
113
  most_relevant_paragraphs = df.sort_values(by='relevance_score', ascending=False).iloc[:top_n_paragraphs]['content'].tolist()
114
 
115
  context = "\n\n".join(most_relevant_paragraphs)
116
+ prompt = f"Question: {question}\n\nContext: {context}\n\nAnswer:"
117
 
118
+ inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True)
119
+ outputs = model.generate(inputs, max_length=200)
120
+ answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
121
 
122
  references = extract_page_and_clause_references(context)
123
  answer = refine_answer_based_on_question(question, answer) + " " + references
124
 
125
  return answer
126
 
 
127
  def get_embedding(text, tokenizer):
128
  try:
129
  inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
 
133
  print("Error obtaining embedding:", e)
134
  embedding = []
135
  return embedding