Spaces:
Sleeping
Sleeping
# Imports | |
import streamlit as st | |
from transformers import pipeline | |
from PIL import Image | |
import os | |
import tempfile | |
from gtts import gTTS | |
# Preload and cache all models at app startup | |
def load_models(): | |
"""Load all models and cache them for faster execution""" | |
models = { | |
"image_captioner": pipeline("image-to-text", model="sooh-j/blip-image-captioning-base"), | |
"story_generator": pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0") | |
} | |
return models | |
# Convert PIL Image to bytes for caching compatibility | |
def get_image_bytes(pil_img): | |
"""Convert PIL image to bytes for hashing""" | |
import io | |
buf = io.BytesIO() | |
pil_img.save(buf, format='JPEG') | |
return buf.getvalue() | |
# Simple image-to-text function using cached model | |
def img2text(image_bytes, _models): | |
"""Convert image to text with caching - using underscore for unhashable arg""" | |
import io | |
pil_img = Image.open(io.BytesIO(image_bytes)) | |
result = _models["image_captioner"](pil_img) | |
return result[0]["generated_text"] | |
def text2story(caption, _models): | |
"""Generate a short story from an image caption. | |
Args: | |
caption (str): Caption describing the image. | |
_models (dict): Dictionary containing loaded models. | |
Returns: | |
str: A generated story that expands on the image caption. | |
""" | |
story_generator = _models["story_generator"] | |
# Updated prompt includes explicit instructions to incorporate details from the caption | |
prompt = ( | |
"<|system|>\n" | |
"You are a creative story generating assistant.\n" | |
"<|user|>\n" | |
f"Please generate a brief story within 100 words based on image caption and expands it without adding any introductory phrases, explanations, or separators:\n\"{caption}\"\n" | |
"<|assistant|>" | |
) | |
# Generate story with parameters tuned for creativity and coherence | |
response = story_generator( | |
prompt, | |
max_new_tokens=120, # Enough tokens for a complete story | |
do_sample=True, | |
temperature=0.7, # Balanced creativity | |
top_p=0.9, # Focus on more likely tokens for coherence | |
repetition_penalty=1.2, # Reduce repetitive patterns | |
eos_token_id=story_generator.tokenizer.eos_token_id | |
) | |
# Extract the generated story text by separating the assistant's reply | |
raw_story = response[0]['generated_text'] | |
story_text = raw_story.split("<|assistant|>")[-1].strip() | |
return story_text | |
# Text-to-speech function | |
def text2audio(story_text): | |
"""Convert text to audio""" | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') | |
temp_filename = temp_file.name | |
temp_file.close() | |
tts = gTTS(text=story_text, lang='en') | |
tts.save(temp_filename) | |
with open(temp_filename, 'rb') as audio_file: | |
audio_bytes = audio_file.read() | |
os.unlink(temp_filename) | |
return audio_bytes | |
# Load models at startup - this happens before the app interface is displayed | |
models = load_models() | |
# Streamlit app interface | |
st.title("Image to Audio Story for Kids") | |
# File uploader | |
uploaded_file = st.file_uploader("Upload an image") | |
if uploaded_file is not None: | |
# Display image | |
image = Image.open(uploaded_file) | |
st.image(image, caption="Uploaded Image", use_container_width=True) | |
# Process image | |
with st.spinner("Processing..."): | |
# Convert to bytes for caching | |
image_bytes = get_image_bytes(image) | |
# Generate caption | |
caption = img2text(image_bytes, models) | |
st.write(f"**Caption:** {caption}") | |
# Generate story that expands on the caption | |
story = text2story(caption, models) | |
word_count = len(story.split()) | |
st.write(f"**Story ({word_count} words):**") | |
st.write(story) | |
# Pre-generate audio | |
if 'audio' not in st.session_state: | |
st.session_state.audio = text2audio(story) | |
# Play audio button | |
if st.button("Play the audio"): | |
st.audio(st.session_state.audio, format="audio/mp3") |