Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
t0-0
commited on
Commit
·
bf7bdee
1
Parent(s):
4c9c17b
Add auto/fp32 option and set auto as the default for submission
Browse files- app.py +1 -1
- src/display/utils.py +6 -0
app.py
CHANGED
|
@@ -582,7 +582,7 @@ with gr.Blocks() as demo_submission:
|
|
| 582 |
label="Precision",
|
| 583 |
choices=[i.value.name for i in Precision],
|
| 584 |
multiselect=False,
|
| 585 |
-
value="
|
| 586 |
)
|
| 587 |
add_special_tokens = gr.Dropdown(
|
| 588 |
label="AddSpecialTokens",
|
|
|
|
| 582 |
label="Precision",
|
| 583 |
choices=[i.value.name for i in Precision],
|
| 584 |
multiselect=False,
|
| 585 |
+
value="auto",
|
| 586 |
)
|
| 587 |
add_special_tokens = gr.Dropdown(
|
| 588 |
label="AddSpecialTokens",
|
src/display/utils.py
CHANGED
|
@@ -129,13 +129,19 @@ class WeightType(Enum):
|
|
| 129 |
|
| 130 |
|
| 131 |
class Precision(Enum):
|
|
|
|
| 132 |
float16 = ModelDetails("float16")
|
|
|
|
| 133 |
bfloat16 = ModelDetails("bfloat16")
|
| 134 |
|
| 135 |
@staticmethod
|
| 136 |
def from_str(precision: str) -> "Precision":
|
|
|
|
|
|
|
| 137 |
if precision in ["torch.float16", "float16"]:
|
| 138 |
return Precision.float16
|
|
|
|
|
|
|
| 139 |
if precision in ["torch.bfloat16", "bfloat16"]:
|
| 140 |
return Precision.bfloat16
|
| 141 |
raise ValueError(f"Unsupported precision type: {precision}")
|
|
|
|
| 129 |
|
| 130 |
|
| 131 |
class Precision(Enum):
|
| 132 |
+
auto = ModelDetails("auto")
|
| 133 |
float16 = ModelDetails("float16")
|
| 134 |
+
float32 = ModelDetails("float32")
|
| 135 |
bfloat16 = ModelDetails("bfloat16")
|
| 136 |
|
| 137 |
@staticmethod
|
| 138 |
def from_str(precision: str) -> "Precision":
|
| 139 |
+
if precision == "auto":
|
| 140 |
+
return Precision.auto
|
| 141 |
if precision in ["torch.float16", "float16"]:
|
| 142 |
return Precision.float16
|
| 143 |
+
if precision in ["torch.float32", "float32"]:
|
| 144 |
+
return Precision.float32
|
| 145 |
if precision in ["torch.bfloat16", "bfloat16"]:
|
| 146 |
return Precision.bfloat16
|
| 147 |
raise ValueError(f"Unsupported precision type: {precision}")
|