bartduis commited on
Commit
072d3b2
·
1 Parent(s): e35a023
Files changed (1) hide show
  1. eval_wrapper/eval.py +1 -1
eval_wrapper/eval.py CHANGED
@@ -26,7 +26,7 @@ from huggingface_hub import hf_hub_download
26
  class EvalWrapper(torch.nn.Module):
27
  def __init__(self,checkpoint_path,distributed=False,device="cuda",dtype=torch.float32,**kwargs):
28
  super().__init__()
29
- checkpoint = torch.load(checkpoint_path, map_location='cpu')
30
  model_string = checkpoint['args'].model
31
 
32
  self.model = eval(model_string).to(device)
 
26
  class EvalWrapper(torch.nn.Module):
27
  def __init__(self,checkpoint_path,distributed=False,device="cuda",dtype=torch.float32,**kwargs):
28
  super().__init__()
29
+ checkpoint = torch.load(checkpoint_path, map_location='cpu',weights_only=False)
30
  model_string = checkpoint['args'].model
31
 
32
  self.model = eval(model_string).to(device)