import os import json import base64 import uuid import logging from typing import Generator from sentence_transformers import SentenceTransformer from openai import OpenAI import gradio as gr from dotenv import load_dotenv from utils.utils import ( get_keys_chunks, get_docs, get_top_chunk_keys, get_messages, load_knowledge_base, ) from utils.chatLogger import ChatUploader def initialize(): """ Initializes embedding model, encodes document chunks, loads environment variables, and initializes clients. """ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) logger.info("Initializing application...") logger.info("Loading sentence embedding model...") embedding_model_path = "ibm-granite/granite-embedding-125m-english" embedding_model = SentenceTransformer(embedding_model_path) logger.info("Loading and encoding document chunks...") knowledge_base = load_knowledge_base() keys, chunks = zip(*get_keys_chunks(knowledge_base)) chunks_encoded = embedding_model.encode(chunks) keys_chunksEncoded = list(zip(keys, chunks_encoded)) logger.info("Loading env variables...") if not os.getenv("SPACE_ID"): load_dotenv() logger.info("Initializing OpenAI client...") openAI_client = OpenAI( base_url="https://api.inference.net/v1", api_key=os.getenv("INFERENCE_API_KEY"), ) logger.info("Loading Drive service account details...") drive_creds_encoded = os.getenv( "GOOGLE_DRIVE_SERVICE_ACCOUNT_CREDENTIALS_BASE64" ).strip() service_account_json = json.loads(base64.b64decode(drive_creds_encoded).decode()) logger.info("Initializing ChatUploader instance...") chat_uploader = ChatUploader(service_account_json) logger.info("Ready for user query...") return ( embedding_model, keys_chunksEncoded, knowledge_base, openAI_client, logger, chat_uploader, ) ( embedding_model, keys_chunksEncoded, knowledge_base, openAI_client, logger, chat_uploader, ) = initialize() def rag_chatbot( user_message: str, chat_history: list, browser_id: str, ) -> Generator[list, None, None]: """ Retrieves relevant documents to user query and streams LLM response catching errors along the way. """ try: logger.info("Trying to encode user query and retrieve related docs...") user_query_encoded = embedding_model.encode(user_message) top_chunk_keys = get_top_chunk_keys( user_query_encoded, keys_chunksEncoded, top_n=5 ) docs = get_docs(top_chunk_keys, knowledge_base) except Exception as e: logger.exception(f"Error during document retrieval: {str(e)}") yield [ { "role": "assistant", "content": f"⚠️ An error occurred during document retrieval. Please try again later.", } ] return try: logger.info("Trying to call openAI chat API...") messages = get_messages(docs, user_message, chat_history) chatCompletion_response = openAI_client.chat.completions.create( model="deepseek/r1-distill-llama-70b/fp-8", messages=messages, stream=True, ) except Exception as e: logger.exception(f"Error during call to OpenAI Chat API: {str(e)}") yield [ { "role": "assistant", "content": f"⚠️ An error occurred during client API call. Please try again later.", } ] return try: logger.info("Trying to stream LLM response...") llm_thinking = False buffer = "" chat_history.append({"role": "user", "content": user_message}) chat_history.append({"role": "assistant", "content": ""}) for chunk in chatCompletion_response: chunk_content = chunk.choices[0].delta.content if not chunk_content: continue if chunk_content == "": llm_thinking = True yield [{"role": "assistant", "content": "Thinking..."}] continue if llm_thinking and chunk_content == "": llm_thinking = False yield [{"role": "assistant", "content": "Finished thinking."}] continue if not llm_thinking: buffer += chunk_content if len(buffer) > 20 or "\n" in buffer: chat_history[-1]["content"] += buffer yield [chat_history[-1]] buffer = "" if buffer: chat_history[-1]["content"] += buffer yield [chat_history[-1]] except Exception as e: logger.exception(f"Error during LLM response streaming: {str(e)}") yield [ { "role": "assistant", "content": f"⚠️ An error occurred during LLM response streaming. Please try again later.", } ] try: logger.info("Trying to upload chat history to Drive...") chat_uploader.upload_chat_history(chat_history, browser_id) except Exception as e: logger.warning(f"Warning: error during Google Drive upload: {e}") logger.info("Returning chat history...") return chat_history # Gradio app code with gr.Blocks() as demo: browser_id_state = gr.BrowserState(default_value=None) @demo.load(inputs=browser_id_state, outputs=browser_id_state) def load_browser_id(current_id): if current_id is None or current_id == "": new_id = str(uuid.uuid4()) return new_id return current_id gr.ChatInterface( fn=rag_chatbot, title="Matthew Schulz's RAG Chatbot 💬🤖", additional_inputs=browser_id_state, type="messages", examples=[ ["What is Matthew's educational background?"], ["What machine learning projects has Matthew worked on?"], ["What experience does Matthew have in software engineering?"], ["Why did Matthew choose to pursue a degree in computer science?"], ["Does Matthew have any leadership experience?"], ["Has Matthew completed any Summer internships?"], ["Tell me about some real-world projects Matthew has worked on."], ["What is Matthew's greatest strength and weakness?"], ], save_history=True, run_examples_on_click=False, cache_examples=False, ) if __name__ == "__main__": demo.launch()