H commited on
Commit
21583be
·
verified ·
1 Parent(s): 61dc92d
Files changed (3) hide show
  1. app.py +134 -0
  2. requirements.txt +19 -0
  3. search.py +136 -0
app.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import streamlit as st
3
+ import pandas as pd
4
+ import os
5
+ from dotenv import load_dotenv
6
+ import search # Import the search module
7
+ from reportlab.lib.pagesizes import letter
8
+ from reportlab.pdfgen import canvas
9
+ from docx import Document
10
+
11
+ load_dotenv()
12
+
13
+ st.set_page_config(
14
+ page_title="DocGPT GT",
15
+ page_icon="speech_balloon",
16
+ layout="wide",
17
+ )
18
+
19
+ hide_streamlit_style = """
20
+ <style>
21
+ #MainMenu {visibility: hidden;}
22
+ footer {visibility: hidden;}
23
+ footer:after {
24
+ content:'2023';
25
+ visibility: visible;
26
+ display: block;
27
+ position: relative;
28
+ padding: 5px;
29
+ top: 2px;
30
+ }
31
+ </style>
32
+ """
33
+ st.markdown(hide_streamlit_style, unsafe_allow_html=True)
34
+
35
+ def save_as_pdf(conversation):
36
+ pdf_filename = "conversation.pdf"
37
+ c = canvas.Canvas(pdf_filename, pagesize=letter)
38
+
39
+ c.drawString(100, 750, "Conversation:")
40
+ y_position = 730
41
+ for q, a in conversation:
42
+ c.drawString(120, y_position, f"Q: {q}")
43
+ c.drawString(120, y_position - 20, f"A: {a}")
44
+ y_position -= 40
45
+
46
+ c.save()
47
+
48
+ st.markdown(f"Download [PDF](./{pdf_filename})")
49
+
50
+ def save_as_docx(conversation):
51
+ doc = Document()
52
+ doc.add_heading('Conversation', 0)
53
+
54
+ for q, a in conversation:
55
+ doc.add_paragraph(f'Q: {q}')
56
+ doc.add_paragraph(f'A: {a}')
57
+
58
+ doc_filename = "conversation.docx"
59
+ doc.save(doc_filename)
60
+
61
+ st.markdown(f"Download [DOCX](./{doc_filename})")
62
+
63
+ def save_as_xlsx(conversation):
64
+ df = pd.DataFrame(conversation, columns=["Question", "Answer"])
65
+ xlsx_filename = "conversation.xlsx"
66
+ df.to_excel(xlsx_filename, index=False)
67
+
68
+ st.markdown(f"Download [XLSX](./{xlsx_filename})")
69
+
70
+ def save_as_txt(conversation):
71
+ txt_filename = "conversation.txt"
72
+ with open(txt_filename, "w") as txt_file:
73
+ for q, a in conversation:
74
+ txt_file.write(f"Q: {q}\nA: {a}\n\n")
75
+
76
+ st.markdown(f"Download [TXT](./{txt_filename})")
77
+
78
+ def main():
79
+ st.markdown('<h1>Ask anything from Legal Texts</h1><p style="font-size: 12; color: gray;"></p>', unsafe_allow_html=True)
80
+ st.markdown("<h2>Upload documents</h2>", unsafe_allow_html=True)
81
+ uploaded_files = st.file_uploader("Upload one or more documents", type=['pdf', 'docx'], accept_multiple_files=True)
82
+ question = st.text_input("Ask a question based on the documents", key="question_input")
83
+
84
+ progress = st.progress(0)
85
+ for i in range(100):
86
+ progress.progress(i + 1)
87
+ time.sleep(0.01)
88
+
89
+ if uploaded_files:
90
+ df = pd.DataFrame(columns=["page_num", "paragraph_num", "content", "tokens"])
91
+ for uploaded_file in uploaded_files:
92
+ paragraphs = search.read_pdf_pdfminer(uploaded_file) if uploaded_file.type == "application/pdf" else search.read_docx(uploaded_file)
93
+ temp_df = pd.DataFrame(
94
+ [(p.page_num, p.paragraph_num, p.content, search.count_tokens(p.content))
95
+ for p in paragraphs],
96
+ columns=["page_num", "paragraph_num", "content", "tokens"]
97
+ )
98
+ df = pd.concat([df, temp_df], ignore_index=True)
99
+
100
+ if "interactions" not in st.session_state:
101
+ st.session_state["interactions"] = []
102
+
103
+ answer = ""
104
+ if question != st.session_state.get("last_question", ""):
105
+ st.text("Searching...")
106
+ answer = search.answer_query_with_context(question, df)
107
+ st.session_state["interactions"].append((question, answer))
108
+ st.write(answer)
109
+
110
+ st.markdown("### Interaction History")
111
+ for q, a in st.session_state["interactions"]:
112
+ st.write(f"**Q:** {q}\n\n**A:** {a}")
113
+
114
+ st.session_state["last_question"] = question
115
+
116
+ st.markdown("<h2>Sample paragraphs</h2>", unsafe_allow_html=True)
117
+ sample_size = min(len(df), 5)
118
+ st.dataframe(df.sample(n=sample_size))
119
+
120
+ if st.button("Save as PDF"):
121
+ save_as_pdf(st.session_state["interactions"])
122
+ if st.button("Save as DOCX"):
123
+ save_as_docx(st.session_state["interactions"])
124
+ if st.button("Save as XLSX"):
125
+ save_as_xlsx(st.session_state["interactions"])
126
+ if st.button("Save as TXT"):
127
+ save_as_txt(st.session_state["interactions"])
128
+
129
+
130
+ else:
131
+ st.markdown("<h2>Please upload a document to proceed.</h2>", unsafe_allow_html=True)
132
+
133
+ if __name__ == "__main__":
134
+ main()
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ certifi==2021.5.30
2
+ charset-normalizer==2.0.6
3
+ idna==3.2
4
+ openai==0.27.0
5
+ pandas==2.0.3
6
+ Pillow==10.0.0
7
+ PyPDF2==1.26.0
8
+ regex==2023.6.3
9
+ requests==2.26.0
10
+ sentencepiece==0.1.99
11
+ six==1.16.0
12
+ streamlit==1.25.0
13
+ tenacity==8.2.2
14
+ tiktoken==0.4.0
15
+ tqdm==4.65.0
16
+ transformers==4.31.0
17
+ urllib3==1.26.6
18
+ python-dotenv
19
+ dataclasses
search.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ from docx import Document
3
+ from pdfminer.high_level import extract_text
4
+ from transformers import GPT2Tokenizer
5
+ from dataclasses import dataclass
6
+ from typing import List
7
+ from tqdm import tqdm
8
+ import os
9
+ import pandas as pd
10
+ import re
11
+
12
+ EMBEDDING_SEG_LEN = 1500
13
+ COMPLETIONS_MODEL = "text-davinci-003"
14
+ EMBEDDING_MODEL = "gpt-4"
15
+ openai.api_key = os.environ["OPENAI_API_KEY"]
16
+ EMBEDDING_CTX_LENGTH = 8191
17
+ EMBEDDING_ENCODING = "cl100k_base"
18
+ ENCODING = "gpt2"
19
+
20
+ @dataclass
21
+ class Paragraph:
22
+ page_num: int
23
+ paragraph_num: int
24
+ content: str
25
+
26
+ def read_pdf_pdfminer(file_path) -> List[Paragraph]:
27
+ text = extract_text(file_path).replace('\n', ' ').strip()
28
+ paragraphs = batched(text, EMBEDDING_SEG_LEN)
29
+ paragraphs_objs = []
30
+ paragraph_num = 1
31
+ for p in paragraphs:
32
+ para = Paragraph(0, paragraph_num, p)
33
+ paragraphs_objs.append(para)
34
+ paragraph_num += 1
35
+ return paragraphs_objs
36
+
37
+ def read_docx(file) -> List[Paragraph]:
38
+ doc = Document(file)
39
+ paragraphs = []
40
+ for paragraph_num, paragraph in enumerate(doc.paragraphs, start=1):
41
+ content = paragraph.text.strip()
42
+ if content:
43
+ para = Paragraph(1, paragraph_num, content)
44
+ paragraphs.append(para)
45
+ return paragraphs
46
+
47
+ def count_tokens(text):
48
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
49
+ return len(tokenizer.encode(text))
50
+
51
+ def batched(iterable, n):
52
+ l = len(iterable)
53
+ for ndx in range(0, l, n):
54
+ yield iterable[ndx : min(ndx + n, l)]
55
+
56
+ def compute_doc_embeddings(df):
57
+ embeddings = {}
58
+ for index, row in tqdm(df.iterrows(), total=df.shape[0]):
59
+ doc = row["content"]
60
+ doc_embedding = get_embedding(doc)
61
+ embeddings[index] = doc_embedding
62
+ return embeddings
63
+
64
+ def enhanced_context_extraction(document, keywords, top_n=5):
65
+ paragraphs = [para for para in document.split("\n") if para]
66
+ def score_paragraph(para, keywords):
67
+ keyword_count = sum([para.lower().count(keyword) for keyword in keywords])
68
+ positions = [para.lower().find(keyword) for keyword in keywords if keyword in para.lower()]
69
+ proximity_score = 1 if max(positions) else 0
70
+ return keyword_count + proximity_score
71
+ scores = [score_paragraph(para, keywords) for para in paragraphs]
72
+ top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_n]
73
+ relevant_paragraphs = [paragraphs[i] for i in top_indices]
74
+ return " ".join(relevant_paragraphs)
75
+
76
+ def targeted_context_extraction(document, keywords, top_n=5):
77
+ paragraphs = [para for para in document.split("\n") if para]
78
+ scores = [sum([para.lower().count(keyword) for keyword in keywords]) for para in paragraphs]
79
+ top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_n]
80
+ relevant_paragraphs = [paragraphs[i] for i in top_indices]
81
+ return " ".join(relevant_paragraphs)
82
+
83
+
84
+ def extract_page_and_clause_references(paragraph: str) -> str:
85
+ page_matches = re.findall(r'Page (\d+)', paragraph)
86
+ clause_matches = re.findall(r'Clause (\d+\.\d+)', paragraph)
87
+
88
+ page_ref = f"Page {page_matches[0]}" if page_matches else ""
89
+ clause_ref = f"Clause {clause_matches[0]}" if clause_matches else ""
90
+
91
+ return f"({page_ref}, {clause_ref})".strip(", ")
92
+
93
+ def refine_answer_based_on_question(question: str, answer: str) -> str:
94
+ if "Does the agreement contain" in question:
95
+ if "not" in answer or "No" in answer:
96
+ refined_answer = f"No, the agreement does not contain {answer}"
97
+ else:
98
+ refined_answer = f"Yes, the agreement contains {answer}"
99
+ else:
100
+ refined_answer = answer
101
+
102
+ return refined_answer
103
+
104
+ def answer_query_with_context(question: str, df: pd.DataFrame, top_n_paragraphs: int = 5) -> str:
105
+ question_words = set(question.split())
106
+
107
+ # Prioritizing certain keywords for better context extraction
108
+ priority_keywords = ["duration", "term", "period", "month", "year", "day", "week", "agreement", "obligation", "effective date"]
109
+
110
+ df['relevance_score'] = df['content'].apply(lambda x: len(question_words.intersection(set(x.split()))) + sum([x.lower().count(pk) for pk in priority_keywords]))
111
+
112
+ most_relevant_paragraphs = df.sort_values(by='relevance_score', ascending=False).iloc[:top_n_paragraphs]['content'].tolist()
113
+
114
+ context = "\n\n".join(most_relevant_paragraphs)
115
+ prompt = f"Question: {question}\n\nContext: {context}\n\nAnswer:"
116
+ response = openai.Completion.create(model=COMPLETIONS_MODEL, prompt=prompt, max_tokens=150)
117
+ answer = response.choices[0].text.strip()
118
+
119
+ # Refine the answer to include page and clause references and match the phrasing of the question
120
+ references = extract_page_and_clause_references(context)
121
+ answer = refine_answer_based_on_question(question, answer) + " " + references
122
+
123
+ return answer
124
+ def get_embedding(text):
125
+ try:
126
+ response = openai.Embed.create(
127
+ model=EMBEDDING_MODEL,
128
+ context=text,
129
+ context_encoding=EMBEDDING_ENCODING,
130
+ context_length=EMBEDDING_CTX_LENGTH
131
+ )
132
+ embedding = response["embedding"]
133
+ except Exception as e:
134
+ print("Error obtaining embedding:", e)
135
+ embedding = []
136
+ return embedding