toby007 commited on
Commit
92fc27b
·
1 Parent(s): 3c71e14

update gene mask script

Browse files
Files changed (1) hide show
  1. generated_mask.py +163 -0
generated_mask.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 一键生成mask工具:上传原图 -> 自动做人像分割 -> 生成mask图
2
+
3
+ import os
4
+ import uuid
5
+ import numpy as np
6
+ from PIL import Image
7
+ import mediapipe as mp
8
+ from flask import Flask, request, jsonify, send_file, render_template_string
9
+
10
+ app = Flask(__name__)
11
+
12
+ # 初始化 mediapipe 模型
13
+ mp_selfie_segmentation = mp.solutions.selfie_segmentation.SelfieSegmentation(model_selection=1)
14
+
15
+ UPLOAD_FOLDER = "uploads"
16
+ MASK_FOLDER = "masks"
17
+
18
+ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
19
+ os.makedirs(MASK_FOLDER, exist_ok=True)
20
+
21
+ # 简单的HTML上传表单
22
+ HTML_TEMPLATE = """
23
+ <!DOCTYPE html>
24
+ <html>
25
+ <head>
26
+ <title>生成图像分割Mask</title>
27
+ <style>
28
+ body { font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; }
29
+ h1 { color: #333; }
30
+ .container { margin-top: 20px; }
31
+ .result { margin-top: 20px; }
32
+ img { max-width: 100%; margin-top: 10px; }
33
+ </style>
34
+ </head>
35
+ <body>
36
+ <h1>上传图像生成Mask</h1>
37
+ <div class="container">
38
+ <form action="/upload" method="post" enctype="multipart/form-data">
39
+ <input type="file" name="file" accept="image/*" required>
40
+ <button type="submit">生成Mask</button>
41
+ </form>
42
+ </div>
43
+ {% if original_image %}
44
+ <div class="result">
45
+ <h2>处理结果</h2>
46
+ <div>
47
+ <h3>原始图像</h3>
48
+ <img src="{{ original_image }}" alt="原始图像">
49
+ </div>
50
+ <div>
51
+ <h3>生成的Mask</h3>
52
+ <img src="{{ mask_image }}" alt="生成的Mask">
53
+ </div>
54
+ </div>
55
+ {% endif %}
56
+ </body>
57
+ </html>
58
+ """
59
+
60
+ @app.route("/", methods=["GET"])
61
+ def index():
62
+ # 显示一个简单的上传表单
63
+ return render_template_string(HTML_TEMPLATE)
64
+
65
+ @app.route("/upload", methods=["POST"])
66
+ def upload_image():
67
+ try:
68
+ if 'file' not in request.files:
69
+ return jsonify({"error": "No file uploaded."}), 400
70
+
71
+ file = request.files['file']
72
+ if file.filename == '':
73
+ return jsonify({"error": "Empty filename."}), 400
74
+
75
+ # 检查是否为图像文件
76
+ try:
77
+ img = Image.open(file.stream)
78
+ img.verify() # 验证图像
79
+ file.stream.seek(0) # 重置文件流位置
80
+ except:
81
+ return jsonify({"error": "Invalid image file."}), 400
82
+
83
+ # 保存上传的文件
84
+ filename = f"{uuid.uuid4().hex}.png"
85
+ file_path = os.path.join(UPLOAD_FOLDER, filename)
86
+ file.save(file_path)
87
+
88
+ print(f"Saved uploaded image to {file_path}")
89
+
90
+ # 处理图像生成mask
91
+ try:
92
+ mask_filename = generate_mask(file_path)
93
+ mask_path = os.path.join(MASK_FOLDER, mask_filename)
94
+ print(f"Generated mask saved to {mask_path}")
95
+
96
+ # 判断是API请求还是网页请求
97
+ if request.headers.get('Accept') == 'application/json':
98
+ return jsonify({
99
+ "status": "success",
100
+ "original_image": f"/uploads/{filename}",
101
+ "mask_image": f"/masks/{mask_filename}"
102
+ })
103
+ else:
104
+ # 返回HTML页面展示结果
105
+ return render_template_string(
106
+ HTML_TEMPLATE,
107
+ original_image=f"/uploads/{filename}",
108
+ mask_image=f"/masks/{mask_filename}"
109
+ )
110
+
111
+ except Exception as e:
112
+ print(f"Error generating mask: {str(e)}")
113
+ return jsonify({"error": f"Failed to generate mask: {str(e)}"}), 500
114
+
115
+ except Exception as e:
116
+ print(f"Unexpected error: {str(e)}")
117
+ return jsonify({"error": f"Server error: {str(e)}"}), 500
118
+
119
+ @app.route("/uploads/<filename>")
120
+ def serve_upload(filename):
121
+ return send_file(os.path.join(UPLOAD_FOLDER, filename))
122
+
123
+ @app.route("/masks/<filename>")
124
+ def serve_mask(filename):
125
+ return send_file(os.path.join(MASK_FOLDER, filename))
126
+
127
+ def generate_mask(image_path):
128
+ print(f"Generating mask for {image_path}")
129
+
130
+ # 读取图片
131
+ image = Image.open(image_path).convert("RGB")
132
+ image_np = np.array(image)
133
+
134
+ height, width = image_np.shape[:2]
135
+ print(f"Image dimensions: {width}x{height}")
136
+
137
+ # 生成分割mask
138
+ results = mp_selfie_segmentation.process(image_np)
139
+
140
+ if results.segmentation_mask is None:
141
+ raise ValueError("Segmentation failed. No mask generated.")
142
+
143
+ mask = results.segmentation_mask
144
+ print(f"Mask shape: {mask.shape}")
145
+
146
+ # 根据分割结果生成二值mask(背景白,前景黑)
147
+ binary_mask = (mask < 0.5).astype(np.uint8) * 255
148
+
149
+ # 确保mask是RGB模式
150
+ mask_img = Image.fromarray(binary_mask).convert("RGB")
151
+
152
+ # 保存mask
153
+ mask_filename = f"mask_{os.path.basename(image_path)}"
154
+ mask_path = os.path.join(MASK_FOLDER, mask_filename)
155
+ mask_img.save(mask_path)
156
+
157
+ print(f"Mask saved to {mask_path}")
158
+ return mask_filename
159
+
160
+
161
+ if __name__ == "__main__":
162
+ print("Starting mask generation server on http://127.0.0.1:5555")
163
+ app.run(host="0.0.0.0", port=5555, debug=True)