CR7CAD commited on
Commit
1a8c2bf
·
verified ·
1 Parent(s): 4f45e40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -21
app.py CHANGED
@@ -5,7 +5,6 @@ from PIL import Image
5
  import os
6
  import tempfile
7
  from gtts import gTTS
8
- import io
9
 
10
  # Preload and cache all models at app startup
11
  @st.cache_resource
@@ -13,23 +12,41 @@ def load_models():
13
  """Load all models and cache them for faster execution"""
14
  models = {
15
  "image_captioner": pipeline("image-to-text", model="sooh-j/blip-image-captioning-base"),
16
- "story_generator": pipeline("text-generation", model="Qwen/Qwen2.5-1.5B-Instruct")
17
  }
18
  return models
19
 
 
 
 
 
 
 
 
 
20
  # Simple image-to-text function using cached model
21
  @st.cache_data
22
- def img2text(image, _models):
23
- """Convert image to text with caching"""
24
- result = _models["image_captioner"](image)
 
 
25
  return result[0]["generated_text"]
26
 
27
  @st.cache_data
28
  def text2story(caption, _models):
29
- """Generate a short story from image caption"""
 
 
 
 
 
 
 
 
30
  story_generator = _models["story_generator"]
31
 
32
- # Format prompt
33
  prompt = f"""<|system|>
34
  You are a creative short story writer. Write a brief, engaging story that expands on the given image caption.
35
  The story should be under 100 words and have a natural beginning, middle, and end.
@@ -38,19 +55,23 @@ Image caption: "{caption}"
38
  Create a short story that expands on this image caption and brings it to life.
39
  <|assistant|>"""
40
 
41
- # Generate story
42
  response = story_generator(
43
  prompt,
44
- max_new_tokens=100,
45
  do_sample=True,
46
- temperature=0.7,
47
- top_p=0.9,
48
- repetition_penalty=1.2,
49
  eos_token_id=story_generator.tokenizer.eos_token_id
50
  )
51
 
52
- # Extract just the assistant's response
53
- story_text = response[0]['generated_text'].split("<|assistant|>")[-1].strip()
 
 
 
 
54
  return story_text
55
 
56
  # Text-to-speech function
@@ -70,8 +91,9 @@ def text2audio(story_text):
70
  os.unlink(temp_filename)
71
  return audio_bytes
72
 
73
- # Load models at startup
74
  models = load_models()
 
75
 
76
  # Streamlit app interface
77
  st.title("Image to Audio Story")
@@ -80,23 +102,26 @@ st.title("Image to Audio Story")
80
  uploaded_file = st.file_uploader("Upload an image")
81
 
82
  if uploaded_file is not None:
83
- # Display image at a smaller size (200px width instead of 300px)
84
  image = Image.open(uploaded_file)
85
- st.image(image, caption="Uploaded Image", width=200)
86
 
87
  # Process image
88
  with st.spinner("Processing..."):
89
- # Generate caption directly from the image (no need to convert to bytes)
90
- caption = img2text(image, models)
 
 
 
91
  st.write(f"**Caption:** {caption}")
92
 
93
- # Generate story
94
  story = text2story(caption, models)
95
  word_count = len(story.split())
96
  st.write(f"**Story ({word_count} words):**")
97
  st.write(story)
98
 
99
- # Generate audio
100
  if 'audio' not in st.session_state:
101
  st.session_state.audio = text2audio(story)
102
 
 
5
  import os
6
  import tempfile
7
  from gtts import gTTS
 
8
 
9
  # Preload and cache all models at app startup
10
  @st.cache_resource
 
12
  """Load all models and cache them for faster execution"""
13
  models = {
14
  "image_captioner": pipeline("image-to-text", model="sooh-j/blip-image-captioning-base"),
15
+ "story_generator": pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
16
  }
17
  return models
18
 
19
+ # Convert PIL Image to bytes for caching compatibility
20
+ def get_image_bytes(pil_img):
21
+ """Convert PIL image to bytes for hashing"""
22
+ import io
23
+ buf = io.BytesIO()
24
+ pil_img.save(buf, format='JPEG')
25
+ return buf.getvalue()
26
+
27
  # Simple image-to-text function using cached model
28
  @st.cache_data
29
+ def img2text(image_bytes, _models):
30
+ """Convert image to text with caching - using underscore for unhashable arg"""
31
+ import io
32
+ pil_img = Image.open(io.BytesIO(image_bytes))
33
+ result = _models["image_captioner"](pil_img)
34
  return result[0]["generated_text"]
35
 
36
  @st.cache_data
37
  def text2story(caption, _models):
38
+ """Generate a short story from image caption.
39
+
40
+ Args:
41
+ caption: Caption describing the image
42
+ _models: Dictionary containing loaded models
43
+
44
+ Returns:
45
+ A generated story that expands on the image caption
46
+ """
47
  story_generator = _models["story_generator"]
48
 
49
+ # Format prompt to ensure the story expands on the image caption
50
  prompt = f"""<|system|>
51
  You are a creative short story writer. Write a brief, engaging story that expands on the given image caption.
52
  The story should be under 100 words and have a natural beginning, middle, and end.
 
55
  Create a short story that expands on this image caption and brings it to life.
56
  <|assistant|>"""
57
 
58
+ # Generate story with parameters tuned for brevity and coherence
59
  response = story_generator(
60
  prompt,
61
+ max_new_tokens=100, # Allow enough tokens for a complete story
62
  do_sample=True,
63
+ temperature=0.7, # Balanced creativity
64
+ top_p=0.9, # Focus on more likely tokens
65
+ repetition_penalty=1.2, # Avoid repetitive patterns
66
  eos_token_id=story_generator.tokenizer.eos_token_id
67
  )
68
 
69
+ # Extract just the generated story text
70
+ raw_story = response[0]['generated_text']
71
+
72
+ # Parse out just the assistant's response from the conversation format
73
+ story_text = raw_story.split("<|assistant|>")[-1].strip()
74
+
75
  return story_text
76
 
77
  # Text-to-speech function
 
91
  os.unlink(temp_filename)
92
  return audio_bytes
93
 
94
+ # Load models at startup - this happens before the app interface is displayed
95
  models = load_models()
96
+ st.write("✅ Models loaded and cached!")
97
 
98
  # Streamlit app interface
99
  st.title("Image to Audio Story")
 
102
  uploaded_file = st.file_uploader("Upload an image")
103
 
104
  if uploaded_file is not None:
105
+ # Display image
106
  image = Image.open(uploaded_file)
107
+ st.image(image, caption="Uploaded Image", width=300)
108
 
109
  # Process image
110
  with st.spinner("Processing..."):
111
+ # Convert to bytes for caching
112
+ image_bytes = get_image_bytes(image)
113
+
114
+ # Generate caption
115
+ caption = img2text(image_bytes, models)
116
  st.write(f"**Caption:** {caption}")
117
 
118
+ # Generate story that expands on the caption
119
  story = text2story(caption, models)
120
  word_count = len(story.split())
121
  st.write(f"**Story ({word_count} words):**")
122
  st.write(story)
123
 
124
+ # Pre-generate audio
125
  if 'audio' not in st.session_state:
126
  st.session_state.audio = text2audio(story)
127