fuvty commited on
Commit
3cb5609
·
1 Parent(s): 04539d0

[debug] zeroGPU

Browse files
rosetta/baseline/multi_stage.py CHANGED
@@ -59,24 +59,20 @@ class TwoStageInference:
59
  if context_path == "google/gemma-3-1b-it":
60
  torch._dynamo.config.cache_size_limit = 64
61
  self.context_model = AutoModelForCausalLM.from_pretrained(
62
- context_path, torch_dtype=torch.bfloat16, sliding_window=4096,
63
- # device_map={"": self.device},
64
- ).to(self.device)
65
  else:
66
  self.context_model = AutoModelForCausalLM.from_pretrained(
67
- context_path,
68
- torch_dtype=torch.bfloat16,
69
- # device_map={"": self.device}
70
- ).to(self.device)
71
  # Apply generation config to context model
72
  apply_generation_config(self.context_model, self.generation_config)
73
 
74
  # Load answer LLM
75
  self.answer_tokenizer = AutoTokenizer.from_pretrained(answer_path)
76
  self.answer_model = AutoModelForCausalLM.from_pretrained(
77
- answer_path, torch_dtype=torch.bfloat16,
78
- # device_map={"": self.device}
79
- ).to(self.device)
80
  # Apply generation config to answer model
81
  apply_generation_config(self.answer_model, self.generation_config)
82
 
 
59
  if context_path == "google/gemma-3-1b-it":
60
  torch._dynamo.config.cache_size_limit = 64
61
  self.context_model = AutoModelForCausalLM.from_pretrained(
62
+ context_path, torch_dtype=torch.bfloat16, device_map={"": self.device}, sliding_window=4096
63
+ )
 
64
  else:
65
  self.context_model = AutoModelForCausalLM.from_pretrained(
66
+ context_path, torch_dtype=torch.bfloat16, device_map={"": self.device}
67
+ )
 
 
68
  # Apply generation config to context model
69
  apply_generation_config(self.context_model, self.generation_config)
70
 
71
  # Load answer LLM
72
  self.answer_tokenizer = AutoTokenizer.from_pretrained(answer_path)
73
  self.answer_model = AutoModelForCausalLM.from_pretrained(
74
+ answer_path, torch_dtype=torch.bfloat16, device_map={"": self.device}
75
+ )
 
76
  # Apply generation config to answer model
77
  apply_generation_config(self.answer_model, self.generation_config)
78
 
rosetta/utils/evaluate.py CHANGED
@@ -313,8 +313,8 @@ def load_hf_model(model_name: str, device: torch.device, generation_config: Opti
313
  model = AutoModelForCausalLM.from_pretrained(
314
  str(model_name),
315
  torch_dtype=torch.bfloat16,
316
- # device_map={"": device}
317
- ).eval().to(device)
318
 
319
  # Apply generation config
320
  apply_generation_config(model, generation_config)
@@ -352,8 +352,8 @@ def load_rosetta_model(model_config: Dict[str, Any], eval_config: Dict[str, Any]
352
  slm_model = AutoModelForCausalLM.from_pretrained(
353
  str(slm_model_path),
354
  torch_dtype=torch.bfloat16,
355
- # device_map={"": device}
356
- ).eval().to(device)
357
 
358
  # Apply generation config to SLM
359
  apply_generation_config(slm_model, generation_config)
@@ -362,15 +362,15 @@ def load_rosetta_model(model_config: Dict[str, Any], eval_config: Dict[str, Any]
362
  llm_model = AutoModelForCausalLM.from_pretrained(
363
  str(llm_model_path),
364
  torch_dtype=torch.bfloat16,
365
- # device_map={"": device},
366
  sliding_window=4096
367
- ).eval().to(device)
368
  else:
369
  llm_model = AutoModelForCausalLM.from_pretrained(
370
  str(llm_model_path),
371
  torch_dtype=torch.bfloat16,
372
- # device_map={"": device}
373
- ).eval().to(device)
374
 
375
  # Apply generation config to LLM
376
  apply_generation_config(llm_model, generation_config)
 
313
  model = AutoModelForCausalLM.from_pretrained(
314
  str(model_name),
315
  torch_dtype=torch.bfloat16,
316
+ device_map={"": device}
317
+ ).eval()
318
 
319
  # Apply generation config
320
  apply_generation_config(model, generation_config)
 
352
  slm_model = AutoModelForCausalLM.from_pretrained(
353
  str(slm_model_path),
354
  torch_dtype=torch.bfloat16,
355
+ device_map={"": device}
356
+ ).eval()
357
 
358
  # Apply generation config to SLM
359
  apply_generation_config(slm_model, generation_config)
 
362
  llm_model = AutoModelForCausalLM.from_pretrained(
363
  str(llm_model_path),
364
  torch_dtype=torch.bfloat16,
365
+ device_map={"": device},
366
  sliding_window=4096
367
+ ).eval()
368
  else:
369
  llm_model = AutoModelForCausalLM.from_pretrained(
370
  str(llm_model_path),
371
  torch_dtype=torch.bfloat16,
372
+ device_map={"": device}
373
+ ).eval()
374
 
375
  # Apply generation config to LLM
376
  apply_generation_config(llm_model, generation_config)