Update eval_wrapper/eval.py
Browse files- eval_wrapper/eval.py +1 -1
eval_wrapper/eval.py
CHANGED
@@ -23,7 +23,7 @@ from eval_wrapper.eval_utils import filter_all_masks
|
|
23 |
from huggingface_hub import hf_hub_download
|
24 |
|
25 |
class EvalWrapper(torch.nn.Module):
|
26 |
-
def __init__(self,checkpoint_path,distributed=False,device="
|
27 |
super().__init__()
|
28 |
checkpoint = torch.load(checkpoint_path, map_location='cpu',weights_only=False)
|
29 |
model_string = checkpoint['args'].model
|
|
|
23 |
from huggingface_hub import hf_hub_download
|
24 |
|
25 |
class EvalWrapper(torch.nn.Module):
|
26 |
+
def __init__(self,checkpoint_path,distributed=False,device="cpu",dtype=torch.float32,**kwargs):
|
27 |
super().__init__()
|
28 |
checkpoint = torch.load(checkpoint_path, map_location='cpu',weights_only=False)
|
29 |
model_string = checkpoint['args'].model
|