# Imports import streamlit as st from transformers import pipeline from PIL import Image import os import tempfile from gtts import gTTS # Preload and cache all models at app startup @st.cache_resource def load_models(): """Load all models and cache them for faster execution""" models = { "image_captioner": pipeline("image-to-text", model="sooh-j/blip-image-captioning-base"), "story_generator": pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0") } return models # Convert PIL Image to bytes for caching compatibility def get_image_bytes(pil_img): """Convert PIL image to bytes for hashing""" import io buf = io.BytesIO() pil_img.save(buf, format='JPEG') return buf.getvalue() # Simple image-to-text function using cached model @st.cache_data def img2text(image_bytes, _models): """Convert image to text with caching - using underscore for unhashable arg""" import io pil_img = Image.open(io.BytesIO(image_bytes)) result = _models["image_captioner"](pil_img) return result[0]["generated_text"] @st.cache_data def text2story(caption, _models): """Generate a short story from an image caption. Args: caption (str): Caption describing the image. _models (dict): Dictionary containing loaded models. Returns: str: A generated story that expands on the image caption. """ story_generator = _models["story_generator"] # Updated prompt includes explicit instructions to incorporate details from the caption prompt = ( "<|system|>\n" "You are a creative story generating assistant.\n" "<|user|>\n" f"Please generate a brief story within 100 words based on image caption and expands it without adding any introductory phrases, explanations, or separators:\n\"{caption}\"\n" "<|assistant|>" ) # Generate story with parameters tuned for creativity and coherence response = story_generator( prompt, max_new_tokens=120, # Enough tokens for a complete story do_sample=True, temperature=0.7, # Balanced creativity top_p=0.9, # Focus on more likely tokens for coherence repetition_penalty=1.2, # Reduce repetitive patterns eos_token_id=story_generator.tokenizer.eos_token_id ) # Extract the generated story text by separating the assistant's reply raw_story = response[0]['generated_text'] story_text = raw_story.split("<|assistant|>")[-1].strip() return story_text # Text-to-speech function @st.cache_data def text2audio(story_text): """Convert text to audio""" temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') temp_filename = temp_file.name temp_file.close() tts = gTTS(text=story_text, lang='en') tts.save(temp_filename) with open(temp_filename, 'rb') as audio_file: audio_bytes = audio_file.read() os.unlink(temp_filename) return audio_bytes # Load models at startup - this happens before the app interface is displayed models = load_models() # Streamlit app interface st.title("Image to Audio Story for Kids") # File uploader uploaded_file = st.file_uploader("Upload an image") if uploaded_file is not None: # Display image image = Image.open(uploaded_file) st.image(image, caption="Uploaded Image", use_container_width=True) # Process image with st.spinner("Processing..."): # Convert to bytes for caching image_bytes = get_image_bytes(image) # Generate caption caption = img2text(image_bytes, models) st.write(f"**Caption:** {caption}") # Generate story that expands on the caption story = text2story(caption, models) word_count = len(story.split()) st.write(f"**Story ({word_count} words):**") st.write(story) # Pre-generate audio if 'audio' not in st.session_state: st.session_state.audio = text2audio(story) # Play audio button if st.button("Play the audio"): st.audio(st.session_state.audio, format="audio/mp3")