Spaces:
Runtime error
Runtime error
eliphatfs
commited on
Commit
·
d56acb1
1
Parent(s):
b58ab91
Add queue and docker cache.
Browse files- Dockerfile +2 -0
- app.py +39 -36
- download_checkpoints.py +19 -0
Dockerfile
CHANGED
|
@@ -35,4 +35,6 @@ WORKDIR $HOME/app
|
|
| 35 |
# Copy the current directory contents into the container at $HOME/app setting the owner to the user
|
| 36 |
COPY --chown=user . $HOME/app
|
| 37 |
|
|
|
|
|
|
|
| 38 |
CMD ["streamlit", "run", "--server.enableXsrfProtection", "false", "app.py"]
|
|
|
|
| 35 |
# Copy the current directory contents into the container at $HOME/app setting the owner to the user
|
| 36 |
COPY --chown=user . $HOME/app
|
| 37 |
|
| 38 |
+
RUN python3 download_checkpoints.py
|
| 39 |
+
|
| 40 |
CMD ["streamlit", "run", "--server.enableXsrfProtection", "false", "app.py"]
|
app.py
CHANGED
|
@@ -3,15 +3,13 @@ import sys
|
|
| 3 |
import numpy
|
| 4 |
import torch
|
| 5 |
import rembg
|
|
|
|
| 6 |
import urllib.request
|
| 7 |
from PIL import Image
|
| 8 |
import streamlit as st
|
| 9 |
import huggingface_hub
|
| 10 |
|
| 11 |
|
| 12 |
-
if 'HF_TOKEN' in os.environ:
|
| 13 |
-
huggingface_hub.login(os.environ['HF_TOKEN'])
|
| 14 |
-
|
| 15 |
img_example_counter = 0
|
| 16 |
iret_base = 'resources/examples'
|
| 17 |
iret = [
|
|
@@ -186,10 +184,12 @@ def check_dependencies():
|
|
| 186 |
|
| 187 |
@st.cache_resource
|
| 188 |
def load_zero123plus_pipeline():
|
|
|
|
|
|
|
| 189 |
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
|
| 190 |
pipeline = DiffusionPipeline.from_pretrained(
|
| 191 |
-
|
| 192 |
-
|
| 193 |
)
|
| 194 |
# Feel free to tune the scheduler
|
| 195 |
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
|
@@ -197,6 +197,7 @@ def load_zero123plus_pipeline():
|
|
| 197 |
)
|
| 198 |
if torch.cuda.is_available():
|
| 199 |
pipeline.to('cuda:0')
|
|
|
|
| 200 |
return pipeline
|
| 201 |
|
| 202 |
|
|
@@ -227,36 +228,38 @@ if sample_got:
|
|
| 227 |
pic = sample_got
|
| 228 |
with results_container:
|
| 229 |
if sample_got or submit:
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
st.
|
| 236 |
-
|
| 237 |
-
prog.progress(0.1, "Preparing Inputs")
|
| 238 |
-
if rem_input_bg:
|
| 239 |
-
with right:
|
| 240 |
-
img = segment_img(img)
|
| 241 |
st.image(img)
|
| 242 |
-
st.caption("Input
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
|
|
|
| 260 |
st.image(result)
|
| 261 |
-
st.caption("Result
|
| 262 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import numpy
|
| 4 |
import torch
|
| 5 |
import rembg
|
| 6 |
+
import threading
|
| 7 |
import urllib.request
|
| 8 |
from PIL import Image
|
| 9 |
import streamlit as st
|
| 10 |
import huggingface_hub
|
| 11 |
|
| 12 |
|
|
|
|
|
|
|
|
|
|
| 13 |
img_example_counter = 0
|
| 14 |
iret_base = 'resources/examples'
|
| 15 |
iret = [
|
|
|
|
| 184 |
|
| 185 |
@st.cache_resource
|
| 186 |
def load_zero123plus_pipeline():
|
| 187 |
+
if 'HF_TOKEN' in os.environ:
|
| 188 |
+
huggingface_hub.login(os.environ['HF_TOKEN'])
|
| 189 |
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
|
| 190 |
pipeline = DiffusionPipeline.from_pretrained(
|
| 191 |
+
"sudo-ai/zero123plus-v1.1", custom_pipeline="sudo-ai/zero123plus-pipeline",
|
| 192 |
+
torch_dtype=torch.float16
|
| 193 |
)
|
| 194 |
# Feel free to tune the scheduler
|
| 195 |
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
|
|
|
| 197 |
)
|
| 198 |
if torch.cuda.is_available():
|
| 199 |
pipeline.to('cuda:0')
|
| 200 |
+
sys.main_lock = threading.Lock()
|
| 201 |
return pipeline
|
| 202 |
|
| 203 |
|
|
|
|
| 228 |
pic = sample_got
|
| 229 |
with results_container:
|
| 230 |
if sample_got or submit:
|
| 231 |
+
prog.progress(0.03, "Waiting in Queue...")
|
| 232 |
+
with sys.main_lock:
|
| 233 |
+
seed = int(seed)
|
| 234 |
+
torch.manual_seed(seed)
|
| 235 |
+
img = Image.open(pic)
|
| 236 |
+
left, right = st.columns(2)
|
| 237 |
+
with left:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
st.image(img)
|
| 239 |
+
st.caption("Input Image")
|
| 240 |
+
prog.progress(0.1, "Preparing Inputs")
|
| 241 |
+
if rem_input_bg:
|
| 242 |
+
with right:
|
| 243 |
+
img = segment_img(img)
|
| 244 |
+
st.image(img)
|
| 245 |
+
st.caption("Input (Background Removed)")
|
| 246 |
+
img = expand2square(img, (127, 127, 127, 0))
|
| 247 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 248 |
+
result = pipeline(
|
| 249 |
+
img,
|
| 250 |
+
num_inference_steps=num_inference_steps,
|
| 251 |
+
guidance_scale=cfg_scale,
|
| 252 |
+
generator=torch.Generator(pipeline.device).manual_seed(seed),
|
| 253 |
+
callback=lambda i, t, latents: prog.progress(0.1 + 0.8 * i / num_inference_steps, "Diffusion Step %d" % i)
|
| 254 |
+
).images[0]
|
| 255 |
+
prog.progress(0.9, "Post Processing")
|
| 256 |
+
left, right = st.columns(2)
|
| 257 |
+
with left:
|
| 258 |
st.image(result)
|
| 259 |
+
st.caption("Result")
|
| 260 |
+
if rem_output_bg:
|
| 261 |
+
result = segment_6imgs(result)
|
| 262 |
+
with right:
|
| 263 |
+
st.image(result)
|
| 264 |
+
st.caption("Result (Background Removed)")
|
| 265 |
+
prog.progress(1.0, "Idle")
|
download_checkpoints.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import urllib.request
|
| 4 |
+
import huggingface_hub
|
| 5 |
+
from diffusers import DiffusionPipeline
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
if 'HF_TOKEN' in os.environ:
|
| 9 |
+
huggingface_hub.login(os.environ['HF_TOKEN'])
|
| 10 |
+
sam_checkpoint = "tmp/sam_vit_h_4b8939.pth"
|
| 11 |
+
os.makedirs('tmp', exist_ok=True)
|
| 12 |
+
urllib.request.urlretrieve(
|
| 13 |
+
"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
|
| 14 |
+
sam_checkpoint
|
| 15 |
+
)
|
| 16 |
+
DiffusionPipeline.from_pretrained(
|
| 17 |
+
"sudo-ai/zero123plus-v1.1", custom_pipeline="sudo-ai/zero123plus-pipeline",
|
| 18 |
+
torch_dtype=torch.float16
|
| 19 |
+
)
|