Spaces:
Running
Running
Referencing local models
Browse files- app.py +6 -12
- models/flake_classifier_5layer.pth +3 -0
- models/uark_detector_v3.pt +3 -0
app.py
CHANGED
@@ -66,24 +66,18 @@ def calibration(source_img, target_img):
|
|
66 |
return corrected_img.astype(np.uint8)
|
67 |
|
68 |
|
69 |
-
device = torch.device("cuda
|
70 |
print(f"Using device: {device}")
|
71 |
|
72 |
# Load YOLO detector
|
73 |
#yolo = YOLO("/home/sankalp/flake_classification/models/best.pt")
|
74 |
#yolo = YOLO("/home/sankalp/yolo_flake_detection/yolo11n_synthetic_runs/exp1/weights/best.pt")
|
75 |
#yolo = YOLO("/home/sankalp/yolo_flake_detection/yolo_runs/yolo11l_flake_runs/weights/best.pt")
|
76 |
-
|
77 |
-
yolo = YOLO(
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
classifier_path = hf_hub_download(
|
82 |
-
repo_id="sanpdy/flake-classifier",
|
83 |
-
filename="flake_classifier_5layer.pth",
|
84 |
-
token=False
|
85 |
-
)
|
86 |
-
ckpt = torch.load(classifier_path, map_location=device)
|
87 |
|
88 |
num_classes = len(ckpt["class_to_idx"])
|
89 |
classifier = FlakeLayerClassifier(
|
|
|
66 |
return corrected_img.astype(np.uint8)
|
67 |
|
68 |
|
69 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
70 |
print(f"Using device: {device}")
|
71 |
|
72 |
# Load YOLO detector
|
73 |
#yolo = YOLO("/home/sankalp/flake_classification/models/best.pt")
|
74 |
#yolo = YOLO("/home/sankalp/yolo_flake_detection/yolo11n_synthetic_runs/exp1/weights/best.pt")
|
75 |
#yolo = YOLO("/home/sankalp/yolo_flake_detection/yolo_runs/yolo11l_flake_runs/weights/best.pt")
|
76 |
+
|
77 |
+
yolo = YOLO("models/uark_detector_v3.pt")
|
78 |
+
torch.load("models/flake_classifier_5layer.pth")
|
79 |
+
ckpt_path = "models/flake_classifier_5layer.pth"
|
80 |
+
ckpt = torch.load(ckpt_path, map_location=device)
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
num_classes = len(ckpt["class_to_idx"])
|
83 |
classifier = FlakeLayerClassifier(
|
models/flake_classifier_5layer.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9e5d8873b2708e53cdd0b603918728c75c86c4a219e92b4895d5cfc5c39cc275
|
3 |
+
size 47208000
|
models/uark_detector_v3.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ccfd0d79fc4e721474669054bb0eb4d2f6fb6e19f34e60a74d597947d9f66163
|
3 |
+
size 6132509
|