nklomp commited on
Commit
9eaa8cc
·
verified ·
1 Parent(s): 4d2daf5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +217 -0
app.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from dotenv import load_dotenv
3
+ from PyPDF2 import PdfReader
4
+ from langchain.text_splitter import CharacterTextSplitter
5
+ from langchain_community.embeddings import HuggingFaceInstructEmbeddings
6
+ from langchain_openai import OpenAIEmbeddings,ChatOpenAI
7
+ from langchain_community.vectorstores import FAISS
8
+ from langchain.memory import ConversationBufferMemory
9
+ from langchain.chains import ConversationalRetrievalChain
10
+ from htmlTemplates import css, bot_template, user_template
11
+ from langchain_community.llms import HuggingFaceHub
12
+
13
+ #Llama2
14
+ import torch
15
+ import transformers
16
+ from langchain_community.llms import HuggingFacePipeline
17
+ from transformers import AutoTokenizer
18
+ from torch import cuda, bfloat16
19
+ import langchain
20
+ langchain.verbose = False
21
+
22
+
23
+
24
+ def get_pdf_text(pdf_docs):
25
+ text = ""
26
+ for pdf in pdf_docs:
27
+ pdf_reader = PdfReader(pdf)
28
+ for page in pdf_reader.pages:
29
+ text += page.extract_text()
30
+ return text
31
+
32
+ def get_text_chunks(text):
33
+ text_splitter = CharacterTextSplitter(
34
+ separator="\n",
35
+ chunk_size=1000, # the character length of the chunck
36
+ chunk_overlap=200, # the character length of the overlap between chuncks
37
+ length_function=len # the length function - in this case, character length (aka the python len() fn.)
38
+ )
39
+ chunks = text_splitter.split_text(text)
40
+ return chunks
41
+
42
+ def get_vectorstore(text_chunks,selected_embedding):
43
+ if selected_embedding == 'OpenAI':
44
+ print('OpenAI embedding')
45
+ embeddings = OpenAIEmbeddings()
46
+ elif selected_embedding == 'Instructor-xl':
47
+ print('Instructor-xl embedding')
48
+ embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-xl")
49
+
50
+ vectorstore = FAISS.from_texts(texts=text_chunks, embedding=embeddings)
51
+ vectorstore.save_local("faiss_index")
52
+ return vectorstore
53
+
54
+ def load_vectorstore(text_chunks,selected_embedding):
55
+ if selected_embedding == 'OpenAI':
56
+ print('OpenAI embedding')
57
+ embeddings = OpenAIEmbeddings()
58
+ elif selected_embedding == 'Instructor-xl':
59
+ print('Instructor-xl embedding')
60
+
61
+ vectorstore = FAISS.load_local("faiss_index", embeddings)
62
+ return vectorstore
63
+
64
+ def get_conversation_chain(vectorstore,selected_llm):
65
+ if selected_llm == 'OpenAI':
66
+ print('OpenAi LLM')
67
+ llm = ChatOpenAI()
68
+
69
+
70
+ elif selected_llm == 'Llama2':
71
+ print('Llama2 LLM')
72
+ model_id = 'meta-llama/Llama-2-7b-chat-hf'
73
+ hf_auth = hf_auth
74
+
75
+ model_config = transformers.AutoConfig.from_pretrained(
76
+ model_id,
77
+ token=hf_auth
78
+ )
79
+
80
+ device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
81
+
82
+ if('cuda' in device):
83
+ # set quantization configuration to load large model with less GPU memory
84
+ # this requires the `bitsandbytes` library
85
+ bnb_config = transformers.BitsAndBytesConfig(
86
+ load_in_4bit=True,
87
+ bnb_4bit_quant_type='nf4',
88
+ bnb_4bit_use_double_quant=True,
89
+ bnb_4bit_compute_dtype=bfloat16
90
+ )
91
+
92
+ model = transformers.AutoModelForCausalLM.from_pretrained(
93
+ model_id,
94
+ trust_remote_code=True,
95
+ config=model_config,
96
+ quantization_config=bnb_config,
97
+ device_map='auto',
98
+ token=hf_auth
99
+ )
100
+ else:
101
+ model = transformers.AutoModelForCausalLM.from_pretrained(
102
+ model_id,
103
+ trust_remote_code=True,
104
+ config=model_config,
105
+ device_map='auto',
106
+ token=hf_auth
107
+ )
108
+
109
+ # enable evaluation mode to allow model inference
110
+ model.eval()
111
+ print(f"Model loaded on {device}")
112
+
113
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
114
+ model_id,
115
+ token=hf_auth
116
+ )
117
+
118
+ pipeline = transformers.pipeline(
119
+ torch_dtype=torch.float32,
120
+ model=model,
121
+ tokenizer=tokenizer,
122
+ return_full_text=True, # langchain expects the full text
123
+ task='text-generation',
124
+ temperature=0.1, # 'randomness' of outputs, 0.0 is the min and 1.0 the max
125
+ max_new_tokens=512, # max number of tokens to generate in the output
126
+ repetition_penalty=1.1 # without this output begins repeating
127
+ )
128
+
129
+ llm = HuggingFacePipeline(pipeline=pipeline)
130
+
131
+ # Generic LLM
132
+ memory = ConversationBufferMemory(
133
+ memory_key='chat_history', return_messages=True)
134
+
135
+ conversation_chain = ConversationalRetrievalChain.from_llm(
136
+ llm=llm,
137
+ retriever=vectorstore.as_retriever(),
138
+ memory=memory,
139
+ return_source_documents=False
140
+ )
141
+ #print(conversation_chain)
142
+
143
+ return conversation_chain
144
+
145
+
146
+ def handle_userinput(user_question):
147
+
148
+ print('Question: ' + user_question)
149
+ response = st.session_state.conversation({'question': user_question})
150
+ st.session_state.chat_history = response['chat_history']
151
+
152
+
153
+ for i, message in enumerate(st.session_state.chat_history):
154
+ if i % 2 == 0:
155
+ st.write(user_template.replace(
156
+ "{{MSG}}", message.content), unsafe_allow_html=True)
157
+ else:
158
+ st.write(bot_template.replace(
159
+ "{{MSG}}", message.content), unsafe_allow_html=True)
160
+
161
+
162
+ def main():
163
+ load_dotenv()
164
+ st.set_page_config(page_title="VerAi",
165
+ page_icon=":books:")
166
+ st.write(css, unsafe_allow_html=True)
167
+
168
+ if "conversation" not in st.session_state:
169
+ st.session_state.conversation = None
170
+ if "chat_history" not in st.session_state:
171
+ st.session_state.chat_history = None
172
+
173
+
174
+
175
+ with st.sidebar:
176
+ st.subheader("Your documents")
177
+ pdf_docs = st.file_uploader(
178
+ "Upload your new PDFs here and click on 'Process' or load the last upload by clicking on 'Load'", accept_multiple_files=True)
179
+
180
+ selected_embedding = st.radio("Which Embedding?",["OpenAI", "Instructor-xl"])
181
+ selected_llm = st.radio("Which LLM?",["OpenAI", "Llama2"])
182
+
183
+ if st.button("Process"):
184
+ with st.spinner("Processing"):
185
+ # get pdf text
186
+ raw_text = get_pdf_text(pdf_docs)
187
+
188
+ # get the text chunks
189
+ text_chunks = get_text_chunks(raw_text)
190
+
191
+ # create vector store
192
+ vectorstore = get_vectorstore(text_chunks,selected_embedding)
193
+
194
+ # create conversation chain
195
+ st.session_state.conversation = get_conversation_chain(
196
+ vectorstore,selected_llm)
197
+
198
+ if st.button("Load"):
199
+ with st.spinner("Processing"):
200
+
201
+ # load vector store
202
+ vectorstore = load_vectorstore(selected_embedding,selected_embedding)
203
+
204
+ # create conversation chain
205
+ st.session_state.conversation = get_conversation_chain(
206
+ vectorstore,selected_llm)
207
+
208
+ if st.session_state.conversation:
209
+ st.header("VerAi :books:")
210
+ user_question = st.text_input("Stel een vraag hieronder")
211
+ # Vertel me iets over Wettelijke uren
212
+ # wat zijn Overige verloftypes bij kpn
213
+ if st.session_state.conversation and user_question:
214
+ handle_userinput(user_question)
215
+
216
+ if __name__ == '__main__':
217
+ main()