Spaces:
Running
Running
Use extracted lightweight models
Browse files- app.py +3 -2
- model.py +60 -20
- requirements.txt +1 -0
app.py
CHANGED
|
@@ -63,9 +63,10 @@ with gr.Blocks(css='style.css') as demo:
|
|
| 63 |
create_demo_scribble(model.process_scribble, max_images=MAX_IMAGES)
|
| 64 |
with gr.TabItem('Scribble Interactive'):
|
| 65 |
create_demo_scribble_interactive(
|
| 66 |
-
model.process_scribble_interactive,
|
| 67 |
with gr.TabItem('Fake Scribble'):
|
| 68 |
-
create_demo_fake_scribble(model.process_fake_scribble,
|
|
|
|
| 69 |
with gr.TabItem('Pose'):
|
| 70 |
create_demo_pose(model.process_pose, max_images=MAX_IMAGES)
|
| 71 |
with gr.TabItem('Segmentation'):
|
|
|
|
| 63 |
create_demo_scribble(model.process_scribble, max_images=MAX_IMAGES)
|
| 64 |
with gr.TabItem('Scribble Interactive'):
|
| 65 |
create_demo_scribble_interactive(
|
| 66 |
+
model.process_scribble_interactive, max_images=MAX_IMAGES)
|
| 67 |
with gr.TabItem('Fake Scribble'):
|
| 68 |
+
create_demo_fake_scribble(model.process_fake_scribble,
|
| 69 |
+
max_images=MAX_IMAGES)
|
| 70 |
with gr.TabItem('Pose'):
|
| 71 |
create_demo_pose(model.process_pose, max_images=MAX_IMAGES)
|
| 72 |
with gr.TabItem('Segmentation'):
|
model.py
CHANGED
|
@@ -28,22 +28,36 @@ from cldm.model import create_model, load_state_dict
|
|
| 28 |
from ldm.models.diffusion.ddim import DDIMSampler
|
| 29 |
from share import *
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
-
class Model:
|
| 33 |
-
WEIGHT_NAMES = {
|
| 34 |
-
'canny': 'control_sd15_canny.pth',
|
| 35 |
-
'hough': 'control_sd15_mlsd.pth',
|
| 36 |
-
'hed': 'control_sd15_hed.pth',
|
| 37 |
-
'scribble': 'control_sd15_scribble.pth',
|
| 38 |
-
'pose': 'control_sd15_openpose.pth',
|
| 39 |
-
'seg': 'control_sd15_seg.pth',
|
| 40 |
-
'depth': 'control_sd15_depth.pth',
|
| 41 |
-
'normal': 'control_sd15_normal.pth',
|
| 42 |
-
}
|
| 43 |
|
|
|
|
| 44 |
def __init__(self,
|
| 45 |
model_config_path: str = 'ControlNet/models/cldm_v15.yaml',
|
| 46 |
-
model_dir: str = 'models'
|
|
|
|
| 47 |
self.device = torch.device(
|
| 48 |
'cuda:0' if torch.cuda.is_available() else 'cpu')
|
| 49 |
self.model = create_model(model_config_path).to(self.device)
|
|
@@ -51,31 +65,57 @@ class Model:
|
|
| 51 |
self.task_name = ''
|
| 52 |
|
| 53 |
self.model_dir = pathlib.Path(model_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
self.download_models()
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
def load_weight(self, task_name: str) -> None:
|
| 57 |
if task_name == self.task_name:
|
| 58 |
return
|
| 59 |
weight_path = self.get_weight_path(task_name)
|
| 60 |
-
self.
|
| 61 |
-
load_state_dict(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
self.task_name = task_name
|
| 63 |
|
| 64 |
def get_weight_path(self, task_name: str) -> str:
|
| 65 |
if 'scribble' in task_name:
|
| 66 |
task_name = 'scribble'
|
| 67 |
-
return f'{self.model_dir}/{self.
|
| 68 |
|
| 69 |
-
def download_models(self):
|
| 70 |
self.model_dir.mkdir(exist_ok=True, parents=True)
|
| 71 |
-
for name in self.
|
| 72 |
out_path = self.model_dir / name
|
| 73 |
if out_path.exists():
|
| 74 |
continue
|
| 75 |
subprocess.run(
|
| 76 |
-
shlex.split(
|
| 77 |
-
f'wget https://huggingface.co/ckpt/ControlNet/resolve/main/{name} -O {out_path}'
|
| 78 |
-
))
|
| 79 |
|
| 80 |
@torch.inference_mode()
|
| 81 |
def process_canny(self, input_image, prompt, a_prompt, n_prompt,
|
|
|
|
| 28 |
from ldm.models.diffusion.ddim import DDIMSampler
|
| 29 |
from share import *
|
| 30 |
|
| 31 |
+
ORIGINAL_MODEL_NAMES = {
|
| 32 |
+
'canny': 'control_sd15_canny.pth',
|
| 33 |
+
'hough': 'control_sd15_mlsd.pth',
|
| 34 |
+
'hed': 'control_sd15_hed.pth',
|
| 35 |
+
'scribble': 'control_sd15_scribble.pth',
|
| 36 |
+
'pose': 'control_sd15_openpose.pth',
|
| 37 |
+
'seg': 'control_sd15_seg.pth',
|
| 38 |
+
'depth': 'control_sd15_depth.pth',
|
| 39 |
+
'normal': 'control_sd15_normal.pth',
|
| 40 |
+
}
|
| 41 |
+
ORIGINAL_WEIGHT_ROOT = 'https://huggingface.co/ckpt/ControlNet/resolve/main/'
|
| 42 |
+
|
| 43 |
+
LIGHTWEIGHT_MODEL_NAMES = {
|
| 44 |
+
'canny': 'control_canny-fp16.safetensors',
|
| 45 |
+
'hough': 'control_mlsd-fp16.safetensors',
|
| 46 |
+
'hed': 'control_hed-fp16.safetensors',
|
| 47 |
+
'scribble': 'control_scribble-fp16.safetensors',
|
| 48 |
+
'pose': 'control_openpose-fp16.safetensors',
|
| 49 |
+
'seg': 'control_seg-fp16.safetensors',
|
| 50 |
+
'depth': 'control_depth-fp16.safetensors',
|
| 51 |
+
'normal': 'control_normal-fp16.safetensors',
|
| 52 |
+
}
|
| 53 |
+
LIGHTWEIGHT_WEIGHT_ROOT = 'https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/'
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
+
class Model:
|
| 57 |
def __init__(self,
|
| 58 |
model_config_path: str = 'ControlNet/models/cldm_v15.yaml',
|
| 59 |
+
model_dir: str = 'models',
|
| 60 |
+
use_lightweight: bool = True):
|
| 61 |
self.device = torch.device(
|
| 62 |
'cuda:0' if torch.cuda.is_available() else 'cpu')
|
| 63 |
self.model = create_model(model_config_path).to(self.device)
|
|
|
|
| 65 |
self.task_name = ''
|
| 66 |
|
| 67 |
self.model_dir = pathlib.Path(model_dir)
|
| 68 |
+
|
| 69 |
+
self.use_lightweight = use_lightweight
|
| 70 |
+
if use_lightweight:
|
| 71 |
+
self.model_names = LIGHTWEIGHT_MODEL_NAMES
|
| 72 |
+
self.weight_root = LIGHTWEIGHT_WEIGHT_ROOT
|
| 73 |
+
base_model_url = 'https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors'
|
| 74 |
+
self.download_base_model(base_model_url)
|
| 75 |
+
base_model_path = self.model_dir / base_model_url.split('/')[-1]
|
| 76 |
+
self.load_base_model(base_model_path)
|
| 77 |
+
else:
|
| 78 |
+
self.model_names = ORIGINAL_MODEL_NAMES
|
| 79 |
+
self.weight_root = ORIGINAL_WEIGHT_ROOT
|
| 80 |
self.download_models()
|
| 81 |
|
| 82 |
+
def download_base_model(self, base_model_url: str) -> None:
|
| 83 |
+
model_name = base_model_url.split('/')[-1]
|
| 84 |
+
out_path = self.model_dir / model_name
|
| 85 |
+
if out_path.exists():
|
| 86 |
+
return
|
| 87 |
+
subprocess.run(shlex.split(f'wget {base_model_url} -O {out_path}'))
|
| 88 |
+
|
| 89 |
+
def load_base_model(self, model_path: pathlib.Path) -> None:
|
| 90 |
+
self.model.load_state_dict(load_state_dict(model_path,
|
| 91 |
+
location=self.device.type),
|
| 92 |
+
strict=False)
|
| 93 |
+
|
| 94 |
def load_weight(self, task_name: str) -> None:
|
| 95 |
if task_name == self.task_name:
|
| 96 |
return
|
| 97 |
weight_path = self.get_weight_path(task_name)
|
| 98 |
+
if not self.use_lightweight:
|
| 99 |
+
self.model.load_state_dict(
|
| 100 |
+
load_state_dict(weight_path, location=self.device))
|
| 101 |
+
else:
|
| 102 |
+
self.model.control_model.load_state_dict(
|
| 103 |
+
load_state_dict(weight_path, location=self.device.type))
|
| 104 |
self.task_name = task_name
|
| 105 |
|
| 106 |
def get_weight_path(self, task_name: str) -> str:
|
| 107 |
if 'scribble' in task_name:
|
| 108 |
task_name = 'scribble'
|
| 109 |
+
return f'{self.model_dir}/{self.model_names[task_name]}'
|
| 110 |
|
| 111 |
+
def download_models(self) -> None:
|
| 112 |
self.model_dir.mkdir(exist_ok=True, parents=True)
|
| 113 |
+
for name in self.model_names.values():
|
| 114 |
out_path = self.model_dir / name
|
| 115 |
if out_path.exists():
|
| 116 |
continue
|
| 117 |
subprocess.run(
|
| 118 |
+
shlex.split(f'wget {self.weight_root}{name} -O {out_path}'))
|
|
|
|
|
|
|
| 119 |
|
| 120 |
@torch.inference_mode()
|
| 121 |
def process_canny(self, input_image, prompt, a_prompt, n_prompt,
|
requirements.txt
CHANGED
|
@@ -11,6 +11,7 @@ opencv-contrib-python==4.7.0.68
|
|
| 11 |
opencv-python-headless==4.7.0.68
|
| 12 |
prettytable==3.6.0
|
| 13 |
pytorch-lightning==1.9.0
|
|
|
|
| 14 |
timm==0.6.12
|
| 15 |
torch==1.13.1
|
| 16 |
torchvision==0.14.1
|
|
|
|
| 11 |
opencv-python-headless==4.7.0.68
|
| 12 |
prettytable==3.6.0
|
| 13 |
pytorch-lightning==1.9.0
|
| 14 |
+
safetensors==0.2.8
|
| 15 |
timm==0.6.12
|
| 16 |
torch==1.13.1
|
| 17 |
torchvision==0.14.1
|