zino36 commited on
Commit
7a6bafa
·
verified ·
1 Parent(s): e8dcc74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -1
app.py CHANGED
@@ -4,6 +4,8 @@ import matplotlib
4
  import numpy as np
5
  import os
6
  from PIL import Image
 
 
7
  import spaces
8
  import torch
9
  import tempfile
@@ -43,7 +45,31 @@ encoder = 'vitl'
43
  model_name = encoder2name[encoder]
44
  model = DepthAnythingV2(**model_configs[encoder])
45
  filepath = hf_hub_download(repo_id="depth-anything/Depth-Anything-V2-Metric-Indoor-Large-hf", filename="model.safetensors", repo_type="model")
46
- state_dict = torch.load(filepath, map_location="cpu", weights_only=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  model.load_state_dict(state_dict)
48
  model = model.to(DEVICE).eval()
49
 
 
4
  import numpy as np
5
  import os
6
  from PIL import Image
7
+ import mmap
8
+ import json
9
  import spaces
10
  import torch
11
  import tempfile
 
45
  model_name = encoder2name[encoder]
46
  model = DepthAnythingV2(**model_configs[encoder])
47
  filepath = hf_hub_download(repo_id="depth-anything/Depth-Anything-V2-Metric-Indoor-Large-hf", filename="model.safetensors", repo_type="model")
48
+
49
+ def create_tensor(storage, info, offset):
50
+ DTYPES = {"F32": torch.float32}
51
+ dtype = DTYPES[info["dtype"]]
52
+ shape = info["shape"]
53
+ start, stop = info["data_offsets"]
54
+ return torch.asarray(storage[start + offset : stop + offset], dtype=torch.uint8).view(dtype=dtype).reshape(shape)
55
+
56
+ def load_file(filename):
57
+ with open(filename, mode="r", encoding="utf8") as file_obj:
58
+ with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as m:
59
+ header = m.read(8)
60
+ n = int.from_bytes(header, "little")
61
+ metadata_bytes = m.read(n)
62
+ metadata = json.loads(metadata_bytes)
63
+
64
+ size = os.stat(filename).st_size
65
+ storage = torch.ByteStorage.from_file(filename, shared=False, size=size).untyped()
66
+ offset = n + 8
67
+ return {name: create_tensor(storage, info, offset) for name, info in metadata.items() if name != "__metadata__"}
68
+
69
+
70
+ #state_dict = torch.load(filepath, map_location="cpu", weights_only=True)
71
+ state_dict = load_file(filepath)
72
+
73
  model.load_state_dict(state_dict)
74
  model = model.to(DEVICE).eval()
75