csepartha commited on
Commit
dd9433f
·
verified ·
1 Parent(s): e3c9e71

Upload 3 files

Browse files
Files changed (3) hide show
  1. best.pt +3 -0
  2. main.py +89 -0
  3. requirements.txt +6 -0
best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9dd2707a610464e2b7c338bee1bf31cc68c76d429b52894a679078776b5c2380
3
+ size 3211515
main.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Partha Pratim Ray
2
+
3
+ import os
4
+ import cv2
5
+ import numpy as np
6
+ from ultralytics import YOLO
7
+ from PIL import Image
8
+ import gradio as gr
9
+ import pandas as pd
10
+
11
+ # Paths
12
+ model_path = "best.pt" # Ensure the best.pt is in the local directory or provide full path
13
+
14
+ if not os.path.exists(model_path):
15
+ raise FileNotFoundError(f"Model file not found at {model_path}.")
16
+
17
+ # Load the YOLO model
18
+ model = YOLO(model_path)
19
+
20
+ ##################################
21
+ # Hardcoded metrics for classification
22
+ overall_top1_accuracy = 0.9142 # Replace with your Top-1 accuracy
23
+ overall_top5_accuracy = 0.9926 # Replace with your Top-5 accuracy
24
+
25
+ # Metrics DataFrame
26
+ metrics_data = [
27
+ ["Overall Top-1 Accuracy", f"{overall_top1_accuracy * 100:.2f}%"],
28
+ ["Overall Top-5 Accuracy", f"{overall_top5_accuracy * 100:.2f}%"]
29
+ ]
30
+ metrics_df = pd.DataFrame(metrics_data, columns=["Metric", "Value"])
31
+ ##################################
32
+
33
+ def run_inference(img: np.ndarray, model):
34
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
35
+ results = model.predict(img_rgb) # Run prediction
36
+ result_probs = results[0].probs
37
+
38
+ # Get top-1 and top-5 predictions
39
+ top1_class = result_probs.top1
40
+ top5_classes = result_probs.top5
41
+ top1_conf = result_probs.top1conf.item()
42
+ top5_conf = result_probs.top5conf
43
+
44
+ # Generate annotated image
45
+ annotated_img = results[0].plot()
46
+
47
+ # Format results
48
+ top1_result = f"Class: {model.names[top1_class]}, Confidence: {top1_conf:.2f}"
49
+ top5_results = [
50
+ f"{model.names[c]}: {conf:.2f}" for c, conf in zip(top5_classes, top5_conf)
51
+ ]
52
+
53
+ return annotated_img, top1_result, top5_results
54
+
55
+ def process_image(image):
56
+ img = np.array(image)
57
+ img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
58
+
59
+ # Run classification inference
60
+ annotated_img, top1_result, top5_results = run_inference(img_bgr, model)
61
+
62
+ # Convert annotated image back to PIL format
63
+ annotated_img_pil = Image.fromarray(cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB))
64
+
65
+ # Return the annotated image, Top-1, and Top-5 predictions, along with metrics
66
+ return annotated_img_pil, f"Top-1: {top1_result}", "\n".join(top5_results), metrics_df
67
+
68
+ with gr.Blocks() as demo:
69
+ gr.Markdown("# YOLO Dog ImageWoof Classification Web App")
70
+ gr.Markdown("Upload an image, and the model will classify it and show precomputed validation metrics.")
71
+
72
+ with gr.Row():
73
+ with gr.Column():
74
+ input_image = gr.Image(type="pil", label="Upload Image")
75
+ submit_btn = gr.Button("Run Inference")
76
+ with gr.Column():
77
+ annotated_image = gr.Image(type="pil", label="Annotated Image") # Updated to show annotated image
78
+ top1_output = gr.Textbox(label="Top-1 Prediction")
79
+ top5_output = gr.Textbox(label="Top-5 Predictions")
80
+ metrics_table = gr.DataFrame(value=metrics_df, label="Validation Metrics")
81
+
82
+ submit_btn.click(
83
+ fn=process_image,
84
+ inputs=input_image,
85
+ outputs=[annotated_image, top1_output, top5_output, metrics_table]
86
+ )
87
+
88
+ demo.launch()
89
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ numpy
2
+ opencv-python
3
+ Pillow
4
+ gradio
5
+ pandas
6
+ ultralytics