move list not in list logic to fn
Browse files- scripts/finetune.py +6 -2
scripts/finetune.py
CHANGED
|
@@ -5,7 +5,7 @@ import random
|
|
| 5 |
import signal
|
| 6 |
import sys
|
| 7 |
from pathlib import Path
|
| 8 |
-
from typing import Optional
|
| 9 |
|
| 10 |
import fire
|
| 11 |
import torch
|
|
@@ -117,6 +117,10 @@ def choose_config(path: Path):
|
|
| 117 |
return chosen_file
|
| 118 |
|
| 119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
def train(
|
| 121 |
config: Path = Path("configs/"),
|
| 122 |
prepare_ds_only: bool = False,
|
|
@@ -169,7 +173,7 @@ def train(
|
|
| 169 |
cfg
|
| 170 |
)
|
| 171 |
|
| 172 |
-
if "inference"
|
| 173 |
train_dataset, eval_dataset = load_prepare_datasets(
|
| 174 |
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
| 175 |
)
|
|
|
|
| 5 |
import signal
|
| 6 |
import sys
|
| 7 |
from pathlib import Path
|
| 8 |
+
from typing import Optional, List, Dict, Any, Union
|
| 9 |
|
| 10 |
import fire
|
| 11 |
import torch
|
|
|
|
| 117 |
return chosen_file
|
| 118 |
|
| 119 |
|
| 120 |
+
def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool:
|
| 121 |
+
return not any(el in list2 for el in list1)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
def train(
|
| 125 |
config: Path = Path("configs/"),
|
| 126 |
prepare_ds_only: bool = False,
|
|
|
|
| 173 |
cfg
|
| 174 |
)
|
| 175 |
|
| 176 |
+
if check_not_in(["inference", "shard", "merge_lora"], kwargs): # don't need to load dataset for these
|
| 177 |
train_dataset, eval_dataset = load_prepare_datasets(
|
| 178 |
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
| 179 |
)
|