albhu commited on
Commit
a6fe316
·
verified ·
1 Parent(s): 9d48d5a

Update search.py

Browse files
Files changed (1) hide show
  1. search.py +77 -110
search.py CHANGED
@@ -1,135 +1,102 @@
1
- from transformers import AutoTokenizer, GPT2LMHeadModel
2
  from docx import Document
3
  from pdfminer.high_level import extract_text
4
- from transformers import GPT2Tokenizer
5
  from dataclasses import dataclass
6
- from typing import List
7
- from tqdm import tqdm
8
- import os
9
- import pandas as pd
10
- import re
11
- from sklearn.feature_extraction.text import TfidfVectorizer
12
- import numpy as np
13
 
14
- tokenizer = AutoTokenizer.from_pretrained("impira/layoutlm-document-qa", trust_remote_code=True)
15
- model = GPT2LMHeadModel.from_pretrained("impira/layoutlm-document-qa", trust_remote_code=True)
16
-
17
- EMBEDDING_SEG_LEN = 1500
18
- EMBEDDING_MODEL = "gpt-4"
19
-
20
- EMBEDDING_CTX_LENGTH = 8191
21
- EMBEDDING_ENCODING = "cl100k_base"
22
- ENCODING = "gpt2"
23
 
 
24
  @dataclass
25
  class Paragraph:
26
  page_num: int
27
  paragraph_num: int
28
  content: str
 
29
 
30
- def read_pdf_pdfminer(file_path) -> List[Paragraph]:
 
31
  text = extract_text(file_path).replace('\n', ' ').strip()
32
- paragraphs = batched(text, EMBEDDING_SEG_LEN)
33
- paragraphs_objs = []
34
- paragraph_num = 1
35
- for p in paragraphs:
36
- para = Paragraph(0, paragraph_num, p)
37
- paragraphs_objs.append(para)
38
- paragraph_num += 1
39
- return paragraphs_objs
40
 
41
- def read_docx(file) -> List[Paragraph]:
42
- doc = Document(file)
43
- paragraphs = []
44
- for paragraph_num, paragraph in enumerate(doc.paragraphs, start=1):
45
- content = paragraph.text.strip()
46
- if content:
47
- para = Paragraph(1, paragraph_num, content)
48
- paragraphs.append(para)
49
  return paragraphs
50
 
51
- def count_tokens(text, tokenizer):
52
- return len(tokenizer.encode(text))
 
 
 
 
 
 
 
53
 
 
54
  def batched(iterable, n):
55
  l = len(iterable)
56
  for ndx in range(0, l, n):
57
- yield iterable[ndx : min(ndx + n, l)]
58
-
59
- def compute_doc_embeddings(df, tokenizer):
60
- embeddings = {}
61
- for index, row in tqdm(df.iterrows(), total=df.shape[0]):
62
- doc = row["content"]
63
- doc_embedding = get_embedding(doc, tokenizer)
64
- embeddings[index] = doc_embedding
65
- return embeddings
66
-
67
- def enhanced_context_extraction(document, keywords, vectorizer, tfidf_scores, top_n=5):
68
- paragraphs = [para for para in document.split("\n") if para]
69
- scores = [sum([para.lower().count(keyword) * tfidf_scores[vectorizer.vocabulary_[keyword]] for keyword in keywords if keyword in para.lower()]) for para in paragraphs]
70
-
71
- top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_n]
72
- relevant_paragraphs = [paragraphs[i] for i in top_indices]
73
-
74
- return " ".join(relevant_paragraphs)
75
-
76
- def targeted_context_extraction(document, keywords, vectorizer, tfidf_scores, top_n=5):
77
- paragraphs = [para for para in document.split("\n") if para]
78
- scores = [sum([para.lower().count(keyword) * tfidf_scores[vectorizer.vocabulary_[keyword]] for keyword in keywords]) for para in paragraphs]
79
-
80
- top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_n]
81
- relevant_paragraphs = [paragraphs[i] for i in top_indices]
82
-
83
- return " ".join(relevant_paragraphs)
84
 
85
-
86
- def extract_page_and_clause_references(paragraph: str) -> str:
87
- page_matches = re.findall(r'Page (\d+)', paragraph)
88
- clause_matches = re.findall(r'Clause (\d+\.\d+)', paragraph)
89
-
90
- page_ref = f"Page {page_matches[0]}" if page_matches else ""
91
- clause_ref = f"Clause {clause_matches[0]}" if clause_matches else ""
92
-
93
- return f"({page_ref}, {clause_ref})".strip(", ")
94
-
95
- def refine_answer_based_on_question(question: str, answer: str) -> str:
96
- if "Does the agreement contain" in question:
97
- if "not" in answer or "No" in answer:
98
- refined_answer = f"No, the agreement does not contain {answer}"
99
- else:
100
- refined_answer = f"Yes, the agreement contains {answer}"
101
- else:
102
- refined_answer = answer
103
-
104
- return refined_answer
105
-
106
- def answer_query_with_context(question: str, df: pd.DataFrame, tokenizer, model, top_n_paragraphs: int = 5) -> str:
107
- question_words = set(question.split())
108
-
109
- priority_keywords = ["duration", "term", "period", "month", "year", "day", "week", "agreement", "obligation", "effective date"]
110
-
111
- df['relevance_score'] = df['content'].apply(lambda x: len(question_words.intersection(set(x.split()))) + sum([x.lower().count(pk) for pk in priority_keywords]))
112
-
113
- most_relevant_paragraphs = df.sort_values(by='relevance_score', ascending=False).iloc[:top_n_paragraphs]['content'].tolist()
114
-
115
- context = "\n\n".join(most_relevant_paragraphs)
116
- prompt = f"Question: {question}\n\nContext: {context}\n\nAnswer:"
117
-
118
- inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True)
119
- outputs = model.generate(inputs, max_length=200)
120
- answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
121
-
122
- references = extract_page_and_clause_references(context)
123
- answer = refine_answer_based_on_question(question, answer) + " " + references
124
-
125
- return answer
126
-
127
- def get_embedding(text, tokenizer):
128
  try:
129
- inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
130
  outputs = model(**inputs)
131
  embedding = outputs.last_hidden_state
 
132
  except Exception as e:
133
  print("Error obtaining embedding:", e)
134
- embedding = []
135
- return embedding
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
  from docx import Document
3
  from pdfminer.high_level import extract_text
4
+ from typing import List, Union
5
  from dataclasses import dataclass
 
 
 
 
 
 
 
6
 
7
+ # Initialize the tokenizer and model
8
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
9
+ model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True)
 
 
 
 
 
 
10
 
11
+ # Define the Paragraph data class
12
  @dataclass
13
  class Paragraph:
14
  page_num: int
15
  paragraph_num: int
16
  content: str
17
+ embedding: Union[list, None] = None
18
 
19
+ # Function to read text from a PDF file
20
+ def read_pdf(file_path: str) -> List[Paragraph]:
21
  text = extract_text(file_path).replace('\n', ' ').strip()
22
+ return create_paragraphs(text)
 
 
 
 
 
 
 
23
 
24
+ # Function to read text from a DOCX file
25
+ def read_docx(file_path: str) -> List[Paragraph]:
26
+ doc = Document(file_path)
27
+ paragraphs = [Paragraph(1, idx + 1, para.text.strip()) for idx, para in enumerate(doc.paragraphs) if para.text.strip()]
 
 
 
 
28
  return paragraphs
29
 
30
+ # Helper function to split text into paragraphs
31
+ def create_paragraphs(text: str, max_length: int = 1500) -> List[Paragraph]:
32
+ paragraphs = []
33
+ paragraph_num = 1
34
+ for chunk in batched(text, max_length):
35
+ para = Paragraph(0, paragraph_num, chunk)
36
+ paragraphs.append(para)
37
+ paragraph_num += 1
38
+ return paragraphs
39
 
40
+ # Helper function to batch an iterable
41
  def batched(iterable, n):
42
  l = len(iterable)
43
  for ndx in range(0, l, n):
44
+ yield iterable[ndx: min(ndx + n, l)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ # Function to obtain embeddings for a given text
47
+ def get_embedding(text: str, tokenizer, max_length: int = 512) -> Union[list, None]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  try:
49
+ inputs = tokenizer(text, return_tensors="pt", max_length=max_length, truncation=True)
50
  outputs = model(**inputs)
51
  embedding = outputs.last_hidden_state
52
+ return embedding
53
  except Exception as e:
54
  print("Error obtaining embedding:", e)
55
+ return None
56
+
57
+ # Function to process a single paragraph and obtain its embedding
58
+ def process_paragraph(paragraph: Paragraph) -> Union[list, None]:
59
+ try:
60
+ embedding = get_embedding(paragraph.content, tokenizer)
61
+ return embedding
62
+ except Exception as e:
63
+ print(f"Error processing paragraph {paragraph.paragraph_num}: {e}")
64
+ return None
65
+
66
+ # Main function to process a document and obtain embeddings for each paragraph
67
+ def process_document(file_path: str, file_type: str = None) -> List[Paragraph]:
68
+ supported_types = ['pdf', 'docx']
69
+ if file_type not in supported_types:
70
+ print(f"Unsupported file type. Please provide one of the following supported types: {', '.join(supported_types)}")
71
+ return []
72
+
73
+ if file_type == 'pdf':
74
+ paragraphs = read_pdf(file_path)
75
+ elif file_type == 'docx':
76
+ paragraphs = read_docx(file_path)
77
+
78
+ if not paragraphs:
79
+ print("No paragraphs found in the document.")
80
+ return []
81
+
82
+ # Process each paragraph and obtain embeddings
83
+ for idx, paragraph in enumerate(paragraphs):
84
+ print(f"Processing paragraph {idx + 1}...")
85
+ embedding = process_paragraph(paragraph)
86
+ if embedding:
87
+ paragraph.embedding = embedding
88
+ else:
89
+ print(f"Embedding for paragraph {idx + 1} could not be obtained.")
90
+ return paragraphs
91
+
92
+ # Example usage
93
+ if __name__ == "__main__":
94
+ file_path = "example.pdf"
95
+ file_type = file_path.split(".")[-1]
96
+ paragraphs = process_document(file_path, file_type)
97
+ for para in paragraphs:
98
+ print(para.content)
99
+ if hasattr(para, 'embedding') and para.embedding is not None:
100
+ print("Embedding:", para.embedding)
101
+ else:
102
+ print("Embedding could not be obtained.")