zino36 commited on
Commit
82fa24a
·
verified ·
1 Parent(s): 7a6bafa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -1
app.py CHANGED
@@ -11,6 +11,7 @@ import torch
11
  import tempfile
12
  from gradio_imageslider import ImageSlider
13
  from huggingface_hub import hf_hub_download
 
14
 
15
  from depth_anything_v2.dpt import DepthAnythingV2
16
 
@@ -66,9 +67,19 @@ def load_file(filename):
66
  offset = n + 8
67
  return {name: create_tensor(storage, info, offset) for name, info in metadata.items() if name != "__metadata__"}
68
 
 
 
 
 
 
 
 
 
 
69
 
70
  #state_dict = torch.load(filepath, map_location="cpu", weights_only=True)
71
- state_dict = load_file(filepath)
 
72
 
73
  model.load_state_dict(state_dict)
74
  model = model.to(DEVICE).eval()
 
11
  import tempfile
12
  from gradio_imageslider import ImageSlider
13
  from huggingface_hub import hf_hub_download
14
+ import safetensors
15
 
16
  from depth_anything_v2.dpt import DepthAnythingV2
17
 
 
67
  offset = n + 8
68
  return {name: create_tensor(storage, info, offset) for name, info in metadata.items() if name != "__metadata__"}
69
 
70
+ tensor_data = safetensors.load(filepath)
71
+
72
+ # Convert to PyTorch tensor
73
+ if isinstance(tensor_data, np.ndarray):
74
+ pytorch_tensor = torch.tensor(tensor_data)
75
+ elif isinstance(tensor_data, safetensors.Tensor):
76
+ pytorch_tensor = torch.tensor(tensor_data.numpy()) # Assuming safetensors Tensor has a .numpy() method
77
+ else:
78
+ raise TypeError("Unsupported data type from safetensors")
79
 
80
  #state_dict = torch.load(filepath, map_location="cpu", weights_only=True)
81
+ #state_dict = load_file(filepath)
82
+ state_dict = pytorch_tensor
83
 
84
  model.load_state_dict(state_dict)
85
  model = model.to(DEVICE).eval()