Spaces:
Sleeping
Sleeping
File size: 11,338 Bytes
1cfd505 0e3c4ab 4f563d7 1cfd505 4f563d7 0e3c4ab 1cfd505 4f563d7 1cfd505 4f563d7 1cfd505 4f563d7 1cfd505 4f563d7 1cfd505 4f563d7 8140f15 768e131 1cfd505 976f7e9 4f563d7 1cfd505 4f563d7 768e131 1cfd505 aa50810 1cfd505 aa50810 1cfd505 aa50810 0e3c4ab 768e131 1cfd505 aa50810 1cfd505 aa50810 1cfd505 aa50810 0e3c4ab 768e131 1cfd505 aa50810 1cfd505 aa50810 1cfd505 0e3c4ab 1cfd505 4f563d7 1cfd505 4f563d7 1cfd505 768e131 4f563d7 1cfd505 4f563d7 0e3c4ab 1cfd505 0e3c4ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 |
import streamlit as st
import cv2
import time
from streamlit_webrtc import VideoTransformerBase, webrtc_streamer
from PIL import Image
from transformers import pipeline
import os
from collections import Counter
import base64
# ======================
# 模型加载函数(缓存)
# ======================
@st.cache_resource
def load_smoke_pipeline():
"""初始化并缓存吸烟图片分类 pipeline。"""
return pipeline("image-classification", model="ccclllwww/smoker_cls_base_V9", use_fast=True)
@st.cache_resource
def load_gender_pipeline():
"""初始化并缓存性别图片分类 pipeline。"""
return pipeline("image-classification", model="rizvandwiki/gender-classification-2", use_fast=True)
@st.cache_resource
def load_age_pipeline():
"""初始化并缓存年龄图片分类 pipeline。"""
return pipeline("image-classification", model="akashmaggon/vit-base-age-classification", use_fast=True)
# 预先加载所有模型
load_smoke_pipeline()
load_gender_pipeline()
load_age_pipeline()
# ======================
# 音频加载函数(缓存)
# ======================
@st.cache_resource
def load_all_audios():
"""加载 audio 目录中的所有 .wav 文件,并返回一个字典,
键为文件名(不带扩展名),值为音频字节数据。"""
audio_dir = "audio"
audio_files = [f for f in os.listdir(audio_dir) if f.endswith(".wav")]
audio_dict = {}
for audio_file in audio_files:
file_path = os.path.join(audio_dir, audio_file)
with open(file_path, "rb") as af:
audio_bytes = af.read()
# 去掉扩展名作为键
key = os.path.splitext(audio_file)[0]
audio_dict[key] = audio_bytes
return audio_dict
# 应用启动时加载所有音频
audio_data = load_all_audios()
# ======================
# 核心处理函数
# ======================
@st.cache_data(show_spinner=False, max_entries=3)
def smoking_classification(image: Image.Image) -> str:
"""接受 PIL 图片并利用吸烟分类 pipeline 进行判定,返回标签(如 "smoking")。"""
try:
smoke_pipeline = load_smoke_pipeline()
output = smoke_pipeline(image)
status = max(output, key=lambda x: x["score"])['label']
return status
except Exception as e:
st.error(f"🔍 图像处理错误: {str(e)}")
st.stop()
@st.cache_data(show_spinner=False, max_entries=3)
def gender_classification(image: Image.Image) -> str:
"""进行性别分类,返回模型输出的性别(依模型输出)。"""
try:
gender_pipeline = load_gender_pipeline()
output = gender_pipeline(image)
status = max(output, key=lambda x: x["score"])['label']
return status
except Exception as e:
st.error(f"🔍 图像处理错误: {str(e)}")
st.stop()
@st.cache_data(show_spinner=False, max_entries=3)
def age_classification(image: Image.Image) -> str:
"""进行年龄分类,返回年龄范围,例如 "10-19" 等。"""
try:
age_pipeline = load_age_pipeline()
output = age_pipeline(image)
age_range = max(output, key=lambda x: x["score"])['label']
return age_range
except Exception as e:
st.error(f"🔍 图像处理错误: {str(e)}")
st.stop()
# ======================
# 自定义JS播放音频函数
# ======================
@st.cache_resource
def play_audio_via_js(audio_bytes):
"""
利用自定义 HTML 和 JavaScript 播放音频。
将二进制音频数据转换为 Base64 后嵌入 audio 标签,
并用 JS 在页面加载后模拟点击进行播放。
"""
audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
html_content = f"""
<audio id="audio_player" controls style="width: 100%;">
<source src="data:audio/wav;base64,{audio_base64}" type="audio/wav">
Your browser does not support the audio element.
</audio>
<script type="text/javascript">
// 等待 DOMContentLoaded 事件,并在1秒后自动调用 play() 方法
window.addEventListener('DOMContentLoaded', function() {{
setTimeout(function() {{
var audioElement = document.getElementById("audio_player");
if (audioElement) {{
audioElement.play().catch(function(e) {{
console.log("播放被浏览器阻止:", e);
}});
}}
}}, 1000);
}});
</script>
"""
st.components.v1.html(html_content, height=150)
# ======================
# VideoTransformer 定义:处理摄像头帧与快照捕获
# ======================
class VideoTransformer(VideoTransformerBase):
def __init__(self):
self.snapshots = [] # 存储捕获的快照
self.last_capture_time = time.time() # 上次捕获时间
self.capture_interval = 0.5 # 每0.5秒捕获一张快照
def transform(self, frame):
"""从摄像头流捕获单帧图像,并转换为 PIL Image。"""
img = frame.to_ndarray(format="bgr24")
current_time = time.time()
# 每隔 capture_interval 秒捕获一张快照,直到捕获20张
if current_time - self.last_capture_time >= self.capture_interval and len(self.snapshots) < 20:
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
self.snapshots.append(Image.fromarray(img_rgb))
self.last_capture_time = current_time
st.write(f"已捕获快照 {len(self.snapshots)}/20")
return img # 返回原始帧以供前端显示
# ======================
# 主函数:整合视频流、自动图片分类并展示结果
# ======================
def main():
st.title("Streamlit-WebRTC 自动图片分类示例")
st.write("程序在一分钟内捕获20张快照进行图片分类,首先判定是否吸烟。若检测到吸烟的快照超过2次,则展示年龄与性别分类结果。")
# 创建用于显示进度文字和进度条的占位容器
capture_text_placeholder = st.empty()
capture_progress_placeholder = st.empty()
classification_text_placeholder = st.empty()
classification_progress_placeholder = st.empty()
detection_info_placeholder = st.empty() # 用于显示“开始侦测”
# 启动实时视频流
ctx = webrtc_streamer(key="unique_example", video_transformer_factory=VideoTransformer)
image_placeholder = st.empty()
audio_placeholder = st.empty()
capture_target = 10 # 本轮捕获目标张数
if ctx.video_transformer is not None:
classification_result_placeholder = st.empty() # 用于显示分类结果
detection_info_placeholder.info("开始侦测")
while True:
snapshots = ctx.video_transformer.snapshots
# 更新捕获阶段进度:同时显示文字和进度条
if len(snapshots) < capture_target:
capture_text_placeholder.text(f"捕获进度: {len(snapshots)}/{capture_target} 张快照")
progress_value = int(len(snapshots) / capture_target * 100)
capture_progress_placeholder.progress(progress_value)
else:
# 捕获完成,清空捕获进度条,并显示完成提示
capture_text_placeholder.text("捕获进度: 捕获完成!")
capture_progress_placeholder.empty()
detection_info_placeholder.empty() # 清除“开始侦测”提示
# ---------- 分类阶段进度 ----------
total = len(snapshots)
classification_text_placeholder.text("分类进度: 正在分类...")
classification_progress = classification_progress_placeholder.progress(0)
# 1. 吸烟分类 (0 ~ 33%)
smoke_results = []
for idx, img in enumerate(snapshots):
smoke_results.append(smoking_classification(img))
smoking_count = sum(1 for result in smoke_results if result.lower() == "smoking")
classification_progress.progress(33)
# 2. 若吸烟次数超过2,再进行性别和年龄分类 (33% ~ 100%)
if smoking_count > 2:
gender_results = []
for idx, img in enumerate(snapshots):
gender_results.append(gender_classification(img))
classification_progress.progress(66)
age_results = []
for idx, img in enumerate(snapshots):
age_results.append(age_classification(img))
classification_progress.progress(100)
classification_text_placeholder.text("分类进度: 分类完成!")
most_common_gender = Counter(gender_results).most_common(1)[0][0]
most_common_age = Counter(age_results).most_common(1)[0][0]
result_text = (
f"**吸烟状态:** Smoking (检测到 {smoking_count} 次)\n\n"
f"**性别:** {most_common_gender}\n\n"
f"**年龄范围:** {most_common_age}"
)
classification_result_placeholder.markdown(result_text)
# 选择第一张分类结果为 "smoking" 的快照,如未检测到,则显示第一张
smoking_image = None
for idx, label in enumerate(smoke_results):
if label.lower() == "smoking":
smoking_image = snapshots[idx]
break
if smoking_image is None:
smoking_image = snapshots[0]
image_placeholder.image(smoking_image, caption="捕获的快照示例", use_container_width=True)
# 清空播放区域后再播放对应音频
audio_placeholder.empty()
audio_key = f"{most_common_age} {most_common_gender.lower()}"
if audio_key in audio_data:
audio_bytes = audio_data[audio_key]
play_audio_via_js(audio_bytes)
else:
st.error(f"音频文件不存在: {audio_key}.wav")
else:
result_text = "**吸烟状态:** Not Smoking"
classification_result_placeholder.markdown(result_text)
image_placeholder.empty()
audio_placeholder.empty()
classification_text_placeholder.text("分类进度: 分类完成!")
classification_progress.progress(100)
# 分类阶段结束后清空分类进度占位区
time.sleep(1)
classification_progress_placeholder.empty()
classification_text_placeholder.empty()
capture_text_placeholder.empty()
# 重置快照列表,准备下一轮捕获
detection_info_placeholder.info("开始侦测")
ctx.video_transformer.snapshots = []
ctx.video_transformer.last_capture_time = time.time()
time.sleep(0.1)
if __name__ == "__main__":
main()
|