CR7CAD commited on
Commit
4e987e0
·
verified ·
1 Parent(s): 9287e9e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -165
app.py CHANGED
@@ -1,86 +1,22 @@
1
- # Imports
2
  import streamlit as st
3
  from transformers import pipeline
4
  from PIL import Image
5
- import torch
6
  import os
7
  import tempfile
8
- import time
9
- import numpy as np
10
 
11
- # Use Streamlit's caching mechanisms to optimize model loading
12
  @st.cache_resource
13
- def load_image_to_text_pipeline():
14
- """Load and cache the image-to-text model"""
15
- return pipeline("image-to-text", model="sooh-j/blip-image-captioning-base")
16
-
17
- @st.cache_resource
18
- def load_text_generation_pipeline():
19
- """Load and cache the text generation model"""
20
- return pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
21
-
22
- @st.cache_resource
23
- def load_tts_pipeline():
24
- """Load and cache the text-to-speech pipeline as fallback"""
25
- try:
26
- return pipeline("text-to-speech", model="facebook/mms-tts-eng")
27
- except:
28
- # Return None if loading fails
29
- return None
30
-
31
- # Initialize all models at app startup
32
- with st.spinner("Loading models (this may take a moment the first time)..."):
33
- # Load all models at startup and cache them
34
- img2text_model = load_image_to_text_pipeline()
35
- story_generator_model = load_text_generation_pipeline()
36
- tts_fallback_model = load_tts_pipeline()
37
-
38
- # For TTS, try multiple options in order of preference
39
- try:
40
- # Try importing gTTS
41
- from gtts import gTTS
42
- has_gtts = True
43
- except ImportError:
44
- has_gtts = False
45
- if tts_fallback_model is None:
46
- st.warning("No text-to-speech capability available. Audio generation will be disabled.")
47
-
48
- # Cache the text-to-audio conversion
49
- @st.cache_data
50
- def text2audio(story_text):
51
- """Convert text to audio with caching to avoid regenerating the same audio"""
52
- if has_gtts:
53
- # Use gTTS
54
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3')
55
- temp_filename = temp_file.name
56
- temp_file.close()
57
-
58
- # Use gTTS to convert text to speech
59
- tts = gTTS(text=story_text, lang='en', slow=False)
60
- tts.save(temp_filename)
61
-
62
- # Read the audio file
63
- with open(temp_filename, 'rb') as audio_file:
64
- audio_bytes = audio_file.read()
65
-
66
- # Clean up the temporary file
67
- os.unlink(temp_filename)
68
-
69
- return audio_bytes, 'audio/mp3'
70
- elif tts_fallback_model is not None:
71
- # Use transformers TTS
72
- speech = tts_fallback_model(story_text)
73
-
74
- # Return the audio data
75
- if 'audio' in speech:
76
- return speech['audio'], speech.get('sampling_rate', 16000)
77
- elif 'audio_array' in speech:
78
- return speech['audio_array'], speech.get('sampling_rate', 16000)
79
-
80
- # If we got here, no TTS method worked
81
- raise Exception("No text-to-speech capability available")
82
 
83
- # Convert PIL Image to bytes for hashing in cache
84
  def get_image_bytes(pil_img):
85
  """Convert PIL image to bytes for hashing"""
86
  import io
@@ -90,29 +26,21 @@ def get_image_bytes(pil_img):
90
 
91
  # Simple image-to-text function using cached model
92
  @st.cache_data
93
- def img2text(image_bytes):
94
- """Convert image to text with caching - using bytes for caching compatibility"""
95
- # Convert bytes back to PIL image for processing
96
  import io
97
- from PIL import Image
98
  pil_img = Image.open(io.BytesIO(image_bytes))
99
-
100
- # Process with the model
101
- result = img2text_model(pil_img)
102
  return result[0]["generated_text"]
103
 
104
- # Helper function to count words
105
- def count_words(text):
106
- return len(text.split())
107
-
108
- # Improved text-to-story function without "Once upon a time" constraint
109
  @st.cache_data
110
- def text2story(text):
111
- generator = pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
112
  prompt = f"Write a short children's story based on this: {text}. The story should have a clear beginning, middle, and end. Keep it under 150 words. "
113
 
114
  # Generate a longer text to ensure we get a complete story
115
- story_result = generator(
116
  prompt,
117
  max_length=300,
118
  num_return_sequences=1,
@@ -148,89 +76,57 @@ def text2story(text):
148
  # If no good ending is found, return as is
149
  return story_text
150
 
151
- # Function to reset progress when a new file is uploaded
152
- def reset_progress():
153
- st.session_state.progress = {
154
- 'caption_generated': False,
155
- 'story_generated': False,
156
- 'audio_generated': False,
157
- 'caption': '',
158
- 'story': '',
159
- 'audio_data': None,
160
- 'audio_format': None
161
- }
162
-
163
- # Basic Streamlit interface
164
- st.title("Image to Audio Story")
 
 
165
 
166
- # Add processing status indicator
167
- status_container = st.empty()
 
168
 
169
- # Initialize session state for tracking progress
170
- if 'progress' not in st.session_state:
171
- st.session_state.progress = {
172
- 'caption_generated': False,
173
- 'story_generated': False,
174
- 'audio_generated': False,
175
- 'caption': '',
176
- 'story': '',
177
- 'audio_data': None,
178
- 'audio_format': None
179
- }
180
 
181
  # File uploader
182
- uploaded_file = st.file_uploader("Upload an image", on_change=reset_progress)
183
 
184
- # Process the image if uploaded
185
  if uploaded_file is not None:
186
  # Display image
187
- st.image(uploaded_file, caption="Uploaded Image")
188
-
189
- # Convert to PIL Image
190
  image = Image.open(uploaded_file)
 
191
 
192
- # Convert image to bytes for caching compatibility
193
- image_bytes = get_image_bytes(image)
194
-
195
- # Image to Text (if not already done)
196
- if not st.session_state.progress['caption_generated']:
197
- status_container.info("Generating caption...")
198
- st.session_state.progress['caption'] = img2text(image_bytes)
199
- st.session_state.progress['caption_generated'] = True
200
-
201
- st.write(f"Caption: {st.session_state.progress['caption']}")
202
-
203
- # Text to Story (if not already done)
204
- if not st.session_state.progress['story_generated']:
205
- status_container.info("Creating story...")
206
- st.session_state.progress['story'] = text2story(st.session_state.progress['caption'])
207
- st.session_state.progress['story_generated'] = True
208
-
209
- # Display word count for transparency
210
- word_count = count_words(st.session_state.progress['story'])
211
- st.write(f"Story ({word_count} words):")
212
- st.write(st.session_state.progress['story'])
213
-
214
- # Pre-generate audio in background (if not already done)
215
- if not st.session_state.progress['audio_generated'] and (has_gtts or tts_fallback_model is not None):
216
- status_container.info("Pre-generating audio in background...")
217
- try:
218
- st.session_state.progress['audio_data'], st.session_state.progress['audio_format'] = text2audio(st.session_state.progress['story'])
219
- st.session_state.progress['audio_generated'] = True
220
- status_container.success("Ready to play audio!")
221
- except Exception as e:
222
- status_container.error(f"Error pre-generating audio: {e}")
223
 
224
- # Button to play audio
225
  if st.button("Play the audio"):
226
- if st.session_state.progress['audio_generated']:
227
- # Display the audio player
228
- if isinstance(st.session_state.progress['audio_format'], str) and st.session_state.progress['audio_format'].startswith('audio/'):
229
- st.audio(st.session_state.progress['audio_data'], format=st.session_state.progress['audio_format'])
230
- else:
231
- st.audio(st.session_state.progress['audio_data'], sample_rate=st.session_state.progress['audio_format'])
232
- else:
233
- # Handle case where audio generation failed or is not available
234
- st.error("Unable to play audio. Audio generation was not successful.")
235
- else:
236
- status_container.info("Upload an image to begin")
 
1
+ # Imports - just the essentials
2
  import streamlit as st
3
  from transformers import pipeline
4
  from PIL import Image
 
5
  import os
6
  import tempfile
7
+ from gtts import gTTS
 
8
 
9
+ # Preload and cache all models at app startup
10
  @st.cache_resource
11
+ def load_models():
12
+ """Load all models and cache them for faster execution"""
13
+ models = {
14
+ "image_captioner": pipeline("image-to-text", model="sooh-j/blip-image-captioning-base"),
15
+ "story_generator": pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
16
+ }
17
+ return models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ # Convert PIL Image to bytes for caching compatibility
20
  def get_image_bytes(pil_img):
21
  """Convert PIL image to bytes for hashing"""
22
  import io
 
26
 
27
  # Simple image-to-text function using cached model
28
  @st.cache_data
29
+ def img2text(image_bytes, models):
30
+ """Convert image to text with caching"""
 
31
  import io
 
32
  pil_img = Image.open(io.BytesIO(image_bytes))
33
+ result = models["image_captioner"](pil_img)
 
 
34
  return result[0]["generated_text"]
35
 
36
+ # Generate story from text - using your approach with caching
 
 
 
 
37
  @st.cache_data
38
+ def text2story(text, models):
39
+ """Generate a story from text with sensible endings"""
40
  prompt = f"Write a short children's story based on this: {text}. The story should have a clear beginning, middle, and end. Keep it under 150 words. "
41
 
42
  # Generate a longer text to ensure we get a complete story
43
+ story_result = models["story_generator"](
44
  prompt,
45
  max_length=300,
46
  num_return_sequences=1,
 
76
  # If no good ending is found, return as is
77
  return story_text
78
 
79
+ # Text-to-speech function
80
+ @st.cache_data
81
+ def text2audio(story_text):
82
+ """Convert text to audio"""
83
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3')
84
+ temp_filename = temp_file.name
85
+ temp_file.close()
86
+
87
+ tts = gTTS(text=story_text, lang='en')
88
+ tts.save(temp_filename)
89
+
90
+ with open(temp_filename, 'rb') as audio_file:
91
+ audio_bytes = audio_file.read()
92
+
93
+ os.unlink(temp_filename)
94
+ return audio_bytes
95
 
96
+ # Load models at startup - this happens before the app interface is displayed
97
+ models = load_models()
98
+ st.write("✅ Models loaded and cached!")
99
 
100
+ # Streamlit app interface
101
+ st.title("Image to Audio Story")
 
 
 
 
 
 
 
 
 
102
 
103
  # File uploader
104
+ uploaded_file = st.file_uploader("Upload an image")
105
 
 
106
  if uploaded_file is not None:
107
  # Display image
 
 
 
108
  image = Image.open(uploaded_file)
109
+ st.image(image, caption="Uploaded Image", width=300)
110
 
111
+ # Process image
112
+ with st.spinner("Processing..."):
113
+ # Convert to bytes for caching
114
+ image_bytes = get_image_bytes(image)
115
+
116
+ # Generate caption
117
+ caption = img2text(image_bytes, models)
118
+ st.write(f"**Caption:** {caption}")
119
+
120
+ # Generate story
121
+ story = text2story(caption, models)
122
+ word_count = len(story.split())
123
+ st.write(f"**Story ({word_count} words):**")
124
+ st.write(story)
125
+
126
+ # Pre-generate audio
127
+ if 'audio' not in st.session_state:
128
+ st.session_state.audio = text2audio(story)
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
+ # Play audio button
131
  if st.button("Play the audio"):
132
+ st.audio(st.session_state.audio, format="audio/mp3")