|
from dotenv import load_dotenv
|
|
import os
|
|
import uuid
|
|
from PyPDF2 import PdfReader
|
|
from docx import Document
|
|
from docx.text.paragraph import Paragraph
|
|
from docx.table import Table
|
|
from langchain.text_splitter import CharacterTextSplitter
|
|
from langchain_community.vectorstores import Chroma
|
|
from langchain_community.embeddings import OpenAIEmbeddings
|
|
import streamlit as st
|
|
from textwrap import dedent
|
|
from Prompts_and_Chains import LLMChains
|
|
|
|
|
|
def extract_text_from_file(file):
|
|
text = file.read().decode("utf-8")
|
|
return text
|
|
|
|
|
|
def process_paragraph(paragraph):
|
|
|
|
return paragraph.text
|
|
|
|
|
|
def process_table(table):
|
|
|
|
text = ""
|
|
for row in table.rows:
|
|
for cell in row.cells:
|
|
text += cell.text
|
|
|
|
return text
|
|
|
|
|
|
def read_docx(file_path):
|
|
doc = Document(file_path)
|
|
data = []
|
|
|
|
for element in doc.iter_inner_content():
|
|
if isinstance(element, Paragraph):
|
|
data.append(process_paragraph(element))
|
|
if isinstance(element, Table):
|
|
data.append(process_table(element))
|
|
|
|
return "\n".join(data)
|
|
|
|
|
|
def get_pdf_text(pdf):
|
|
"""This function extracts the text from the PDF file"""
|
|
text = []
|
|
pdf_reader = PdfReader(pdf)
|
|
for page in pdf_reader.pages:
|
|
text.append(page.extract_text())
|
|
|
|
return "\n".join(text)
|
|
|
|
|
|
class RFPProcessor:
|
|
def __init__(self):
|
|
load_dotenv()
|
|
self.openai_api_key = os.getenv("OPENAI_API_KEY")
|
|
self.chains_obj = LLMChains()
|
|
|
|
def process_case_data(self, case_name, files):
|
|
if case_name and files:
|
|
|
|
case_id = str(uuid.uuid4())
|
|
extracted_data = []
|
|
all_texts = []
|
|
for file in files:
|
|
file_text = []
|
|
if file.name.endswith(".docx"):
|
|
file_text = read_docx(file)
|
|
elif file.name.endswith(".pdf"):
|
|
file_text = get_pdf_text(file)
|
|
else:
|
|
file_text = extract_text_from_file(file)
|
|
|
|
text_splitter = CharacterTextSplitter(
|
|
separator="\n", chunk_size=1000, chunk_overlap=150, length_function=len
|
|
)
|
|
texts = text_splitter.split_text(" ".join(file_text))
|
|
|
|
all_texts.extend(texts)
|
|
extracted_data.append(" ".join(file_text))
|
|
|
|
project_dir = os.path.dirname(os.path.abspath(__file__))
|
|
vectorstore = Chroma(
|
|
persist_directory=os.path.join(
|
|
project_dir, "vector_stores", case_name),
|
|
embedding_function=OpenAIEmbeddings(
|
|
openai_api_key=self.openai_api_key),
|
|
)
|
|
vectorstore.add_texts(all_texts)
|
|
|
|
st.session_state[case_id] = {
|
|
"vectorstore": vectorstore,
|
|
"extracted_data": extracted_data,
|
|
}
|
|
all_text = " ".join(extracted_data)
|
|
st.session_state["case_summary"] = self.chains_obj.case_summary_chain.run(
|
|
{
|
|
"case_name": case_name,
|
|
"case_info": dedent(all_text),
|
|
}
|
|
)
|
|
st.session_state["is_data_processed"] = True
|
|
st.session_state["case_name"] = case_name
|
|
st.session_state["case_details"] = dedent(all_text)
|
|
st.session_state["current_case_id"] = case_id
|
|
st.success("Data processed successfully")
|
|
|
|
def genrate_legal_bot_result(self):
|
|
if len(st.session_state["bot_input"]) > 0:
|
|
case_id = st.session_state.get("current_case_id")
|
|
if case_id:
|
|
vector_store = st.session_state[case_id]["vectorstore"]
|
|
query = st.session_state["bot_input"]
|
|
results = vector_store.similarity_search(query, 3)
|
|
|
|
source_knowledge = "\n".join([x.page_content for x in results])
|
|
inputs = {
|
|
"case_summary":st.session_state["case_summary"],
|
|
"context": source_knowledge,
|
|
"input": query,
|
|
}
|
|
output = self.chains_obj.legal_case_bot_chain.run(inputs)
|
|
st.session_state.past_bot_results.append(st.session_state["bot_input"])
|
|
st.session_state.generated_bot_results.append(output)
|
|
st.session_state["bot_input"] = ""
|
|
else:
|
|
st.warning(f"No vector store found for the current case ID")
|
|
|