ababio commited on
Commit
e24982e
1 Parent(s): 560dc37

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -58
app.py CHANGED
@@ -1,73 +1,71 @@
1
- import os
2
  import streamlit as st
3
- from dotenv import load_dotenv
4
- from pinecone.grpc import PineconeGRPC
5
- from pinecone import ServerlessSpec
 
 
6
  from llama_index.embeddings import OpenAIEmbedding
7
  from llama_index.ingestion import IngestionPipeline
8
- from llama_index.query_engine import RetrieverQueryEngine
 
9
  from llama_index.vector_stores import PineconeVectorStore
10
- from llama_index.node_parser import SemanticSplitterNodeParser
11
  from llama_index.retrievers import VectorIndexRetriever
12
- from htmlTemplates import css, bot_template, user_template
13
-
14
- # Load environment variables
15
- load_dotenv()
16
- pinecone_api_key = os.getenv("PINECONE_API_KEY")
17
- openai_api_key = os.getenv("OPENAI_API_KEY")
18
- index_name = os.getenv("INDEX_NAME")
19
 
20
- # Initialize OpenAI embedding model
21
- embed_model = OpenAIEmbedding(api_key=openai_api_key)
 
 
22
 
23
- # Initialize connection to Pinecone
24
- pinecone_client = PineconeGRPC(api_key=pinecone_api_key)
25
- pinecone_index = pinecone_client.Index(index_name)
26
- vector_store = PineconeVectorStore(pinecone_index=pinecone_index)
 
 
 
 
 
 
 
27
 
28
- # Define the initial pipeline
29
- pipeline = IngestionPipeline(
30
- transformations=[
31
- SemanticSplitterNodeParser(
32
- buffer_size=1,
33
- breakpoint_percentile_threshold=95,
34
- embed_model=embed_model,
35
- ),
36
- embed_model,
37
- ],
38
- )
39
 
40
- # Initialize LlamaIndex components
41
- vector_index = VectorStoreIndex.from_vector_store(vector_store=vector_store)
42
- retriever = VectorIndexRetriever(index=vector_index, similarity_top_k=5)
43
- query_engine = RetrieverQueryEngine(retriever=retriever)
44
 
45
- # Function to handle user input and return the query response
46
- def handle_userinput(user_question):
47
- response = st.session_state.conversation({'question': user_question})
48
- st.session_state.chat_history = response['chat_history']
49
 
50
- for i, message in enumerate(st.session_state.chat_history):
51
- if i % 2 == 0:
52
- st.write(user_template.replace("{{MSG}}", message.content), unsafe_allow_html=True)
53
- else:
54
- st.write(bot_template.replace("{{MSG}}", message.content), unsafe_allow_html=True)
 
 
 
 
 
55
 
56
- # Main function to run the Streamlit app
57
- def main():
58
- load_dotenv()
59
- st.set_page_config(page_title="Chat with Annual Reports", page_icon=":books:")
60
- st.write(css, unsafe_allow_html=True)
61
 
62
- if "conversation" not in st.session_state:
63
- st.session_state.conversation = None
64
- if "chat_history" not in st.session_state:
65
- st.session_state.chat_history = None
66
 
67
- st.header("Chat with Annual Report Documents")
68
- user_question = st.text_input("Ask a question about your documents:")
69
- if user_question:
70
- handle_userinput(user_question)
71
 
72
- if __name__ == "__main__":
73
- main()
 
1
+ # Streamlit application
2
  import streamlit as st
3
+ import os
4
+ from getpass import getpass
5
+ from transformers import pipeline
6
+
7
+ from llama_index.node_parser import SemanticSplitterNodeParser
8
  from llama_index.embeddings import OpenAIEmbedding
9
  from llama_index.ingestion import IngestionPipeline
10
+ from pinecone.grpc import PineconeGRPC
11
+ from pinecone import ServerlessSpec
12
  from llama_index.vector_stores import PineconeVectorStore
13
+ from llama_index import VectorStoreIndex
14
  from llama_index.retrievers import VectorIndexRetriever
15
+ from llama_index.query_engine import RetrieverQueryEngine
 
 
 
 
 
 
16
 
17
+ # Function to initialize the Pinecone and LlamaIndex setup
18
+ def initialize_pipeline():
19
+ pinecone_api_key = os.getenv("PINECONE_API_KEY")
20
+ openai_api_key = os.getenv("OPENAI_API_KEY")
21
 
22
+ embed_model = OpenAIEmbedding(api_key=openai_api_key)
23
+ pipeline = IngestionPipeline(
24
+ transformations=[
25
+ SemanticSplitterNodeParser(
26
+ buffer_size=1,
27
+ breakpoint_percentile_threshold=95,
28
+ embed_model=embed_model,
29
+ ),
30
+ embed_model,
31
+ ],
32
+ )
33
 
34
+ pc = PineconeGRPC(api_key=pinecone_api_key)
35
+ index_name = "anualreport"
36
+ pinecone_index = pc.Index(index_name)
37
+ vector_store = PineconeVectorStore(pinecone_index=pinecone_index)
38
+ pinecone_index.describe_index_stats()
 
 
 
 
 
 
39
 
40
+ if not os.getenv('OPENAI_API_KEY'):
41
+ os.environ['OPENAI_API_KEY'] = openai_api_key
 
 
42
 
43
+ vector_index = VectorStoreIndex.from_vector_store(vector_store=vector_store)
44
+ retriever = VectorIndexRetriever(index=vector_index, similarity_top_k=5)
45
+ query_engine = RetrieverQueryEngine(retriever=retriever)
 
46
 
47
+ return query_engine
48
+
49
+ # Streamlit UI
50
+ st.title("Chat with Annual Reports")
51
+
52
+ # Initialize the query engine
53
+ query_engine = initialize_pipeline()
54
+
55
+ # Conversation model using Hugging Face transformers
56
+ conversation_pipeline = pipeline("conversational", model="microsoft/DialoGPT-medium")
57
 
58
+ # User input
59
+ user_input = st.text_input("You: ", "")
 
 
 
60
 
61
+ if user_input:
62
+ # Query the vector DB
63
+ llm_query = query_engine.query(user_input)
64
+ response = llm_query.response
65
 
66
+ # Generate response using Hugging Face conversation model
67
+ conversation = conversation_pipeline([user_input, response])
68
+ bot_response = conversation[-1]["generated_text"]
 
69
 
70
+ # Display response
71
+ st.text_area("Bot: ", bot_response, height=200)