Shashank1406's picture
Upload 6 files
925601e verified
from dotenv import load_dotenv
import os
import uuid
from PyPDF2 import PdfReader
from docx import Document
from docx.text.paragraph import Paragraph
from docx.table import Table
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import OpenAIEmbeddings
import streamlit as st
from textwrap import dedent
from Prompts_and_Chains import LLMChains
def extract_text_from_file(file):
text = file.read().decode("utf-8")
return text
def process_paragraph(paragraph):
# Process the content of the paragraph as needed
return paragraph.text
def process_table(table):
# Process the content of the table as needed
text = ""
for row in table.rows:
for cell in row.cells:
text += cell.text
return text
def read_docx(file_path):
doc = Document(file_path)
data = []
for element in doc.iter_inner_content():
if isinstance(element, Paragraph):
data.append(process_paragraph(element))
if isinstance(element, Table):
data.append(process_table(element))
return "\n".join(data)
def get_pdf_text(pdf):
"""This function extracts the text from the PDF file"""
text = []
pdf_reader = PdfReader(pdf)
for page in pdf_reader.pages:
text.append(page.extract_text())
return "\n".join(text)
class RFPProcessor:
def __init__(self):
load_dotenv()
self.openai_api_key = os.getenv("OPENAI_API_KEY")
self.chains_obj = LLMChains()
def process_case_data(self, case_name, files):
if case_name and files:
# Generate a unique identifier for the case data set
case_id = str(uuid.uuid4())
extracted_data = []
all_texts = []
for file in files:
file_text = []
if file.name.endswith(".docx"):
file_text = read_docx(file)
elif file.name.endswith(".pdf"):
file_text = get_pdf_text(file)
else:
file_text = extract_text_from_file(file)
text_splitter = CharacterTextSplitter(
separator="\n", chunk_size=1000, chunk_overlap=150, length_function=len
)
texts = text_splitter.split_text(" ".join(file_text))
all_texts.extend(texts)
extracted_data.append(" ".join(file_text))
project_dir = os.path.dirname(os.path.abspath(__file__))
vectorstore = Chroma(
persist_directory=os.path.join(
project_dir, "vector_stores", case_name),
embedding_function=OpenAIEmbeddings(
openai_api_key=self.openai_api_key),
)
vectorstore.add_texts(all_texts)
st.session_state[case_id] = {
"vectorstore": vectorstore,
"extracted_data": extracted_data,
}
all_text = " ".join(extracted_data)
st.session_state["case_summary"] = self.chains_obj.case_summary_chain.run(
{
"case_name": case_name,
"case_info": dedent(all_text),
}
)
st.session_state["is_data_processed"] = True
st.session_state["case_name"] = case_name
st.session_state["case_details"] = dedent(all_text)
st.session_state["current_case_id"] = case_id
st.success("Data processed successfully")
def genrate_legal_bot_result(self):
if len(st.session_state["bot_input"]) > 0:
case_id = st.session_state.get("current_case_id")
if case_id:
vector_store = st.session_state[case_id]["vectorstore"]
query = st.session_state["bot_input"]
results = vector_store.similarity_search(query, 3)
# get the text from the results
source_knowledge = "\n".join([x.page_content for x in results])
inputs = {
"case_summary":st.session_state["case_summary"],
"context": source_knowledge,
"input": query,
}
output = self.chains_obj.legal_case_bot_chain.run(inputs)
st.session_state.past_bot_results.append(st.session_state["bot_input"])
st.session_state.generated_bot_results.append(output)
st.session_state["bot_input"] = ""
else:
st.warning(f"No vector store found for the current case ID")