Gradio configuration parameters (#1591)
Browse files* Gradio Configuration Settings
* Making various Gradio variables configurable instead of hardcoded
* Remove overwriting behavour of 'default tokens' that breaks tokenizer for llama3
* Fix type of gradio_temperature
* revert un-necessary change and lint
---------
Co-authored-by: Marijn Stollenga <[email protected]>
Co-authored-by: Marijn Stollenga <[email protected]>
Co-authored-by: Wing Lian <[email protected]>
src/axolotl/cli/__init__.py
CHANGED
|
@@ -264,8 +264,8 @@ def do_inference_gradio(
|
|
| 264 |
with torch.no_grad():
|
| 265 |
generation_config = GenerationConfig(
|
| 266 |
repetition_penalty=1.1,
|
| 267 |
-
max_new_tokens=1024,
|
| 268 |
-
temperature=0.9,
|
| 269 |
top_p=0.95,
|
| 270 |
top_k=40,
|
| 271 |
bos_token_id=tokenizer.bos_token_id,
|
|
@@ -300,7 +300,13 @@ def do_inference_gradio(
|
|
| 300 |
outputs="text",
|
| 301 |
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
| 302 |
)
|
| 303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
|
| 305 |
|
| 306 |
def choose_config(path: Path):
|
|
|
|
| 264 |
with torch.no_grad():
|
| 265 |
generation_config = GenerationConfig(
|
| 266 |
repetition_penalty=1.1,
|
| 267 |
+
max_new_tokens=cfg.get("gradio_max_new_tokens", 1024),
|
| 268 |
+
temperature=cfg.get("gradio_temperature", 0.9),
|
| 269 |
top_p=0.95,
|
| 270 |
top_k=40,
|
| 271 |
bos_token_id=tokenizer.bos_token_id,
|
|
|
|
| 300 |
outputs="text",
|
| 301 |
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
| 302 |
)
|
| 303 |
+
|
| 304 |
+
demo.queue().launch(
|
| 305 |
+
show_api=False,
|
| 306 |
+
share=cfg.get("gradio_share", True),
|
| 307 |
+
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
|
| 308 |
+
server_port=cfg.get("gradio_server_port", None),
|
| 309 |
+
)
|
| 310 |
|
| 311 |
|
| 312 |
def choose_config(path: Path):
|
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
CHANGED
|
@@ -409,6 +409,17 @@ class WandbConfig(BaseModel):
|
|
| 409 |
return data
|
| 410 |
|
| 411 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
# pylint: disable=too-many-public-methods,too-many-ancestors
|
| 413 |
class AxolotlInputConfig(
|
| 414 |
ModelInputConfig,
|
|
@@ -419,6 +430,7 @@ class AxolotlInputConfig(
|
|
| 419 |
WandbConfig,
|
| 420 |
MLFlowConfig,
|
| 421 |
LISAConfig,
|
|
|
|
| 422 |
RemappedParameters,
|
| 423 |
DeprecatedParameters,
|
| 424 |
BaseModel,
|
|
|
|
| 409 |
return data
|
| 410 |
|
| 411 |
|
| 412 |
+
class GradioConfig(BaseModel):
|
| 413 |
+
"""Gradio configuration subset"""
|
| 414 |
+
|
| 415 |
+
gradio_title: Optional[str] = None
|
| 416 |
+
gradio_share: Optional[bool] = None
|
| 417 |
+
gradio_server_name: Optional[str] = None
|
| 418 |
+
gradio_server_port: Optional[int] = None
|
| 419 |
+
gradio_max_new_tokens: Optional[int] = None
|
| 420 |
+
gradio_temperature: Optional[float] = None
|
| 421 |
+
|
| 422 |
+
|
| 423 |
# pylint: disable=too-many-public-methods,too-many-ancestors
|
| 424 |
class AxolotlInputConfig(
|
| 425 |
ModelInputConfig,
|
|
|
|
| 430 |
WandbConfig,
|
| 431 |
MLFlowConfig,
|
| 432 |
LISAConfig,
|
| 433 |
+
GradioConfig,
|
| 434 |
RemappedParameters,
|
| 435 |
DeprecatedParameters,
|
| 436 |
BaseModel,
|