Spaces:
Build error
Build error
| import streamlit as st | |
| import load_model | |
| import utils as ut | |
| import elements as el | |
| import os | |
| import torch | |
| persist_directory = load_model.persist_directory | |
| st.title('myRetrievalGPT') | |
| st.header('An GPT Retrieval example brought to you by Heiko Wagner') | |
| st.markdown('*Let $\phi$ be a word embedding mapping $W$ → $\mathbb{R}^n$ where $W$ is the word space and $\mathbb{R}^n$ is an $n$-dimensional vector space then: $\phi(king)-\phi(man)+\phi(woman)=\phi(queen)$.* ') | |
| agree = st.checkbox('Load new Documents') | |
| if agree: | |
| el.load_files() | |
| else: | |
| import torch | |
| torch.cuda.empty_cache() | |
| model_type = st.selectbox( | |
| 'Select the Documents to be used to answer your question', | |
| ('OpenAI', 'decapoda-research/llama-7b-hf (gpu+cpu)', 'llama-7b 4bit (cpu only)',) ) | |
| if model_type=='OpenAI': | |
| if 'openai_key' not in st.session_state: | |
| openai_key= st.text_area('OpenAI Key:', '') | |
| if len(openai_key)>-1: | |
| st.session_state['openai_key'] = openai_key | |
| os.environ["OPENAI_API_KEY"] = openai_key | |
| else: | |
| os.environ["OPENAI_API_KEY"] = st.session_state.openai_key | |
| llm= load_model.load_openai_model() | |
| elif model_type=='decapoda-research/llama-7b-hf (gpu+cpu)': | |
| # Add more models here | |
| llm = load_model.load_gpu_model("decapoda-research/llama-7b-hf") | |
| else: | |
| llm = load_model.load_cpu_model() | |
| collections = ut.retrieve_collections() | |
| option = st.selectbox( | |
| 'Select the Documents to be used to answer your question', | |
| collections ) | |
| st.write('You selected:', option['name']) | |
| chain = load_model.create_chain(llm, collection=option['name'], model_name=option['model_name'], metadata= option['metadata']) | |
| query = st.text_area('Ask a question:', 'Hallo how are you today?') | |
| result = chain({"query": query + " Add a Score of the propability that your answer is correct to your answer"}) | |
| ut.format_result_set(result) | |
| #from langchain.chains import ConversationChain | |
| #from langchain.memory import ConversationBufferMemory | |
| #conversation = ConversationChain( | |
| # llm=chat, | |
| # memory=ConversationBufferMemory() | |
| #) | |