Tymec commited on
Commit
d8a44ef
1 Parent(s): 9b6760b

Move common logic to function

Browse files
Files changed (1) hide show
  1. app/cli.py +70 -95
app/cli.py CHANGED
@@ -1,15 +1,75 @@
 
 
1
  from __future__ import annotations
2
 
 
 
3
  from pathlib import Path
4
  from typing import Literal
5
 
6
  import click
 
 
 
 
7
 
8
  __all__ = ["cli_wrapper"]
9
 
10
  DONE_STR = click.style("DONE", fg="green")
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  @click.group()
14
  def cli() -> None: ...
15
 
@@ -29,8 +89,6 @@ def cli() -> None: ...
29
  )
30
  def gui(model_path: Path, share: bool) -> None:
31
  """Launch the Gradio GUI"""
32
- import os
33
-
34
  from app.gui import launch_gui
35
 
36
  os.environ["MODEL_PATH"] = model_path.as_posix()
@@ -51,14 +109,12 @@ def predict(model_path: Path, text: list[str]) -> None:
51
 
52
  Note: Piped input takes precedence over the text argument
53
  """
54
- import sys
55
-
56
- import joblib
57
-
58
  from app.model import infer_model
59
 
 
60
  text = " ".join(text).strip()
61
  if not sys.stdin.isatty():
 
62
  piped_text = sys.stdin.read().strip()
63
  text = piped_text or text
64
 
@@ -72,7 +128,6 @@ def predict(model_path: Path, text: list[str]) -> None:
72
 
73
  click.echo("Performing sentiment analysis... ", nl=False)
74
  prediction = infer_model(model, [text])[0]
75
- # prediction = model.predict([text])[0]
76
  if prediction == 0:
77
  sentiment = click.style("NEGATIVE", fg="red")
78
  elif prediction == 1:
@@ -101,7 +156,7 @@ def predict(model_path: Path, text: list[str]) -> None:
101
  default=5,
102
  help="Number of cross-validation folds",
103
  show_default=True,
104
- type=click.IntRange(1, 50),
105
  )
106
  @click.option(
107
  "--token-batch-size",
@@ -136,64 +191,20 @@ def evaluate(
136
  force_cache: bool,
137
  ) -> None:
138
  """Evaluate the model on the the specified dataset"""
139
- import gc
140
-
141
- import joblib
142
- import pandas as pd
143
-
144
- from app.constants import TOKENIZER_CACHE_DIR
145
- from app.data import load_data, tokenize
146
  from app.model import evaluate_model
147
- from app.utils import deserialize, serialize
148
-
149
- token_cache_path = TOKENIZER_CACHE_DIR / f"{dataset}_tokenized.pkl"
150
- label_cache_path = TOKENIZER_CACHE_DIR / f"{dataset}_labels.pkl"
151
- use_cached_data = False
152
 
153
- if token_cache_path.exists():
154
- use_cached_data = force_cache or click.confirm(
155
- f"Found existing tokenized data for '{dataset}'. Use it?",
156
- default=True,
157
- )
158
-
159
- if use_cached_data:
160
- click.echo("Loading cached data... ", nl=False)
161
- token_data = pd.Series(deserialize(token_cache_path))
162
- label_data = joblib.load(label_cache_path)
163
- click.echo(DONE_STR)
164
- else:
165
- click.echo("Loading dataset... ", nl=False)
166
- text_data, label_data = load_data(dataset)
167
- click.echo(DONE_STR)
168
-
169
- click.echo("Tokenizing data... ")
170
- token_data = tokenize(text_data, batch_size=token_batch_size, n_jobs=token_jobs, show_progress=True)
171
- serialize(token_data, token_cache_path, show_progress=True)
172
- joblib.dump(label_data, label_cache_path, compress=3)
173
-
174
- del text_data
175
- gc.collect()
176
-
177
- click.echo("Size of vocabulary: ", nl=False)
178
- vocab = token_data.explode().value_counts()
179
- click.secho(str(len(vocab)), fg="blue")
180
 
181
  click.echo("Loading model... ", nl=False)
182
  model = joblib.load(model_path)
183
  click.echo(DONE_STR)
184
 
185
- if cv == 1:
186
- click.echo("Evaluating model... ", nl=False)
187
- acc = model.score(token_data, label_data)
188
- click.secho(f"{acc:.2%}", fg="blue")
189
- return
190
-
191
  click.echo("Evaluating model... ")
192
  acc_mean, acc_std = evaluate_model(
193
  model,
194
  token_data,
195
  label_data,
196
- folds=cv,
197
  n_jobs=eval_jobs,
198
  )
199
  click.secho(f"{acc_mean:.2%} ± {acc_std:.2%}", fg="blue")
@@ -230,7 +241,7 @@ def evaluate(
230
  default=5,
231
  help="Number of cross-validation folds",
232
  show_default=True,
233
- type=click.IntRange(1, 50),
234
  )
235
  @click.option(
236
  "--token-batch-size",
@@ -281,51 +292,14 @@ def train(
281
  force_cache: bool,
282
  ) -> None:
283
  """Train the model on the provided dataset"""
284
- import gc
285
-
286
- import joblib
287
- import pandas as pd
288
-
289
- from app.constants import MODEL_DIR, TOKENIZER_CACHE_DIR
290
- from app.data import load_data, tokenize
291
  from app.model import train_model
292
- from app.utils import deserialize, serialize
293
 
294
  model_path = MODEL_DIR / f"{dataset}_{vectorizer}_ft{max_features}.pkl"
295
  if model_path.exists() and not overwrite:
296
  click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True)
297
 
298
- token_cache_path = TOKENIZER_CACHE_DIR / f"{dataset}_tokenized.pkl"
299
- label_cache_path = TOKENIZER_CACHE_DIR / f"{dataset}_labels.pkl"
300
- use_cached_data = False
301
-
302
- if token_cache_path.exists():
303
- use_cached_data = force_cache or click.confirm(
304
- f"Found existing tokenized data for '{dataset}'. Use it?",
305
- default=True,
306
- )
307
-
308
- if use_cached_data:
309
- click.echo("Loading cached data... ", nl=False)
310
- token_data = pd.Series(deserialize(token_cache_path))
311
- label_data = joblib.load(label_cache_path)
312
- click.echo(DONE_STR)
313
- else:
314
- click.echo("Loading dataset... ", nl=False)
315
- text_data, label_data = load_data(dataset)
316
- click.echo(DONE_STR)
317
-
318
- click.echo("Tokenizing data... ")
319
- token_data = tokenize(text_data, batch_size=token_batch_size, n_jobs=token_jobs, show_progress=True)
320
- serialize(token_data, token_cache_path, show_progress=True)
321
- joblib.dump(label_data, label_cache_path, compress=3)
322
-
323
- del text_data
324
- gc.collect()
325
-
326
- click.echo("Size of vocabulary: ", nl=False)
327
- vocab = token_data.explode().value_counts()
328
- click.secho(str(len(vocab)), fg="blue")
329
 
330
  click.echo("Training model... ")
331
  model, accuracy = train_model(
@@ -334,10 +308,11 @@ def train(
334
  vectorizer=vectorizer,
335
  max_features=max_features,
336
  min_df=min_df,
337
- folds=cv,
338
  n_jobs=train_jobs,
339
  seed=seed,
340
  )
 
341
  click.echo("Model accuracy: ", nl=False)
342
  click.secho(f"{accuracy:.2%}", fg="blue")
343
 
 
1
+ """CLI using Click."""
2
+
3
  from __future__ import annotations
4
 
5
+ import os
6
+ import sys
7
  from pathlib import Path
8
  from typing import Literal
9
 
10
  import click
11
+ import joblib
12
+ import pandas as pd
13
+
14
+ from app.constants import TOKENIZER_CACHE_DIR
15
 
16
  __all__ = ["cli_wrapper"]
17
 
18
  DONE_STR = click.style("DONE", fg="green")
19
 
20
 
21
+ def _load_dataset(
22
+ dataset: str,
23
+ batch_size: int = 512,
24
+ n_jobs: int = 4,
25
+ force_cache: bool = False,
26
+ ) -> tuple[pd.Series, pd.Series]:
27
+ """Helper function to load and tokenize the dataset or use cached data if available.
28
+
29
+ Args:
30
+ dataset: Name of the dataset
31
+ batch_size: Batch size for tokenization
32
+ n_jobs: Number of parallel jobs
33
+ force_cache: Whether to force using the cached data
34
+
35
+ Returns:
36
+ Tokenized text data and label data
37
+ """
38
+ from app.data import load_data, tokenize
39
+ from app.utils import deserialize, serialize
40
+
41
+ token_cache_path = TOKENIZER_CACHE_DIR / f"{dataset}_tokenized.pkl"
42
+ label_cache_path = TOKENIZER_CACHE_DIR / f"{dataset}_labels.pkl"
43
+ use_cached_data = False
44
+
45
+ if token_cache_path.exists() and label_cache_path.exists():
46
+ use_cached_data = force_cache or click.confirm(
47
+ f"Found existing tokenized data for '{dataset}'. Use it?",
48
+ default=True,
49
+ )
50
+
51
+ if use_cached_data:
52
+ click.echo("Loading cached data... ", nl=False)
53
+ token_data = pd.Series(deserialize(token_cache_path))
54
+ label_data = joblib.load(label_cache_path)
55
+ click.echo(DONE_STR)
56
+ else:
57
+ click.echo("Loading dataset... ", nl=False)
58
+ text_data, label_data = load_data(dataset)
59
+ click.echo(DONE_STR)
60
+
61
+ click.echo("Tokenizing data... ")
62
+ token_data = tokenize(text_data, batch_size=batch_size, n_jobs=n_jobs, show_progress=True)
63
+ serialize(token_data, token_cache_path, show_progress=True)
64
+ joblib.dump(label_data, label_cache_path, compress=3)
65
+
66
+ click.echo("Dataset vocabulary size: ", nl=False)
67
+ vocab = token_data.explode().value_counts()
68
+ click.secho(str(len(vocab)), fg="blue")
69
+
70
+ return token_data, label_data
71
+
72
+
73
  @click.group()
74
  def cli() -> None: ...
75
 
 
89
  )
90
  def gui(model_path: Path, share: bool) -> None:
91
  """Launch the Gradio GUI"""
 
 
92
  from app.gui import launch_gui
93
 
94
  os.environ["MODEL_PATH"] = model_path.as_posix()
 
109
 
110
  Note: Piped input takes precedence over the text argument
111
  """
 
 
 
 
112
  from app.model import infer_model
113
 
114
+ # Combine the text arguments into a single string
115
  text = " ".join(text).strip()
116
  if not sys.stdin.isatty():
117
+ # If there is piped input, read it
118
  piped_text = sys.stdin.read().strip()
119
  text = piped_text or text
120
 
 
128
 
129
  click.echo("Performing sentiment analysis... ", nl=False)
130
  prediction = infer_model(model, [text])[0]
 
131
  if prediction == 0:
132
  sentiment = click.style("NEGATIVE", fg="red")
133
  elif prediction == 1:
 
156
  default=5,
157
  help="Number of cross-validation folds",
158
  show_default=True,
159
+ type=click.IntRange(2, 50),
160
  )
161
  @click.option(
162
  "--token-batch-size",
 
191
  force_cache: bool,
192
  ) -> None:
193
  """Evaluate the model on the the specified dataset"""
 
 
 
 
 
 
 
194
  from app.model import evaluate_model
 
 
 
 
 
195
 
196
+ token_data, label_data = _load_dataset(dataset, token_batch_size, token_jobs, force_cache)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  click.echo("Loading model... ", nl=False)
199
  model = joblib.load(model_path)
200
  click.echo(DONE_STR)
201
 
 
 
 
 
 
 
202
  click.echo("Evaluating model... ")
203
  acc_mean, acc_std = evaluate_model(
204
  model,
205
  token_data,
206
  label_data,
207
+ cv=cv,
208
  n_jobs=eval_jobs,
209
  )
210
  click.secho(f"{acc_mean:.2%} ± {acc_std:.2%}", fg="blue")
 
241
  default=5,
242
  help="Number of cross-validation folds",
243
  show_default=True,
244
+ type=click.IntRange(2, 50),
245
  )
246
  @click.option(
247
  "--token-batch-size",
 
292
  force_cache: bool,
293
  ) -> None:
294
  """Train the model on the provided dataset"""
295
+ from app.constants import MODEL_DIR
 
 
 
 
 
 
296
  from app.model import train_model
 
297
 
298
  model_path = MODEL_DIR / f"{dataset}_{vectorizer}_ft{max_features}.pkl"
299
  if model_path.exists() and not overwrite:
300
  click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True)
301
 
302
+ token_data, label_data = _load_dataset(dataset, token_batch_size, token_jobs, force_cache)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
  click.echo("Training model... ")
305
  model, accuracy = train_model(
 
308
  vectorizer=vectorizer,
309
  max_features=max_features,
310
  min_df=min_df,
311
+ cv=cv,
312
  n_jobs=train_jobs,
313
  seed=seed,
314
  )
315
+
316
  click.echo("Model accuracy: ", nl=False)
317
  click.secho(f"{accuracy:.2%}", fg="blue")
318