Spaces:
Sleeping
Sleeping
File size: 4,192 Bytes
cf468d0 90bef38 8d5fabf ab8ead3 fbf01fc 8fe6281 4e987e0 fc13d66 fbf01fc fc13d66 4e987e0 fbf01fc 4e987e0 b62ae95 4e987e0 fbf01fc 1a8c2bf fbf01fc fc13d66 fbf01fc 4f45e40 8d5fabf fc13d66 4f45e40 090d705 fbf01fc 090d705 fbf01fc 090d705 fbf01fc 090d705 61b9f37 090d705 3ab269d 090d705 fbf01fc 4f45e40 45f5b70 4f45e40 fbf01fc 090d705 fbf01fc 5f21a2d 4f45e40 090d705 fbf01fc 5f21a2d fbf01fc 4e987e0 4f45e40 72c189d 4f45e40 fbf01fc 4f45e40 72c189d 4f45e40 72c189d fc13d66 fbf01fc 4e987e0 15c1038 fbf01fc c93e946 fc13d66 fbf01fc 4f45e40 15c1038 4f45e40 fbf01fc ab8ead3 fbf01fc 4f45e40 fbf01fc 1a8c2bf fbf01fc 4f45e40 fbf01fc 4e987e0 fbf01fc 4f45e40 fbf01fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
# Imports
import streamlit as st
from transformers import pipeline
from PIL import Image
import os
import tempfile
from gtts import gTTS
# Preload and cache all models at app startup
@st.cache_resource
def load_models():
"""Load all models and cache them for faster execution"""
models = {
"image_captioner": pipeline("image-to-text", model="sooh-j/blip-image-captioning-base"),
"story_generator": pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
}
return models
# Convert PIL Image to bytes for caching compatibility
def get_image_bytes(pil_img):
"""Convert PIL image to bytes for hashing"""
import io
buf = io.BytesIO()
pil_img.save(buf, format='JPEG')
return buf.getvalue()
# Simple image-to-text function using cached model
@st.cache_data
def img2text(image_bytes, _models):
"""Convert image to text with caching - using underscore for unhashable arg"""
import io
pil_img = Image.open(io.BytesIO(image_bytes))
result = _models["image_captioner"](pil_img)
return result[0]["generated_text"]
@st.cache_data
def text2story(caption, _models):
"""Generate a short story from an image caption.
Args:
caption (str): Caption describing the image.
_models (dict): Dictionary containing loaded models.
Returns:
str: A generated story that expands on the image caption.
"""
story_generator = _models["story_generator"]
# Updated prompt includes explicit instructions to incorporate details from the caption
prompt = (
"<|system|>\n"
"You are a creative story generating assistant.\n"
"<|user|>\n"
f"Please generate a brief story within 100 words based on image caption and expands it without adding any introductory phrases, explanations, or separators:\n\"{caption}\"\n"
"<|assistant|>"
)
# Generate story with parameters tuned for creativity and coherence
response = story_generator(
prompt,
max_new_tokens=120, # Enough tokens for a complete story
do_sample=True,
temperature=0.7, # Balanced creativity
top_p=0.9, # Focus on more likely tokens for coherence
repetition_penalty=1.2, # Reduce repetitive patterns
eos_token_id=story_generator.tokenizer.eos_token_id
)
# Extract the generated story text by separating the assistant's reply
raw_story = response[0]['generated_text']
story_text = raw_story.split("<|assistant|>")[-1].strip()
return story_text
# Text-to-speech function
@st.cache_data
def text2audio(story_text):
"""Convert text to audio"""
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3')
temp_filename = temp_file.name
temp_file.close()
tts = gTTS(text=story_text, lang='en')
tts.save(temp_filename)
with open(temp_filename, 'rb') as audio_file:
audio_bytes = audio_file.read()
os.unlink(temp_filename)
return audio_bytes
# Load models at startup - this happens before the app interface is displayed
models = load_models()
# Streamlit app interface
st.title("Image to Audio Story for Kids")
# File uploader
uploaded_file = st.file_uploader("Upload an image")
if uploaded_file is not None:
# Display image
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_container_width=True)
# Process image
with st.spinner("Processing..."):
# Convert to bytes for caching
image_bytes = get_image_bytes(image)
# Generate caption
caption = img2text(image_bytes, models)
st.write(f"**Caption:** {caption}")
# Generate story that expands on the caption
story = text2story(caption, models)
word_count = len(story.split())
st.write(f"**Story ({word_count} words):**")
st.write(story)
# Pre-generate audio
if 'audio' not in st.session_state:
st.session_state.audio = text2audio(story)
# Play audio button
if st.button("Play the audio"):
st.audio(st.session_state.audio, format="audio/mp3") |