File size: 765 Bytes
			
			| 903ea30 e0602a9 d2e7f27 ce34d64 d2e7f27 e0602a9 0f74464 903ea30 d2e7f27 903ea30 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 | """Module to load prompt strategies."""
import importlib
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
def load(strategy, tokenizer, cfg, ds_cfg):
    try:
        load_fn = "load"
        if strategy.split(".")[-1].startswith("load_"):
            load_fn = strategy.split(".")[-1]
            strategy = ".".join(strategy.split(".")[:-1])
        mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
        func = getattr(mod, load_fn)
        load_kwargs = {}
        if strategy == "user_defined":
            load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
        return func(tokenizer, cfg, **load_kwargs)
    except Exception:  # pylint: disable=broad-exception-caught
        return None
 |