Nbfit's picture
Update app.py
f1b936e verified
# ======================
# Import Section
# ======================
# Core Libraries
import io # Input/output operations for byte streams
# AI/ML Frameworks
from transformers import pipeline # Hugging Face transformers pipeline
import torch # PyTorch tensor operations
# Audio Processing
import soundfile as sf # Audio file I/O operations
# Image Processing
from PIL import Image # Image manipulation library
# Data Handling
from datasets import load_dataset # Hugging Face datasets loader
# Web Interface
import streamlit as st # Web app framework
# ======================
# Model Loading Functions
# ======================
@st.cache_resource
def load_caption_pipeline():
"""Initialize and cache the image captioning pipeline.
Returns:
Pipeline: BLIP model for image-to-text generation
"""
return pipeline("image-to-text", model="Salesforce/blip-image-captioning-large",use_fast=True)
@st.cache_resource
def load_story_pipeline():
"""Initialize and cache the story generation pipeline.
Returns:
Pipeline: Fine-tuned LLaMA model for children's story generation
"""
return pipeline("text-generation", model="wy2001/storygenratorllama3.21b",use_fast=True)
@st.cache_resource
def load_tts_pipeline():
"""Initialize and cache the text-to-speech pipeline.
Returns:
Pipeline: Microsoft's SpeechT5 for high-quality speech synthesis
"""
return pipeline("text-to-speech", model="microsoft/speecht5_tts",use_fast=True)
# ======================
# Core Processing Functions
# ======================
@st.cache_data(show_spinner=False, max_entries=3)
def generate_image_caption(image: Image.Image) -> str:
"""Generate descriptive caption for uploaded image.
Args:
image (PIL.Image): RGB formatted input image
Returns:
str: Generated image caption
Raises:
StreamlitError: If caption generation fails
"""
try:
img2caption = load_caption_pipeline()
# Generate caption
caption = img2caption(image)[0]['generated_text']
return caption
except Exception as e:
st.error(f"πŸ” The caption fairy is confused about the picture! says: {str(e)}")
st.stop()
@st.cache_data(show_spinner=False, max_entries=3)
def generate_story(caption: str) -> str:
"""Generate child-friendly story from image caption.
Args:
caption (str): Image description from previous step
Returns:
str: Generated story (60-80 words) with happy ending
Raises:
StreamlitError: If story generation fails
"""
try:
messages = [{"role": "user", "content": f"Creating a story for 3-10 years old kids about {caption} between 60 to 80 words with friendly words and happy ending. present the story itself only."},]
cap2story = load_story_pipeline()
output = cap2story(messages,max_new_tokens=200,num_return_sequences=1)
story = output[0]['generated_text'][1]['content']
return story
except Exception as e:
st.error(f"🧚 The writing fairy is sleeping! says: {str(e)}")
st.stop()
@st.cache_resource
def load_speaker_embeddings():
"""loading the embedding dataset that model required"""
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
return speaker_embeddings
@st.cache_data(show_spinner=False, max_entries=3)
def read_story(story):
"""Convert generated story to speech audio.
Args:
story (str): Generated story text
Returns:
io.BytesIO: Audio buffer in WAV format
Raises:
StreamlitError: If audio generation fails
"""
try:
text2speech = load_tts_pipeline()
audio_data = text2speech(story,forward_params={"speaker_embeddings": load_speaker_embeddings()})
audio_buffer = io.BytesIO()
sf.write(audio_buffer, audio_data["audio"], samplerate=audio_data["sampling_rate"],format='WAV')
audio_buffer.seek(0)
return audio_buffer
except Exception as e:
st.error(f"πŸ”Š The reading fairy is sneezing! says: {str(e)}")
st.stop()
# ======================
# Main Application
# ======================
def main():
"""Main application flow and UI configuration."""
# Configure page settings
st.set_page_config(
page_title="Magic Story Time",
page_icon="🧚",
layout="centered",
initial_sidebar_state="expanded"
)
# Custom CSS styling
st.markdown("""
<style>
.story-box {
background: linear-gradient(145deg, #fff1eb 0%, #ace0f9 100%);
border-radius: 15px;
padding: 25px;
font-size: 1.1em;
line-height: 1.8;
color: #2c3e50;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
margin: 20px 0;
}
.upload-section {
border: 2px dashed #4CAF50;
border-radius: 10px;
padding: 20px;
background: rgba(76, 175, 80, 0.05);
}
</style>
""", unsafe_allow_html=True)
# Sidebar - Image Upload
with st.sidebar:
st.header("πŸ–ΌοΈ Upload Your Magic Drawing Paper")
uploaded_image = st.file_uploader(
label="Upload an image",
type=["jpg", "jpeg", "png"],
help="Format in JPEG/JPG/PNG, max 1MB",
key="image_uploader",
accept_multiple_files=False
)
if uploaded_image:
st.success(f"πŸ” The caption fairy received your image: {uploaded_image.name}")
# Main Content Area
# App title
st.title("🧚 Magic Story Camp")
st.markdown("---")
# input validation
if uploaded_image:
# Validate file specifications
if uploaded_image.size > 1024* 1024:
st.error("πŸ” The caption fairy says the image is too big! please give me image under 1MB")
st.stop()
if uploaded_image.type not in ["image/jpeg", "image/png"]:
st.error("πŸ” The caption fairy says only JPG/PNG allowed!")
st.stop()
# Processing pipeline
with st.spinner("πŸ§™ The fairies are casting magic spells, it may take some time⏳..."):
try:
# Convert to RGB format for model compatibility
image = Image.open(uploaded_image).convert("RGB")
# Display processing UI elements
status_display = st.empty()
progress_bar = st.progress(0)
# Image preview expander
with st.expander("view the image", expanded=True):
st.image(image, use_container_width=True)
# Processing stages
# Stage 1: Image Captioning
status_display.markdown("πŸ” **The caption fairy is viewing the image...**")
progress_bar.progress(25)
caption = generate_image_caption(image)
# Stage 2: Story Generation
status_display.markdown("🧚 **The writing fairy is writing the story...**")
progress_bar.progress(50)
story = generate_story(caption)
# Stage 3: Audio Synthesis
status_display.markdown("πŸ”Š **The reading fairy is preparing audio magic...**")
progress_bar.progress(75)
speech = read_story(story)
# Finish
progress_bar.progress(100)
status_display.markdown("🧚 **The Story is ready!**")
# Display formatted story
st.markdown("### πŸ“– Your Magic story")
st.markdown(f'<div class="story-box">{story}</div>', unsafe_allow_html=True)
# Audio playback and download
st.audio(speech, format="audio/wav")
st.download_button(
"🎡 Download Story",
data=speech,
file_name="magic_story.wav",
mime="audio/wav",
help="click to download your story"
)
except Exception as e:
st.error(f"πŸ’₯ The magic spell broke! Please try again. {str(e)}")
st.stop()
else:
# Page instructions
st.markdown("""
<div class="upload-section">
<h3 style="color:#4CAF50; text-align:center;">❓ guidance</h3>
1. πŸ–ΌοΈ Upload Your Picture in the sidebar<br>
2. Wait for the magic sparkles may take 10 min βœ¨οΌ‰<br>
3. Read/listen to your story and download with 🎡 button!<br>
<br>
Note: First-time model loading may take longer.<br>
Please have a glass of juice and be patient for a few moments<br>
</div>
""", unsafe_allow_html=True)
if __name__ == "__main__":
main()