sanpdy commited on
Commit
14abc40
·
1 Parent(s): d2ad9ac

Updated classifier

Browse files
Files changed (2) hide show
  1. app.py +3 -3
  2. models/flake_monolayer_classifier.pth +3 -0
app.py CHANGED
@@ -13,7 +13,7 @@ from huggingface_hub import hf_hub_download
13
 
14
 
15
  class FlakeLayerClassifier(nn.Module):
16
- def __init__(self, num_materials, material_dim, num_classes=4, dropout_prob=0.1, freeze_cnn=False):
17
  super().__init__()
18
  self.cnn = ResNetModel.from_pretrained("microsoft/resnet-18")
19
  if freeze_cnn:
@@ -72,8 +72,8 @@ print(f"Using device: {device}")
72
  # Load YOLO detector
73
  yolo = YOLO("models/uark_detector_v3.pt")
74
 
75
- # Load classifier model checkpoint with proper device mapping
76
- ckpt_path = "models/flake_classifier_5layer.pth"
77
  ckpt = torch.load(ckpt_path, map_location=device)
78
 
79
  num_classes = len(ckpt["class_to_idx"])
 
13
 
14
 
15
  class FlakeLayerClassifier(nn.Module):
16
+ def __init__(self, num_materials, material_dim, num_classes=2, dropout_prob=0.1, freeze_cnn=False):
17
  super().__init__()
18
  self.cnn = ResNetModel.from_pretrained("microsoft/resnet-18")
19
  if freeze_cnn:
 
72
  # Load YOLO detector
73
  yolo = YOLO("models/uark_detector_v3.pt")
74
 
75
+ # Load classifier model checkpoint
76
+ ckpt_path = "models/flake_monolayer_classifier.pth"
77
  ckpt = torch.load(ckpt_path, map_location=device)
78
 
79
  num_classes = len(ckpt["class_to_idx"])
models/flake_monolayer_classifier.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc4e98bf4dd3127970ca7c68b633d29aa523a0a66a2d6481bf346d7662dbe7b8
3
+ size 47191055