Nbfit commited on
Commit
1cfd505
·
verified ·
1 Parent(s): 338f7c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +228 -242
app.py CHANGED
@@ -1,282 +1,268 @@
 
 
 
 
 
 
 
 
 
1
 
2
  # ======================
3
- # Import Section
4
  # ======================
5
 
6
- # Core Libraries
7
- import io # Input/output operations for byte streams
8
-
9
- # AI/ML Frameworks
10
- from transformers import pipeline # Hugging Face transformers pipeline
11
- import torch # PyTorch tensor operations
12
-
13
- # Audio Processing
14
- import numpy as np
15
- from scipy.io.wavfile import write as write_wav # Audio file I/O operations
16
-
17
- # Image Processing
18
- from PIL import Image # Image manipulation library
19
 
 
 
 
 
20
 
21
- # Web Interface
22
- import streamlit as st # Web app framework
 
 
23
 
 
 
 
 
24
 
25
  # ======================
26
- # Model Loading Functions
27
  # ======================
28
 
29
  @st.cache_resource
30
- def load_caption_pipeline():
31
- """Initialize and cache the image captioning pipeline.
32
-
33
- Returns:
34
- Pipeline: BLIP model for image-to-text generation
35
- """
36
- return pipeline("image-to-text", model="Salesforce/blip-image-captioning-base",use_fast=True)
37
-
38
- @st.cache_resource
39
- def load_story_pipeline():
40
- """Initialize and cache the story generation pipeline.
41
-
42
- Returns:
43
- Pipeline: Fine-tuned LLaMA model for children's story generation
44
- """
45
- return pipeline("text-generation", model="aspis/gpt2-genre-story-generation",use_fast=True)
46
-
47
- @st.cache_resource
48
- def load_tts_pipeline():
49
- """Initialize and cache the text-to-speech pipeline.
50
-
51
- Returns:
52
- Pipeline: Microsoft's SpeechT5 for high-quality speech synthesis
53
- """
54
- return pipeline("text-to-speech", model="facebook/mms-tts-eng",use_fast=True)
55
-
56
 
57
  # ======================
58
- # Core Processing Functions
59
  # ======================
60
 
61
  @st.cache_data(show_spinner=False, max_entries=3)
62
- def generate_image_caption(image: Image.Image) -> str:
63
- """Generate descriptive caption for uploaded image.
64
-
65
- Args:
66
- image (PIL.Image): RGB formatted input image
67
-
68
- Returns:
69
- str: Generated image caption
70
-
71
- Raises:
72
- StreamlitError: If caption generation fails
73
- """
74
  try:
75
- img2caption = load_caption_pipeline()
76
- # Generate caption
77
- caption = img2caption(image)[0]['generated_text']
78
- return caption
79
  except Exception as e:
80
- st.error(f"🔍 The caption fairy is confused about the picture! says: {str(e)}")
81
  st.stop()
82
 
83
  @st.cache_data(show_spinner=False, max_entries=3)
84
- def generate_story(caption: str) -> str:
85
- """Generate child-friendly story from image caption.
86
-
87
- Args:
88
- caption (str): Image description from previous step
89
-
90
- Returns:
91
- str: Generated story (60-80 words) with happy ending
92
-
93
- Raises:
94
- StreamlitError: If story generation fails
95
- """
96
  try:
97
- cap2story = load_story_pipeline()
98
- output = cap2story(f"<BOS> <adventure>{caption}",truncation=True,max_length=80,do_sample=True,repetition_penalty=1.5, temperature=1.2,top_p=0.95,top_k=50,no_repeat_ngram_size=2)
99
- story = output[0]['generated_text'].replace("<BOS> <adventure>","")
100
- return story
101
  except Exception as e:
102
- st.error(f"🧚 The writing fairy is sleeping! says: {str(e)}")
103
  st.stop()
104
 
105
  @st.cache_data(show_spinner=False, max_entries=3)
106
- def read_story(story):
107
- """Convert generated story to speech audio.
108
-
109
- Args:
110
- story (str): Generated story text
111
-
112
- Returns:
113
- io.BytesIO: Audio buffer in WAV format
114
-
115
- Raises:
116
- StreamlitError: If audio generation fails
117
- """
118
  try:
119
- text2speech = load_tts_pipeline()
120
- audio_data = text2speech(story)
121
- audio_buffer = io.BytesIO()
122
- audio_array = np.array(audio_data["audio"])
123
- if len(audio_array.shape) == 2:
124
- audio_array = audio_array[:, 0]
125
- if audio_array.dtype == np.float32:
126
- audio_array = (audio_array * 32767).astype(np.int16)
127
- elif audio_array.dtype != np.int16:
128
- audio_array = audio_array.astype(np.int16)
129
- channels = 1
130
- bit_depth = 16
131
- sample_rate = int(audio_data["sampling_rate"])
132
- block_align = channels * (bit_depth // 8)
133
- bytes_per_second = sample_rate * block_align
134
- write_wav(audio_buffer, sample_rate, audio_array)
135
- audio_buffer.seek(0)
136
- return audio_buffer
137
  except Exception as e:
138
- st.error(f"🔊 The reading fairy is sneezing! says: {str(e)}")
139
- st.stop()
140
 
 
 
 
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  # ======================
145
- # Main Application
146
  # ======================
147
 
148
  def main():
149
- """Main application flow and UI configuration."""
150
-
151
- # Configure page settings
152
- st.set_page_config(
153
- page_title="Magic Story Time",
154
- page_icon="🧚",
155
- layout="centered",
156
- initial_sidebar_state="expanded"
157
- )
158
-
159
- # Custom CSS styling
160
- st.markdown("""
161
- <style>
162
- .story-box {
163
- background: linear-gradient(145deg, #fff1eb 0%, #ace0f9 100%);
164
- border-radius: 15px;
165
- padding: 25px;
166
- font-size: 1.1em;
167
- line-height: 1.8;
168
- color: #2c3e50;
169
- box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
170
- margin: 20px 0;
171
- }
172
- .upload-section {
173
- border: 2px dashed #4CAF50;
174
- border-radius: 10px;
175
- padding: 20px;
176
- background: rgba(76, 175, 80, 0.05);
177
- }
178
- </style>
179
- """, unsafe_allow_html=True)
180
-
181
-
182
- # Sidebar - Image Upload
183
- with st.sidebar:
184
- st.header("🖼️ Upload Your Magic Drawing Paper")
185
- uploaded_image = st.file_uploader(
186
- label="Upload an image",
187
- type=["jpg", "jpeg", "png"],
188
- help="Format in JPEG/JPG/PNG, max 1MB",
189
- key="image_uploader",
190
- accept_multiple_files=False
191
- )
192
-
193
- if uploaded_image:
194
- st.success(f"🔍 The caption fairy received your image: {uploaded_image.name}")
195
-
196
- # Main Content Area
197
- # App title
198
- st.title("🧚 Magic Story Camp")
199
- st.markdown("---")
200
-
201
-
202
- # input validation
203
- if uploaded_image:
204
-
205
- # Validate file specifications
206
- if uploaded_image.size > 1024* 1024:
207
- st.error("🔍 The caption fairy says the image is too big! please give me image under 1MB")
208
- st.stop()
209
- if uploaded_image.type not in ["image/jpeg", "image/png"]:
210
- st.error("🔍 The caption fairy says only JPG/PNG allowed!")
211
- st.stop()
212
-
213
- # Processing pipeline
214
- with st.spinner("🧙 The fairies are casting magic spells, it may take some time⏳..."):
215
- try:
216
- # Convert to RGB format for model compatibility
217
- image = Image.open(uploaded_image).convert("RGB")
218
-
219
- # Display processing UI elements
220
- status_display = st.empty()
221
- progress_bar = st.progress(0)
222
-
223
- # Image preview expander
224
- with st.expander("view the image", expanded=True):
225
- st.image(image, use_container_width=True)
226
-
227
-
228
- # Processing stages
229
- # Stage 1: Image Captioning
230
- status_display.markdown("🔍 **The caption fairy is viewing the image...**")
231
- progress_bar.progress(25)
232
- caption = generate_image_caption(image)
233
-
234
-
235
- # Stage 2: Story Generation
236
- status_display.markdown("🧚 **The writing fairy is writing the story...**")
237
- progress_bar.progress(50)
238
- story = generate_story(caption)
239
-
240
-
241
- # Stage 3: Audio Synthesis
242
- status_display.markdown("🔊 **The reading fairy is preparing audio magic...**")
243
- progress_bar.progress(75)
244
- speech = read_story(story)
245
-
246
-
247
- # Finish
248
- progress_bar.progress(100)
249
- status_display.markdown("🧚 **The Story is ready!**")
250
-
251
- # Display formatted story
252
- st.markdown("### 📖 Your Magic story")
253
- st.markdown(f'<div class="story-box">{story}</div>', unsafe_allow_html=True)
254
-
255
- # Audio playback and download
256
- st.audio(speech, format="audio/wav")
257
- st.download_button(
258
- "🎵 Download Story",
259
- data=speech,
260
- file_name="magic_story.wav",
261
- mime="audio/wav",
262
- help="click to download your story"
263
- )
264
- except Exception as e:
265
- st.error(f"💥 The magic spell broke! Please try again. {str(e)}")
266
- st.stop()
267
- else:
268
- # Page instructions
269
- st.markdown("""
270
- <div class="upload-section">
271
- <h3 style="color:#4CAF50; text-align:center;">❓ guidance</h3>
272
- 1. 🖼️ Upload Your Picture in the sidebar<br>
273
- 2. Wait for the magic sparkles ✨)<br>
274
- 3. Read/listen to your story and download with 🎵 button!<br>
275
- <br>
276
- Note: First-time model loading may take longer.<br>
277
- Please have a glass of juice and be patient for a few moments<br>
278
- </div>
279
- """, unsafe_allow_html=True)
280
 
281
  if __name__ == "__main__":
282
  main()
 
1
+ import streamlit as st
2
+ import cv2
3
+ import time
4
+ from streamlit_webrtc import VideoTransformerBase, webrtc_streamer
5
+ from PIL import Image
6
+ from transformers import pipeline
7
+ import os
8
+ from collections import Counter
9
+ import base64
10
 
11
  # ======================
12
+ # 模型加载函数(缓存)
13
  # ======================
14
 
15
+ @st.cache_resource
16
+ def load_smoke_pipeline():
17
+ """初始化并缓存吸烟图片分类 pipeline。"""
18
+ return pipeline("image-classification", model="ccclllwww/smoker_cls_base_V9", use_fast=True)
 
 
 
 
 
 
 
 
 
19
 
20
+ @st.cache_resource
21
+ def load_gender_pipeline():
22
+ """初始化并缓存性别图片分类 pipeline。"""
23
+ return pipeline("image-classification", model="rizvandwiki/gender-classification-2", use_fast=True)
24
 
25
+ @st.cache_resource
26
+ def load_age_pipeline():
27
+ """初始化并缓存年龄图片分类 pipeline。"""
28
+ return pipeline("image-classification", model="akashmaggon/vit-base-age-classification", use_fast=True)
29
 
30
+ # 预先加载所有模型
31
+ load_smoke_pipeline()
32
+ load_gender_pipeline()
33
+ load_age_pipeline()
34
 
35
  # ======================
36
+ # 音频加载函数(缓存)
37
  # ======================
38
 
39
  @st.cache_resource
40
+ def load_all_audios():
41
+ """加载 audio 目录中的所有 .wav 文件,并返回一个字典,
42
+ 键为文件名(不带扩展名),值为音频字节数据。"""
43
+ audio_dir = "audio"
44
+ audio_files = [f for f in os.listdir(audio_dir) if f.endswith(".wav")]
45
+ audio_dict = {}
46
+ for audio_file in audio_files:
47
+ file_path = os.path.join(audio_dir, audio_file)
48
+ with open(file_path, "rb") as af:
49
+ audio_bytes = af.read()
50
+ # 去掉扩展名作为键
51
+ key = os.path.splitext(audio_file)[0]
52
+ audio_dict[key] = audio_bytes
53
+ return audio_dict
54
+
55
+ # 应用启动时加载所有音频
56
+ audio_data = load_all_audios()
 
 
 
 
 
 
 
 
 
57
 
58
  # ======================
59
+ # 核心处理函数
60
  # ======================
61
 
62
  @st.cache_data(show_spinner=False, max_entries=3)
63
+ def smoking_classification(image: Image.Image) -> str:
64
+ """接受 PIL 图片并利用吸烟分类 pipeline 进行判定,返回标签(如 "smoking")。"""
 
 
 
 
 
 
 
 
 
 
65
  try:
66
+ smoke_pipeline = load_smoke_pipeline()
67
+ output = smoke_pipeline(image)
68
+ status = max(output, key=lambda x: x["score"])['label']
69
+ return status
70
  except Exception as e:
71
+ st.error(f"🔍 图像处理错误: {str(e)}")
72
  st.stop()
73
 
74
  @st.cache_data(show_spinner=False, max_entries=3)
75
+ def gender_classification(image: Image.Image) -> str:
76
+ """进行性别分类,返回模型输出的性别(依模型输出)。"""
 
 
 
 
 
 
 
 
 
 
77
  try:
78
+ gender_pipeline = load_gender_pipeline()
79
+ output = gender_pipeline(image)
80
+ status = max(output, key=lambda x: x["score"])['label']
81
+ return status
82
  except Exception as e:
83
+ st.error(f"🔍 图像处理错误: {str(e)}")
84
  st.stop()
85
 
86
  @st.cache_data(show_spinner=False, max_entries=3)
87
+ def age_classification(image: Image.Image) -> str:
88
+ """进行年龄分类,返回年龄范围,例如 "10-19" 等。"""
 
 
 
 
 
 
 
 
 
 
89
  try:
90
+ age_pipeline = load_age_pipeline()
91
+ output = age_pipeline(image)
92
+ age_range = max(output, key=lambda x: x["score"])['label']
93
+ return age_range
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  except Exception as e:
95
+ st.error(f"🔍 图像处理错误: {str(e)}")
96
+ st.stop()
97
 
98
+ # ======================
99
+ # 自定义JS播放音频函数
100
+ # ======================
101
 
102
+ @st.cache_resource
103
+ def play_audio_via_js(audio_bytes):
104
+ """
105
+ 利用自定义 HTML 和 JavaScript 播放音频。
106
+ 将二进制音频数据转换为 Base64 后嵌入 audio 标签,
107
+ 并用 JS 在页面加载后模拟点击进行播放。
108
+ """
109
+ audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
110
+ html_content = f"""
111
+ <audio id="audio_player" controls style="width: 100%;">
112
+ <source src="data:audio/wav;base64,{audio_base64}" type="audio/wav">
113
+ Your browser does not support the audio element.
114
+ </audio>
115
+ <script type="text/javascript">
116
+ // 等待 DOMContentLoaded 事件,并在1秒后自动调用 play() 方法
117
+ window.addEventListener('DOMContentLoaded', function() {{
118
+ setTimeout(function() {{
119
+ var audioElement = document.getElementById("audio_player");
120
+ if (audioElement) {{
121
+ audioElement.play().catch(function(e) {{
122
+ console.log("播放被浏览器阻止:", e);
123
+ }});
124
+ }}
125
+ }}, 1000);
126
+ }});
127
+ </script>
128
+ """
129
+ st.components.v1.html(html_content, height=150)
130
+
131
+ # ======================
132
+ # VideoTransformer 定义:处理摄像头帧与快照捕获
133
+ # ======================
134
 
135
+ class VideoTransformer(VideoTransformerBase):
136
+ def __init__(self):
137
+ self.snapshots = [] # 存储捕获的快照
138
+ self.last_capture_time = time.time() # 上次捕获时间
139
+ self.capture_interval = 0.5 # 每0.5秒捕获一张快照
140
+
141
+ def transform(self, frame):
142
+ """从摄像头流捕获单帧图像,并转换为 PIL Image。"""
143
+ img = frame.to_ndarray(format="bgr24")
144
+ current_time = time.time()
145
+ # 每隔 capture_interval 秒捕获一张快照,直到捕获20张
146
+ if current_time - self.last_capture_time >= self.capture_interval and len(self.snapshots) < 20:
147
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
148
+ self.snapshots.append(Image.fromarray(img_rgb))
149
+ self.last_capture_time = current_time
150
+ st.write(f"已捕获快照 {len(self.snapshots)}/20")
151
+ return img # 返回原始帧以供前端显示
152
 
153
  # ======================
154
+ # 主函数:整合视频流、自动图片分类并展示结果
155
  # ======================
156
 
157
  def main():
158
+ st.title("Streamlit-WebRTC 自动图片分类示例")
159
+ st.write("程序在一分钟内捕获20张快照进行图片分类,首先判定是否吸烟。若检测到吸烟的快照超过2次,则展示年龄与性别分类结果。")
160
+
161
+ # 创建用于显示进度文字和进度条的占位容器
162
+ capture_text_placeholder = st.empty()
163
+ capture_progress_placeholder = st.empty()
164
+ classification_text_placeholder = st.empty()
165
+ classification_progress_placeholder = st.empty()
166
+ detection_info_placeholder = st.empty() # 用于显示“开始侦测”
167
+
168
+ # 启动实时视频流
169
+ ctx = webrtc_streamer(key="unique_example", video_transformer_factory=VideoTransformer)
170
+ image_placeholder = st.empty()
171
+ audio_placeholder = st.empty()
172
+
173
+ capture_target = 10 # 本轮捕获目标张数
174
+
175
+ if ctx.video_transformer is not None:
176
+ classification_result_placeholder = st.empty() # 用于显示分类结果
177
+ detection_info_placeholder.info("开始侦测")
178
+
179
+ while True:
180
+ snapshots = ctx.video_transformer.snapshots
181
+
182
+ # 更新捕获阶段进度:同时显示文字和进度条
183
+ if len(snapshots) < capture_target:
184
+ capture_text_placeholder.text(f"捕获进度: {len(snapshots)}/{capture_target} 张快照")
185
+ progress_value = int(len(snapshots) / capture_target * 100)
186
+ capture_progress_placeholder.progress(progress_value)
187
+ else:
188
+ # 捕获完成,清空捕获进度条,并显示完成提示
189
+ capture_text_placeholder.text("捕获进度: 捕获完成!")
190
+ capture_progress_placeholder.empty()
191
+ detection_info_placeholder.empty() # 清除“开始侦测”提示
192
+
193
+ # ---------- 分类阶段进度 ----------
194
+ total = len(snapshots)
195
+ classification_text_placeholder.text("分类进度: 正在分类...")
196
+ classification_progress = classification_progress_placeholder.progress(0)
197
+
198
+ # 1. 吸烟分类 (0 ~ 33%)
199
+ smoke_results = []
200
+ for idx, img in enumerate(snapshots):
201
+ smoke_results.append(smoking_classification(img))
202
+ smoking_count = sum(1 for result in smoke_results if result.lower() == "smoking")
203
+ classification_progress.progress(33)
204
+
205
+ # 2. 若吸烟次数超过2,再进行性别和年龄分类 (33% ~ 100%)
206
+ if smoking_count > 2:
207
+ gender_results = []
208
+ for idx, img in enumerate(snapshots):
209
+ gender_results.append(gender_classification(img))
210
+ classification_progress.progress(66)
211
+
212
+ age_results = []
213
+ for idx, img in enumerate(snapshots):
214
+ age_results.append(age_classification(img))
215
+ classification_progress.progress(100)
216
+ classification_text_placeholder.text("分类进度: 分类完成!")
217
+
218
+ most_common_gender = Counter(gender_results).most_common(1)[0][0]
219
+ most_common_age = Counter(age_results).most_common(1)[0][0]
220
+
221
+ result_text = (
222
+ f"**吸烟状态:** Smoking (检测到 {smoking_count} 次)\n\n"
223
+ f"**性别:** {most_common_gender}\n\n"
224
+ f"**年龄范围:** {most_common_age}"
225
+ )
226
+ classification_result_placeholder.markdown(result_text)
227
+
228
+ # 选择第一张分类结果为 "smoking" 的快照,如未检测到,则显示第一张
229
+ smoking_image = None
230
+ for idx, label in enumerate(smoke_results):
231
+ if label.lower() == "smoking":
232
+ smoking_image = snapshots[idx]
233
+ break
234
+ if smoking_image is None:
235
+ smoking_image = snapshots[0]
236
+ image_placeholder.image(smoking_image, caption="捕获的快照示例", use_container_width=True)
237
+
238
+ # 清空播放区域后再播放对应音频
239
+ audio_placeholder.empty()
240
+ audio_key = f"{most_common_age} {most_common_gender.lower()}"
241
+ if audio_key in audio_data:
242
+ audio_bytes = audio_data[audio_key]
243
+ play_audio_via_js(audio_bytes)
244
+ else:
245
+ st.error(f"音频文件不存在: {audio_key}.wav")
246
+ else:
247
+ result_text = "**吸烟状态:** Not Smoking"
248
+ classification_result_placeholder.markdown(result_text)
249
+ image_placeholder.empty()
250
+ audio_placeholder.empty()
251
+ classification_text_placeholder.text("分类进度: 分类完成!")
252
+ classification_progress.progress(100)
253
+
254
+ # 分类阶段结束后清空分类进度占位区
255
+ time.sleep(1)
256
+ classification_progress_placeholder.empty()
257
+ classification_text_placeholder.empty()
258
+ capture_text_placeholder.empty()
259
+
260
+
261
+ # 重置快照列表,准备下一轮捕获
262
+ detection_info_placeholder.info("开始侦测")
263
+ ctx.video_transformer.snapshots = []
264
+ ctx.video_transformer.last_capture_time = time.time()
265
+ time.sleep(0.1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
  if __name__ == "__main__":
268
  main()