Enferlain commited on
Commit
04cf1bc
·
verified ·
1 Parent(s): df8c6ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -30
app.py CHANGED
@@ -110,8 +110,9 @@ MODEL_CATALOG = {
110
  "config_filename": "AnatomyFlaws-v15.5_dinov3_7b_bnb_fl.config.json",
111
  "head_filename": "AnatomyFlaws-v15.5_dinov3_7b_bnb_fl_s3K_best_val.safetensors",
112
  # Explicitly define the vision model repo ID to prevent errors
113
- # "vision_model_repo_id": "Enferlain/dinov3-vit7b16-pretrain-lvd1689m-8bit"
114
- "vision_model_repo_id": "Enferlain/dinov3-vit7b16-pretrain-lvd1689m-int4",
 
115
  },
116
  "AnatomyFlaws-v14.7 (SigLIP naflex)": {
117
  "repo_id": "Enferlain/lumi-classifier",
@@ -125,61 +126,89 @@ MODEL_CATALOG = {
125
  # --- Model Manager Class ---
126
  class ModelManager:
127
  def __init__(self, catalog: Dict[str, Dict[str, str]]):
128
- self.catalog = catalog; self.current_model_name: str = None; self.vision_model: nn.Module = None
129
- self.hf_processor: Any = None; self.head_model: HybridHeadModel = None
130
- self.labels: Dict[int, str] = None; self.config: Dict[str, Any] = None
 
 
 
 
131
 
132
  def load_model(self, model_name: str):
133
- if model_name == self.current_model_name: return
134
- if model_name not in self.catalog: raise ValueError(f"Model '{model_name}' not found.")
 
 
 
135
  print(f"Switching to model: {model_name}...")
136
-
137
  model_info = self.catalog[model_name]
138
  repo_id = model_info["repo_id"]
139
  config_filename = model_info["config_filename"]
140
  head_filename = model_info["head_filename"]
141
  vision_model_repo_id = model_info["vision_model_repo_id"]
142
-
143
  try:
144
  config_path = hf_hub_download(repo_id=repo_id, filename=config_filename)
145
- with open(config_path, 'r', encoding='utf-8') as f: self.config = json.load(f)
146
-
 
147
  print(f"Loading vision model: {vision_model_repo_id}")
148
-
149
  self.hf_processor = AutoProcessor.from_pretrained(vision_model_repo_id, trust_remote_code=True)
150
-
151
- # --- NEW: Correct loading logic for INT4 vs. standard models ---
152
- if "int4" in vision_model_repo_id.lower():
153
- print("INT4 model detected. Loading for CPU.")
154
- self.vision_model = AutoModel.from_pretrained(
155
- vision_model_repo_id,
156
- torch_dtype=torch.float32,
157
- device_map="cpu", # Force to CPU
158
- trust_remote_code=True
159
- ).eval()
160
- else: # Standard model loading (for SigLIP or GPU environments)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  self.vision_model = AutoModel.from_pretrained(
162
  vision_model_repo_id,
163
  torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
164
  ).to(DEVICE).eval()
165
-
166
- # The rest of the function continues as before
167
  head_model_path = hf_hub_download(repo_id=repo_id, filename=head_filename)
168
  print(f"Loading head model: {head_filename}")
169
  state_dict = load_file(head_model_path, device='cpu')
170
  head_params = self.config.get("predictor_params", self.config)
171
  self.head_model = HybridHeadModel(
172
- features=head_params.get("features"), hidden_dim=head_params.get("hidden_dim"),
173
- num_classes=self.config.get("num_classes"), use_attention=head_params.get("use_attention"),
174
- num_attn_heads=head_params.get("num_attn_heads"), attn_dropout=head_params.get("attn_dropout"),
175
- num_res_blocks=head_params.get("num_res_blocks"), dropout_rate=head_params.get("dropout_rate"),
176
- output_mode=head_params.get("output_mode", "linear"))
 
 
 
 
 
177
  self.head_model.load_state_dict(state_dict, strict=True)
178
  self.head_model.to(DEVICE).eval()
 
179
  raw_labels = self.config.get("labels", {'0': 'Bad', '1': 'Good'})
180
  self.labels = {int(k): (v['name'] if isinstance(v, dict) else v) for k, v in raw_labels.items()}
181
  self.current_model_name = model_name
182
  print(f"Successfully loaded '{model_name}'.")
 
183
  except Exception as e:
184
  self.current_model_name = None
185
  raise RuntimeError(f"Failed to load model '{model_name}': {e}\n{traceback.format_exc()}")
 
110
  "config_filename": "AnatomyFlaws-v15.5_dinov3_7b_bnb_fl.config.json",
111
  "head_filename": "AnatomyFlaws-v15.5_dinov3_7b_bnb_fl_s3K_best_val.safetensors",
112
  # Explicitly define the vision model repo ID to prevent errors
113
+ # "vision_model_repo_id": "Enferlain/dinov3-vit7b16-pretrain-lvd1689m-8bit" bnb 8bit
114
+ # "vision_model_repo_id": "Enferlain/dinov3-vit7b16-pretrain-lvd1689m-int4", int4
115
+ "vision_model_repo_id": "facebook/dinov3-vit7b16-pretrain-lvd1689m",
116
  },
117
  "AnatomyFlaws-v14.7 (SigLIP naflex)": {
118
  "repo_id": "Enferlain/lumi-classifier",
 
126
  # --- Model Manager Class ---
127
  class ModelManager:
128
  def __init__(self, catalog: Dict[str, Dict[str, str]]):
129
+ self.catalog = catalog
130
+ self.current_model_name: str = None
131
+ self.vision_model: nn.Module = None
132
+ self.hf_processor: Any = None
133
+ self.head_model: HybridHeadModel = None
134
+ self.labels: Dict[int, str] = None
135
+ self.config: Dict[str, Any] = None
136
 
137
  def load_model(self, model_name: str):
138
+ if model_name == self.current_model_name:
139
+ return
140
+ if model_name not in self.catalog:
141
+ raise ValueError(f"Model '{model_name}' not found.")
142
+
143
  print(f"Switching to model: {model_name}...")
144
+
145
  model_info = self.catalog[model_name]
146
  repo_id = model_info["repo_id"]
147
  config_filename = model_info["config_filename"]
148
  head_filename = model_info["head_filename"]
149
  vision_model_repo_id = model_info["vision_model_repo_id"]
150
+
151
  try:
152
  config_path = hf_hub_download(repo_id=repo_id, filename=config_filename)
153
+ with open(config_path, 'r', encoding='utf-8') as f:
154
+ self.config = json.load(f)
155
+
156
  print(f"Loading vision model: {vision_model_repo_id}")
 
157
  self.hf_processor = AutoProcessor.from_pretrained(vision_model_repo_id, trust_remote_code=True)
158
+
159
+ # --- UPDATED: CPU-compatible loading logic ---
160
+ if DEVICE == "cpu":
161
+ # For CPU, load unquantized model with BF16 (original format)
162
+ print("Loading unquantized model for CPU...")
163
+ try:
164
+ self.vision_model = AutoModel.from_pretrained(
165
+ vision_model_repo_id,
166
+ torch_dtype=torch.bfloat16, # Keep original BF16 format
167
+ device_map={"": "cpu"}, # Force CPU device mapping
168
+ trust_remote_code=True
169
+ ).eval()
170
+ print("Successfully loaded model in BF16 format.")
171
+ except Exception as bf16_error:
172
+ print(f"BF16 loading failed: {bf16_error}")
173
+ print("Falling back to FP32...")
174
+ self.vision_model = AutoModel.from_pretrained(
175
+ vision_model_repo_id,
176
+ torch_dtype=torch.float32, # Fallback to FP32
177
+ device_map={"": "cpu"},
178
+ trust_remote_code=True
179
+ ).eval()
180
+ print("Successfully loaded model in FP32 format.")
181
+ else:
182
+ # For GPU environments (unchanged)
183
  self.vision_model = AutoModel.from_pretrained(
184
  vision_model_repo_id,
185
  torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
186
  ).to(DEVICE).eval()
187
+
188
+ # Load classifier head (unchanged)
189
  head_model_path = hf_hub_download(repo_id=repo_id, filename=head_filename)
190
  print(f"Loading head model: {head_filename}")
191
  state_dict = load_file(head_model_path, device='cpu')
192
  head_params = self.config.get("predictor_params", self.config)
193
  self.head_model = HybridHeadModel(
194
+ features=head_params.get("features"),
195
+ hidden_dim=head_params.get("hidden_dim"),
196
+ num_classes=self.config.get("num_classes"),
197
+ use_attention=head_params.get("use_attention"),
198
+ num_attn_heads=head_params.get("num_attn_heads"),
199
+ attn_dropout=head_params.get("attn_dropout"),
200
+ num_res_blocks=head_params.get("num_res_blocks"),
201
+ dropout_rate=head_params.get("dropout_rate"),
202
+ output_mode=head_params.get("output_mode", "linear")
203
+ )
204
  self.head_model.load_state_dict(state_dict, strict=True)
205
  self.head_model.to(DEVICE).eval()
206
+
207
  raw_labels = self.config.get("labels", {'0': 'Bad', '1': 'Good'})
208
  self.labels = {int(k): (v['name'] if isinstance(v, dict) else v) for k, v in raw_labels.items()}
209
  self.current_model_name = model_name
210
  print(f"Successfully loaded '{model_name}'.")
211
+
212
  except Exception as e:
213
  self.current_model_name = None
214
  raise RuntimeError(f"Failed to load model '{model_name}': {e}\n{traceback.format_exc()}")