File size: 11,338 Bytes
1cfd505
 
 
 
 
 
 
 
 
0e3c4ab
4f563d7
1cfd505
4f563d7
0e3c4ab
1cfd505
 
 
 
4f563d7
1cfd505
 
 
 
4f563d7
1cfd505
 
 
 
4f563d7
1cfd505
 
 
 
4f563d7
 
1cfd505
4f563d7
8140f15
768e131
1cfd505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
976f7e9
4f563d7
1cfd505
4f563d7
 
768e131
1cfd505
 
aa50810
1cfd505
 
 
 
aa50810
1cfd505
aa50810
0e3c4ab
768e131
1cfd505
 
aa50810
1cfd505
 
 
 
aa50810
1cfd505
aa50810
0e3c4ab
768e131
1cfd505
 
aa50810
1cfd505
 
 
 
aa50810
1cfd505
 
0e3c4ab
1cfd505
 
 
4f563d7
1cfd505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f563d7
1cfd505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
768e131
4f563d7
1cfd505
4f563d7
 
0e3c4ab
1cfd505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e3c4ab
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
import streamlit as st
import cv2
import time
from streamlit_webrtc import VideoTransformerBase, webrtc_streamer
from PIL import Image
from transformers import pipeline
import os
from collections import Counter
import base64

# ======================
# 模型加载函数(缓存)
# ======================

@st.cache_resource
def load_smoke_pipeline():
    """初始化并缓存吸烟图片分类 pipeline。"""
    return pipeline("image-classification", model="ccclllwww/smoker_cls_base_V9", use_fast=True)

@st.cache_resource
def load_gender_pipeline():
    """初始化并缓存性别图片分类 pipeline。"""
    return pipeline("image-classification", model="rizvandwiki/gender-classification-2", use_fast=True)

@st.cache_resource
def load_age_pipeline():
    """初始化并缓存年龄图片分类 pipeline。"""
    return pipeline("image-classification", model="akashmaggon/vit-base-age-classification", use_fast=True)

# 预先加载所有模型
load_smoke_pipeline()
load_gender_pipeline()
load_age_pipeline()

# ======================
# 音频加载函数(缓存)
# ======================

@st.cache_resource
def load_all_audios():
    """加载 audio 目录中的所有 .wav 文件,并返回一个字典,
    键为文件名(不带扩展名),值为音频字节数据。"""
    audio_dir = "audio"
    audio_files = [f for f in os.listdir(audio_dir) if f.endswith(".wav")]
    audio_dict = {}
    for audio_file in audio_files:
        file_path = os.path.join(audio_dir, audio_file)
        with open(file_path, "rb") as af:
            audio_bytes = af.read()
        # 去掉扩展名作为键
        key = os.path.splitext(audio_file)[0]
        audio_dict[key] = audio_bytes
    return audio_dict

# 应用启动时加载所有音频
audio_data = load_all_audios()

# ======================
# 核心处理函数
# ======================

@st.cache_data(show_spinner=False, max_entries=3)
def smoking_classification(image: Image.Image) -> str:
    """接受 PIL 图片并利用吸烟分类 pipeline 进行判定,返回标签(如 "smoking")。"""
    try:
        smoke_pipeline = load_smoke_pipeline()
        output = smoke_pipeline(image)
        status = max(output, key=lambda x: x["score"])['label']
        return status
    except Exception as e:
        st.error(f"🔍 图像处理错误: {str(e)}")
        st.stop()

@st.cache_data(show_spinner=False, max_entries=3)
def gender_classification(image: Image.Image) -> str:
    """进行性别分类,返回模型输出的性别(依模型输出)。"""
    try:
        gender_pipeline = load_gender_pipeline()
        output = gender_pipeline(image)
        status = max(output, key=lambda x: x["score"])['label']
        return status
    except Exception as e:
        st.error(f"🔍 图像处理错误: {str(e)}")
        st.stop()

@st.cache_data(show_spinner=False, max_entries=3)
def age_classification(image: Image.Image) -> str:
    """进行年龄分类,返回年龄范围,例如 "10-19" 等。"""
    try:
        age_pipeline = load_age_pipeline()
        output = age_pipeline(image)
        age_range = max(output, key=lambda x: x["score"])['label']
        return age_range
    except Exception as e:
        st.error(f"🔍 图像处理错误: {str(e)}")
        st.stop()

# ======================
# 自定义JS播放音频函数
# ======================

@st.cache_resource
def play_audio_via_js(audio_bytes):
    """
    利用自定义 HTML 和 JavaScript 播放音频。
    将二进制音频数据转换为 Base64 后嵌入 audio 标签,
    并用 JS 在页面加载后模拟点击进行播放。
    """
    audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
    html_content = f"""
    <audio id="audio_player" controls style="width: 100%;">
        <source src="data:audio/wav;base64,{audio_base64}" type="audio/wav">
        Your browser does not support the audio element.
    </audio>
    <script type="text/javascript">
        // 等待 DOMContentLoaded 事件,并在1秒后自动调用 play() 方法
        window.addEventListener('DOMContentLoaded', function() {{
            setTimeout(function() {{
                var audioElement = document.getElementById("audio_player");
                if (audioElement) {{
                    audioElement.play().catch(function(e) {{
                        console.log("播放被浏览器阻止:", e);
                    }});
                }}
            }}, 1000);
        }});
    </script>
    """
    st.components.v1.html(html_content, height=150)

# ======================
# VideoTransformer 定义:处理摄像头帧与快照捕获
# ======================

class VideoTransformer(VideoTransformerBase):
    def __init__(self):
        self.snapshots = []  # 存储捕获的快照
        self.last_capture_time = time.time()  # 上次捕获时间
        self.capture_interval = 0.5  # 每0.5秒捕获一张快照

    def transform(self, frame):
        """从摄像头流捕获单帧图像,并转换为 PIL Image。"""
        img = frame.to_ndarray(format="bgr24")
        current_time = time.time()
        # 每隔 capture_interval 秒捕获一张快照,直到捕获20张
        if current_time - self.last_capture_time >= self.capture_interval and len(self.snapshots) < 20:
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            self.snapshots.append(Image.fromarray(img_rgb))
            self.last_capture_time = current_time
            st.write(f"已捕获快照 {len(self.snapshots)}/20")
        return img  # 返回原始帧以供前端显示

# ======================
# 主函数:整合视频流、自动图片分类并展示结果
# ======================

def main():
    st.title("Streamlit-WebRTC 自动图片分类示例")
    st.write("程序在一分钟内捕获20张快照进行图片分类,首先判定是否吸烟。若检测到吸烟的快照超过2次,则展示年龄与性别分类结果。")

    # 创建用于显示进度文字和进度条的占位容器
    capture_text_placeholder = st.empty()
    capture_progress_placeholder = st.empty()
    classification_text_placeholder = st.empty()
    classification_progress_placeholder = st.empty()
    detection_info_placeholder = st.empty()  # 用于显示“开始侦测”

    # 启动实时视频流
    ctx = webrtc_streamer(key="unique_example", video_transformer_factory=VideoTransformer)
    image_placeholder = st.empty()
    audio_placeholder = st.empty()

    capture_target = 10  # 本轮捕获目标张数

    if ctx.video_transformer is not None:
        classification_result_placeholder = st.empty()  # 用于显示分类结果
        detection_info_placeholder.info("开始侦测")

        while True:
            snapshots = ctx.video_transformer.snapshots

            # 更新捕获阶段进度:同时显示文字和进度条
            if len(snapshots) < capture_target:
                capture_text_placeholder.text(f"捕获进度: {len(snapshots)}/{capture_target} 张快照")
                progress_value = int(len(snapshots) / capture_target * 100)
                capture_progress_placeholder.progress(progress_value)
            else:
                # 捕获完成,清空捕获进度条,并显示完成提示
                capture_text_placeholder.text("捕获进度: 捕获完成!")
                capture_progress_placeholder.empty()
                detection_info_placeholder.empty()  # 清除“开始侦测”提示

                # ---------- 分类阶段进度 ----------
                total = len(snapshots)
                classification_text_placeholder.text("分类进度: 正在分类...")
                classification_progress = classification_progress_placeholder.progress(0)

                # 1. 吸烟分类 (0 ~ 33%)
                smoke_results = []
                for idx, img in enumerate(snapshots):
                    smoke_results.append(smoking_classification(img))
                smoking_count = sum(1 for result in smoke_results if result.lower() == "smoking")
                classification_progress.progress(33)

                # 2. 若吸烟次数超过2,再进行性别和年龄分类 (33% ~ 100%)
                if smoking_count > 2:
                    gender_results = []
                    for idx, img in enumerate(snapshots):
                        gender_results.append(gender_classification(img))
                    classification_progress.progress(66)

                    age_results = []
                    for idx, img in enumerate(snapshots):
                        age_results.append(age_classification(img))
                    classification_progress.progress(100)
                    classification_text_placeholder.text("分类进度: 分类完成!")

                    most_common_gender = Counter(gender_results).most_common(1)[0][0]
                    most_common_age = Counter(age_results).most_common(1)[0][0]

                    result_text = (
                        f"**吸烟状态:** Smoking (检测到 {smoking_count} 次)\n\n"
                        f"**性别:** {most_common_gender}\n\n"
                        f"**年龄范围:** {most_common_age}"
                    )
                    classification_result_placeholder.markdown(result_text)

                    # 选择第一张分类结果为 "smoking" 的快照,如未检测到,则显示第一张
                    smoking_image = None
                    for idx, label in enumerate(smoke_results):
                        if label.lower() == "smoking":
                            smoking_image = snapshots[idx]
                            break
                    if smoking_image is None:
                        smoking_image = snapshots[0]
                    image_placeholder.image(smoking_image, caption="捕获的快照示例", use_container_width=True)

                    # 清空播放区域后再播放对应音频
                    audio_placeholder.empty()
                    audio_key = f"{most_common_age} {most_common_gender.lower()}"
                    if audio_key in audio_data:
                        audio_bytes = audio_data[audio_key]
                        play_audio_via_js(audio_bytes)
                    else:
                        st.error(f"音频文件不存在: {audio_key}.wav")
                else:
                    result_text = "**吸烟状态:** Not Smoking"
                    classification_result_placeholder.markdown(result_text)
                    image_placeholder.empty()
                    audio_placeholder.empty()
                    classification_text_placeholder.text("分类进度: 分类完成!")
                    classification_progress.progress(100)

                # 分类阶段结束后清空分类进度占位区
                time.sleep(1)
                classification_progress_placeholder.empty()
                classification_text_placeholder.empty()
                capture_text_placeholder.empty()


                # 重置快照列表,准备下一轮捕获
                detection_info_placeholder.info("开始侦测")
                ctx.video_transformer.snapshots = []
                ctx.video_transformer.last_capture_time = time.time()
            time.sleep(0.1)

if __name__ == "__main__":
    main()