Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
6be3e80
1
Parent(s):
20fe0e9
Improve device assignment
Browse files
app.py
CHANGED
|
@@ -20,6 +20,9 @@ def str2bool(v):
|
|
| 20 |
else:
|
| 21 |
raise argparse.ArgumentTypeError("Boolean value expected.")
|
| 22 |
|
|
|
|
|
|
|
|
|
|
| 23 |
import argparse
|
| 24 |
parser = argparse.ArgumentParser()
|
| 25 |
parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
|
|
@@ -68,8 +71,6 @@ base_model_path = model_style_type2base_model_path[args.model_style_type]
|
|
| 68 |
|
| 69 |
# global variable
|
| 70 |
MAX_SEED = np.iinfo(np.int32).max
|
| 71 |
-
device = "cuda" if args.gpu is None else f"cuda:{args.gpu}"
|
| 72 |
-
print(f"Device: {device}")
|
| 73 |
|
| 74 |
global adaface
|
| 75 |
adaface = None
|
|
@@ -113,6 +114,16 @@ def generate_image(image_paths, image_paths2, guidance_scale, perturb_std,
|
|
| 113 |
|
| 114 |
global adaface
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
adaface.to(device)
|
| 117 |
|
| 118 |
if image_paths is None or len(image_paths) == 0:
|
|
|
|
| 20 |
else:
|
| 21 |
raise argparse.ArgumentTypeError("Boolean value expected.")
|
| 22 |
|
| 23 |
+
def is_running_on_spaces():
|
| 24 |
+
return os.getenv("SPACE_ID") is not None
|
| 25 |
+
|
| 26 |
import argparse
|
| 27 |
parser = argparse.ArgumentParser()
|
| 28 |
parser.add_argument("--adaface_encoder_types", type=str, nargs="+", default=["consistentID", "arc2face"],
|
|
|
|
| 71 |
|
| 72 |
# global variable
|
| 73 |
MAX_SEED = np.iinfo(np.int32).max
|
|
|
|
|
|
|
| 74 |
|
| 75 |
global adaface
|
| 76 |
adaface = None
|
|
|
|
| 114 |
|
| 115 |
global adaface
|
| 116 |
|
| 117 |
+
if is_running_on_spaces():
|
| 118 |
+
device = 'cuda:0'
|
| 119 |
+
else:
|
| 120 |
+
if args.gpu is None:
|
| 121 |
+
device = "cuda"
|
| 122 |
+
else:
|
| 123 |
+
device = f"cuda:{args.gpu}"
|
| 124 |
+
|
| 125 |
+
print(f"Device: {device}")
|
| 126 |
+
|
| 127 |
adaface.to(device)
|
| 128 |
|
| 129 |
if image_paths is None or len(image_paths) == 0:
|