eaglelandsonce commited on
Commit
c5c69f5
·
verified ·
1 Parent(s): ac4d3aa

Update pages/5_Gemini-Chat.py

Browse files
Files changed (1) hide show
  1. pages/5_Gemini-Chat.py +75 -23
pages/5_Gemini-Chat.py CHANGED
@@ -5,7 +5,7 @@ import base64
5
  import uuid
6
  from gtts import gTTS
7
  import google.generativeai as genai
8
- from io import BytesIO
9
 
10
  # Set your API key
11
  api_key = "AIzaSyC70u1sN87IkoxOoIj4XCAPw97ae2LZwNM" # Replace with your actual API key
@@ -39,17 +39,12 @@ safety_settings = [
39
 
40
  # Initialize session state
41
  if 'chat_history' not in st.session_state:
42
- st.session_state.chat_history = []
43
  if 'file_uploader_key' not in st.session_state:
44
- st.session_state.file_uploader_key = str(uuid.uuid4())
45
 
46
  st.title("Gemini Chatbot")
47
 
48
- # Displaying the system message for users
49
- st.markdown("""
50
- **AI Planner System Prompt:** As the AI Planner, your primary task is to assist in the development of a coherent and engaging book. You will be responsible for organizing the overall structure, defining the plot or narrative, and outlining the chapters or sections. To accomplish this, you will need to use your understanding of storytelling principles and genre conventions, as well as any specific information provided by the user, to create a well-structured framework for the book.
51
- """)
52
-
53
  # Helper functions for image processing and chat history management
54
  def get_image_base64(image):
55
  image = image.convert("RGB")
@@ -59,11 +54,11 @@ def get_image_base64(image):
59
  return img_str
60
 
61
  def clear_conversation():
62
- st.session_state.chat_history = []
63
- st.session_state.file_uploader_key = str(uuid.uuid4())
64
 
65
  def display_chat_history():
66
- for entry in st.session_state.chat_history:
67
  role = entry["role"]
68
  parts = entry["parts"][0]
69
  if 'text' in parts:
@@ -75,34 +70,50 @@ def get_chat_history_str():
75
  chat_history_str = "\n".join(
76
  f"{entry['role'].title()}: {part['text']}" if 'text' in part
77
  else f"{entry['role'].title()}: (Image)"
78
- for entry in st.session_state.chat_history
79
  for part in entry['parts']
80
  )
81
  return chat_history_str
82
 
83
- # Function to send messages
84
  def send_message():
85
  user_input = st.session_state.user_input
 
86
  prompts = []
87
  prompt_parts = []
88
 
 
 
 
 
 
 
 
 
 
 
 
89
  if user_input:
90
  prompts.append(user_input)
91
- st.session_state.chat_history.append({"role": "user", "parts": [{"text": user_input}]})
 
92
  prompt_parts.append({"text": user_input})
93
 
94
- # Handling uploaded files directly
95
  if uploaded_files:
96
  for uploaded_file in uploaded_files:
97
  base64_image = get_image_base64(Image.open(uploaded_file))
98
  prompts.append("[Image]")
99
  prompt_parts.append({"data": base64_image, "mime_type": "image/jpeg"})
100
- st.session_state.chat_history.append({
101
  "role": "user",
102
  "parts": [{"mime_type": uploaded_file.type, "data": base64_image}]
103
  })
104
 
 
105
  use_vision_model = any(part.get('mime_type') == 'image/jpeg' for part in prompt_parts)
 
 
106
  model_name = 'gemini-pro-vision' if use_vision_model else 'gemini-pro'
107
  model = genai.GenerativeModel(
108
  model_name=model_name,
@@ -111,35 +122,75 @@ def send_message():
111
  )
112
  chat_history_str = "\n".join(prompts)
113
  if use_vision_model:
 
114
  generated_prompt = {"role": "user", "parts": prompt_parts}
115
  else:
 
116
  generated_prompt = {"role": "user", "parts": [{"text": chat_history_str}]}
117
 
118
  response = model.generate_content([generated_prompt])
119
  response_text = response.text if hasattr(response, "text") else "No response text found."
120
 
 
121
  if response_text:
122
- st.session_state.chat_history.append({"role": "model", "parts":[{"text": response_text}]})
 
 
123
  tts = gTTS(text=response_text, lang='en')
124
  tts_file = BytesIO()
125
  tts.write_to_fp(tts_file)
126
  tts_file.seek(0)
127
  st.audio(tts_file, format='audio/mp3')
128
 
 
129
  st.session_state.user_input = ''
 
 
 
 
130
  display_chat_history()
131
 
132
- # UI components for user input, file uploader, send and clear buttons
133
- user_input = st.text_area("Enter your message here:", value="", key="user_input")
134
- uploaded_files = st.file_uploader("Upload images:", type=["png", "jpg", "jpeg"], accept_multiple_files=True, key="file_uploader_key")
135
- send_button = st.button("Send", on_click=send_message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  clear_button = st.button("Clear Conversation", on_click=clear_conversation)
137
 
 
138
  def download_chat_history():
139
- return get_chat_history_str()
 
 
 
 
 
 
 
 
 
140
 
141
- download_button = st.download_button("Download Chat", data=download_chat_history(), file_name="chat_history.txt", mime="text/plain")
 
142
 
 
143
  st.markdown(
144
  """
145
  <script>
@@ -155,3 +206,4 @@ st.markdown(
155
  """,
156
  unsafe_allow_html=True
157
  )
 
 
5
  import uuid
6
  from gtts import gTTS
7
  import google.generativeai as genai
8
+ from io import BytesIO # Import BytesIO
9
 
10
  # Set your API key
11
  api_key = "AIzaSyC70u1sN87IkoxOoIj4XCAPw97ae2LZwNM" # Replace with your actual API key
 
39
 
40
  # Initialize session state
41
  if 'chat_history' not in st.session_state:
42
+ st.session_state['chat_history'] = []
43
  if 'file_uploader_key' not in st.session_state:
44
+ st.session_state['file_uploader_key'] = str(uuid.uuid4())
45
 
46
  st.title("Gemini Chatbot")
47
 
 
 
 
 
 
48
  # Helper functions for image processing and chat history management
49
  def get_image_base64(image):
50
  image = image.convert("RGB")
 
54
  return img_str
55
 
56
  def clear_conversation():
57
+ st.session_state['chat_history'] = []
58
+ st.session_state['file_uploader_key'] = str(uuid.uuid4())
59
 
60
  def display_chat_history():
61
+ for entry in st.session_state['chat_history']:
62
  role = entry["role"]
63
  parts = entry["parts"][0]
64
  if 'text' in parts:
 
70
  chat_history_str = "\n".join(
71
  f"{entry['role'].title()}: {part['text']}" if 'text' in part
72
  else f"{entry['role'].title()}: (Image)"
73
+ for entry in st.session_state['chat_history']
74
  for part in entry['parts']
75
  )
76
  return chat_history_str
77
 
78
+ # Send message function with TTS integration
79
  def send_message():
80
  user_input = st.session_state.user_input
81
+ uploaded_files = st.session_state.uploaded_files
82
  prompts = []
83
  prompt_parts = []
84
 
85
+ # Populate the prompts list with the existing chat history
86
+ for entry in st.session_state['chat_history']:
87
+ for part in entry['parts']:
88
+ if 'text' in part:
89
+ prompts.append(part['text'])
90
+ elif 'data' in part:
91
+ # Add the image in base64 format to prompt_parts for vision model
92
+ prompt_parts.append({"data": part['data'], "mime_type": "image/jpeg"})
93
+ prompts.append("[Image]")
94
+
95
+ # Append the user input to the prompts list
96
  if user_input:
97
  prompts.append(user_input)
98
+ st.session_state['chat_history'].append({"role": "user", "parts": [{"text": user_input}]})
99
+ # Also add the user text input to prompt_parts
100
  prompt_parts.append({"text": user_input})
101
 
102
+ # Handle uploaded files
103
  if uploaded_files:
104
  for uploaded_file in uploaded_files:
105
  base64_image = get_image_base64(Image.open(uploaded_file))
106
  prompts.append("[Image]")
107
  prompt_parts.append({"data": base64_image, "mime_type": "image/jpeg"})
108
+ st.session_state['chat_history'].append({
109
  "role": "user",
110
  "parts": [{"mime_type": uploaded_file.type, "data": base64_image}]
111
  })
112
 
113
+ # Determine if vision model should be used
114
  use_vision_model = any(part.get('mime_type') == 'image/jpeg' for part in prompt_parts)
115
+
116
+ # Set up the model and generate a response
117
  model_name = 'gemini-pro-vision' if use_vision_model else 'gemini-pro'
118
  model = genai.GenerativeModel(
119
  model_name=model_name,
 
122
  )
123
  chat_history_str = "\n".join(prompts)
124
  if use_vision_model:
125
+ # Include text and images for vision model
126
  generated_prompt = {"role": "user", "parts": prompt_parts}
127
  else:
128
+ # Include text only for standard model
129
  generated_prompt = {"role": "user", "parts": [{"text": chat_history_str}]}
130
 
131
  response = model.generate_content([generated_prompt])
132
  response_text = response.text if hasattr(response, "text") else "No response text found."
133
 
134
+ # After generating the response from the model, append it to the chat history
135
  if response_text:
136
+ st.session_state['chat_history'].append({"role": "model", "parts":[{"text": response_text}]})
137
+
138
+ # Convert the response text to speech
139
  tts = gTTS(text=response_text, lang='en')
140
  tts_file = BytesIO()
141
  tts.write_to_fp(tts_file)
142
  tts_file.seek(0)
143
  st.audio(tts_file, format='audio/mp3')
144
 
145
+ # Clear the input fields after sending the message
146
  st.session_state.user_input = ''
147
+ st.session_state.uploaded_files = []
148
+ st.session_state.file_uploader_key = str(uuid.uuid4())
149
+
150
+ # Display the updated chat history
151
  display_chat_history()
152
 
153
+ # User input text area
154
+ user_input = st.text_area(
155
+ "Enter your message here:",
156
+ value="",
157
+ key="user_input"
158
+ )
159
+
160
+ # File uploader for images
161
+ uploaded_files = st.file_uploader(
162
+ "Upload images:",
163
+ type=["png", "jpg", "jpeg"],
164
+ accept_multiple_files=True,
165
+ key=st.session_state.file_uploader_key
166
+ )
167
+
168
+ # Send message button
169
+ send_button = st.button(
170
+ "Send",
171
+ on_click=send_message
172
+ )
173
+
174
+ # Clear conversation button
175
  clear_button = st.button("Clear Conversation", on_click=clear_conversation)
176
 
177
+ # Function to download the chat history as a text file
178
  def download_chat_history():
179
+ chat_history_str = get_chat_history_str()
180
+ return chat_history_str
181
+
182
+ # Download button for the chat history
183
+ download_button = st.download_button(
184
+ label="Download Chat",
185
+ data=download_chat_history(),
186
+ file_name="chat_history.txt",
187
+ mime="text/plain"
188
+ )
189
 
190
+ # Ensure the file_uploader widget state is tied to the randomly generated key
191
+ st.session_state.uploaded_files = uploaded_files
192
 
193
+ # JavaScript to capture the Ctrl+Enter event and trigger a button click
194
  st.markdown(
195
  """
196
  <script>
 
206
  """,
207
  unsafe_allow_html=True
208
  )
209
+