Spaces:
Running
Running
import base64 | |
import io | |
import os | |
import time | |
from typing import Dict, List, Optional, Union | |
import gradio as gr | |
from google import genai | |
from google.genai import types # New types module from google-genai | |
from PIL import Image | |
# Retrieve API key for Google GenAI from the environment variables. | |
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY") | |
# Initialize the client so that it can be reused across functions. | |
CLIENT = genai.Client(api_key=GOOGLE_API_KEY) | |
# General constants for the UI | |
TITLE = """<h1 align="center">Gemini 2.0 Pro Multi-modal Chatbot</h1>""" | |
AVATAR_IMAGES = (None, "https://media.roboflow.com/spaces/gemini-icon.png") | |
IMAGE_WIDTH = 512 | |
def preprocess_stop_sequences(stop_sequences: str) -> Optional[List[str]]: | |
""" | |
Convert a comma-separated string of stop sequences into a list. | |
Parameters: | |
stop_sequences (str): A string containing stop sequences separated by commas. | |
Returns: | |
Optional[List[str]]: A list of trimmed stop sequences if provided; otherwise, None. | |
""" | |
if not stop_sequences: | |
return None | |
return [sequence.strip() for sequence in stop_sequences.split(",")] | |
def preprocess_image(image: Image.Image) -> Image.Image: | |
""" | |
Resize an image to a fixed width while maintaining the aspect ratio. | |
Parameters: | |
image (Image.Image): The original image. | |
Returns: | |
Image.Image: The resized image with width fixed at IMAGE_WIDTH. | |
""" | |
image_height = int(image.height * IMAGE_WIDTH / image.width) | |
return image.resize((IMAGE_WIDTH, image_height)) | |
def image_to_base64_html_from_pil(image: Image.Image, max_width: int = 150) -> str: | |
""" | |
Convert a PIL Image to an HTML <img> tag with base64-encoded image data. | |
Parameters: | |
image (Image.Image): The image to encode. | |
max_width (int): Maximum width (in pixels) for the displayed image. | |
Returns: | |
str: An HTML string with the embedded image. | |
""" | |
buffered = io.BytesIO() | |
image.save(buffered, format="JPEG") | |
b64_data = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
return ( | |
f'<img src="data:image/jpeg;base64,{b64_data}" alt="Uploaded Image" ' | |
f'style="max-width:{max_width}px;">' | |
) | |
def preprocess_chat_history_messages( | |
chat_history: List[Union[dict, gr.ChatMessage]], | |
) -> List[Dict[str, Union[str, List[str]]]]: | |
""" | |
Normalize chat history messages into a consistent list of dictionaries. | |
Each message (whether as a dict or gr.ChatMessage) is converted into a dictionary | |
containing a role and a list of parts (message content). | |
Parameters: | |
chat_history (List[Union[dict, gr.ChatMessage]]): The conversation history. | |
Returns: | |
List[Dict[str, Union[str, List[str]]]]: A normalized list of messages. | |
""" | |
messages = [] | |
for msg in chat_history: | |
if isinstance(msg, dict): | |
content = msg.get("content") | |
role = msg.get("role") | |
else: | |
content = msg.content | |
role = msg.role | |
if content is not None: | |
# Convert "assistant" role to "model" if needed. | |
role = "model" if role == "assistant" else role | |
messages.append({"role": role, "parts": [content]}) | |
return messages | |
def chat_history_to_prompt(chat_history: List[Union[dict, gr.ChatMessage]]) -> str: | |
""" | |
Convert the entire chat conversation into a single text prompt. | |
Each message is prefixed by “User:” or “Assistant:” to form a full conversation. | |
Parameters: | |
chat_history (List[Union[dict, gr.ChatMessage]]): The conversation history. | |
Returns: | |
str: A string that concatenates the conversation history. | |
""" | |
conversation = "" | |
for msg in chat_history: | |
content = get_message_content(msg) | |
role = msg.get("role") if isinstance(msg, dict) else msg.role | |
if role in ["assistant", "model"]: | |
conversation += f"Assistant: {content}\n" | |
else: | |
conversation += f"User: {content}\n" | |
return conversation | |
def upload(files: Optional[List[str]], chatbot: List[Union[dict, gr.ChatMessage]]): | |
""" | |
Process uploaded image files: resize them, convert to an HTML <img> tag (with base64 data), | |
and append it as a new user message to the chatbot history. | |
Parameters: | |
files (Optional[List[str]]): List of image file paths. | |
chatbot (List[Union[dict, gr.ChatMessage]]): The current conversation history. | |
Returns: | |
List[Union[dict, gr.ChatMessage]]: Updated conversation history. | |
""" | |
for file in files: | |
image = Image.open(file).convert("RGB") | |
image = preprocess_image(image) | |
image_html = image_to_base64_html_from_pil(image) | |
chatbot.append(gr.ChatMessage(role="user", content=image_html)) | |
return chatbot | |
def upload_audio( | |
files: Optional[List[str]], chatbot: List[Union[dict, gr.ChatMessage]] | |
): | |
""" | |
Process uploaded audio files: read and base64-encode them, wrap the data in an HTML audio player, | |
and append it as a new user message. | |
Parameters: | |
files (Optional[List[str]]): List of audio file paths. | |
chatbot (List[Union[dict, gr.ChatMessage]]): The conversation history. | |
Returns: | |
List[Union[dict, gr.ChatMessage]]: The updated chatbot history. | |
""" | |
for file in files: | |
with open(file, "rb") as f: | |
audio_bytes = f.read() | |
b64_data = base64.b64encode(audio_bytes).decode("utf-8") | |
audio_html = f"""<audio controls style="max-width:150px;"> | |
<source src="data:audio/mp3;base64,{b64_data}" type="audio/mp3"> | |
Your browser does not support the audio element. | |
</audio>""" | |
chatbot.append(gr.ChatMessage(role="user", content=audio_html)) | |
return chatbot | |
def upload_document( | |
files: Optional[List[str]], chatbot: List[Union[dict, gr.ChatMessage]] | |
): | |
""" | |
Process uploaded document files (assumed to be PDFs) and add a notification message | |
(with an HTML snippet) indicating that the document has been uploaded. | |
Parameters: | |
files (Optional[List[str]]): List of document file paths. | |
chatbot (List[Union[dict, gr.ChatMessage]]): The conversation history. | |
Returns: | |
List[Union[dict, gr.ChatMessage]]: The updated chatbot history. | |
""" | |
for file in files: | |
filename = os.path.basename(file) | |
doc_html = f"<p>📄 Document uploaded: {filename}</p>" | |
chatbot.append(gr.ChatMessage(role="user", content=doc_html)) | |
return chatbot | |
def user(text_prompt: str, chatbot: List[gr.ChatMessage]): | |
""" | |
Append a new user text message to the chat history. | |
Parameters: | |
text_prompt (str): The input text provided by the user. | |
chatbot (List[gr.ChatMessage]): The existing conversation history. | |
Returns: | |
Tuple[str, List[gr.ChatMessage]]: A tuple of an empty string (clearing the prompt) | |
and the updated conversation history. | |
""" | |
if text_prompt: | |
chatbot.append(gr.ChatMessage(role="user", content=text_prompt)) | |
return "", chatbot | |
def get_message_content(msg): | |
""" | |
Retrieve the content of a message that can be either a dictionary or a gr.ChatMessage. | |
Parameters: | |
msg (Union[dict, gr.ChatMessage]): The message object. | |
Returns: | |
str: The textual content of the message. | |
""" | |
if isinstance(msg, dict): | |
return msg.get("content", "") | |
return msg.content | |
def bot( | |
image_files: Optional[List[str]], | |
audio_files: Optional[List[str]], | |
doc_files: Optional[List[str]], | |
chatbot: List[Union[dict, gr.ChatMessage]], | |
): | |
""" | |
Generate a chatbot response from Gemini 2.0 based on provided inputs. | |
This function supports three branches: | |
1. Document branch: when doc_files are provided. | |
2. Multi-modal branch: when image and/or audio files are provided. | |
3. Text-only conversation branch. | |
All branches now use generate_content_stream to yield incremental responses. | |
Parameters: | |
image_files (Optional[List[str]]): List of image file paths. | |
audio_files (Optional[List[str]]): List of audio file paths. | |
doc_files (Optional[List[str]]): List of document file paths. | |
chatbot (List[Union[dict, gr.ChatMessage]]): The conversation history. | |
Yields: | |
List[Union[dict, gr.ChatMessage]]: The updated conversation history with streamed responses. | |
""" | |
if len(chatbot) == 0: | |
return chatbot | |
# Append a placeholder for the assistant's response. | |
chatbot.append(gr.ChatMessage(role="assistant", content="")) | |
generation_config = types.GenerateContentConfig( | |
temperature=0.4, | |
max_output_tokens=4096, | |
top_k=32, | |
top_p=1, | |
) | |
# Branch 1: Document uploads. | |
if doc_files and len(doc_files) > 0: | |
prev_msg_content = get_message_content(chatbot[-2]) if len(chatbot) >= 2 else "" | |
prompt = [prev_msg_content] if prev_msg_content else [] | |
doc_parts = [] | |
for file in doc_files: | |
with open(file, "rb") as f: | |
doc_bytes = f.read() | |
doc_parts.append( | |
types.Part.from_bytes( | |
data=doc_bytes, | |
mime_type="application/pdf", | |
) | |
) | |
# Combine document parts and previous text. | |
contents = doc_parts + prompt | |
# Use the streaming endpoint. | |
response = CLIENT.models.generate_content_stream( | |
model="gemini-2.0-pro-exp-02-05", | |
contents=contents, | |
config=generation_config, | |
) | |
for chunk in response: | |
for i in range(0, len(chunk.text), 10): | |
section = chunk.text[i : i + 10] | |
if isinstance(chatbot[-1], dict): | |
chatbot[-1]["content"] += section | |
else: | |
chatbot[-1].content += section | |
time.sleep(0.01) | |
yield chatbot | |
return | |
# Branch 2: Image or audio uploads. | |
elif (image_files and len(image_files) > 0) or ( | |
audio_files and len(audio_files) > 0 | |
): | |
prev_msg_content = get_message_content(chatbot[-2]) if len(chatbot) >= 2 else "" | |
text_prompt = [prev_msg_content] if prev_msg_content else [] | |
image_prompt = ( | |
[Image.open(file).convert("RGB") for file in image_files] | |
if image_files | |
else [] | |
) | |
audio_prompt = [] | |
if audio_files: | |
for file in audio_files: | |
with open(file, "rb") as f: | |
audio_bytes = f.read() | |
audio_prompt.append( | |
types.Part.from_bytes( | |
data=audio_bytes, | |
mime_type="audio/mp3", | |
) | |
) | |
# Combine all inputs into a multi-modal prompt. | |
contents = text_prompt + image_prompt + audio_prompt | |
response = CLIENT.models.generate_content_stream( | |
model="gemini-2.0-pro-exp-02-05", | |
contents=contents, | |
config=generation_config, | |
) | |
for chunk in response: | |
for i in range(0, len(chunk.text), 10): | |
section = chunk.text[i : i + 10] | |
if isinstance(chatbot[-1], dict): | |
chatbot[-1]["content"] += section | |
else: | |
chatbot[-1].content += section | |
time.sleep(0.01) | |
yield chatbot | |
return | |
# Branch 3: Text-only conversation. | |
else: | |
conversation_text = chat_history_to_prompt(chatbot) | |
response = CLIENT.models.generate_content_stream( | |
model="gemini-2.0-pro-exp-02-05", | |
contents=[conversation_text], | |
config=generation_config, | |
) | |
for chunk in response: | |
for i in range(0, len(chunk.text), 10): | |
section = chunk.text[i : i + 10] | |
if isinstance(chatbot[-1], dict): | |
chatbot[-1]["content"] += section | |
else: | |
chatbot[-1].content += section | |
time.sleep(0.01) | |
yield chatbot | |
return | |
def run_code_execution(code_prompt: str, chatbot: List[Union[dict, gr.ChatMessage]]): | |
""" | |
Append the user's code execution query to the chat history, then call Gemini | |
with code execution enabled using the user's input. The results (including any | |
generated code and execution output) are appended as a new assistant message. | |
""" | |
# Only add a user message if there is content. | |
if code_prompt.strip(): | |
chatbot.append(gr.ChatMessage(role="user", content=code_prompt)) | |
# Append an empty assistant message to update with the code execution response. | |
chatbot.append(gr.ChatMessage(role="assistant", content="")) | |
generation_config = types.GenerateContentConfig( | |
tools=[types.Tool(code_execution=types.ToolCodeExecution)] | |
) | |
response = CLIENT.models.generate_content( | |
model="gemini-2.0-pro-exp-02-05", | |
contents=code_prompt, | |
config=generation_config, | |
) | |
output_text = "" | |
for part in response.candidates[0].content.parts: | |
if part.text is not None: | |
output_text += f"{part.text}\n" | |
if part.executable_code is not None: | |
# Display the executable code in a code block (using markdown formatting) | |
output_text += ( | |
f"\n**Generated Code:**\n```python\n{part.executable_code.code}\n```\n" | |
) | |
if part.code_execution_result is not None: | |
output_text += ( | |
f"\n**Output:**\n```\n{part.code_execution_result.output}\n```\n" | |
) | |
if part.inline_data is not None: | |
image_data = base64.b64decode(part.inline_data.data) | |
image = Image.open(io.BytesIO(image_data)) | |
buffered = io.BytesIO() | |
image.save(buffered, format="PNG") | |
b64_data = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
output_text += f'\n<img src="data:image/png;base64,{b64_data}" alt="Inline Image" style="max-width:300px;"/>\n' | |
output_text += "\n---\n" | |
# Update the last assistant message with the code execution result. | |
if isinstance(chatbot[-1], dict): | |
chatbot[-1]["content"] = output_text | |
else: | |
chatbot[-1].content = output_text | |
# Clear the text prompt after processing. | |
return "", chatbot | |
# Define the Gradio UI components. | |
chatbot_component = gr.Chatbot( | |
label="Gemini 2.0 Pro", | |
type="messages", # Using message objects. | |
bubble_full_width=False, | |
avatar_images=AVATAR_IMAGES, | |
scale=2, | |
height=400, | |
) | |
text_prompt_component = gr.Textbox( | |
placeholder="Enter your message or code query here...", | |
show_label=False, | |
autofocus=True, | |
scale=19, | |
) | |
upload_button_component = gr.UploadButton( | |
label="Upload Images", | |
file_count="multiple", | |
file_types=["image"], | |
scale=1, | |
) | |
upload_audio_button_component = gr.UploadButton( | |
label="Upload Audio", | |
file_count="multiple", | |
file_types=["audio"], | |
scale=1, | |
) | |
upload_doc_button_component = gr.UploadButton( | |
label="Upload Documents", | |
file_count="multiple", | |
file_types=[".pdf"], | |
scale=1, | |
) | |
run_button_component = gr.Button(value="Run", variant="primary", scale=1, min_width=60) | |
run_code_execution_button = gr.Button( | |
value="Run Code Execution", variant="secondary", scale=1 | |
) | |
# Define input lists for button chaining. | |
user_inputs = [text_prompt_component, chatbot_component] | |
bot_inputs = [ | |
upload_button_component, | |
upload_audio_button_component, | |
upload_doc_button_component, | |
chatbot_component, | |
] | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.HTML(TITLE) | |
with gr.Column(): | |
chatbot_component.render() | |
with gr.Row(equal_height=True): | |
text_prompt_component.render() | |
run_button_component.render() | |
with gr.Row(): | |
# Render file-upload buttons and the code execution button in a single row. | |
upload_button_component.render() | |
upload_audio_button_component.render() | |
upload_doc_button_component.render() | |
run_code_execution_button.render() | |
# When the Run button is clicked, first process the user text then stream a response. | |
run_button_component.click( | |
fn=user, | |
inputs=user_inputs, | |
outputs=[text_prompt_component, chatbot_component], | |
queue=False, | |
).then( | |
fn=bot, | |
inputs=bot_inputs, | |
outputs=[chatbot_component], | |
) | |
# Allow submission using the Enter key. | |
text_prompt_component.submit( | |
fn=user, | |
inputs=user_inputs, | |
outputs=[text_prompt_component, chatbot_component], | |
queue=False, | |
).then( | |
fn=bot, | |
inputs=bot_inputs, | |
outputs=[chatbot_component], | |
) | |
# Handle image uploads. | |
upload_button_component.upload( | |
fn=upload, | |
inputs=[upload_button_component, chatbot_component], | |
outputs=[chatbot_component], | |
queue=False, | |
) | |
# Handle audio uploads. | |
upload_audio_button_component.upload( | |
fn=upload_audio, | |
inputs=[upload_audio_button_component, chatbot_component], | |
outputs=[chatbot_component], | |
queue=False, | |
) | |
# Handle document uploads. | |
upload_doc_button_component.upload( | |
fn=upload_document, | |
inputs=[upload_doc_button_component, chatbot_component], | |
outputs=[chatbot_component], | |
queue=False, | |
) | |
# When the Code Execution button is clicked, process the code prompt and stream the output. | |
run_code_execution_button.click( | |
fn=run_code_execution, | |
inputs=[text_prompt_component, chatbot_component], | |
outputs=[text_prompt_component, chatbot_component], | |
queue=False, | |
) | |
# Launch the demo interface with queuing enabled. | |
demo.queue(max_size=99, api_open=False).launch(debug=False, pwa=True, show_error=True) | |