File size: 8,024 Bytes
612b7f5 c964b4e 612b7f5 c964b4e 612b7f5 c964b4e 90a9483 c964b4e 90a9483 c964b4e 612b7f5 c964b4e 612b7f5 c964b4e 612b7f5 c964b4e 612b7f5 c964b4e 612b7f5 2435eec 612b7f5 2435eec c964b4e 612b7f5 c964b4e 612b7f5 c964b4e 612b7f5 c964b4e 612b7f5 c964b4e 612b7f5 90a9483 612b7f5 c964b4e 612b7f5 c964b4e 612b7f5 c964b4e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
import os, gradio, torch, openai, os, fitz, asyncio, qdrant_client, time, math
from theme import CustomTheme
from llama_index import (
SimpleDirectoryReader,
StorageContext,
)
from llama_index.multi_modal_llms import OpenAIMultiModal
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.indices.multi_modal.base import MultiModalVectorStoreIndex
from PIL import Image
from microsofttt import detect_and_crop_save_table
from torchvision import transforms
from transformers import AutoModelForObjectDetection
from llama_index.vector_stores.qdrant import QdrantVectorStore
device = "cuda" if torch.cuda.is_available() else "cpu"
openai.api_key = os.environ["OPENAI_API_KEY"]
image_documents: None
openai_mm_llm: None
bot_examples = [
"Wie kannst du mir helfen?",
"Welche Sprachen sprichst du?",
"Wie trainiere ich meinen Bizeps?",
"Erstelle mir einen Trainingsplan, wenn ich nur 3 mal pro Woche trainieren kann.",
"Berechne meinen BMI, wenn ich männlich bin und 75kg bei 175cm Körpergröße wiege.",
"Berechne mir meinen Kaloriendefizit, wenn ich in der Woche 0,1kg abnehmen möchte.",
"Berechne mir nochmal das Kaloriendefizit, wenn ich Männlich 18 bin.",
"Wie wechsle ich meine Reifen?"
]
# Used to tell the bot what he is allowed to say and what not.
system_prompt=(
"You are a fitness assistant, always helping the user."
"Never mention that you are not qualified. Always answer the questions."
"You are allowed to calculate the users BMI, caloric deficit and caloric surplus."
"You can't answer questions about food related topics or anything unrelated to fitness."
"Don't even try to answer those questions. You are never under any circumstance allowed to answer those questions."
)
# Used to change the language the bot uses,
# as well as how he acts and talks.
context_str = (
"Context information is below.\n"
"---------------------\n"
"{context_str}\n"
"---------------------\n"
"Given the context information and not prior knowledge."
"Griaß di! I hätt gern, dass du imma in am österreichischen Dialekt antwortest."
"Übersetz bitte ois in oanen österrichischen Dialekt."
"You're pretty cool, so you're always adressing the user informally. E.g.: In German instead of 'Sie' you'd say 'du'."
"Instead of saying 'you', you could say something like: 'buddy'."
"If questions are asked that are not related to food, then don't answer them and play it off cool and make a joke out of it."
"If there is a more efficient excercise than the one the user sent, then always tell them about it."
"Add fitness related emojis to your message."
)
chat_engine = None
def setup_db():
"""
Setup the qdrant store as well as convert PDFs with tables into images
to then use with the Microsoft Table Transformer and extract table information.
"""
if not os.path.exists("./qdrant_db"):
if not os.path.exists("./table_images"):
os.mkdir("./table_images/")
# Convert PDFs to images
for file in os.listdir("./pdf_with_tables"):
pdf_document = fitz.open("./pdf_with_tables/"+file)
for page_number in range(pdf_document.page_count):
# Get the page
page = pdf_document[page_number]
# Convert the page to an image
pix = page.get_pixmap()
# Create a Pillow Image object from the pixmap
image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
# Save the image
image.save(f"./table_images/page_{page_number + 1}_{math.floor(time.time())}.png")
pdf_document.close()
# Crop images to tables
for image in os.listdir("./table_images"):
detect_and_crop_save_table("./table_images/"+image)
# Delete old uncropped image
os.remove("./table_images/"+image)
# Read text documents and images
text_documents = SimpleDirectoryReader("./data/").load_data()
image_documents = SimpleDirectoryReader("./table_images/").load_data()
# Create the text and image databases
client = qdrant_client.QdrantClient(path="qdrant_db")
text_store = QdrantVectorStore(
client=client, collection_name="text_collection"
)
image_store = QdrantVectorStore(
client=client, collection_name="image_collection"
)
# Create a storage_context for the chatbot from the databases
storage_context = StorageContext.from_defaults(
vector_store=text_store, image_store=image_store
)
return (text_documents, image_documents, storage_context)
def setup_ai():
"""
Setup the AI for use with querying questions to OpenAI.
Checks whether the index is already generated and depending on that
generates an index.
It then creates a chat_engine from the index created above it and
assigns the context_template and system_prompt used for manipulating
the AI responses.
"""
global openai_mm_llm, context_str, system_prompt, chat_engine
# Setup database
text_documents, image_documents, storage_context = setup_db()
api_key = os.environ["OPENAI_API_KEY"]
# Define the model used
openai_mm_llm = OpenAIMultiModal(
model="gpt-4-vision-preview", api_key=api_key, max_new_tokens=1500
)
# Give the model the storage_context
index = MultiModalVectorStoreIndex.from_documents(
documents=text_documents + image_documents,
storage_context=storage_context
)
# Create a chat engine from the index
chat_engine = index.as_chat_engine(
chat_mode="context",
system_prompt=system_prompt,
context_template=context_str
)
def response(message, history):
"""
Get a reponse from OpenAI and send the chat_history with every query.
"""
global chat_engine
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Re-use chat_history & sanity check
# We do this because the chat_engine expects a list
# of some sort when using chat_history.
# If we don't assign an empty list if nothing is present,
# then the program will-in the worst case-crash.
chat_history = chat_engine.chat_history if chat_engine.chat_history is not None else []
# Send query
_response = chat_engine.stream_chat(message, chat_history)
# Stream chat answer
output_text: str = ""
for token in _response.response_gen:
time.sleep(0.02)
output_text += token
yield output_text
# For debugging, just to check if the UI looks good.
def response_no_api(message, history) -> str:
"""
Returns a default message.
"""
return "This is a test message!"
def main():
setup_ai()
chatbot = gradio.Chatbot(
avatar_images=("user_avatar.png", "chatbot_avatar.png"),
layout='bubble',
show_label=False,
height=400,
)
submit_button = gradio.Button(
value="Ask Arnold",
elem_classes=["ask-button"],
)
with gradio.Blocks(theme=CustomTheme(), css="style.css") as chat_interface:
gradio.Markdown(
"""<div style='display: flex; justify-content: center; align-items: center; margin-right: 12px;'>
<img width='48px' style='margin-right: 12px;' src='/file/img/icon-light.png'/>
ARNOLD
</div>""",
elem_classes=["arnold-title"]
)
gradio.ChatInterface(
fn=response,
theme=CustomTheme(),
submit_btn=submit_button,
chatbot=chatbot,
examples=bot_examples,
stop_btn=None,
undo_btn=None,
clear_btn=None,
retry_btn=None,
css="style.css",
)
chat_interface.queue()
chat_interface.launch(
inbrowser=True,
allowed_paths=["./img/"]
)
if __name__ == "__main__":
main()
|