Tymec commited on
Commit
8471e78
1 Parent(s): 3854a1f

Ability to change number of parallel jobs for search

Browse files
Files changed (2) hide show
  1. app/cli.py +3 -2
  2. app/model.py +5 -3
app/cli.py CHANGED
@@ -194,8 +194,8 @@ def evaluate(
194
  )
195
  @click.option(
196
  "--processes",
197
- default=8,
198
- help="Number of parallel jobs during tokenization",
199
  show_default=True,
200
  )
201
  @click.option(
@@ -263,6 +263,7 @@ def train(
263
  label_data,
264
  max_features=max_features,
265
  folds=cv,
 
266
  seed=seed,
267
  verbose=verbose,
268
  )
 
194
  )
195
  @click.option(
196
  "--processes",
197
+ default=4,
198
+ help="Number of parallel jobs to run",
199
  show_default=True,
200
  )
201
  @click.option(
 
263
  label_data,
264
  max_features=max_features,
265
  folds=cv,
266
+ n_jobs=processes,
267
  seed=seed,
268
  verbose=verbose,
269
  )
app/model.py CHANGED
@@ -36,6 +36,7 @@ def train_model(
36
  label_data: list[int],
37
  max_features: int,
38
  folds: int = 5,
 
39
  seed: int = 42,
40
  verbose: bool = False,
41
  ) -> tuple[BaseEstimator, float]:
@@ -47,6 +48,7 @@ def train_model(
47
  label_data: Label data
48
  max_features: Maximum number of features
49
  folds: Number of cross-validation folds
 
50
  seed: Random seed (None for random seed)
51
  verbose: Whether to output additional information
52
 
@@ -94,12 +96,12 @@ def train_model(
94
  search = RandomizedSearchCV(
95
  model,
96
  param_distributions,
97
- n_iter=10,
98
  cv=folds,
99
- scoring="accuracy",
100
  random_state=seed,
101
- n_jobs=-1,
102
  verbose=verbose,
 
 
103
  )
104
 
105
  # os.environ["PYTHONWARNINGS"] = "ignore"
 
36
  label_data: list[int],
37
  max_features: int,
38
  folds: int = 5,
39
+ n_jobs: int = 4,
40
  seed: int = 42,
41
  verbose: bool = False,
42
  ) -> tuple[BaseEstimator, float]:
 
48
  label_data: Label data
49
  max_features: Maximum number of features
50
  folds: Number of cross-validation folds
51
+ n_jobs: Number of parallel jobs
52
  seed: Random seed (None for random seed)
53
  verbose: Whether to output additional information
54
 
 
96
  search = RandomizedSearchCV(
97
  model,
98
  param_distributions,
 
99
  cv=folds,
 
100
  random_state=seed,
101
+ n_jobs=n_jobs,
102
  verbose=verbose,
103
+ scoring="accuracy",
104
+ n_iter=10,
105
  )
106
 
107
  # os.environ["PYTHONWARNINGS"] = "ignore"