create a model card with axolotl badge (#624)
Browse files- src/axolotl/train.py +7 -2
src/axolotl/train.py
CHANGED
|
@@ -9,8 +9,7 @@ from pathlib import Path
|
|
| 9 |
from typing import Optional
|
| 10 |
|
| 11 |
import torch
|
| 12 |
-
|
| 13 |
-
# add src to the pythonpath so we don't need to pip install this
|
| 14 |
from datasets import Dataset
|
| 15 |
from optimum.bettertransformer import BetterTransformer
|
| 16 |
|
|
@@ -103,6 +102,9 @@ def train(
|
|
| 103 |
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
|
| 104 |
)
|
| 105 |
|
|
|
|
|
|
|
|
|
|
| 106 |
LOG.info("Starting trainer...")
|
| 107 |
if cfg.group_by_length:
|
| 108 |
LOG.info("hang tight... sorting dataset for group_by_length")
|
|
@@ -138,4 +140,7 @@ def train(
|
|
| 138 |
|
| 139 |
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
| 140 |
|
|
|
|
|
|
|
|
|
|
| 141 |
return model, tokenizer
|
|
|
|
| 9 |
from typing import Optional
|
| 10 |
|
| 11 |
import torch
|
| 12 |
+
import transformers.modelcard
|
|
|
|
| 13 |
from datasets import Dataset
|
| 14 |
from optimum.bettertransformer import BetterTransformer
|
| 15 |
|
|
|
|
| 102 |
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
|
| 103 |
)
|
| 104 |
|
| 105 |
+
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
|
| 106 |
+
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
|
| 107 |
+
|
| 108 |
LOG.info("Starting trainer...")
|
| 109 |
if cfg.group_by_length:
|
| 110 |
LOG.info("hang tight... sorting dataset for group_by_length")
|
|
|
|
| 140 |
|
| 141 |
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
| 142 |
|
| 143 |
+
if not cfg.hub_model_id:
|
| 144 |
+
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
|
| 145 |
+
|
| 146 |
return model, tokenizer
|