Spaces:
Sleeping
Sleeping
| # Imports | |
| import streamlit as st | |
| from transformers import pipeline | |
| from PIL import Image | |
| import os | |
| import tempfile | |
| from gtts import gTTS | |
| # Preload and cache all models at app startup | |
| def load_models(): | |
| """Load all models and cache them for faster execution""" | |
| models = { | |
| "image_captioner": pipeline("image-to-text", model="sooh-j/blip-image-captioning-base"), | |
| "story_generator": pipeline("text-generation", model="trl-internal-testing/tiny-MistralForCausalLM-0.2") | |
| } | |
| return models | |
| # Convert PIL Image to bytes for caching compatibility | |
| def get_image_bytes(pil_img): | |
| """Convert PIL image to bytes for hashing""" | |
| import io | |
| buf = io.BytesIO() | |
| pil_img.save(buf, format='JPEG') | |
| return buf.getvalue() | |
| # Simple image-to-text function using cached model | |
| def img2text(image_bytes, _models): | |
| """Convert image to text with caching - using underscore for unhashable arg""" | |
| import io | |
| pil_img = Image.open(io.BytesIO(image_bytes)) | |
| result = _models["image_captioner"](pil_img) | |
| return result[0]["generated_text"] | |
| def text2story(caption, _models): | |
| """Generate a short story from an image caption. | |
| Args: | |
| caption (str): Caption describing the image. | |
| _models (dict): Dictionary containing loaded models. | |
| Returns: | |
| str: A generated story that expands on the image caption. | |
| """ | |
| story_generator = _models["story_generator"] | |
| # Updated prompt includes explicit instructions to incorporate details from the caption | |
| prompt = ( | |
| "<|system|>\n" | |
| "You are a creative story generating assistant.\n" | |
| "<|user|>\n" | |
| f"Please generate a short story within 100 words that clearly utilizes the following image caption and expands upon it with no introduction or closing remarks:\n\"{caption}\"\n" | |
| "<|assistant|>" | |
| ) | |
| # Generate story with parameters tuned for creativity and coherence | |
| response = story_generator( | |
| prompt, | |
| max_new_tokens=120, # Enough tokens for a complete story | |
| do_sample=True, | |
| temperature=0.7, # Balanced creativity | |
| top_p=0.9, # Focus on more likely tokens for coherence | |
| repetition_penalty=1.2, # Reduce repetitive patterns | |
| eos_token_id=story_generator.tokenizer.eos_token_id | |
| ) | |
| # Extract the generated story text by separating the assistant's reply | |
| raw_story = response[0]['generated_text'] | |
| story_text = raw_story.split("<|assistant|>")[-1].strip() | |
| return story_text | |
| # Text-to-speech function | |
| def text2audio(story_text): | |
| """Convert text to audio""" | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') | |
| temp_filename = temp_file.name | |
| temp_file.close() | |
| tts = gTTS(text=story_text, lang='en') | |
| tts.save(temp_filename) | |
| with open(temp_filename, 'rb') as audio_file: | |
| audio_bytes = audio_file.read() | |
| os.unlink(temp_filename) | |
| return audio_bytes | |
| # Load models at startup - this happens before the app interface is displayed | |
| models = load_models() | |
| # Streamlit app interface | |
| st.title("Image to Audio Story for Kids") | |
| # File uploader | |
| uploaded_file = st.file_uploader("Upload an image") | |
| if uploaded_file is not None: | |
| # Display image | |
| image = Image.open(uploaded_file) | |
| st.image(image, caption="Uploaded Image", use_container_width=True) | |
| # Process image | |
| with st.spinner("Processing..."): | |
| # Convert to bytes for caching | |
| image_bytes = get_image_bytes(image) | |
| # Generate caption | |
| caption = img2text(image_bytes, models) | |
| st.write(f"**Caption:** {caption}") | |
| # Generate story that expands on the caption | |
| story = text2story(caption, models) | |
| word_count = len(story.split()) | |
| st.write(f"**Story ({word_count} words):**") | |
| st.write(story) | |
| # Pre-generate audio | |
| if 'audio' not in st.session_state: | |
| st.session_state.audio = text2audio(story) | |
| # Play audio button | |
| if st.button("Play the audio"): | |
| st.audio(st.session_state.audio, format="audio/mp3") |