GrokAgenticWorkforce / pages /5_Gemini-Chat.py
eaglelandsonce's picture
Update pages/5_Gemini-Chat.py
c5c69f5 verified
import streamlit as st
from PIL import Image
import io
import base64
import uuid
from gtts import gTTS
import google.generativeai as genai
from io import BytesIO # Import BytesIO
# Set your API key
api_key = "AIzaSyC70u1sN87IkoxOoIj4XCAPw97ae2LZwNM" # Replace with your actual API key
genai.configure(api_key=api_key)
# Configure the generative AI model
generation_config = genai.GenerationConfig(
temperature=0.9,
max_output_tokens=3000
)
# Safety settings configuration
safety_settings = [
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE",
},
]
# Initialize session state
if 'chat_history' not in st.session_state:
st.session_state['chat_history'] = []
if 'file_uploader_key' not in st.session_state:
st.session_state['file_uploader_key'] = str(uuid.uuid4())
st.title("Gemini Chatbot")
# Helper functions for image processing and chat history management
def get_image_base64(image):
image = image.convert("RGB")
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return img_str
def clear_conversation():
st.session_state['chat_history'] = []
st.session_state['file_uploader_key'] = str(uuid.uuid4())
def display_chat_history():
for entry in st.session_state['chat_history']:
role = entry["role"]
parts = entry["parts"][0]
if 'text' in parts:
st.markdown(f"{role.title()}: {parts['text']}")
elif 'data' in parts:
st.image(Image.open(io.BytesIO(base64.b64decode(parts['data']))), caption='Uploaded Image')
def get_chat_history_str():
chat_history_str = "\n".join(
f"{entry['role'].title()}: {part['text']}" if 'text' in part
else f"{entry['role'].title()}: (Image)"
for entry in st.session_state['chat_history']
for part in entry['parts']
)
return chat_history_str
# Send message function with TTS integration
def send_message():
user_input = st.session_state.user_input
uploaded_files = st.session_state.uploaded_files
prompts = []
prompt_parts = []
# Populate the prompts list with the existing chat history
for entry in st.session_state['chat_history']:
for part in entry['parts']:
if 'text' in part:
prompts.append(part['text'])
elif 'data' in part:
# Add the image in base64 format to prompt_parts for vision model
prompt_parts.append({"data": part['data'], "mime_type": "image/jpeg"})
prompts.append("[Image]")
# Append the user input to the prompts list
if user_input:
prompts.append(user_input)
st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]})
# Also add the user text input to prompt_parts
prompt_parts.append({"text": user_input})
# Handle uploaded files
if uploaded_files:
for uploaded_file in uploaded_files:
base64_image = get_image_base64(Image.open(uploaded_file))
prompts.append("[Image]")
prompt_parts.append({"data": base64_image, "mime_type": "image/jpeg"})
st.session_state['chat_history'].append({
"role": "user",
"parts": [{"mime_type": uploaded_file.type, "data": base64_image}]
})
# Determine if vision model should be used
use_vision_model = any(part.get('mime_type') == 'image/jpeg' for part in prompt_parts)
# Set up the model and generate a response
model_name = 'gemini-pro-vision' if use_vision_model else 'gemini-pro'
model = genai.GenerativeModel(
model_name=model_name,
generation_config=generation_config,
safety_settings=safety_settings
)
chat_history_str = "\n".join(prompts)
if use_vision_model:
# Include text and images for vision model
generated_prompt = {"role": "user", "parts": prompt_parts}
else:
# Include text only for standard model
generated_prompt = {"role": "user", "parts": [{"text": chat_history_str}]}
response = model.generate_content([generated_prompt])
response_text = response.text if hasattr(response, "text") else "No response text found."
# After generating the response from the model, append it to the chat history
if response_text:
st.session_state['chat_history'].append({"role": "model", "parts":[{"text": response_text}]})
# Convert the response text to speech
tts = gTTS(text=response_text, lang='en')
tts_file = BytesIO()
tts.write_to_fp(tts_file)
tts_file.seek(0)
st.audio(tts_file, format='audio/mp3')
# Clear the input fields after sending the message
st.session_state.user_input = ''
st.session_state.uploaded_files = []
st.session_state.file_uploader_key = str(uuid.uuid4())
# Display the updated chat history
display_chat_history()
# User input text area
user_input = st.text_area(
"Enter your message here:",
value="",
key="user_input"
)
# File uploader for images
uploaded_files = st.file_uploader(
"Upload images:",
type=["png", "jpg", "jpeg"],
accept_multiple_files=True,
key=st.session_state.file_uploader_key
)
# Send message button
send_button = st.button(
"Send",
on_click=send_message
)
# Clear conversation button
clear_button = st.button("Clear Conversation", on_click=clear_conversation)
# Function to download the chat history as a text file
def download_chat_history():
chat_history_str = get_chat_history_str()
return chat_history_str
# Download button for the chat history
download_button = st.download_button(
label="Download Chat",
data=download_chat_history(),
file_name="chat_history.txt",
mime="text/plain"
)
# Ensure the file_uploader widget state is tied to the randomly generated key
st.session_state.uploaded_files = uploaded_files
# JavaScript to capture the Ctrl+Enter event and trigger a button click
st.markdown(
"""
<script>
document.addEventListener('DOMContentLoaded', (event) => {
document.querySelector('.stTextArea textarea').addEventListener('keydown', function(e) {
if (e.key === 'Enter' && e.ctrlKey) {
document.querySelector('.stButton > button').click();
e.preventDefault();
}
});
});
</script>
""",
unsafe_allow_html=True
)