nagasurendra commited on
Commit
1b0039c
·
verified ·
1 Parent(s): 1724c7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -52
app.py CHANGED
@@ -14,7 +14,7 @@ from ultralytics import YOLO
14
  import ultralytics
15
  import time
16
  import piexif
17
- import shutil
18
 
19
  # Set YOLO config directory
20
  os.environ["YOLO_CONFIG_DIR"] = "/tmp/Ultralytics"
@@ -44,9 +44,9 @@ detected_issues: List[str] = []
44
  gps_coordinates: List[List[float]] = []
45
  last_metrics: Dict[str, Any] = {}
46
  frame_count: int = 0
47
- SAVE_IMAGE_INTERVAL = 1 # Save every frame with detections
48
 
49
- # Detection classes (aligned with model classes, excluding 'Crocodile')
50
  DETECTION_CLASSES = ["Longitudinal", "Pothole", "Transverse"]
51
 
52
  # Debug: Check environment
@@ -60,9 +60,24 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
60
  print(f"Using device: {device}")
61
  model = YOLO('./data/best.pt').to(device)
62
  if device == "cuda":
63
- model.half() # Use half-precision (FP16)
64
  print(f"Model classes: {model.names}")
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  def generate_map(gps_coords: List[List[float]], items: List[Dict[str, Any]]) -> str:
67
  map_path = os.path.join(OUTPUT_DIR, "map_temp.png")
68
  plt.figure(figsize=(4, 4))
@@ -111,9 +126,9 @@ def write_flight_log(frame_count: int, gps_coord: List[float], timestamp: str) -
111
  def check_image_quality(frame: np.ndarray, input_resolution: int) -> bool:
112
  height, width, _ = frame.shape
113
  frame_resolution = width * height
114
- if frame_resolution < 12_000_000: # NHAI requires 12 MP
115
  log_entries.append(f"Frame {frame_count}: Resolution {width}x{height} ({frame_resolution/1e6:.2f}MP) below 12MP, non-compliant")
116
- if frame_resolution < input_resolution: # Ensure output is not below input
117
  log_entries.append(f"Frame {frame_count}: Output resolution {width}x{height} below input resolution")
118
  return False
119
  return True
@@ -141,10 +156,6 @@ def generate_line_chart() -> Optional[str]:
141
  plt.close()
142
  return chart_path
143
 
144
- def generate_download_zip():
145
- shutil.make_archive("outputs_bundle", 'zip', OUTPUT_DIR)
146
- return "outputs_bundle.zip"
147
-
148
  def process_video(video, resize_width=4000, resize_height=3000, frame_skip=5):
149
  global frame_count, last_metrics, detected_counts, detected_issues, gps_coordinates, log_entries
150
  frame_count = 0
@@ -157,14 +168,14 @@ def process_video(video, resize_width=4000, resize_height=3000, frame_skip=5):
157
  if video is None:
158
  log_entries.append("Error: No video uploaded")
159
  logging.error("No video uploaded")
160
- return "processed_output.mp4", json.dumps({"error": "No video uploaded"}, indent=2), "\n".join(log_entries), [], None, None, None
161
 
162
  start_time = time.time()
163
  cap = cv2.VideoCapture(video)
164
  if not cap.isOpened():
165
  log_entries.append("Error: Could not open video file")
166
  logging.error("Could not open video file")
167
- return "processed_output.mp4", json.dumps({"error": "Could not open video file"}, indent=2), "\n".join(log_entries), [], None, None, None
168
 
169
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
170
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
@@ -172,36 +183,25 @@ def process_video(video, resize_width=4000, resize_height=3000, frame_skip=5):
172
  fps = cap.get(cv2.CAP_PROP_FPS)
173
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
174
  expected_duration = total_frames / fps if fps > 0 else 0
175
- log_entries.append(f"Input video: {frame_width}x{frame_height} ({input_resolution/1e6:.2f}MP), {fps} FPS, {total_frames} frames, {expected_duration:.2f} seconds")
176
- logging.info(f"Input video: {frame_width}x{frame_height} ({input_resolution/1e6:.2f}MP), {fps} FPS, {total_frames} frames, {expected_duration:.2f} seconds")
177
- print(f"Input video: {frame_width}x{frame_height} ({input_resolution/1e6:.2f}MP), {fps} FPS, {total_frames} frames, {expected_duration:.2f} seconds")
178
 
179
  out_width, out_height = resize_width, resize_height
180
  output_path = os.path.join(OUTPUT_DIR, "processed_output.mp4")
181
- codecs = [('mp4v', '.mp4'), ('XVID', '.avi'), ('MJPG', '.avi')] # Prioritize mp4v
182
- out = None
183
- for codec, ext in codecs:
184
- fourcc = cv2.VideoWriter_fourcc(*codec)
185
- temp_output_path = os.path.join(OUTPUT_DIR, f"processed_output{ext}")
186
- out = cv2.VideoWriter(temp_output_path, fourcc, fps, (out_width, out_height))
187
- if out.isOpened():
188
- output_path = temp_output_path
189
- log_entries.append(f"Using codec: {codec}, output: {output_path}")
190
- logging.info(f"Using codec: {codec}, output: {output_path}")
191
- break
192
- else:
193
- log_entries.append(f"Failed to initialize codec: {codec}")
194
- logging.warning(f"Failed to initialize codec: {codec}")
195
-
196
- if not out or not out.isOpened():
197
- log_entries.append("Error: All codecs failed to initialize video writer")
198
- logging.error("All codecs failed to initialize video writer")
199
  cap.release()
200
- return "processed_output.mp4", json.dumps({"error": "All codecs failed"}, indent=2), "\n".join(log_entries), [], None, None, None
201
 
202
  processed_frames = 0
203
  all_detections = []
204
  frame_times = []
 
 
 
205
  detection_frame_count = 0
206
  output_frame_count = 0
207
  last_annotated_frame = None
@@ -222,13 +222,20 @@ def process_video(video, resize_width=4000, resize_height=3000, frame_skip=5):
222
  processed_frames += 1
223
  frame_start = time.time()
224
 
 
 
225
  frame = cv2.resize(frame, (out_width, out_height))
 
 
226
  if not check_image_quality(frame, input_resolution):
227
  log_entries.append(f"Frame {frame_count}: Skipped due to low resolution")
228
  continue
229
 
 
 
230
  results = model(frame, verbose=False, conf=0.5, iou=0.7)
231
  annotated_frame = results[0].plot()
 
232
 
233
  frame_timestamp = frame_count / fps if fps > 0 else 0
234
  timestamp_str = f"{int(frame_timestamp // 60)}:{int(frame_timestamp % 60):02d}"
@@ -236,6 +243,7 @@ def process_video(video, resize_width=4000, resize_height=3000, frame_skip=5):
236
  gps_coord = [17.385044 + (frame_count * 0.0001), 78.486671 + (frame_count * 0.0001)]
237
  gps_coordinates.append(gps_coord)
238
 
 
239
  frame_detections = []
240
  for detection in results[0].boxes:
241
  cls = int(detection.cls)
@@ -282,6 +290,8 @@ def process_video(video, resize_width=4000, resize_height=3000, frame_skip=5):
282
  "frame": frame_count
283
  })
284
 
 
 
285
  out.write(annotated_frame)
286
  output_frame_count += 1
287
  last_annotated_frame = annotated_frame
@@ -295,20 +305,15 @@ def process_video(video, resize_width=4000, resize_height=3000, frame_skip=5):
295
 
296
  frame_time = (time.time() - frame_start) * 1000
297
  frame_times.append(frame_time)
298
-
299
- detection_summary = {
300
- "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
301
- "video_timestamp": timestamp_str,
302
- "frame": frame_count,
303
- "gps": gps_coord,
304
- "processing_time_ms": frame_time,
305
- "detections": {label: sum(1 for det in frame_detections if det["label"] == label) for label in DETECTION_CLASSES}
306
- }
307
- data_lake_submission["analytics"].append(detection_summary)
308
- log_entries.append(json.dumps(detection_summary, indent=2))
309
  if len(log_entries) > 50:
310
  log_entries.pop(0)
311
 
 
 
 
 
 
312
  while output_frame_count < total_frames and last_annotated_frame is not None:
313
  out.write(last_annotated_frame)
314
  output_frame_count += 1
@@ -339,16 +344,23 @@ def process_video(video, resize_width=4000, resize_height=3000, frame_skip=5):
339
 
340
  total_time = time.time() - start_time
341
  avg_frame_time = sum(frame_times) / len(frame_times) if frame_times else 0
 
 
 
342
  log_entries.append(f"Output video: {output_frames} frames, {output_fps:.2f} FPS, {output_duration:.2f} seconds")
343
  logging.info(f"Output video: {output_frames} frames, {output_fps:.2f} FPS, {output_duration:.2f} seconds")
344
- log_entries.append(f"Total processing time: {total_time:.2f} seconds, Avg frame time: {avg_frame_time:.2f} ms, Detection frames: {detection_frame_count}, Output frames: {output_frame_count}")
345
- logging.info(f"Total processing time: {total_time:.2f} seconds, Avg frame time: {avg_frame_time:.2f} ms, Detection frames: {detection_frame_count}, Output frames: {output_frame_count}")
346
  print(f"Output video: {output_frames} frames, {output_fps:.2f} FPS, {output_duration:.2f} seconds")
347
  print(f"Total processing time: {total_time:.2f} seconds, Avg frame time: {avg_frame_time:.2f} ms, Detection frames: {detection_frame_count}, Output frames: {output_frame_count}")
348
 
349
  chart_path = generate_line_chart()
350
  map_path = generate_map(gps_coordinates[-5:], all_detections)
351
 
 
 
 
 
352
  return (
353
  output_path,
354
  json.dumps(last_metrics, indent=2),
@@ -356,7 +368,10 @@ def process_video(video, resize_width=4000, resize_height=3000, frame_skip=5):
356
  detected_issues,
357
  chart_path,
358
  map_path,
359
- generate_download_zip() # Provide the zip link for all outputs
 
 
 
360
  )
361
 
362
  # Gradio interface
@@ -379,13 +394,30 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange")) as iface:
379
  map_output = gr.Image(label="Issue Locations Map")
380
  with gr.Row():
381
  logs_output = gr.Textbox(label="Logs", lines=5, interactive=False)
382
- zip_download = gr.File(label="Download All Outputs (ZIP)")
 
 
 
 
 
 
383
 
384
  process_btn.click(
385
- process_video,
386
  inputs=[video_input, width_slider, height_slider, skip_slider],
387
- outputs=[video_output, metrics_output, logs_output, issue_gallery, chart_output, map_output, zip_download]
 
 
 
 
 
 
 
 
 
 
 
388
  )
389
 
390
  if __name__ == "__main__":
391
- iface.launch()
 
14
  import ultralytics
15
  import time
16
  import piexif
17
+ import zipfile
18
 
19
  # Set YOLO config directory
20
  os.environ["YOLO_CONFIG_DIR"] = "/tmp/Ultralytics"
 
44
  gps_coordinates: List[List[float]] = []
45
  last_metrics: Dict[str, Any] = {}
46
  frame_count: int = 0
47
+ SAVE_IMAGE_INTERVAL = 1
48
 
49
+ # Detection classes
50
  DETECTION_CLASSES = ["Longitudinal", "Pothole", "Transverse"]
51
 
52
  # Debug: Check environment
 
60
  print(f"Using device: {device}")
61
  model = YOLO('./data/best.pt').to(device)
62
  if device == "cuda":
63
+ model.half()
64
  print(f"Model classes: {model.names}")
65
 
66
+ def zip_directory(folder_path: str, zip_path: str) -> str:
67
+ """Zip all files in a directory."""
68
+ try:
69
+ with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
70
+ for root, _, files in os.walk(folder_path):
71
+ for file in files:
72
+ file_path = os.path.join(root, file)
73
+ arcname = os.path.relpath(file_path, folder_path)
74
+ zipf.write(file_path, arcname)
75
+ return zip_path
76
+ except Exception as e:
77
+ logging.error(f"Failed to zip {folder_path}: {str(e)}")
78
+ log_entries.append(f"Error: Failed to zip {folder_path}: {str(e)}")
79
+ return ""
80
+
81
  def generate_map(gps_coords: List[List[float]], items: List[Dict[str, Any]]) -> str:
82
  map_path = os.path.join(OUTPUT_DIR, "map_temp.png")
83
  plt.figure(figsize=(4, 4))
 
126
  def check_image_quality(frame: np.ndarray, input_resolution: int) -> bool:
127
  height, width, _ = frame.shape
128
  frame_resolution = width * height
129
+ if frame_resolution < 12_000_000:
130
  log_entries.append(f"Frame {frame_count}: Resolution {width}x{height} ({frame_resolution/1e6:.2f}MP) below 12MP, non-compliant")
131
+ if frame_resolution < input_resolution:
132
  log_entries.append(f"Frame {frame_count}: Output resolution {width}x{height} below input resolution")
133
  return False
134
  return True
 
156
  plt.close()
157
  return chart_path
158
 
 
 
 
 
159
  def process_video(video, resize_width=4000, resize_height=3000, frame_skip=5):
160
  global frame_count, last_metrics, detected_counts, detected_issues, gps_coordinates, log_entries
161
  frame_count = 0
 
168
  if video is None:
169
  log_entries.append("Error: No video uploaded")
170
  logging.error("No video uploaded")
171
+ return None, json.dumps({"error": "No video uploaded"}, indent=2), "\n".join(log_entries), [], None, None, None, None, None, None
172
 
173
  start_time = time.time()
174
  cap = cv2.VideoCapture(video)
175
  if not cap.isOpened():
176
  log_entries.append("Error: Could not open video file")
177
  logging.error("Could not open video file")
178
+ return None, json.dumps({"error": "Could not open video file"}, indent=2), "\n".join(log_entries), [], None, None, None, None, None, None
179
 
180
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
181
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
 
183
  fps = cap.get(cv2.CAP_PROP_FPS)
184
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
185
  expected_duration = total_frames / fps if fps > 0 else 0
186
+ log_entries.append(f"Input video: {frame_width}x{frame_height} ({input_resolution/1e6:.2f}MP), {fps} FPS, {total_frames} frames, {expected_duration:.2f} seconds, Frame skip: {frame_skip}")
187
+ logging.info(f"Input video: {frame_width}x{frame_height} ({input_resolution/1e6:.2f}MP), {fps} FPS, {total_frames} frames, {expected_duration:.2f} seconds, Frame skip: {frame_skip}")
188
+ print(f"Input video: {frame_width}x{frame_height} ({input_resolution/1e6:.2f}MP), {fps} FPS, {total_frames} frames, {expected_duration:.2f} seconds, Frame skip: {frame_skip}")
189
 
190
  out_width, out_height = resize_width, resize_height
191
  output_path = os.path.join(OUTPUT_DIR, "processed_output.mp4")
192
+ out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (out_width, out_height))
193
+ if not out.isOpened():
194
+ log_entries.append("Error: Failed to initialize mp4v codec")
195
+ logging.error("Failed to initialize mp4v codec")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  cap.release()
197
+ return None, json.dumps({"error": "mp4v codec failed"}, indent=2), "\n".join(log_entries), [], None, None, None, None, None, None
198
 
199
  processed_frames = 0
200
  all_detections = []
201
  frame_times = []
202
+ inference_times = []
203
+ resize_times = []
204
+ io_times = []
205
  detection_frame_count = 0
206
  output_frame_count = 0
207
  last_annotated_frame = None
 
222
  processed_frames += 1
223
  frame_start = time.time()
224
 
225
+ # Resize
226
+ resize_start = time.time()
227
  frame = cv2.resize(frame, (out_width, out_height))
228
+ resize_times.append((time.time() - resize_start) * 1000)
229
+
230
  if not check_image_quality(frame, input_resolution):
231
  log_entries.append(f"Frame {frame_count}: Skipped due to low resolution")
232
  continue
233
 
234
+ # Inference
235
+ inference_start = time.time()
236
  results = model(frame, verbose=False, conf=0.5, iou=0.7)
237
  annotated_frame = results[0].plot()
238
+ inference_times.append((time.time() - inference_start) * 1000)
239
 
240
  frame_timestamp = frame_count / fps if fps > 0 else 0
241
  timestamp_str = f"{int(frame_timestamp // 60)}:{int(frame_timestamp % 60):02d}"
 
243
  gps_coord = [17.385044 + (frame_count * 0.0001), 78.486671 + (frame_count * 0.0001)]
244
  gps_coordinates.append(gps_coord)
245
 
246
+ io_start = time.time()
247
  frame_detections = []
248
  for detection in results[0].boxes:
249
  cls = int(detection.cls)
 
290
  "frame": frame_count
291
  })
292
 
293
+ io_times.append((time.time() - io_start) * 1000)
294
+
295
  out.write(annotated_frame)
296
  output_frame_count += 1
297
  last_annotated_frame = annotated_frame
 
305
 
306
  frame_time = (time.time() - frame_start) * 1000
307
  frame_times.append(frame_time)
308
+ log_entries.append(f"Frame {frame_count}: Processed in {frame_time:.2f} ms (Resize: {resize_times[-1]:.2f} ms, Inference: {inference_times[-1]:.2f} ms, I/O: {io_times[-1]:.2f} ms)")
 
 
 
 
 
 
 
 
 
 
309
  if len(log_entries) > 50:
310
  log_entries.pop(0)
311
 
312
+ if time.time() - start_time > 600:
313
+ log_entries.append("Error: Processing timeout after 600 seconds")
314
+ logging.error("Processing timeout after 600 seconds")
315
+ break
316
+
317
  while output_frame_count < total_frames and last_annotated_frame is not None:
318
  out.write(last_annotated_frame)
319
  output_frame_count += 1
 
344
 
345
  total_time = time.time() - start_time
346
  avg_frame_time = sum(frame_times) / len(frame_times) if frame_times else 0
347
+ avg_resize_time = sum(resize_times) / len(resize_times) if resize_times else 0
348
+ avg_inference_time = sum(inference_times) / len(inference_times) if inference_times else 0
349
+ avg_io_time = sum(io_times) / len(io_times) if io_times else 0
350
  log_entries.append(f"Output video: {output_frames} frames, {output_fps:.2f} FPS, {output_duration:.2f} seconds")
351
  logging.info(f"Output video: {output_frames} frames, {output_fps:.2f} FPS, {output_duration:.2f} seconds")
352
+ log_entries.append(f"Total processing time: {total_time:.2f} seconds, Avg frame time: {avg_frame_time:.2f} ms (Avg Resize: {avg_resize_time:.2f} ms, Avg Inference: {avg_inference_time:.2f} ms, Avg I/O: {avg_io_time:.2f} ms), Detection frames: {detection_frame_count}, Output frames: {output_frame_count}")
353
+ logging.info(f"Total processing time: {total_time:.2f} seconds, Avg frame time: {avg_frame_time:.2f} ms (Avg Resize: {avg_resize_time:.2f} ms, Avg Inference: {avg_inference_time:.2f} ms, Avg I/O: {avg_io_time:.2f} ms), Detection frames: {detection_frame_count}, Output frames: {output_frame_count}")
354
  print(f"Output video: {output_frames} frames, {output_fps:.2f} FPS, {output_duration:.2f} seconds")
355
  print(f"Total processing time: {total_time:.2f} seconds, Avg frame time: {avg_frame_time:.2f} ms, Detection frames: {detection_frame_count}, Output frames: {output_frame_count}")
356
 
357
  chart_path = generate_line_chart()
358
  map_path = generate_map(gps_coordinates[-5:], all_detections)
359
 
360
+ # Zip images and logs
361
+ images_zip = zip_directory(CAPTURED_FRAMES_DIR, os.path.join(OUTPUT_DIR, "captured_frames.zip"))
362
+ logs_zip = zip_directory(FLIGHT_LOG_DIR, os.path.join(OUTPUT_DIR, "flight_logs.zip"))
363
+
364
  return (
365
  output_path,
366
  json.dumps(last_metrics, indent=2),
 
368
  detected_issues,
369
  chart_path,
370
  map_path,
371
+ submission_json_path,
372
+ images_zip,
373
+ logs_zip,
374
+ output_path # For video download
375
  )
376
 
377
  # Gradio interface
 
394
  map_output = gr.Image(label="Issue Locations Map")
395
  with gr.Row():
396
  logs_output = gr.Textbox(label="Logs", lines=5, interactive=False)
397
+ with gr.Row():
398
+ gr.Markdown("## Download Results")
399
+ with gr.Row():
400
+ json_download = gr.File(label="Download Data Lake JSON")
401
+ images_zip_download = gr.File(label="Download Geotagged Images (ZIP)")
402
+ logs_zip_download = gr.File(label="Download Flight Logs (ZIP)")
403
+ video_download = gr.File(label="Download Processed Video")
404
 
405
  process_btn.click(
406
+ fn=process_video,
407
  inputs=[video_input, width_slider, height_slider, skip_slider],
408
+ outputs=[
409
+ video_output,
410
+ metrics_output,
411
+ logs_output,
412
+ issue_gallery,
413
+ chart_output,
414
+ map_output,
415
+ json_download,
416
+ images_zip_download,
417
+ logs_zip_download,
418
+ video_download
419
+ ]
420
  )
421
 
422
  if __name__ == "__main__":
423
+ iface.launch()