Napuh
commited on
Warn users to login to HuggingFace (#645)
Browse files* added warning if user is not logged in HF
* updated doc to suggest logging in to HF
- README.md +5 -0
- scripts/finetune.py +2 -0
- src/axolotl/cli/__init__.py +15 -0
- src/axolotl/cli/train.py +2 -0
README.md
CHANGED
|
@@ -124,6 +124,11 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
|
| 124 |
pip3 install packaging
|
| 125 |
pip3 install -e '.[flash-attn,deepspeed]'
|
| 126 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
- LambdaLabs
|
| 129 |
<details>
|
|
|
|
| 124 |
pip3 install packaging
|
| 125 |
pip3 install -e '.[flash-attn,deepspeed]'
|
| 126 |
```
|
| 127 |
+
4. (Optional) Login to Huggingface to use gated models/datasets.
|
| 128 |
+
```bash
|
| 129 |
+
huggingface-cli login
|
| 130 |
+
```
|
| 131 |
+
Get the token at huggingface.co/settings/tokens
|
| 132 |
|
| 133 |
- LambdaLabs
|
| 134 |
<details>
|
scripts/finetune.py
CHANGED
|
@@ -7,6 +7,7 @@ import transformers
|
|
| 7 |
|
| 8 |
from axolotl.cli import (
|
| 9 |
check_accelerate_default_config,
|
|
|
|
| 10 |
do_inference,
|
| 11 |
do_merge_lora,
|
| 12 |
load_cfg,
|
|
@@ -31,6 +32,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|
| 31 |
)
|
| 32 |
parsed_cfg = load_cfg(config, **kwargs)
|
| 33 |
check_accelerate_default_config()
|
|
|
|
| 34 |
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
| 35 |
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
| 36 |
return_remaining_strings=True
|
|
|
|
| 7 |
|
| 8 |
from axolotl.cli import (
|
| 9 |
check_accelerate_default_config,
|
| 10 |
+
check_user_token,
|
| 11 |
do_inference,
|
| 12 |
do_merge_lora,
|
| 13 |
load_cfg,
|
|
|
|
| 32 |
)
|
| 33 |
parsed_cfg = load_cfg(config, **kwargs)
|
| 34 |
check_accelerate_default_config()
|
| 35 |
+
check_user_token()
|
| 36 |
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
| 37 |
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
| 38 |
return_remaining_strings=True
|
src/axolotl/cli/__init__.py
CHANGED
|
@@ -14,6 +14,8 @@ import yaml
|
|
| 14 |
# add src to the pythonpath so we don't need to pip install this
|
| 15 |
from accelerate.commands.config import config_args
|
| 16 |
from art import text2art
|
|
|
|
|
|
|
| 17 |
from transformers import GenerationConfig, TextStreamer
|
| 18 |
|
| 19 |
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
|
@@ -247,3 +249,16 @@ def check_accelerate_default_config():
|
|
| 247 |
LOG.warning(
|
| 248 |
f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
|
| 249 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
# add src to the pythonpath so we don't need to pip install this
|
| 15 |
from accelerate.commands.config import config_args
|
| 16 |
from art import text2art
|
| 17 |
+
from huggingface_hub import HfApi
|
| 18 |
+
from huggingface_hub.utils import LocalTokenNotFoundError
|
| 19 |
from transformers import GenerationConfig, TextStreamer
|
| 20 |
|
| 21 |
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
|
|
|
| 249 |
LOG.warning(
|
| 250 |
f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
|
| 251 |
)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def check_user_token():
|
| 255 |
+
# Verify if token is valid
|
| 256 |
+
api = HfApi()
|
| 257 |
+
try:
|
| 258 |
+
user_info = api.whoami()
|
| 259 |
+
return bool(user_info)
|
| 260 |
+
except LocalTokenNotFoundError:
|
| 261 |
+
LOG.warning(
|
| 262 |
+
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
|
| 263 |
+
)
|
| 264 |
+
return False
|
src/axolotl/cli/train.py
CHANGED
|
@@ -8,6 +8,7 @@ import transformers
|
|
| 8 |
|
| 9 |
from axolotl.cli import (
|
| 10 |
check_accelerate_default_config,
|
|
|
|
| 11 |
load_cfg,
|
| 12 |
load_datasets,
|
| 13 |
print_axolotl_text_art,
|
|
@@ -21,6 +22,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|
| 21 |
print_axolotl_text_art()
|
| 22 |
parsed_cfg = load_cfg(config, **kwargs)
|
| 23 |
check_accelerate_default_config()
|
|
|
|
| 24 |
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
| 25 |
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
| 26 |
return_remaining_strings=True
|
|
|
|
| 8 |
|
| 9 |
from axolotl.cli import (
|
| 10 |
check_accelerate_default_config,
|
| 11 |
+
check_user_token,
|
| 12 |
load_cfg,
|
| 13 |
load_datasets,
|
| 14 |
print_axolotl_text_art,
|
|
|
|
| 22 |
print_axolotl_text_art()
|
| 23 |
parsed_cfg = load_cfg(config, **kwargs)
|
| 24 |
check_accelerate_default_config()
|
| 25 |
+
check_user_token()
|
| 26 |
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
| 27 |
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
| 28 |
return_remaining_strings=True
|