Spaces:
Sleeping
Sleeping
[debug] zeroGPU
Browse files- rosetta/baseline/multi_stage.py +6 -10
- rosetta/utils/evaluate.py +8 -8
rosetta/baseline/multi_stage.py
CHANGED
|
@@ -59,24 +59,20 @@ class TwoStageInference:
|
|
| 59 |
if context_path == "google/gemma-3-1b-it":
|
| 60 |
torch._dynamo.config.cache_size_limit = 64
|
| 61 |
self.context_model = AutoModelForCausalLM.from_pretrained(
|
| 62 |
-
context_path, torch_dtype=torch.bfloat16, sliding_window=4096
|
| 63 |
-
|
| 64 |
-
).to(self.device)
|
| 65 |
else:
|
| 66 |
self.context_model = AutoModelForCausalLM.from_pretrained(
|
| 67 |
-
context_path,
|
| 68 |
-
|
| 69 |
-
# device_map={"": self.device}
|
| 70 |
-
).to(self.device)
|
| 71 |
# Apply generation config to context model
|
| 72 |
apply_generation_config(self.context_model, self.generation_config)
|
| 73 |
|
| 74 |
# Load answer LLM
|
| 75 |
self.answer_tokenizer = AutoTokenizer.from_pretrained(answer_path)
|
| 76 |
self.answer_model = AutoModelForCausalLM.from_pretrained(
|
| 77 |
-
answer_path, torch_dtype=torch.bfloat16,
|
| 78 |
-
|
| 79 |
-
).to(self.device)
|
| 80 |
# Apply generation config to answer model
|
| 81 |
apply_generation_config(self.answer_model, self.generation_config)
|
| 82 |
|
|
|
|
| 59 |
if context_path == "google/gemma-3-1b-it":
|
| 60 |
torch._dynamo.config.cache_size_limit = 64
|
| 61 |
self.context_model = AutoModelForCausalLM.from_pretrained(
|
| 62 |
+
context_path, torch_dtype=torch.bfloat16, device_map={"": self.device}, sliding_window=4096
|
| 63 |
+
)
|
|
|
|
| 64 |
else:
|
| 65 |
self.context_model = AutoModelForCausalLM.from_pretrained(
|
| 66 |
+
context_path, torch_dtype=torch.bfloat16, device_map={"": self.device}
|
| 67 |
+
)
|
|
|
|
|
|
|
| 68 |
# Apply generation config to context model
|
| 69 |
apply_generation_config(self.context_model, self.generation_config)
|
| 70 |
|
| 71 |
# Load answer LLM
|
| 72 |
self.answer_tokenizer = AutoTokenizer.from_pretrained(answer_path)
|
| 73 |
self.answer_model = AutoModelForCausalLM.from_pretrained(
|
| 74 |
+
answer_path, torch_dtype=torch.bfloat16, device_map={"": self.device}
|
| 75 |
+
)
|
|
|
|
| 76 |
# Apply generation config to answer model
|
| 77 |
apply_generation_config(self.answer_model, self.generation_config)
|
| 78 |
|
rosetta/utils/evaluate.py
CHANGED
|
@@ -313,8 +313,8 @@ def load_hf_model(model_name: str, device: torch.device, generation_config: Opti
|
|
| 313 |
model = AutoModelForCausalLM.from_pretrained(
|
| 314 |
str(model_name),
|
| 315 |
torch_dtype=torch.bfloat16,
|
| 316 |
-
|
| 317 |
-
).eval()
|
| 318 |
|
| 319 |
# Apply generation config
|
| 320 |
apply_generation_config(model, generation_config)
|
|
@@ -352,8 +352,8 @@ def load_rosetta_model(model_config: Dict[str, Any], eval_config: Dict[str, Any]
|
|
| 352 |
slm_model = AutoModelForCausalLM.from_pretrained(
|
| 353 |
str(slm_model_path),
|
| 354 |
torch_dtype=torch.bfloat16,
|
| 355 |
-
|
| 356 |
-
).eval()
|
| 357 |
|
| 358 |
# Apply generation config to SLM
|
| 359 |
apply_generation_config(slm_model, generation_config)
|
|
@@ -362,15 +362,15 @@ def load_rosetta_model(model_config: Dict[str, Any], eval_config: Dict[str, Any]
|
|
| 362 |
llm_model = AutoModelForCausalLM.from_pretrained(
|
| 363 |
str(llm_model_path),
|
| 364 |
torch_dtype=torch.bfloat16,
|
| 365 |
-
|
| 366 |
sliding_window=4096
|
| 367 |
-
).eval()
|
| 368 |
else:
|
| 369 |
llm_model = AutoModelForCausalLM.from_pretrained(
|
| 370 |
str(llm_model_path),
|
| 371 |
torch_dtype=torch.bfloat16,
|
| 372 |
-
|
| 373 |
-
).eval()
|
| 374 |
|
| 375 |
# Apply generation config to LLM
|
| 376 |
apply_generation_config(llm_model, generation_config)
|
|
|
|
| 313 |
model = AutoModelForCausalLM.from_pretrained(
|
| 314 |
str(model_name),
|
| 315 |
torch_dtype=torch.bfloat16,
|
| 316 |
+
device_map={"": device}
|
| 317 |
+
).eval()
|
| 318 |
|
| 319 |
# Apply generation config
|
| 320 |
apply_generation_config(model, generation_config)
|
|
|
|
| 352 |
slm_model = AutoModelForCausalLM.from_pretrained(
|
| 353 |
str(slm_model_path),
|
| 354 |
torch_dtype=torch.bfloat16,
|
| 355 |
+
device_map={"": device}
|
| 356 |
+
).eval()
|
| 357 |
|
| 358 |
# Apply generation config to SLM
|
| 359 |
apply_generation_config(slm_model, generation_config)
|
|
|
|
| 362 |
llm_model = AutoModelForCausalLM.from_pretrained(
|
| 363 |
str(llm_model_path),
|
| 364 |
torch_dtype=torch.bfloat16,
|
| 365 |
+
device_map={"": device},
|
| 366 |
sliding_window=4096
|
| 367 |
+
).eval()
|
| 368 |
else:
|
| 369 |
llm_model = AutoModelForCausalLM.from_pretrained(
|
| 370 |
str(llm_model_path),
|
| 371 |
torch_dtype=torch.bfloat16,
|
| 372 |
+
device_map={"": device}
|
| 373 |
+
).eval()
|
| 374 |
|
| 375 |
# Apply generation config to LLM
|
| 376 |
apply_generation_config(llm_model, generation_config)
|