alidenewade commited on
Commit
c3644ec
·
verified ·
1 Parent(s): de6a098

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -30
app.py CHANGED
@@ -29,6 +29,11 @@ def get_quantization_config():
29
  Falls back gracefully if bitsandbytes is not available.
30
  """
31
  try:
 
 
 
 
 
32
  # 8-bit quantization configuration - good balance of speed and quality
33
  quantization_config = BitsAndBytesConfig(
34
  load_in_8bit=True,
@@ -64,42 +69,43 @@ def load_optimized_models():
64
  # Model names
65
  model_name = "seyonec/PubChem10M_SMILES_BPE_450k"
66
 
67
- # Load tokenizer (doesn't need quantization)
68
- fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
 
69
 
70
- # Load model with quantization if available
71
- model_kwargs = {
72
- "torch_dtype": torch_dtype,
73
- }
74
 
75
- if quantization_config is not None and torch.cuda.is_available(): # Quantization typically for GPU
76
- model_kwargs["quantization_config"] = quantization_config
77
- # device_map="auto" is often used with bitsandbytes for automatic distribution
78
- model_kwargs["device_map"] = "auto"
79
- elif torch.cuda.is_available():
80
- model_kwargs["device_map"] = "auto" # For non-quantized GPU loading
81
- else:
82
- model_kwargs["device_map"] = None # For CPU
83
 
84
- try:
85
  # Masked LM Model
86
  fill_mask_model = AutoModelForMaskedLM.from_pretrained(
87
  model_name,
88
  **model_kwargs
89
  )
90
 
 
 
 
 
91
  # Set model to evaluation mode for inference
92
  fill_mask_model.eval()
93
 
94
- # Create optimized pipeline
95
- # Let pipeline infer device from model if possible, or set based on model's device
96
- pipeline_device = fill_mask_model.device.index if hasattr(fill_mask_model.device, 'type') and fill_mask_model.device.type == "cuda" else -1
97
 
98
  fill_mask_pipeline = pipeline(
99
  'fill-mask',
100
  model=fill_mask_model,
101
  tokenizer=fill_mask_tokenizer,
102
- device=pipeline_device, # Use model's device
103
  )
104
 
105
  logger.info("Models loaded successfully with optimizations")
@@ -113,16 +119,31 @@ def load_optimized_models():
113
 
114
  def load_standard_models(model_name):
115
  """Fallback standard model loading without quantization."""
116
- fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
117
- fill_mask_model = AutoModelForMaskedLM.from_pretrained(model_name)
118
- # Determine device for standard loading
119
- device_idx = 0 if torch.cuda.is_available() else -1
120
- fill_mask_pipeline = pipeline('fill-mask', model=fill_mask_model, tokenizer=fill_mask_tokenizer, device=device_idx)
121
-
122
- if torch.cuda.is_available():
123
- fill_mask_model.to("cuda")
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
 
 
 
 
126
 
127
  # --- Memory Management Utilities ---
128
  def clear_gpu_cache():
@@ -163,7 +184,7 @@ def get_image_with_highlight(mol, atomset=None, size=(300, 300)):
163
  if atomset:
164
  try:
165
  valid_atomset = [int(a) for a in atomset]
166
- except ValueError:
167
  logger.warning(f"Invalid atom in atomset: {atomset}. Proceeding without highlighting problematic atoms.")
168
  valid_atomset = [int(a) for a in atomset if str(a).isdigit()] # Filter out non-integers
169
 
@@ -230,7 +251,11 @@ def predict_and_visualize_masked_smiles(smiles_mask, substructure_smarts_highlig
230
  """
231
  # Load models when needed
232
  try:
233
- fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline = load_optimized_models()
 
 
 
 
234
  except Exception as e:
235
  st.error(f"Error loading models: {str(e)}")
236
  return
 
29
  Falls back gracefully if bitsandbytes is not available.
30
  """
31
  try:
32
+ # Only use quantization on CUDA
33
+ if not torch.cuda.is_available():
34
+ logger.info("CUDA not available, skipping quantization")
35
+ return None
36
+
37
  # 8-bit quantization configuration - good balance of speed and quality
38
  quantization_config = BitsAndBytesConfig(
39
  load_in_8bit=True,
 
69
  # Model names
70
  model_name = "seyonec/PubChem10M_SMILES_BPE_450k"
71
 
72
+ try:
73
+ # Load tokenizer (doesn't need quantization)
74
+ fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
75
 
76
+ # Load model with quantization if available
77
+ model_kwargs = {
78
+ "torch_dtype": torch_dtype,
79
+ }
80
 
81
+ if quantization_config is not None and torch.cuda.is_available():
82
+ model_kwargs["quantization_config"] = quantization_config
83
+ model_kwargs["device_map"] = "auto"
84
+ else:
85
+ # For CPU or non-quantized loading
86
+ model_kwargs["device_map"] = None
 
 
87
 
 
88
  # Masked LM Model
89
  fill_mask_model = AutoModelForMaskedLM.from_pretrained(
90
  model_name,
91
  **model_kwargs
92
  )
93
 
94
+ # Move to device if not using device_map
95
+ if model_kwargs["device_map"] is None and torch.cuda.is_available():
96
+ fill_mask_model.to(device)
97
+
98
  # Set model to evaluation mode for inference
99
  fill_mask_model.eval()
100
 
101
+ # Create pipeline with proper device handling
102
+ pipeline_device = 0 if torch.cuda.is_available() else -1
 
103
 
104
  fill_mask_pipeline = pipeline(
105
  'fill-mask',
106
  model=fill_mask_model,
107
  tokenizer=fill_mask_tokenizer,
108
+ device=pipeline_device,
109
  )
110
 
111
  logger.info("Models loaded successfully with optimizations")
 
119
 
120
  def load_standard_models(model_name):
121
  """Fallback standard model loading without quantization."""
122
+ try:
123
+ fill_mask_tokenizer = AutoTokenizer.from_pretrained(model_name)
124
+ fill_mask_model = AutoModelForMaskedLM.from_pretrained(
125
+ model_name,
126
+ torch_dtype=torch.float32
127
+ )
128
+
129
+ # Determine device for standard loading
130
+ device_idx = 0 if torch.cuda.is_available() else -1
131
+
132
+ if torch.cuda.is_available():
133
+ fill_mask_model.to("cuda")
134
+
135
+ fill_mask_pipeline = pipeline(
136
+ 'fill-mask',
137
+ model=fill_mask_model,
138
+ tokenizer=fill_mask_tokenizer,
139
+ device=device_idx
140
+ )
141
 
142
+ return fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline
143
+ except Exception as e:
144
+ logger.error(f"Failed to load models: {e}")
145
+ st.error(f"Failed to load models: {e}")
146
+ return None, None, None
147
 
148
  # --- Memory Management Utilities ---
149
  def clear_gpu_cache():
 
184
  if atomset:
185
  try:
186
  valid_atomset = [int(a) for a in atomset]
187
+ except (ValueError, TypeError):
188
  logger.warning(f"Invalid atom in atomset: {atomset}. Proceeding without highlighting problematic atoms.")
189
  valid_atomset = [int(a) for a in atomset if str(a).isdigit()] # Filter out non-integers
190
 
 
251
  """
252
  # Load models when needed
253
  try:
254
+ models = load_optimized_models()
255
+ if models[0] is None: # Check if loading failed
256
+ st.error("Failed to load models. Please check the logs.")
257
+ return
258
+ fill_mask_tokenizer, fill_mask_model, fill_mask_pipeline = models
259
  except Exception as e:
260
  st.error(f"Error loading models: {str(e)}")
261
  return