needs / app.py
ryanrwatkins's picture
Update app.py
0107316
raw
history blame
11.7 kB
import gradio as gr
import openai
import requests
import csv
import os
import langchain
import chromadb
import glob
#import pickle
from PyPDF2 import PdfReader
from PyPDF2 import PdfWriter
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.text_splitter import CharacterTextSplitter
#from langchain.vectorstores import ElasticVectorSearch, Pinecone, Weaviate, FAISS
from langchain.chains.question_answering import load_qa_chain
from langchain.llms import OpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
#from langchain.vectorstores import Chroma
#from langchain.text_splitter import TokenTextSplitter
#from langchain.llms import OpenAI
from langchain import OpenAI
from langchain.chat_models import ChatOpenAI
#from langchain.chains import ChatVectorDBChain
#from langchain.chains import RetrievalQA
from langchain.document_loaders import PyPDFLoader
from langchain.chains.question_answering import load_qa_chain
# Use Chroma in Colab to create vector embeddings, I then saved them to HuggingFace so now I have to set it use them here.
#from chromadb.config import Settings
#client = chromadb.Client(Settings(
## chroma_db_impl="duckdb+parquet",
# persist_directory="./embeddings" # Optional, defaults to .chromadb/ in the current directory
#))
openai.api_key = os.environ['openai_key']
os.environ["OPENAI_API_KEY"] = os.environ['openai_key']
def get_empty_state():
return {"total_tokens": 0, "messages": []}
#Initial prompt template, others added below from TXT file
prompt_templates = {"All Needs Experts": "Respond as if you are combiation of all needs assessment experts."}
actor_description = {}
def download_prompt_templates():
url = "https://huggingface.co/spaces/ryanrwatkins/needs/raw/main/gurus.txt"
try:
response = requests.get(url)
reader = csv.reader(response.text.splitlines())
next(reader) # skip the header row
for row in reader:
if len(row) >= 2:
act = row[0].strip('"')
prompt = row[1].strip('"')
description = row[2].strip('"')
prompt_templates[act] = prompt
actor_description[act] = description
except requests.exceptions.RequestException as e:
print(f"An error occurred while downloading prompt templates: {e}")
return
choices = list(prompt_templates.keys())
choices = choices[:1] + sorted(choices[1:])
return gr.update(value=choices[0], choices=choices)
def on_prompt_template_change(prompt_template):
if not isinstance(prompt_template, str): return
return prompt_templates[prompt_template]
def on_prompt_template_change_description(prompt_template):
if not isinstance(prompt_template, str): return
return actor_description[prompt_template]
def submit_message(prompt, prompt_template, temperature, max_tokens, context_length, state):
# load in all the files
#path = './files'
#pdf_files = glob.glob(os.path.join(path, "*.pdf"))
#pdf_files = glob.glob(os.path.join(path, "*.pdf"))
#for file in pdf_files:
# loader = PyPDFLoader(file)
# pages = loader.load_and_split()
# text_splitter = TokenTextSplitter(chunk_size=1000, chunk_overlap=0)
# split_pages = text_splitter.split_documents(pages)
#persist_directory = "./embeddings"
#embeddings = OpenAIEmbeddings()
#vectordb = Chroma.from_documents(split_pages, embeddings, persist_directory=persist_directory)
#vectordb.persist()
path = './files'
pdf_files = glob.glob(os.path.join(path, "*.pdf"))
merger = PdfWriter()
# add all file in the list to the merger object
for pdf in pdf_files:
merger.append(pdf)
merger.write("merged-pdf.pdf")
merger.close()
reader = PdfReader("merged-pdf.pdf")
raw_text = ''
for i, page in enumerate(reader.pages):
text = page.extract_text()
if text:
raw_text += text
text_splitter = CharacterTextSplitter(
separator = "\n",
chunk_size = 1000,
chunk_overlap = 200,
length_function = len,
)
texts = text_splitter.split_text(raw_text)
len(texts)
embeddings = OpenAIEmbeddings()
history = state['messages']
if not prompt:
return gr.update(value=''), [(history[i]['content'], history[i+1]['content']) for i in range(0, len(history)-1, 2)], f"Total tokens used: {state['total_tokens']}", state
prompt_template = prompt_templates[prompt_template]
system_prompt = []
if prompt_template:
system_prompt = [{ "role": "system", "content": prompt_template }]
prompt_msg = { "role": "user", "content": prompt }
try:
#completion = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=system_prompt + history[-context_length*2:] + [prompt_msg], temperature=temperature, max_tokens=max_tokens)
# completion = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=system_prompt + history[-context_length*2:] + [prompt_msg], temperature=temperature, max_tokens=max_tokens)
#completion_chain = load_qa_chain(ChatOpenAI(temperature=temperature, max_tokens=max_tokens, model_name="gpt-3.5-turbo"), chain_type="stuff" )
#completion = RetrievalQA(combine_documents_chain=completion_chain, retriever=vectordb.as_retriever())
#query = str(system_prompt + history[-context_length*2:] + [prompt_msg])
#completion = completion.run(query)
# from https://blog.devgenius.io/chat-with-document-s-using-openai-chatgpt-api-and-text-embedding-6a0ce3dc8bc8
#completion_chain = load_qa_chain(ChatOpenAI(temperature=temperature, max_tokens=max_tokens, model_name="gpt-3.5-turbo"), chain_type="stuff" )
#completion = RetrievalQA(combine_documents_chain=completion_chain, retriever=vectordb.as_retriever(), return_source_documents=False)
#completion = RetrievalQA.from_chain_type(llm=ChatOpenAI(temperature=temperature, max_tokens=max_tokens, model_name="gpt-3.5-turbo"), chain_type="stuff", retriever=vectordb.as_retriever(), return_source_documents=True)
#query = str(system_prompt + history[-context_length*2:] + [prompt_msg])
#completion = completion({"query": query})
#completion = completion.run(query)
# completion = completion({"question": query, "chat_history": history[-context_length*2:]})
#with open("foo.pkl", 'rb') as f:
# new_docsearch = pickle.load(f)
docsearch = FAISS.from_texts(texts, embeddings)
#query = str(system_prompt + history[-context_length*2:] + [prompt_msg])
query = str(system_prompt + history + [prompt_msg])
docs = docsearch.similarity_search(query)
#print(docs[0].page_content)
chain = load_qa_chain(ChatOpenAI(temperature=temperature, max_tokens=max_tokens, model_name="gpt-3.5-turbo"), chain_type="stuff")
completion = chain.run(input_documents=docs, question=query)
completion = { "content": completion }
# VectorDBQA.from_chain_type(llm=OpenAI(), chain_type="stuff", vectorstore=docsearch, return_source_documents=True)
# https://colab.research.google.com/drive/1dzdNDZyofRB0f2KIB4gHXmIza7ehMX30?usp=sharing#scrollTo=b-ejDn_JfpWW
get_empty_state()
state.append(completion.copy())
#history.append(prompt_msg.copy())
#history.append(completion.copy())
#history.append(completion.choices[0].message.to_dict())
#history.append(completion["result"].choices[0].message.to_dict())
state['total_tokens'] += completion['usage']['total_tokens']
except Exception as e:
history.append(prompt_msg.copy())
error = {
"role": "system",
"content": f"Error: {e}"
}
history.append(error.copy())
total_tokens_used_msg = f"Total tokens used: {state['total_tokens']}"
chat_messages = [(prompt_msg['content'], completion['content'])]
#chat_messages = [(history[i]['content'], history[i+1]['content']) for i in range(0, len(history)-1, 2)]
#chat_messages = [(history[-2]['content'], history[-1]['content'])]
return '', chat_messages, total_tokens_used_msg, state
def clear_conversation():
return gr.update(value=None, visible=True), None, "", get_empty_state()
css = """
#col-container {max-width: 80%; margin-left: auto; margin-right: auto;}
#chatbox {min-height: 400px;}
#header {text-align: center;}
#prompt_template_preview {padding: 1em; border-width: 1px; border-style: solid; border-color: #e0e0e0; border-radius: 4px;}
#total_tokens_str {text-align: right; font-size: 0.8em; color: #666;}
#label {font-size: 0.8em; padding: 0.5em; margin: 0;}
.message { font-size: 1.2em; }
"""
with gr.Blocks(css=css) as demo:
state = gr.State(get_empty_state())
with gr.Column(elem_id="col-container"):
#with open("embeddings.pkl", 'rb') as f:
# new_docsearch = pickle.load(f)
#query = str("performance")
#docs = new_docsearch.similarity_search(query)
gr.Markdown("""# Chat with Needs Assessment Experts (Past and Present)
## Ask questions of experts on needs assessments, get responses from *needs assessment* version of ChatGPT.
Ask questions of all of them, or pick your expert.""" ,
elem_id="header")
with gr.Row():
with gr.Column():
chatbot = gr.Chatbot(elem_id="chatbox")
input_message = gr.Textbox(show_label=False, placeholder="Enter your needs assessment question and press enter", visible=True).style(container=False)
btn_submit = gr.Button("Submit")
total_tokens_str = gr.Markdown(elem_id="total_tokens_str")
btn_clear_conversation = gr.Button("Start New Conversation")
with gr.Column():
prompt_template = gr.Dropdown(label="Choose a expert:", choices=list(prompt_templates.keys()))
prompt_template_preview = gr.Markdown(elem_id="prompt_template_preview")
with gr.Accordion("Advanced parameters", open=False):
temperature = gr.Slider(minimum=0, maximum=2.0, value=0.7, step=0.1, label="Flexibility", info="Higher = more creative/chaotic, Lower = just the guru")
max_tokens = gr.Slider(minimum=100, maximum=400, value=200, step=1, label="Max tokens per response")
context_length = gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Context length", info="Number of previous questions you have asked. Be careful with high values, it can blow up the token budget quickly.")
btn_submit.click(submit_message, [ input_message, prompt_template, temperature, max_tokens, context_length, state], [input_message, chatbot, total_tokens_str, state])
input_message.submit(submit_message, [ input_message, prompt_template, temperature, max_tokens, context_length, state], [input_message, chatbot, total_tokens_str, state])
btn_clear_conversation.click(clear_conversation, [], [input_message, chatbot, total_tokens_str, state])
#prompt_template.change(on_prompt_template_change, inputs=[prompt_template], outputs=[prompt_template_preview])
prompt_template.change(on_prompt_template_change, inputs=[prompt_template], outputs=[prompt_template_preview])
demo.load(download_prompt_templates, inputs=None, outputs=[prompt_template], queur=False)
demo.queue(concurrency_count=10)
demo.launch(height='800px')