Spaces:
Running
Running
Updated classifier
Browse files- app.py +3 -3
- models/flake_monolayer_classifier.pth +3 -0
app.py
CHANGED
@@ -13,7 +13,7 @@ from huggingface_hub import hf_hub_download
|
|
13 |
|
14 |
|
15 |
class FlakeLayerClassifier(nn.Module):
|
16 |
-
def __init__(self, num_materials, material_dim, num_classes=
|
17 |
super().__init__()
|
18 |
self.cnn = ResNetModel.from_pretrained("microsoft/resnet-18")
|
19 |
if freeze_cnn:
|
@@ -72,8 +72,8 @@ print(f"Using device: {device}")
|
|
72 |
# Load YOLO detector
|
73 |
yolo = YOLO("models/uark_detector_v3.pt")
|
74 |
|
75 |
-
# Load classifier model checkpoint
|
76 |
-
ckpt_path = "models/
|
77 |
ckpt = torch.load(ckpt_path, map_location=device)
|
78 |
|
79 |
num_classes = len(ckpt["class_to_idx"])
|
|
|
13 |
|
14 |
|
15 |
class FlakeLayerClassifier(nn.Module):
|
16 |
+
def __init__(self, num_materials, material_dim, num_classes=2, dropout_prob=0.1, freeze_cnn=False):
|
17 |
super().__init__()
|
18 |
self.cnn = ResNetModel.from_pretrained("microsoft/resnet-18")
|
19 |
if freeze_cnn:
|
|
|
72 |
# Load YOLO detector
|
73 |
yolo = YOLO("models/uark_detector_v3.pt")
|
74 |
|
75 |
+
# Load classifier model checkpoint
|
76 |
+
ckpt_path = "models/flake_monolayer_classifier.pth"
|
77 |
ckpt = torch.load(ckpt_path, map_location=device)
|
78 |
|
79 |
num_classes = len(ckpt["class_to_idx"])
|
models/flake_monolayer_classifier.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cc4e98bf4dd3127970ca7c68b633d29aa523a0a66a2d6481bf346d7662dbe7b8
|
3 |
+
size 47191055
|