ragalicious-app / app.py
mickkhaw's picture
Merge changes from main
bb8f798
raw
history blame contribute delete
No virus
5.75 kB
from io import BytesIO
import os
from pprint import pprint
import uuid
import chainlit as cl
from chainlit.element import ElementBased
from dotenv import load_dotenv
# modules for audio processing
from langchain.schema.runnable.config import RunnableConfig
from langchain_openai.chat_models import ChatOpenAI
from openai import AsyncOpenAI
from utils.graph import generate_workflow
client = AsyncOpenAI()
# ---- ENV VARIABLES ---- #
load_dotenv()
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
# -- AUGMENTED -- #
# Define the LLM
base_llm = ChatOpenAI(model="gpt-4o-mini", openai_api_key=OPENAI_API_KEY, tags=["base_llm"], temperature=0)
power_llm = ChatOpenAI(model="gpt-4o", openai_api_key=OPENAI_API_KEY, tags=["base_llm"])
# Conversation starters for the 1st screen
@cl.set_starters
async def set_starters():
return [
cl.Starter(
label="Plan your quick daily meals",
message="Give me ideas for making an easy weeknight dinner that takes less than 25 minutes to prepare",
icon="/public/meals4.svg",
),
cl.Starter(
label="Ideas for special occasions",
message="What are good Middle Eastern dishes to make for Thanksgiving?",
icon="/public/occasion4.svg",
),
cl.Starter(
label="Use ingredients you have",
message="Suggest Spanish recipes that are good for the summer that makes use of tomatoes",
icon="/public/ingredients4.svg",
),
]
# Chat Start Function: Initialize a RAG (Retrieval-Augmented Generation) chain at the start of each chat session.
@cl.on_chat_start
async def start_chat():
"""
This function will be called at the start of every user session.
We will build our LCEL RAG chain here, and store it in the user session.
The user session is a dictionary that is unique to each user session, and is stored in the memory of the server.
"""
langgraph_chain = generate_workflow(base_llm, power_llm)
cl.user_session.set("langgraph_chain", langgraph_chain)
cl.user_session.set("thread_id", str(uuid.uuid4()))
# Message Handling Function: Process and respond to user messages using the RAG chain.
@cl.on_message
async def main(message: cl.Message):
"""
This function will be called every time a message is recieved from a session.
We will use the LCEL RAG chain to generate a response to the user question.
The LCEL RAG chain is stored in the user session, and is unique to each user session - this is why we can access it here.
"""
langgraph_chain = cl.user_session.get("langgraph_chain")
thread_id = cl.user_session.get("thread_id")
msg = cl.Message(content="")
langgraph_config = {"configurable": {"thread_id": thread_id, "cl_msg": msg}}
async for output in langgraph_chain.astream({"question": message.content}, langgraph_config):
for key, value in output.items():
pprint(f"================== Node: '{key}':")
await msg.send()
# Speech-to-Text Function: Convert audio file to text
@cl.step(type="tool")
async def speech_to_text(audio_file):
response = await client.audio.transcriptions.create(model="whisper-1", file=audio_file)
return response.text
# Generate Text Answer Function: Take the output of Speech-to-Text and generate a text answer
@cl.step(type="tool")
async def generate_text_answer(transcription):
base_rag_chain = cl.user_session.get("base_rag_chain")
msg = cl.Message(content="")
async for chunk in base_rag_chain.astream(
{"question": transcription},
config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
):
if isinstance(chunk, dict) and "response" in chunk and isinstance(chunk["response"], str):
await msg.stream_token(chunk["response"])
return msg.content
# Audio Chunk Function: Process audio chunks as they arrive from the user's microphone
@cl.on_audio_chunk
async def on_audio_chunk(chunk: cl.AudioChunk):
if chunk.isStart:
buffer = BytesIO()
# This is required for whisper to recognize the file type
buffer.name = f"input_audio.{chunk.mimeType.split('/')[1]}"
# Initialize the session for a new audio stream
cl.user_session.set("audio_buffer", buffer)
cl.user_session.set("audio_mime_type", chunk.mimeType)
# For now, write the chunks to a buffer and transcribe the whole audio at the end
cl.user_session.get("audio_buffer").write(chunk.data)
# Audio End Function: Process the audio file and generate a response
@cl.on_audio_end
async def on_audio_end(elements: list[ElementBased]):
# Get the audio buffer from the session
audio_buffer: BytesIO = cl.user_session.get("audio_buffer")
audio_buffer.seek(0) # Move the file pointer to the beginning
audio_file = audio_buffer.read()
audio_mime_type: str = cl.user_session.get("audio_mime_type")
input_audio_el = cl.Audio(mime=audio_mime_type, content=audio_file, name=audio_buffer.name)
await cl.Message(author="You", type="user_message", content="", elements=[input_audio_el, *elements]).send()
whisper_input = (audio_buffer.name, audio_file, audio_mime_type)
transcription = await speech_to_text(whisper_input)
print("Transcription: ", transcription)
langgraph_chain = cl.user_session.get("langgraph_chain")
thread_id = cl.user_session.get("thread_id")
msg = cl.Message(content="")
langgraph_config = {"configurable": {"thread_id": thread_id, "cl_msg": msg}}
async for output in langgraph_chain.astream({"question": transcription}, langgraph_config):
for key, value in output.items():
pprint(f"================== Node: '{key}':")
await msg.send()