Spaces:
Runtime error
Runtime error
Update pages/5_Gemini-Chat.py
Browse files- pages/5_Gemini-Chat.py +75 -23
pages/5_Gemini-Chat.py
CHANGED
@@ -5,7 +5,7 @@ import base64
|
|
5 |
import uuid
|
6 |
from gtts import gTTS
|
7 |
import google.generativeai as genai
|
8 |
-
from io import BytesIO
|
9 |
|
10 |
# Set your API key
|
11 |
api_key = "AIzaSyC70u1sN87IkoxOoIj4XCAPw97ae2LZwNM" # Replace with your actual API key
|
@@ -39,17 +39,12 @@ safety_settings = [
|
|
39 |
|
40 |
# Initialize session state
|
41 |
if 'chat_history' not in st.session_state:
|
42 |
-
st.session_state
|
43 |
if 'file_uploader_key' not in st.session_state:
|
44 |
-
st.session_state
|
45 |
|
46 |
st.title("Gemini Chatbot")
|
47 |
|
48 |
-
# Displaying the system message for users
|
49 |
-
st.markdown("""
|
50 |
-
**AI Planner System Prompt:** As the AI Planner, your primary task is to assist in the development of a coherent and engaging book. You will be responsible for organizing the overall structure, defining the plot or narrative, and outlining the chapters or sections. To accomplish this, you will need to use your understanding of storytelling principles and genre conventions, as well as any specific information provided by the user, to create a well-structured framework for the book.
|
51 |
-
""")
|
52 |
-
|
53 |
# Helper functions for image processing and chat history management
|
54 |
def get_image_base64(image):
|
55 |
image = image.convert("RGB")
|
@@ -59,11 +54,11 @@ def get_image_base64(image):
|
|
59 |
return img_str
|
60 |
|
61 |
def clear_conversation():
|
62 |
-
st.session_state
|
63 |
-
st.session_state
|
64 |
|
65 |
def display_chat_history():
|
66 |
-
for entry in st.session_state
|
67 |
role = entry["role"]
|
68 |
parts = entry["parts"][0]
|
69 |
if 'text' in parts:
|
@@ -75,34 +70,50 @@ def get_chat_history_str():
|
|
75 |
chat_history_str = "\n".join(
|
76 |
f"{entry['role'].title()}: {part['text']}" if 'text' in part
|
77 |
else f"{entry['role'].title()}: (Image)"
|
78 |
-
for entry in st.session_state
|
79 |
for part in entry['parts']
|
80 |
)
|
81 |
return chat_history_str
|
82 |
|
83 |
-
#
|
84 |
def send_message():
|
85 |
user_input = st.session_state.user_input
|
|
|
86 |
prompts = []
|
87 |
prompt_parts = []
|
88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
if user_input:
|
90 |
prompts.append(user_input)
|
91 |
-
st.session_state
|
|
|
92 |
prompt_parts.append({"text": user_input})
|
93 |
|
94 |
-
#
|
95 |
if uploaded_files:
|
96 |
for uploaded_file in uploaded_files:
|
97 |
base64_image = get_image_base64(Image.open(uploaded_file))
|
98 |
prompts.append("[Image]")
|
99 |
prompt_parts.append({"data": base64_image, "mime_type": "image/jpeg"})
|
100 |
-
st.session_state
|
101 |
"role": "user",
|
102 |
"parts": [{"mime_type": uploaded_file.type, "data": base64_image}]
|
103 |
})
|
104 |
|
|
|
105 |
use_vision_model = any(part.get('mime_type') == 'image/jpeg' for part in prompt_parts)
|
|
|
|
|
106 |
model_name = 'gemini-pro-vision' if use_vision_model else 'gemini-pro'
|
107 |
model = genai.GenerativeModel(
|
108 |
model_name=model_name,
|
@@ -111,35 +122,75 @@ def send_message():
|
|
111 |
)
|
112 |
chat_history_str = "\n".join(prompts)
|
113 |
if use_vision_model:
|
|
|
114 |
generated_prompt = {"role": "user", "parts": prompt_parts}
|
115 |
else:
|
|
|
116 |
generated_prompt = {"role": "user", "parts": [{"text": chat_history_str}]}
|
117 |
|
118 |
response = model.generate_content([generated_prompt])
|
119 |
response_text = response.text if hasattr(response, "text") else "No response text found."
|
120 |
|
|
|
121 |
if response_text:
|
122 |
-
st.session_state
|
|
|
|
|
123 |
tts = gTTS(text=response_text, lang='en')
|
124 |
tts_file = BytesIO()
|
125 |
tts.write_to_fp(tts_file)
|
126 |
tts_file.seek(0)
|
127 |
st.audio(tts_file, format='audio/mp3')
|
128 |
|
|
|
129 |
st.session_state.user_input = ''
|
|
|
|
|
|
|
|
|
130 |
display_chat_history()
|
131 |
|
132 |
-
#
|
133 |
-
user_input = st.text_area(
|
134 |
-
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
clear_button = st.button("Clear Conversation", on_click=clear_conversation)
|
137 |
|
|
|
138 |
def download_chat_history():
|
139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
-
|
|
|
142 |
|
|
|
143 |
st.markdown(
|
144 |
"""
|
145 |
<script>
|
@@ -155,3 +206,4 @@ st.markdown(
|
|
155 |
""",
|
156 |
unsafe_allow_html=True
|
157 |
)
|
|
|
|
5 |
import uuid
|
6 |
from gtts import gTTS
|
7 |
import google.generativeai as genai
|
8 |
+
from io import BytesIO # Import BytesIO
|
9 |
|
10 |
# Set your API key
|
11 |
api_key = "AIzaSyC70u1sN87IkoxOoIj4XCAPw97ae2LZwNM" # Replace with your actual API key
|
|
|
39 |
|
40 |
# Initialize session state
|
41 |
if 'chat_history' not in st.session_state:
|
42 |
+
st.session_state['chat_history'] = []
|
43 |
if 'file_uploader_key' not in st.session_state:
|
44 |
+
st.session_state['file_uploader_key'] = str(uuid.uuid4())
|
45 |
|
46 |
st.title("Gemini Chatbot")
|
47 |
|
|
|
|
|
|
|
|
|
|
|
48 |
# Helper functions for image processing and chat history management
|
49 |
def get_image_base64(image):
|
50 |
image = image.convert("RGB")
|
|
|
54 |
return img_str
|
55 |
|
56 |
def clear_conversation():
|
57 |
+
st.session_state['chat_history'] = []
|
58 |
+
st.session_state['file_uploader_key'] = str(uuid.uuid4())
|
59 |
|
60 |
def display_chat_history():
|
61 |
+
for entry in st.session_state['chat_history']:
|
62 |
role = entry["role"]
|
63 |
parts = entry["parts"][0]
|
64 |
if 'text' in parts:
|
|
|
70 |
chat_history_str = "\n".join(
|
71 |
f"{entry['role'].title()}: {part['text']}" if 'text' in part
|
72 |
else f"{entry['role'].title()}: (Image)"
|
73 |
+
for entry in st.session_state['chat_history']
|
74 |
for part in entry['parts']
|
75 |
)
|
76 |
return chat_history_str
|
77 |
|
78 |
+
# Send message function with TTS integration
|
79 |
def send_message():
|
80 |
user_input = st.session_state.user_input
|
81 |
+
uploaded_files = st.session_state.uploaded_files
|
82 |
prompts = []
|
83 |
prompt_parts = []
|
84 |
|
85 |
+
# Populate the prompts list with the existing chat history
|
86 |
+
for entry in st.session_state['chat_history']:
|
87 |
+
for part in entry['parts']:
|
88 |
+
if 'text' in part:
|
89 |
+
prompts.append(part['text'])
|
90 |
+
elif 'data' in part:
|
91 |
+
# Add the image in base64 format to prompt_parts for vision model
|
92 |
+
prompt_parts.append({"data": part['data'], "mime_type": "image/jpeg"})
|
93 |
+
prompts.append("[Image]")
|
94 |
+
|
95 |
+
# Append the user input to the prompts list
|
96 |
if user_input:
|
97 |
prompts.append(user_input)
|
98 |
+
st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]})
|
99 |
+
# Also add the user text input to prompt_parts
|
100 |
prompt_parts.append({"text": user_input})
|
101 |
|
102 |
+
# Handle uploaded files
|
103 |
if uploaded_files:
|
104 |
for uploaded_file in uploaded_files:
|
105 |
base64_image = get_image_base64(Image.open(uploaded_file))
|
106 |
prompts.append("[Image]")
|
107 |
prompt_parts.append({"data": base64_image, "mime_type": "image/jpeg"})
|
108 |
+
st.session_state['chat_history'].append({
|
109 |
"role": "user",
|
110 |
"parts": [{"mime_type": uploaded_file.type, "data": base64_image}]
|
111 |
})
|
112 |
|
113 |
+
# Determine if vision model should be used
|
114 |
use_vision_model = any(part.get('mime_type') == 'image/jpeg' for part in prompt_parts)
|
115 |
+
|
116 |
+
# Set up the model and generate a response
|
117 |
model_name = 'gemini-pro-vision' if use_vision_model else 'gemini-pro'
|
118 |
model = genai.GenerativeModel(
|
119 |
model_name=model_name,
|
|
|
122 |
)
|
123 |
chat_history_str = "\n".join(prompts)
|
124 |
if use_vision_model:
|
125 |
+
# Include text and images for vision model
|
126 |
generated_prompt = {"role": "user", "parts": prompt_parts}
|
127 |
else:
|
128 |
+
# Include text only for standard model
|
129 |
generated_prompt = {"role": "user", "parts": [{"text": chat_history_str}]}
|
130 |
|
131 |
response = model.generate_content([generated_prompt])
|
132 |
response_text = response.text if hasattr(response, "text") else "No response text found."
|
133 |
|
134 |
+
# After generating the response from the model, append it to the chat history
|
135 |
if response_text:
|
136 |
+
st.session_state['chat_history'].append({"role": "model", "parts":[{"text": response_text}]})
|
137 |
+
|
138 |
+
# Convert the response text to speech
|
139 |
tts = gTTS(text=response_text, lang='en')
|
140 |
tts_file = BytesIO()
|
141 |
tts.write_to_fp(tts_file)
|
142 |
tts_file.seek(0)
|
143 |
st.audio(tts_file, format='audio/mp3')
|
144 |
|
145 |
+
# Clear the input fields after sending the message
|
146 |
st.session_state.user_input = ''
|
147 |
+
st.session_state.uploaded_files = []
|
148 |
+
st.session_state.file_uploader_key = str(uuid.uuid4())
|
149 |
+
|
150 |
+
# Display the updated chat history
|
151 |
display_chat_history()
|
152 |
|
153 |
+
# User input text area
|
154 |
+
user_input = st.text_area(
|
155 |
+
"Enter your message here:",
|
156 |
+
value="",
|
157 |
+
key="user_input"
|
158 |
+
)
|
159 |
+
|
160 |
+
# File uploader for images
|
161 |
+
uploaded_files = st.file_uploader(
|
162 |
+
"Upload images:",
|
163 |
+
type=["png", "jpg", "jpeg"],
|
164 |
+
accept_multiple_files=True,
|
165 |
+
key=st.session_state.file_uploader_key
|
166 |
+
)
|
167 |
+
|
168 |
+
# Send message button
|
169 |
+
send_button = st.button(
|
170 |
+
"Send",
|
171 |
+
on_click=send_message
|
172 |
+
)
|
173 |
+
|
174 |
+
# Clear conversation button
|
175 |
clear_button = st.button("Clear Conversation", on_click=clear_conversation)
|
176 |
|
177 |
+
# Function to download the chat history as a text file
|
178 |
def download_chat_history():
|
179 |
+
chat_history_str = get_chat_history_str()
|
180 |
+
return chat_history_str
|
181 |
+
|
182 |
+
# Download button for the chat history
|
183 |
+
download_button = st.download_button(
|
184 |
+
label="Download Chat",
|
185 |
+
data=download_chat_history(),
|
186 |
+
file_name="chat_history.txt",
|
187 |
+
mime="text/plain"
|
188 |
+
)
|
189 |
|
190 |
+
# Ensure the file_uploader widget state is tied to the randomly generated key
|
191 |
+
st.session_state.uploaded_files = uploaded_files
|
192 |
|
193 |
+
# JavaScript to capture the Ctrl+Enter event and trigger a button click
|
194 |
st.markdown(
|
195 |
"""
|
196 |
<script>
|
|
|
206 |
""",
|
207 |
unsafe_allow_html=True
|
208 |
)
|
209 |
+
|