lcipolina commited on
Commit
9c9ed77
1 Parent(s): 19cc654

Upload glide_text2im/download.py

Browse files
Files changed (1) hide show
  1. glide_text2im/download.py +71 -0
glide_text2im/download.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from functools import lru_cache
3
+ from typing import Dict, Optional
4
+
5
+ import requests
6
+ import torch as th
7
+ from filelock import FileLock
8
+ from tqdm.auto import tqdm
9
+
10
+ MODEL_PATHS = {
11
+ "base": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base.pt",
12
+ "upsample": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample.pt",
13
+ "base-inpaint": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/base_inpaint.pt",
14
+ "upsample-inpaint": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/upsample_inpaint.pt",
15
+ "clip/image-enc": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/clip_image_enc.pt",
16
+ "clip/text-enc": "https://openaipublic.blob.core.windows.net/diffusion/dec-2021/clip_text_enc.pt",
17
+ }
18
+
19
+
20
+ @lru_cache()
21
+ def default_cache_dir() -> str:
22
+ return os.path.join(os.path.abspath(os.getcwd()), "glide_model_cache")
23
+
24
+
25
+ def fetch_file_cached(
26
+ url: str, progress: bool = True, cache_dir: Optional[str] = None, chunk_size: int = 4096
27
+ ) -> str:
28
+ """
29
+ Download the file at the given URL into a local file and return the path.
30
+
31
+ If cache_dir is specified, it will be used to download the files.
32
+ Otherwise, default_cache_dir() is used.
33
+ """
34
+ if cache_dir is None:
35
+ cache_dir = default_cache_dir()
36
+ os.makedirs(cache_dir, exist_ok=True)
37
+ response = requests.get(url, stream=True)
38
+ size = int(response.headers.get("content-length", "0"))
39
+ local_path = os.path.join(cache_dir, url.split("/")[-1])
40
+ with FileLock(local_path + ".lock"):
41
+ if os.path.exists(local_path):
42
+ return local_path
43
+ if progress:
44
+ pbar = tqdm(total=size, unit="iB", unit_scale=True)
45
+ tmp_path = local_path + ".tmp"
46
+ with open(tmp_path, "wb") as f:
47
+ for chunk in response.iter_content(chunk_size):
48
+ if progress:
49
+ pbar.update(len(chunk))
50
+ f.write(chunk)
51
+ os.rename(tmp_path, local_path)
52
+ if progress:
53
+ pbar.close()
54
+ return local_path
55
+
56
+
57
+ def load_checkpoint(
58
+ checkpoint_name: str,
59
+ device: th.device,
60
+ progress: bool = True,
61
+ cache_dir: Optional[str] = None,
62
+ chunk_size: int = 4096,
63
+ ) -> Dict[str, th.Tensor]:
64
+ if checkpoint_name not in MODEL_PATHS:
65
+ raise ValueError(
66
+ f"Unknown checkpoint name {checkpoint_name}. Known names are: {MODEL_PATHS.keys()}."
67
+ )
68
+ path = fetch_file_cached(
69
+ MODEL_PATHS[checkpoint_name], progress=progress, cache_dir=cache_dir, chunk_size=chunk_size
70
+ )
71
+ return th.load(path, map_location=device)