Spaces:
Sleeping
Sleeping
Update
Browse files- .gitignore +0 -1
- .gitmodules +3 -0
- ELITE +1 -0
- README.md +3 -3
- model.py +4 -16
.gitignore
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
ELITE/
|
2 |
|
3 |
# Byte-compiled / optimized / DLL files
|
4 |
__pycache__/
|
|
|
|
|
1 |
|
2 |
# Byte-compiled / optimized / DLL files
|
3 |
__pycache__/
|
.gitmodules
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "ELITE"]
|
2 |
+
path = ELITE
|
3 |
+
url = https://huggingface.co/ELITE-library/ELITE
|
ELITE
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit 9f563c699684b8b44358b0ab2f5dafd0a5af24b1
|
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
title: ELITE
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.20.1
|
8 |
app_file: app.py
|
|
|
1 |
---
|
2 |
title: ELITE
|
3 |
+
emoji: π
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: green
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.20.1
|
8 |
app_file: app.py
|
model.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
-
import os
|
4 |
import pathlib
|
5 |
import random
|
6 |
import sys
|
@@ -15,17 +14,11 @@ import torch.nn.functional as F
|
|
15 |
import torchvision.transforms as T
|
16 |
import tqdm.auto
|
17 |
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
|
18 |
-
from huggingface_hub import hf_hub_download
|
19 |
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
|
20 |
|
21 |
-
HF_TOKEN = os.getenv('HF_TOKEN')
|
22 |
-
|
23 |
repo_dir = pathlib.Path(__file__).parent
|
24 |
submodule_dir = repo_dir / 'ELITE'
|
25 |
-
snapshot_download('ELITE-library/ELITE',
|
26 |
-
repo_type='model',
|
27 |
-
local_dir=submodule_dir.as_posix(),
|
28 |
-
token=HF_TOKEN)
|
29 |
sys.path.insert(0, submodule_dir.as_posix())
|
30 |
|
31 |
from train_local import (Mapper, MapperLocal, inj_forward_crossattention,
|
@@ -64,13 +57,11 @@ class Model:
|
|
64 |
global_mapper_path = hf_hub_download('ELITE-library/ELITE',
|
65 |
'global_mapper.pt',
|
66 |
subfolder='checkpoints',
|
67 |
-
repo_type='model'
|
68 |
-
token=HF_TOKEN)
|
69 |
local_mapper_path = hf_hub_download('ELITE-library/ELITE',
|
70 |
'local_mapper.pt',
|
71 |
subfolder='checkpoints',
|
72 |
-
repo_type='model'
|
73 |
-
token=HF_TOKEN)
|
74 |
return global_mapper_path, local_mapper_path
|
75 |
|
76 |
def load_model(
|
@@ -139,10 +130,7 @@ class Model:
|
|
139 |
mapper_local.add_module(f'{_name.replace(".", "_")}_to_k',
|
140 |
to_k_local)
|
141 |
|
142 |
-
|
143 |
-
global_mapper_path = submodule_dir / 'checkpoints/global_mapper.pt'
|
144 |
-
local_mapper_path = submodule_dir / 'checkpoints/local_mapper.pt'
|
145 |
-
|
146 |
mapper.load_state_dict(
|
147 |
torch.load(global_mapper_path, map_location='cpu'))
|
148 |
mapper.half()
|
|
|
1 |
from __future__ import annotations
|
2 |
|
|
|
3 |
import pathlib
|
4 |
import random
|
5 |
import sys
|
|
|
14 |
import torchvision.transforms as T
|
15 |
import tqdm.auto
|
16 |
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
|
17 |
+
from huggingface_hub import hf_hub_download
|
18 |
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
|
19 |
|
|
|
|
|
20 |
repo_dir = pathlib.Path(__file__).parent
|
21 |
submodule_dir = repo_dir / 'ELITE'
|
|
|
|
|
|
|
|
|
22 |
sys.path.insert(0, submodule_dir.as_posix())
|
23 |
|
24 |
from train_local import (Mapper, MapperLocal, inj_forward_crossattention,
|
|
|
57 |
global_mapper_path = hf_hub_download('ELITE-library/ELITE',
|
58 |
'global_mapper.pt',
|
59 |
subfolder='checkpoints',
|
60 |
+
repo_type='model')
|
|
|
61 |
local_mapper_path = hf_hub_download('ELITE-library/ELITE',
|
62 |
'local_mapper.pt',
|
63 |
subfolder='checkpoints',
|
64 |
+
repo_type='model')
|
|
|
65 |
return global_mapper_path, local_mapper_path
|
66 |
|
67 |
def load_model(
|
|
|
130 |
mapper_local.add_module(f'{_name.replace(".", "_")}_to_k',
|
131 |
to_k_local)
|
132 |
|
133 |
+
global_mapper_path, local_mapper_path = self.download_mappers()
|
|
|
|
|
|
|
134 |
mapper.load_state_dict(
|
135 |
torch.load(global_mapper_path, map_location='cpu'))
|
136 |
mapper.half()
|