csepartha commited on
Commit
bb945a5
·
verified ·
1 Parent(s): 4f9d127

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -13
app.py CHANGED
@@ -31,37 +31,52 @@ metrics_df = pd.DataFrame(metrics_data, columns=["Metric", "Value"])
31
  ##################################
32
 
33
  def run_inference(img: np.ndarray, model):
 
 
 
 
 
34
  img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
35
- results = model.predict(img_rgb) # Run prediction
 
 
 
 
36
  result_probs = results[0].probs
37
-
38
  # Get top-1 and top-5 predictions
39
  top1_class = result_probs.top1
40
  top5_classes = result_probs.top5
41
  top1_conf = result_probs.top1conf.item()
42
  top5_conf = result_probs.top5conf
43
-
44
- # Generate annotated image
45
- annotated_img = results[0].plot()
46
-
47
  # Format results
48
  top1_result = f"Class: {model.names[top1_class]}, Confidence: {top1_conf:.2f}"
49
  top5_results = [
50
  f"{model.names[c]}: {conf:.2f}" for c, conf in zip(top5_classes, top5_conf)
51
  ]
52
-
53
  return annotated_img, top1_result, top5_results
54
 
55
  def process_image(image):
 
 
 
 
56
  img = np.array(image)
 
 
57
  img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
58
-
59
  # Run classification inference
60
  annotated_img, top1_result, top5_results = run_inference(img_bgr, model)
61
-
62
- # Convert annotated image back to PIL format
63
- annotated_img_pil = Image.fromarray(cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB))
64
-
65
  # Return the annotated image, Top-1, and Top-5 predictions, along with metrics
66
  return annotated_img_pil, f"Top-1: {top1_result}", "\n".join(top5_results), metrics_df
67
 
@@ -74,7 +89,7 @@ with gr.Blocks() as demo:
74
  input_image = gr.Image(type="pil", label="Upload Image")
75
  submit_btn = gr.Button("Run Inference")
76
  with gr.Column():
77
- annotated_image = gr.Image(type="pil", label="Annotated Image") # Updated to show annotated image
78
  top1_output = gr.Textbox(label="Top-1 Prediction")
79
  top5_output = gr.Textbox(label="Top-5 Predictions")
80
  metrics_table = gr.DataFrame(value=metrics_df, label="Validation Metrics")
 
31
  ##################################
32
 
33
  def run_inference(img: np.ndarray, model):
34
+ """
35
+ Runs inference on the input image using the YOLO model.
36
+ Returns the annotated image, Top-1 prediction, and Top-5 predictions.
37
+ """
38
+ # Convert from BGR to RGB
39
  img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
40
+
41
+ # Run prediction
42
+ results = model.predict(img_rgb)
43
+
44
+ # Extract probabilities (if available)
45
  result_probs = results[0].probs
46
+
47
  # Get top-1 and top-5 predictions
48
  top1_class = result_probs.top1
49
  top5_classes = result_probs.top5
50
  top1_conf = result_probs.top1conf.item()
51
  top5_conf = result_probs.top5conf
52
+
53
+ # Generate annotated image (RGB format)
54
+ annotated_img = results[0].plot() # Assuming this returns RGB
55
+
56
  # Format results
57
  top1_result = f"Class: {model.names[top1_class]}, Confidence: {top1_conf:.2f}"
58
  top5_results = [
59
  f"{model.names[c]}: {conf:.2f}" for c, conf in zip(top5_classes, top5_conf)
60
  ]
61
+
62
  return annotated_img, top1_result, top5_results
63
 
64
  def process_image(image):
65
+ """
66
+ Processes the input image, runs inference, and prepares the outputs.
67
+ """
68
+ # Convert PIL Image to NumPy array
69
  img = np.array(image)
70
+
71
+ # Convert from RGB to BGR for OpenCV (if needed by YOLO)
72
  img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
73
+
74
  # Run classification inference
75
  annotated_img, top1_result, top5_results = run_inference(img_bgr, model)
76
+
77
+ # Convert annotated image back to PIL format without altering color channels
78
+ annotated_img_pil = Image.fromarray(annotated_img) # Assuming annotated_img is in RGB
79
+
80
  # Return the annotated image, Top-1, and Top-5 predictions, along with metrics
81
  return annotated_img_pil, f"Top-1: {top1_result}", "\n".join(top5_results), metrics_df
82
 
 
89
  input_image = gr.Image(type="pil", label="Upload Image")
90
  submit_btn = gr.Button("Run Inference")
91
  with gr.Column():
92
+ annotated_image = gr.Image(type="pil", label="Annotated Image") # Shows annotated image in RGB
93
  top1_output = gr.Textbox(label="Top-1 Prediction")
94
  top5_output = gr.Textbox(label="Top-5 Predictions")
95
  metrics_table = gr.DataFrame(value=metrics_df, label="Validation Metrics")