Felix Konrad commited on
Commit
55226e5
Β·
1 Parent(s): 41c94d8

Please please work.

Browse files
Files changed (1) hide show
  1. app.py +51 -25
app.py CHANGED
@@ -78,24 +78,48 @@ def load_model(repo_id: str, revision: str = None):
78
  Works with any public repo_id.
79
  """
80
  try:
81
- # Clean up revision input (handle empty strings)
 
 
 
 
82
  if revision and revision.strip() == "":
83
  revision = None
84
-
85
- # Load model and processor directly (they handle caching automatically)
86
- model = AutoModel.from_pretrained(
87
- repo_id,
88
- revision=revision,
89
- cache_dir="./model_cache",
90
- trust_remote_code=True # Some models might need this
91
- )
92
 
93
- processor = AutoImageProcessor.from_pretrained(
94
- repo_id,
95
- revision=revision,
96
- cache_dir="./model_cache",
97
- trust_remote_code=True
98
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  # Move to appropriate device
101
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -104,7 +128,7 @@ def load_model(repo_id: str, revision: str = None):
104
 
105
  # Validate it's a Vision Transformer
106
  if not hasattr(model.config, 'patch_size'):
107
- return f"❌ Model '{repo_id}' doesn't appear to be a Vision Transformer (no patch_size in config)"
108
 
109
  # Update global state
110
  state["model"] = model
@@ -112,17 +136,19 @@ def load_model(repo_id: str, revision: str = None):
112
  state["repo_id"] = repo_id
113
  state["model_type"] = "custom"
114
 
115
- return f"βœ… Successfully loaded model '{repo_id}' on {device}"
 
116
 
117
- except OSError as e:
118
- if "Repository not found" in str(e):
119
- return f"❌ Repository '{repo_id}' not found. Please check the repo ID."
120
- elif "offline" in str(e).lower():
121
- return f"❌ Network error. Please check your internet connection."
122
- else:
123
- return f"❌ Error accessing model: {str(e)}"
124
  except Exception as e:
125
- return f"❌ Error loading model: {str(e)}"
 
 
 
 
 
 
 
 
126
 
127
  def display_image(image: Image):
128
  """
 
78
  Works with any public repo_id.
79
  """
80
  try:
81
+ # Clean up inputs
82
+ repo_id = repo_id.strip()
83
+ if not repo_id:
84
+ return "Please enter a model repo ID"
85
+
86
  if revision and revision.strip() == "":
87
  revision = None
 
 
 
 
 
 
 
 
88
 
89
+ # First try without cache_dir to avoid permission issues
90
+ try:
91
+ model = AutoModel.from_pretrained(
92
+ repo_id,
93
+ revision=revision,
94
+ trust_remote_code=True,
95
+ use_auth_token=False # Explicitly no auth for public models
96
+ )
97
+
98
+ processor = AutoImageProcessor.from_pretrained(
99
+ repo_id,
100
+ revision=revision,
101
+ trust_remote_code=True,
102
+ use_auth_token=False
103
+ )
104
+ except Exception as e1:
105
+ # If that fails, try with explicit cache directory
106
+ model = AutoModel.from_pretrained(
107
+ repo_id,
108
+ revision=revision,
109
+ cache_dir="/tmp/model_cache", # Use /tmp for better permissions
110
+ trust_remote_code=True,
111
+ use_auth_token=False,
112
+ local_files_only=False # Ensure we can download
113
+ )
114
+
115
+ processor = AutoImageProcessor.from_pretrained(
116
+ repo_id,
117
+ revision=revision,
118
+ cache_dir="/tmp/model_cache",
119
+ trust_remote_code=True,
120
+ use_auth_token=False,
121
+ local_files_only=False
122
+ )
123
 
124
  # Move to appropriate device
125
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
128
 
129
  # Validate it's a Vision Transformer
130
  if not hasattr(model.config, 'patch_size'):
131
+ return f"Model '{repo_id}' doesn't appear to be a Vision Transformer (no patch_size in config)"
132
 
133
  # Update global state
134
  state["model"] = model
 
136
  state["repo_id"] = repo_id
137
  state["model_type"] = "custom"
138
 
139
+ patch_size = model.config.patch_size
140
+ return f"Successfully loaded ViT model '{repo_id}' (patch size: {patch_size}) on {device}"
141
 
 
 
 
 
 
 
 
142
  except Exception as e:
143
+ error_str = str(e).lower()
144
+ if "repository not found" in error_str or "404" in error_str:
145
+ return f"Repository '{repo_id}' not found. Please check the repo ID."
146
+ elif "connection" in error_str or "network" in error_str or "offline" in error_str:
147
+ return f"Network error: {str(e)}"
148
+ elif "permission" in error_str or "forbidden" in error_str:
149
+ return f"Permission denied. This might be a private repository."
150
+ else:
151
+ return f"Error loading model: {str(e)}"
152
 
153
  def display_image(image: Image):
154
  """