Update eval_wrapper/eval.py
Browse files- eval_wrapper/eval.py +1 -1
eval_wrapper/eval.py
CHANGED
|
@@ -23,7 +23,7 @@ from eval_wrapper.eval_utils import filter_all_masks
|
|
| 23 |
from huggingface_hub import hf_hub_download
|
| 24 |
|
| 25 |
class EvalWrapper(torch.nn.Module):
|
| 26 |
-
def __init__(self,checkpoint_path,distributed=False,device="
|
| 27 |
super().__init__()
|
| 28 |
checkpoint = torch.load(checkpoint_path, map_location='cpu',weights_only=False)
|
| 29 |
model_string = checkpoint['args'].model
|
|
|
|
| 23 |
from huggingface_hub import hf_hub_download
|
| 24 |
|
| 25 |
class EvalWrapper(torch.nn.Module):
|
| 26 |
+
def __init__(self,checkpoint_path,distributed=False,device="cpu",dtype=torch.float32,**kwargs):
|
| 27 |
super().__init__()
|
| 28 |
checkpoint = torch.load(checkpoint_path, map_location='cpu',weights_only=False)
|
| 29 |
model_string = checkpoint['args'].model
|