minh9972t12 commited on
Commit
84b458c
1 Parent(s): dbf7605

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +285 -283
main.py CHANGED
@@ -1,283 +1,285 @@
1
- import io
2
- import shutil
3
-
4
- import uvicorn
5
- import numpy as np
6
- import uuid
7
- from datetime import datetime
8
- from pathlib import Path
9
- from fastapi import FastAPI, UploadFile, File, HTTPException
10
- from fastapi.responses import JSONResponse, FileResponse
11
- from fastapi.middleware.cors import CORSMiddleware
12
- from fastapi.staticfiles import StaticFiles
13
- from PIL import Image
14
- import cv2
15
- from src.detection import YOLOv11Detector
16
- from src.comparison import DamageComparator
17
- from src.visualization import DamageVisualizer
18
-
19
-
20
- app = FastAPI(
21
- title="Car Damage Detection API",
22
- description="YOLOv11-based car damage detection and comparison system",
23
- version="1.0.0"
24
- )
25
-
26
- # Add CORS middleware
27
- app.add_middleware(
28
- CORSMiddleware,
29
- allow_origins=["*"],
30
- allow_credentials=True,
31
- allow_methods=["*"],
32
- allow_headers=["*"],
33
- )
34
-
35
- # Initialize components
36
- detector = YOLOv11Detector()
37
- comparator = DamageComparator()
38
- visualizer = DamageVisualizer()
39
-
40
- # Create necessary directories
41
- Path("uploads").mkdir(exist_ok=True)
42
- Path("results").mkdir(exist_ok=True)
43
-
44
- # Mount static files directory
45
- app.mount("/uploads", StaticFiles(directory="uploads"), name="uploads")
46
-
47
- @app.get("/")
48
- async def root():
49
- """Root endpoint"""
50
- return {
51
- "message": "Car Damage Detection API with YOLOv11",
52
- "endpoints": {
53
- "/docs": "API documentation",
54
- "/detect": "Single image detection",
55
- "/compare": "Compare before/after images (6 pairs)",
56
- "/uploads/{filename}": "Access saved visualization images",
57
- "/health": "Health check"
58
- }
59
- }
60
-
61
-
62
- def save_temp_file(upload_file: UploadFile) -> str:
63
- """Save an uploaded file into /tmp and return the temp file path"""
64
- tmp_dir = Path("/tmp")
65
- tmp_dir.mkdir(exist_ok=True)
66
-
67
- temp_path = tmp_dir / upload_file.filename
68
-
69
- with open(temp_path, "wb") as buffer:
70
- shutil.copyfileobj(upload_file.file, buffer)
71
-
72
- return str(temp_path)
73
- @app.get("/health")
74
- async def health_check():
75
- """Health check endpoint"""
76
- return {"status": "healthy", "model": "YOLOv11"}
77
-
78
- @app.post("/detect")
79
- async def detect_single_image(file: UploadFile = File(...)):
80
- """
81
- Detect damages in a single image
82
-
83
- Args:
84
- file: Image file
85
-
86
- Returns:
87
- Detection results with bounding boxes and path to visualized image
88
- """
89
- try:
90
- # Read and process image
91
- contents = await file.read()
92
- image = Image.open(io.BytesIO(contents)).convert("RGB")
93
- image_np = np.array(image)
94
- image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
95
-
96
- # Perform detection
97
- detections = detector.detect(image_bgr)
98
-
99
- # Create visualization
100
- visualized = visualizer.draw_detections(image_bgr, detections, 'new_damage')
101
-
102
-
103
- filename = f"detection_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}.jpg"
104
- output_path = Path("/tmp") / filename
105
- cv2.imwrite(str(output_path), visualized)
106
-
107
- return JSONResponse({
108
- "status": "success",
109
- "detections": detections,
110
- "statistics": {
111
- "total_damages": len(detections['boxes']),
112
- "damage_types": list(set(detections['classes']))
113
- },
114
- "visualized_image_path": f"/tmp/{filename}",
115
- })
116
-
117
- except Exception as e:
118
- raise HTTPException(status_code=500, detail=f"Detection failed: {str(e)}")
119
-
120
- @app.post("/compare")
121
- async def compare_vehicle_damages(
122
- # Before delivery images (6 positions)
123
- before_1: UploadFile = File(..., description="Before - Position 1"),
124
- before_2: UploadFile = File(..., description="Before - Position 2"),
125
- before_3: UploadFile = File(..., description="Before - Position 3"),
126
- before_4: UploadFile = File(..., description="Before - Position 4"),
127
- before_5: UploadFile = File(..., description="Before - Position 5"),
128
- before_6: UploadFile = File(..., description="Before - Position 6"),
129
- # After delivery images (6 positions)
130
- after_1: UploadFile = File(..., description="After - Position 1"),
131
- after_2: UploadFile = File(..., description="After - Position 2"),
132
- after_3: UploadFile = File(..., description="After - Position 3"),
133
- after_4: UploadFile = File(..., description="After - Position 4"),
134
- after_5: UploadFile = File(..., description="After - Position 5"),
135
- after_6: UploadFile = File(..., description="After - Position 6"),
136
- ):
137
- """
138
- Compare vehicle damages before and after delivery
139
-
140
- Analyzes 6 pairs of images (before/after) from different positions
141
- and determines the damage status according to 3 cases:
142
- - Case 1: Existing damages (from before) -> Delivery completed
143
- - Case 2: New damages detected -> Error during delivery
144
- - Case 3: No damages -> Successful delivery
145
-
146
- Returns:
147
- Detailed comparison results for each position and overall status
148
- """
149
- try:
150
- before_images = [before_1, before_2, before_3, before_4, before_5, before_6]
151
- after_images = [after_1, after_2, after_3, after_4, after_5, after_6]
152
-
153
- position_results = []
154
- all_visualizations = []
155
- image_pairs = []
156
-
157
- # Overall statistics
158
- total_new_damages = 0
159
- total_existing_damages = 0
160
- total_matched_damages = 0
161
-
162
- # Generate unique session ID for this comparison
163
- session_id = str(uuid.uuid4())[:8]
164
- timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
165
-
166
- # Process each position pair
167
- for i in range(6):
168
- # Read images
169
- before_contents = await before_images[i].read()
170
- after_contents = await after_images[i].read()
171
-
172
- before_img = Image.open(io.BytesIO(before_contents)).convert("RGB")
173
- after_img = Image.open(io.BytesIO(after_contents)).convert("RGB")
174
-
175
- before_np = np.array(before_img)
176
- after_np = np.array(after_img)
177
-
178
- before_bgr = cv2.cvtColor(before_np, cv2.COLOR_RGB2BGR)
179
- after_bgr = cv2.cvtColor(after_np, cv2.COLOR_RGB2BGR)
180
-
181
- # Store image pairs for grid visualization
182
- image_pairs.append((before_bgr, after_bgr))
183
-
184
- # Detect damages
185
- before_detections = detector.detect(before_bgr)
186
- after_detections = detector.detect(after_bgr)
187
-
188
- # Compare damages
189
- comparison = comparator.analyze_damage_status(before_detections, after_detections)
190
-
191
- # Update overall statistics
192
- total_new_damages += len(comparison['new_damages'])
193
- total_existing_damages += len(comparison['repaired_damages'])
194
- total_matched_damages += len(comparison['matched_damages'])
195
-
196
- # Create visualization for this position
197
- vis_img = visualizer.create_comparison_visualization(
198
- before_bgr, after_bgr,
199
- before_detections, after_detections,
200
- comparison
201
- )
202
-
203
- # Save visualization image with unique filename
204
- vis_filename = f"comparison_{timestamp_str}_{session_id}_pos{i+1}.jpg"
205
- vis_path = Path("/tmp") / vis_filename
206
- cv2.imwrite(str(vis_path), vis_img)
207
-
208
- vis_url = f"http://localhost:8000/uploads/{vis_filename}"
209
- all_visualizations.append(vis_url)
210
-
211
- # Store position result
212
- position_results.append({
213
- f"position_{i+1}": {
214
- "case": comparison['case'],
215
- "message": comparison['message'],
216
- "statistics": comparison['statistics'],
217
- "new_damages": comparison['new_damages'],
218
- "matched_damages": comparison['matched_damages'],
219
- "repaired_damages": comparison['repaired_damages'],
220
- "visualization_path": f"/tmp/{vis_filename}",
221
- "visualization_url": vis_url
222
- }
223
- })
224
-
225
- # Determine overall case
226
- overall_case = "CASE_3_SUCCESS"
227
- overall_message = "Successful delivery - No damage detected"
228
-
229
- if total_new_damages > 0:
230
- overall_case = "CASE_2_NEW_DAMAGE"
231
- overall_message = f"Error during vehicle delivery - Detection {total_new_damages} new damage"
232
- elif total_matched_damages > 0 and total_new_damages == 0:
233
- overall_case = "CASE_1_EXISTING"
234
- overall_message = "Error from the beginning, not during the delivery process -> Delivery completed"
235
-
236
- # Create summary grid visualization
237
- grid_results = [res[f"position_{i+1}"] for i, res in enumerate(position_results)]
238
- grid_img = visualizer.create_summary_grid(grid_results, image_pairs)
239
-
240
- # Save grid summary image
241
- grid_filename = f"summary_grid_{timestamp_str}_{session_id}.jpg"
242
- grid_path = Path("uploads") / grid_filename
243
- cv2.imwrite(str(grid_path), grid_img)
244
- grid_url = f"http://localhost:8000/uploads/{grid_filename}"
245
-
246
- # Generate timestamp for tracking
247
- timestamp = datetime.now().isoformat()
248
-
249
- return JSONResponse({
250
- "status": "success",
251
- "session_id": session_id,
252
- "timestamp": timestamp,
253
- "overall_result": {
254
- "case": overall_case,
255
- "message": overall_message,
256
- "statistics": {
257
- "total_new_damages": total_new_damages,
258
- "total_matched_damages": total_matched_damages,
259
- "total_repaired_damages": total_existing_damages
260
- }
261
- },
262
- "position_results": position_results,
263
- "summary_visualization_path": f"/uploads/{grid_filename}",
264
- "summary_visualization_url": grid_url,
265
- "recommendations": {
266
- "action_required": total_new_damages > 0,
267
- "suggested_action": "Investigate delivery process" if total_new_damages > 0 else "Proceed with delivery completion"
268
- }
269
- })
270
-
271
- except Exception as e:
272
- raise HTTPException(status_code=500, detail=f"Comparison failed: {str(e)}")
273
-
274
-
275
-
276
- if __name__ == "__main__":
277
- uvicorn.run(
278
- "main:app",
279
- host="0.0.0.0",
280
- port=8000,
281
- reload=True,
282
- log_level="info"
283
- )
 
 
 
1
+ import io
2
+ import shutil
3
+
4
+ import uvicorn
5
+ import numpy as np
6
+ import uuid
7
+ from datetime import datetime
8
+ from pathlib import Path
9
+ from fastapi import FastAPI, UploadFile, File, HTTPException
10
+ from fastapi.responses import JSONResponse, FileResponse
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+ from fastapi.staticfiles import StaticFiles
13
+ from PIL import Image
14
+ import cv2
15
+ from src.detection import YOLOv11Detector
16
+ from src.comparison import DamageComparator
17
+ from src.visualization import DamageVisualizer
18
+
19
+
20
+ app = FastAPI(
21
+ title="Car Damage Detection API",
22
+ description="YOLOv11-based car damage detection and comparison system",
23
+ version="1.0.0"
24
+ )
25
+
26
+ # Add CORS middleware
27
+ app.add_middleware(
28
+ CORSMiddleware,
29
+ allow_origins=["*"],
30
+ allow_credentials=True,
31
+ allow_methods=["*"],
32
+ allow_headers=["*"],
33
+ )
34
+
35
+ # Initialize components
36
+ detector = YOLOv11Detector()
37
+ comparator = DamageComparator()
38
+ visualizer = DamageVisualizer()
39
+
40
+ # Create necessary directories
41
+ Path("uploads").mkdir(exist_ok=True)
42
+ Path("results").mkdir(exist_ok=True)
43
+
44
+ # Mount static files directory
45
+ app.mount("/uploads", StaticFiles(directory="uploads"), name="uploads")
46
+
47
+ @app.get("/")
48
+ async def root():
49
+ """Root endpoint"""
50
+ return {
51
+ "message": "Car Damage Detection API with YOLOv11",
52
+ "endpoints": {
53
+ "/docs": "API documentation",
54
+ "/detect": "Single image detection",
55
+ "/compare": "Compare before/after images (6 pairs)",
56
+ "/uploads/{filename}": "Access saved visualization images",
57
+ "/health": "Health check"
58
+ }
59
+ }
60
+
61
+
62
+ def save_temp_file(upload_file: UploadFile) -> str:
63
+ """Save an uploaded file into /tmp and return the temp file path"""
64
+ tmp_dir = Path("/tmp")
65
+ tmp_dir.mkdir(exist_ok=True)
66
+
67
+ temp_path = tmp_dir / upload_file.filename
68
+
69
+ with open(temp_path, "wb") as buffer:
70
+ shutil.copyfileobj(upload_file.file, buffer)
71
+ return str(temp_path)
72
+
73
+
74
+ @app.get("/health")
75
+ async def health_check():
76
+ """Health check endpoint"""
77
+ return {"status": "healthy", "model": "YOLOv11"}
78
+
79
+ @app.post("/detect")
80
+ async def detect_single_image(file: UploadFile = File(...)):
81
+ """
82
+ Detect damages in a single image
83
+
84
+ Args:
85
+ file: Image file
86
+
87
+ Returns:
88
+ Detection results with bounding boxes and path to visualized image
89
+ """
90
+ try:
91
+ # Read and process image
92
+ contents = await file.read()
93
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
94
+ image_np = np.array(image)
95
+ image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
96
+
97
+ # Perform detection
98
+ detections = detector.detect(image_bgr)
99
+
100
+ # Create visualization
101
+ visualized = visualizer.draw_detections(image_bgr, detections, 'new_damage')
102
+
103
+
104
+ filename = f"detection_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}.jpg"
105
+ output_path = Path("/tmp") / filename
106
+ cv2.imwrite(str(output_path), visualized)
107
+
108
+ return JSONResponse({
109
+ "status": "success",
110
+ "detections": detections,
111
+ "statistics": {
112
+ "total_damages": len(detections['boxes']),
113
+ "damage_types": list(set(detections['classes']))
114
+ },
115
+ "visualized_image_path": f"/tmp/{filename}",
116
+ })
117
+
118
+ except Exception as e:
119
+ raise HTTPException(status_code=500, detail=f"Detection failed: {str(e)}")
120
+
121
+ @app.post("/compare")
122
+ async def compare_vehicle_damages(
123
+ # Before delivery images (6 positions)
124
+ before_1: UploadFile = File(..., description="Before - Position 1"),
125
+ before_2: UploadFile = File(..., description="Before - Position 2"),
126
+ before_3: UploadFile = File(..., description="Before - Position 3"),
127
+ before_4: UploadFile = File(..., description="Before - Position 4"),
128
+ before_5: UploadFile = File(..., description="Before - Position 5"),
129
+ before_6: UploadFile = File(..., description="Before - Position 6"),
130
+ # After delivery images (6 positions)
131
+ after_1: UploadFile = File(..., description="After - Position 1"),
132
+ after_2: UploadFile = File(..., description="After - Position 2"),
133
+ after_3: UploadFile = File(..., description="After - Position 3"),
134
+ after_4: UploadFile = File(..., description="After - Position 4"),
135
+ after_5: UploadFile = File(..., description="After - Position 5"),
136
+ after_6: UploadFile = File(..., description="After - Position 6"),
137
+ ):
138
+ """
139
+ Compare vehicle damages before and after delivery
140
+
141
+ Analyzes 6 pairs of images (before/after) from different positions
142
+ and determines the damage status according to 3 cases:
143
+ - Case 1: Existing damages (from before) -> Delivery completed
144
+ - Case 2: New damages detected -> Error during delivery
145
+ - Case 3: No damages -> Successful delivery
146
+
147
+ Returns:
148
+ Detailed comparison results for each position and overall status
149
+ """
150
+ try:
151
+ before_images = [before_1, before_2, before_3, before_4, before_5, before_6]
152
+ after_images = [after_1, after_2, after_3, after_4, after_5, after_6]
153
+
154
+ position_results = []
155
+ all_visualizations = []
156
+ image_pairs = []
157
+
158
+ # Overall statistics
159
+ total_new_damages = 0
160
+ total_existing_damages = 0
161
+ total_matched_damages = 0
162
+
163
+ # Generate unique session ID for this comparison
164
+ session_id = str(uuid.uuid4())[:8]
165
+ timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
166
+
167
+ # Process each position pair
168
+ for i in range(6):
169
+ # Read images
170
+ before_contents = await before_images[i].read()
171
+ after_contents = await after_images[i].read()
172
+
173
+ before_img = Image.open(io.BytesIO(before_contents)).convert("RGB")
174
+ after_img = Image.open(io.BytesIO(after_contents)).convert("RGB")
175
+
176
+ before_np = np.array(before_img)
177
+ after_np = np.array(after_img)
178
+
179
+ before_bgr = cv2.cvtColor(before_np, cv2.COLOR_RGB2BGR)
180
+ after_bgr = cv2.cvtColor(after_np, cv2.COLOR_RGB2BGR)
181
+
182
+ # Store image pairs for grid visualization
183
+ image_pairs.append((before_bgr, after_bgr))
184
+
185
+ # Detect damages
186
+ before_detections = detector.detect(before_bgr)
187
+ after_detections = detector.detect(after_bgr)
188
+
189
+ # Compare damages
190
+ comparison = comparator.analyze_damage_status(before_detections, after_detections)
191
+
192
+ # Update overall statistics
193
+ total_new_damages += len(comparison['new_damages'])
194
+ total_existing_damages += len(comparison['repaired_damages'])
195
+ total_matched_damages += len(comparison['matched_damages'])
196
+
197
+ # Create visualization for this position
198
+ vis_img = visualizer.create_comparison_visualization(
199
+ before_bgr, after_bgr,
200
+ before_detections, after_detections,
201
+ comparison
202
+ )
203
+
204
+ # Save visualization image with unique filename
205
+ vis_filename = f"comparison_{timestamp_str}_{session_id}_pos{i+1}.jpg"
206
+ vis_path = Path("/tmp") / vis_filename
207
+ cv2.imwrite(str(vis_path), vis_img)
208
+
209
+ vis_url = f"http://localhost:8000/uploads/{vis_filename}"
210
+ all_visualizations.append(vis_url)
211
+
212
+ # Store position result
213
+ position_results.append({
214
+ f"position_{i+1}": {
215
+ "case": comparison['case'],
216
+ "message": comparison['message'],
217
+ "statistics": comparison['statistics'],
218
+ "new_damages": comparison['new_damages'],
219
+ "matched_damages": comparison['matched_damages'],
220
+ "repaired_damages": comparison['repaired_damages'],
221
+ "visualization_path": f"/tmp/{vis_filename}",
222
+ "visualization_url": vis_url
223
+ }
224
+ })
225
+
226
+ # Determine overall case
227
+ overall_case = "CASE_3_SUCCESS"
228
+ overall_message = "Successful delivery - No damage detected"
229
+
230
+ if total_new_damages > 0:
231
+ overall_case = "CASE_2_NEW_DAMAGE"
232
+ overall_message = f"Error during vehicle delivery - Detection {total_new_damages} new damage"
233
+ elif total_matched_damages > 0 and total_new_damages == 0:
234
+ overall_case = "CASE_1_EXISTING"
235
+ overall_message = "Error from the beginning, not during the delivery process -> Delivery completed"
236
+
237
+ # Create summary grid visualization
238
+ grid_results = [res[f"position_{i+1}"] for i, res in enumerate(position_results)]
239
+ grid_img = visualizer.create_summary_grid(grid_results, image_pairs)
240
+
241
+ # Save grid summary image
242
+ grid_filename = f"summary_grid_{timestamp_str}_{session_id}.jpg"
243
+ grid_path = Path("uploads") / grid_filename
244
+ cv2.imwrite(str(grid_path), grid_img)
245
+ grid_url = f"http://localhost:8000/uploads/{grid_filename}"
246
+
247
+ # Generate timestamp for tracking
248
+ timestamp = datetime.now().isoformat()
249
+
250
+ return JSONResponse({
251
+ "status": "success",
252
+ "session_id": session_id,
253
+ "timestamp": timestamp,
254
+ "overall_result": {
255
+ "case": overall_case,
256
+ "message": overall_message,
257
+ "statistics": {
258
+ "total_new_damages": total_new_damages,
259
+ "total_matched_damages": total_matched_damages,
260
+ "total_repaired_damages": total_existing_damages
261
+ }
262
+ },
263
+ "position_results": position_results,
264
+ "summary_visualization_path": f"/uploads/{grid_filename}",
265
+ "summary_visualization_url": grid_url,
266
+ "recommendations": {
267
+ "action_required": total_new_damages > 0,
268
+ "suggested_action": "Investigate delivery process" if total_new_damages > 0 else "Proceed with delivery completion"
269
+ }
270
+ })
271
+
272
+ except Exception as e:
273
+ raise HTTPException(status_code=500, detail=f"Comparison failed: {str(e)}")
274
+
275
+
276
+
277
+ if __name__ == "__main__":
278
+ import os
279
+ uvicorn.run(
280
+ "main:app",
281
+ host="0.0.0.0",
282
+ port=int(os.environ.get("PORT", 7860)), # L岷 port do HF set
283
+ reload=False, # Production kh么ng c岷 reload
284
+ log_level="info"
285
+ )