File size: 4,760 Bytes
48c38d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ac353e
48c38d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from dotenv import load_dotenv
import os
import streamlit as st
from PyPDF2 import PdfFileReader
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains.question_answering import load_qa_chain
from langchain.llms import OpenAI as LLMSOpenAI
from langchain.llms import AzureOpenAI
from langchain.callbacks import get_openai_callback
from langchain.chat_models import ChatOpenAI
from docx import Document
from openpyxl import load_workbook
import pdfplumber


def extract_text_from_pdf(pdf_file):
    with pdfplumber.open(pdf_file) as pdf:
        text = ""
        for page in pdf.pages:
            text += page.extract_text()
    return text


def extract_text_from_docx(docx_file):
    doc = Document(docx_file)
    paragraphs = [paragraph.text for paragraph in doc.paragraphs]
    return "\n".join(paragraphs)


def extract_text_from_excel(excel_file):
    workbook = load_workbook(excel_file)
    text = ""
    for sheet in workbook.sheetnames:
        worksheet = workbook[sheet]
        for row in worksheet.iter_rows():
            for cell in row:
                if cell.value:
                    text += str(cell.value) + "\n"
    return text


def split_text_into_chunks(text):
    text_splitter = CharacterTextSplitter(
        separator="\n",
        chunk_size=1000,
        chunk_overlap=200,
        length_function=len
    )
    return text_splitter.split_text(text)


def create_knowledge_base(chunks, api_key=None):
    embeddings = OpenAIEmbeddings(openai_api_key=api_key)
    knowledge_base = FAISS.from_texts(chunks, embeddings)
    return knowledge_base


def answer_question(question, knowledge_base, model):
    docs = knowledge_base.similarity_search(question)
    llm = model(model_name="gpt-3.5-turbo", openai_api_key=st.session_state.api_key)
    chain = load_qa_chain(llm, chain_type="stuff")
    with get_openai_callback() as cb:
        response = chain.run(input_documents=docs, question=question)
    return response


def save_api_key(api_key):
    st.session_state.api_key = api_key


def main():
    load_dotenv()
    st.set_page_config(page_title="Ask Your PDF", layout="wide")

    # Sidebar
    st.sidebar.title("Settings")

    # API Key input
    st.sidebar.subheader("API Key")
    api_key = st.sidebar.text_input("Insert your API Key", type="password")
    st.sidebar.button("Save API Key", on_click=save_api_key, args=(api_key,))

    model_type = st.sidebar.selectbox("Select Language Model", ["OpenAI", "AzureOpenAI"])
    if model_type == "AzureOpenAI":
        model = AzureOpenAI
    else:
        model = ChatOpenAI

    chunk_size = st.sidebar.slider("Chunk Size", min_value=500, max_value=2000, value=1000, step=100)
    chunk_overlap = st.sidebar.slider("Chunk Overlap", min_value=100, max_value=500, value=200, step=50)
    show_content = st.sidebar.checkbox("Show Document Content")
    show_answers = st.sidebar.checkbox("Show Previous Answers")

    # Main content
    st.title("Ask Your Document 💭")
    file_format = st.selectbox("Select File Format", ["PDF", "docx", "xlsx"])
    document = st.file_uploader("Upload Document", type=[file_format.lower()])

    if not hasattr(st.session_state, "api_key") or not st.session_state.api_key:
        st.warning("You need to insert your API Key first.")
    elif document is not None:
        if file_format == "PDF":
            text = extract_text_from_pdf(document)
        elif file_format == "docx":
            text = extract_text_from_docx(document)
        elif file_format == "xlsx":
            text = extract_text_from_excel(document)
        else:
            text = ""

        if show_content:
            st.subheader("Document Text:")
            st.text_area("Content", value=text, height=300)

        chunks = split_text_into_chunks(text)
        knowledge_base = create_knowledge_base(chunks, api_key=st.session_state.api_key)

        user_question = st.text_input("Ask a question based on the document content:")

        if user_question:
            response = answer_question(user_question, knowledge_base, model)
            st.subheader("Answer:")
            st.write(response)

            # Store and display previous answers
            if "answers" not in st.session_state:
                st.session_state.answers = []
            st.session_state.answers.append((user_question, response))

        if show_answers:
            st.subheader("Previous Answers:")
            for question, answer in st.session_state.answers:
                st.write(f"Question: {question}")
                st.write(f"Answer: {answer}")
                st.write("------")


if __name__ == '__main__':
    main()