Spaces:
Runtime error
Runtime error
Upload glide_text2im/download.py
Browse files- glide_text2im/download.py +71 -0
glide_text2im/download.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from functools import lru_cache
|
3 |
+
from typing import Dict, Optional
|
4 |
+
|
5 |
+
import requests
|
6 |
+
import torch as th
|
7 |
+
from filelock import FileLock
|
8 |
+
from tqdm.auto import tqdm
|
9 |
+
|
10 |
+
MODEL_PATHS = {
|
11 |
+
"base": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt",
|
12 |
+
"upsample": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt",
|
13 |
+
"base-inpaint": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base_inpaint.pt",
|
14 |
+
"upsample-inpaint": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample_inpaint.pt",
|
15 |
+
"clip/image-enc": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/clip_image_enc.pt",
|
16 |
+
"clip/text-enc": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/clip_text_enc.pt",
|
17 |
+
}
|
18 |
+
|
19 |
+
|
20 |
+
@lru_cache()
|
21 |
+
def default_cache_dir() -> str:
|
22 |
+
return os.path.join(os.path.abspath(os.getcwd()), "glide_model_cache")
|
23 |
+
|
24 |
+
|
25 |
+
def fetch_file_cached(
|
26 |
+
url: str, progress: bool = True, cache_dir: Optional[str] = None, chunk_size: int = 4096
|
27 |
+
) -> str:
|
28 |
+
"""
|
29 |
+
Download the file at the given URL into a local file and return the path.
|
30 |
+
|
31 |
+
If cache_dir is specified, it will be used to download the files.
|
32 |
+
Otherwise, default_cache_dir() is used.
|
33 |
+
"""
|
34 |
+
if cache_dir is None:
|
35 |
+
cache_dir = default_cache_dir()
|
36 |
+
os.makedirs(cache_dir, exist_ok=True)
|
37 |
+
response = requests.get(url, stream=True)
|
38 |
+
size = int(response.headers.get("content-length", "0"))
|
39 |
+
local_path = os.path.join(cache_dir, url.split("/")[-1])
|
40 |
+
with FileLock(local_path + ".lock"):
|
41 |
+
if os.path.exists(local_path):
|
42 |
+
return local_path
|
43 |
+
if progress:
|
44 |
+
pbar = tqdm(total=size, unit="iB", unit_scale=True)
|
45 |
+
tmp_path = local_path + ".tmp"
|
46 |
+
with open(tmp_path, "wb") as f:
|
47 |
+
for chunk in response.iter_content(chunk_size):
|
48 |
+
if progress:
|
49 |
+
pbar.update(len(chunk))
|
50 |
+
f.write(chunk)
|
51 |
+
os.rename(tmp_path, local_path)
|
52 |
+
if progress:
|
53 |
+
pbar.close()
|
54 |
+
return local_path
|
55 |
+
|
56 |
+
|
57 |
+
def load_checkpoint(
|
58 |
+
checkpoint_name: str,
|
59 |
+
device: th.device,
|
60 |
+
progress: bool = True,
|
61 |
+
cache_dir: Optional[str] = None,
|
62 |
+
chunk_size: int = 4096,
|
63 |
+
) -> Dict[str, th.Tensor]:
|
64 |
+
if checkpoint_name not in MODEL_PATHS:
|
65 |
+
raise ValueError(
|
66 |
+
f"Unknown checkpoint name {checkpoint_name}. Known names are: {MODEL_PATHS.keys()}."
|
67 |
+
)
|
68 |
+
path = fetch_file_cached(
|
69 |
+
MODEL_PATHS[checkpoint_name], progress=progress, cache_dir=cache_dir, chunk_size=chunk_size
|
70 |
+
)
|
71 |
+
return th.load(path, map_location=device)
|