File size: 4,192 Bytes
cf468d0
90bef38
8d5fabf
ab8ead3
fbf01fc
8fe6281
4e987e0
fc13d66
fbf01fc
fc13d66
4e987e0
fbf01fc
 
4e987e0
b62ae95
4e987e0
fbf01fc
 
 
 
 
 
 
 
 
1a8c2bf
fbf01fc
fc13d66
fbf01fc
 
 
 
 
4f45e40
8d5fabf
fc13d66
4f45e40
090d705
fbf01fc
 
090d705
 
fbf01fc
 
090d705
fbf01fc
 
090d705
 
 
 
61b9f37
090d705
3ab269d
090d705
 
 
 
fbf01fc
4f45e40
45f5b70
4f45e40
fbf01fc
090d705
 
fbf01fc
5f21a2d
4f45e40
090d705
fbf01fc
 
 
 
5f21a2d
fbf01fc
4e987e0
4f45e40
 
72c189d
4f45e40
 
 
fbf01fc
4f45e40
 
 
72c189d
4f45e40
 
72c189d
fc13d66
fbf01fc
4e987e0
15c1038
fbf01fc
c93e946
fc13d66
fbf01fc
4f45e40
15c1038
4f45e40
fbf01fc
ab8ead3
fbf01fc
4f45e40
fbf01fc
 
 
 
 
1a8c2bf
fbf01fc
4f45e40
 
fbf01fc
4e987e0
fbf01fc
 
4f45e40
 
fbf01fc
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Imports
import streamlit as st
from transformers import pipeline
from PIL import Image
import os
import tempfile
from gtts import gTTS

# Preload and cache all models at app startup
@st.cache_resource
def load_models():
    """Load all models and cache them for faster execution"""
    models = {
        "image_captioner": pipeline("image-to-text", model="sooh-j/blip-image-captioning-base"),
        "story_generator": pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
    }
    return models

# Convert PIL Image to bytes for caching compatibility
def get_image_bytes(pil_img):
    """Convert PIL image to bytes for hashing"""
    import io
    buf = io.BytesIO()
    pil_img.save(buf, format='JPEG')
    return buf.getvalue()

# Simple image-to-text function using cached model
@st.cache_data
def img2text(image_bytes, _models):
    """Convert image to text with caching - using underscore for unhashable arg"""
    import io
    pil_img = Image.open(io.BytesIO(image_bytes))
    result = _models["image_captioner"](pil_img)
    return result[0]["generated_text"]

@st.cache_data
def text2story(caption, _models):
    """Generate a short story from an image caption.
    
    Args:
        caption (str): Caption describing the image.
        _models (dict): Dictionary containing loaded models.
        
    Returns:
        str: A generated story that expands on the image caption.
    """
    story_generator = _models["story_generator"]

    # Updated prompt includes explicit instructions to incorporate details from the caption
    prompt = (
        "<|system|>\n"
        "You are a creative story generating assistant.\n"
        "<|user|>\n"
        f"Please generate a brief story within 100 words based on image caption and expands it without adding any introductory phrases, explanations, or separators:\n\"{caption}\"\n"
        "<|assistant|>"
    )

    # Generate story with parameters tuned for creativity and coherence
    response = story_generator(
        prompt,
        max_new_tokens=120,  # Enough tokens for a complete story
        do_sample=True,
        temperature=0.7,     # Balanced creativity
        top_p=0.9,           # Focus on more likely tokens for coherence
        repetition_penalty=1.2,  # Reduce repetitive patterns
        eos_token_id=story_generator.tokenizer.eos_token_id
    )
    
    # Extract the generated story text by separating the assistant's reply
    raw_story = response[0]['generated_text']
    story_text = raw_story.split("<|assistant|>")[-1].strip()
    
    return story_text

# Text-to-speech function
@st.cache_data
def text2audio(story_text):
    """Convert text to audio"""
    temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3')
    temp_filename = temp_file.name
    temp_file.close()
    
    tts = gTTS(text=story_text, lang='en')
    tts.save(temp_filename)
    
    with open(temp_filename, 'rb') as audio_file:
        audio_bytes = audio_file.read()
    
    os.unlink(temp_filename)
    return audio_bytes

# Load models at startup - this happens before the app interface is displayed
models = load_models()

# Streamlit app interface
st.title("Image to Audio Story for Kids")

# File uploader
uploaded_file = st.file_uploader("Upload an image")

if uploaded_file is not None:
    # Display image
    image = Image.open(uploaded_file)
    st.image(image, caption="Uploaded Image", use_container_width=True)
    
    # Process image
    with st.spinner("Processing..."):
        # Convert to bytes for caching
        image_bytes = get_image_bytes(image)
        
        # Generate caption
        caption = img2text(image_bytes, models)
        st.write(f"**Caption:** {caption}")
        
        # Generate story that expands on the caption
        story = text2story(caption, models)
        word_count = len(story.split())
        st.write(f"**Story ({word_count} words):**")
        st.write(story)
        
        # Pre-generate audio
        if 'audio' not in st.session_state:
            st.session_state.audio = text2audio(story)
    
    # Play audio button
    if st.button("Play the audio"):
        st.audio(st.session_state.audio, format="audio/mp3")