Assignment1 / app.py
CR7CAD's picture
Update app.py
b62ae95 verified
# Imports
import streamlit as st
from transformers import pipeline
from PIL import Image
import os
import tempfile
from gtts import gTTS
# Preload and cache all models at app startup
@st.cache_resource
def load_models():
"""Load all models and cache them for faster execution"""
models = {
"image_captioner": pipeline("image-to-text", model="sooh-j/blip-image-captioning-base"),
"story_generator": pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
}
return models
# Convert PIL Image to bytes for caching compatibility
def get_image_bytes(pil_img):
"""Convert PIL image to bytes for hashing"""
import io
buf = io.BytesIO()
pil_img.save(buf, format='JPEG')
return buf.getvalue()
# Simple image-to-text function using cached model
@st.cache_data
def img2text(image_bytes, _models):
"""Convert image to text with caching - using underscore for unhashable arg"""
import io
pil_img = Image.open(io.BytesIO(image_bytes))
result = _models["image_captioner"](pil_img)
return result[0]["generated_text"]
@st.cache_data
def text2story(caption, _models):
"""Generate a short story from an image caption.
Args:
caption (str): Caption describing the image.
_models (dict): Dictionary containing loaded models.
Returns:
str: A generated story that expands on the image caption.
"""
story_generator = _models["story_generator"]
# Updated prompt includes explicit instructions to incorporate details from the caption
prompt = (
"<|system|>\n"
"You are a creative story generating assistant.\n"
"<|user|>\n"
f"Please generate a brief story within 100 words based on image caption and expands it without adding any introductory phrases, explanations, or separators:\n\"{caption}\"\n"
"<|assistant|>"
)
# Generate story with parameters tuned for creativity and coherence
response = story_generator(
prompt,
max_new_tokens=120, # Enough tokens for a complete story
do_sample=True,
temperature=0.7, # Balanced creativity
top_p=0.9, # Focus on more likely tokens for coherence
repetition_penalty=1.2, # Reduce repetitive patterns
eos_token_id=story_generator.tokenizer.eos_token_id
)
# Extract the generated story text by separating the assistant's reply
raw_story = response[0]['generated_text']
story_text = raw_story.split("<|assistant|>")[-1].strip()
return story_text
# Text-to-speech function
@st.cache_data
def text2audio(story_text):
"""Convert text to audio"""
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3')
temp_filename = temp_file.name
temp_file.close()
tts = gTTS(text=story_text, lang='en')
tts.save(temp_filename)
with open(temp_filename, 'rb') as audio_file:
audio_bytes = audio_file.read()
os.unlink(temp_filename)
return audio_bytes
# Load models at startup - this happens before the app interface is displayed
models = load_models()
# Streamlit app interface
st.title("Image to Audio Story for Kids")
# File uploader
uploaded_file = st.file_uploader("Upload an image")
if uploaded_file is not None:
# Display image
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_container_width=True)
# Process image
with st.spinner("Processing..."):
# Convert to bytes for caching
image_bytes = get_image_bytes(image)
# Generate caption
caption = img2text(image_bytes, models)
st.write(f"**Caption:** {caption}")
# Generate story that expands on the caption
story = text2story(caption, models)
word_count = len(story.split())
st.write(f"**Story ({word_count} words):**")
st.write(story)
# Pre-generate audio
if 'audio' not in st.session_state:
st.session_state.audio = text2audio(story)
# Play audio button
if st.button("Play the audio"):
st.audio(st.session_state.audio, format="audio/mp3")