Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -105,15 +105,19 @@ class HybridHeadModel(nn.Module):
|
|
| 105 |
|
| 106 |
# --- Model Catalog ---
|
| 107 |
MODEL_CATALOG = {
|
| 108 |
-
"AnatomyFlaws-v15.5 (DINOv3 7b 8-bit)": {
|
| 109 |
"repo_id": "Enferlain/lumi-classifier",
|
| 110 |
"config_filename": "AnatomyFlaws-v15.5_dinov3_7b_bnb_fl.config.json",
|
| 111 |
-
"head_filename": "AnatomyFlaws-v15.5_dinov3_7b_bnb_fl_s4K.safetensors"
|
|
|
|
|
|
|
| 112 |
},
|
| 113 |
"AnatomyFlaws-v14.7 (SigLIP naflex)": {
|
| 114 |
"repo_id": "Enferlain/lumi-classifier",
|
| 115 |
"config_filename": "AnatomyFlaws-v14.7_adabelief_fl_naflex_4670.config.json",
|
| 116 |
-
"head_filename": "AnatomyFlaws-v14.7_adabelief_fl_naflex_4670_s2K.safetensors"
|
|
|
|
|
|
|
| 117 |
},
|
| 118 |
}
|
| 119 |
|
|
@@ -128,32 +132,32 @@ class ModelManager:
|
|
| 128 |
if model_name == self.current_model_name: return
|
| 129 |
if model_name not in self.catalog: raise ValueError(f"Model '{model_name}' not found.")
|
| 130 |
print(f"Switching to model: {model_name}...")
|
|
|
|
| 131 |
model_info = self.catalog[model_name]
|
| 132 |
-
repo_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
try:
|
| 134 |
config_path = hf_hub_download(repo_id=repo_id, filename=config_filename)
|
| 135 |
with open(config_path, 'r', encoding='utf-8') as f: self.config = json.load(f)
|
| 136 |
|
| 137 |
-
|
| 138 |
-
print(f"Loading vision model: {base_vision_model_name}")
|
| 139 |
|
| 140 |
-
|
|
|
|
| 141 |
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
processor_name = base_vision_model_name if is_dinov3_8bit else self.config.get("base_vision_model")
|
| 146 |
-
self.hf_processor = AutoProcessor.from_pretrained(processor_name, trust_remote_code=True) # <-- THE ONLY CHANGE IS HERE
|
| 147 |
-
|
| 148 |
-
if is_dinov3_8bit:
|
| 149 |
-
self.vision_model = AutoModel.from_pretrained(
|
| 150 |
-
base_vision_model_name, load_in_8bit=True, trust_remote_code=True
|
| 151 |
-
).eval()
|
| 152 |
else:
|
| 153 |
self.vision_model = AutoModel.from_pretrained(
|
| 154 |
-
|
| 155 |
).to(DEVICE).eval()
|
| 156 |
|
|
|
|
| 157 |
head_model_path = hf_hub_download(repo_id=repo_id, filename=head_filename)
|
| 158 |
print(f"Loading head model: {head_filename}")
|
| 159 |
state_dict = load_file(head_model_path, device='cpu')
|
|
|
|
| 105 |
|
| 106 |
# --- Model Catalog ---
|
| 107 |
MODEL_CATALOG = {
|
| 108 |
+
"AnatomyFlaws-v15.5 (DINOv3 7b 8-bit)": {
|
| 109 |
"repo_id": "Enferlain/lumi-classifier",
|
| 110 |
"config_filename": "AnatomyFlaws-v15.5_dinov3_7b_bnb_fl.config.json",
|
| 111 |
+
"head_filename": "AnatomyFlaws-v15.5_dinov3_7b_bnb_fl_s4K.safetensors",
|
| 112 |
+
# Explicitly define the vision model repo ID to prevent errors
|
| 113 |
+
"vision_model_repo_id": "Enferlain/dinov3-vit7b16-pretrain-lvd1689m-8bit"
|
| 114 |
},
|
| 115 |
"AnatomyFlaws-v14.7 (SigLIP naflex)": {
|
| 116 |
"repo_id": "Enferlain/lumi-classifier",
|
| 117 |
"config_filename": "AnatomyFlaws-v14.7_adabelief_fl_naflex_4670.config.json",
|
| 118 |
+
"head_filename": "AnatomyFlaws-v14.7_adabelief_fl_naflex_4670_s2K.safetensors",
|
| 119 |
+
# The base SigLIP model is not custom, so we use its official ID
|
| 120 |
+
"vision_model_repo_id": "google/siglip2-so400m-patch16-naflex"
|
| 121 |
},
|
| 122 |
}
|
| 123 |
|
|
|
|
| 132 |
if model_name == self.current_model_name: return
|
| 133 |
if model_name not in self.catalog: raise ValueError(f"Model '{model_name}' not found.")
|
| 134 |
print(f"Switching to model: {model_name}...")
|
| 135 |
+
|
| 136 |
model_info = self.catalog[model_name]
|
| 137 |
+
repo_id = model_info["repo_id"]
|
| 138 |
+
config_filename = model_info["config_filename"]
|
| 139 |
+
head_filename = model_info["head_filename"]
|
| 140 |
+
# --- NEW: Use the reliable repo ID from our catalog ---
|
| 141 |
+
vision_model_repo_id = model_info["vision_model_repo_id"]
|
| 142 |
+
|
| 143 |
try:
|
| 144 |
config_path = hf_hub_download(repo_id=repo_id, filename=config_filename)
|
| 145 |
with open(config_path, 'r', encoding='utf-8') as f: self.config = json.load(f)
|
| 146 |
|
| 147 |
+
print(f"Loading vision model: {vision_model_repo_id}")
|
|
|
|
| 148 |
|
| 149 |
+
# Load processor and model using our trusted repo ID
|
| 150 |
+
self.hf_processor = AutoProcessor.from_pretrained(vision_model_repo_id, trust_remote_code=True)
|
| 151 |
|
| 152 |
+
is_8bit_model = "8bit" in vision_model_repo_id
|
| 153 |
+
if is_8bit_model:
|
| 154 |
+
self.vision_model = AutoModel.from_pretrained(vision_model_repo_id, load_in_8bit=True, trust_remote_code=True).eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
else:
|
| 156 |
self.vision_model = AutoModel.from_pretrained(
|
| 157 |
+
vision_model_repo_id, torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
|
| 158 |
).to(DEVICE).eval()
|
| 159 |
|
| 160 |
+
# The rest of the function continues as before
|
| 161 |
head_model_path = hf_hub_download(repo_id=repo_id, filename=head_filename)
|
| 162 |
print(f"Loading head model: {head_filename}")
|
| 163 |
state_dict = load_file(head_model_path, device='cpu')
|