Tymec commited on
Commit
afaacd1
1 Parent(s): 18cc46a

Chunked serialization

Browse files
Files changed (3) hide show
  1. Makefile +0 -20
  2. app/cli.py +37 -21
  3. app/utils.py +44 -0
Makefile DELETED
@@ -1,20 +0,0 @@
1
- #!/usr/bin/make -f
2
-
3
- default: install
4
-
5
- install:
6
- @poetry install --only main
7
- @poetry run spacy download en_core_web_sm
8
-
9
- install-dev:
10
- @poetry self add poetry-plugin-export
11
- @poetry install
12
-
13
- requirements:
14
- @poetry export -f requirements.txt --output requirements.txt --without dev
15
- @poetry export -f requirements.txt --output requirements-dev.txt
16
-
17
- lint:
18
- @poetry run pre-commit run --all-files
19
-
20
- .PHONY: install install-dev requirements gradio lint run
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/cli.py CHANGED
@@ -136,28 +136,34 @@ def evaluate(
136
  from app.constants import CACHE_DIR
137
  from app.data import load_data, tokenize
138
  from app.model import evaluate_model
 
139
 
140
  cached_data_path = CACHE_DIR / f"{dataset}_tokenized.pkl"
141
  use_cached_data = False
142
  if cached_data_path.exists():
143
  use_cached_data = click.confirm(f"Found existing tokenized data for '{dataset}'. Use it?", default=True)
144
 
 
 
 
 
145
  if use_cached_data:
146
  click.echo("Loading cached data... ", nl=False)
147
- token_data, label_data = joblib.load(cached_data_path)
 
148
  click.echo(DONE_STR)
149
  else:
150
- click.echo("Loading dataset... ", nl=False)
151
- text_data, label_data = load_data(dataset)
152
- click.echo(DONE_STR)
153
-
154
  click.echo("Tokenizing data... ", nl=False)
155
  token_data = tokenize(text_data, batch_size=batch_size, n_jobs=processes, show_progress=True)
156
- joblib.dump((token_data, label_data), cached_data_path, compress=3)
157
  click.echo(DONE_STR)
158
 
159
- del text_data
160
- gc.collect()
 
 
 
 
 
161
 
162
  click.echo("Loading model... ", nl=False)
163
  model = joblib.load(model_path)
@@ -221,9 +227,9 @@ def evaluate(
221
  help="Overwrite the model file if it already exists",
222
  )
223
  @click.option(
224
- "--skip-cache",
225
  is_flag=True,
226
- help="Ignore cached tokenized data",
227
  )
228
  @click.option(
229
  "--verbose",
@@ -238,7 +244,7 @@ def train(
238
  processes: int,
239
  seed: int,
240
  overwrite: bool,
241
- skip_cache: bool,
242
  verbose: bool,
243
  ) -> None:
244
  """Train the model on the provided dataset"""
@@ -249,6 +255,7 @@ def train(
249
  from app.constants import CACHE_DIR, MODELS_DIR
250
  from app.data import load_data, tokenize
251
  from app.model import train_model
 
252
 
253
  model_path = MODELS_DIR / f"{dataset}_tfidf_ft-{max_features}.pkl"
254
  if model_path.exists() and not overwrite:
@@ -256,25 +263,34 @@ def train(
256
 
257
  cached_data_path = CACHE_DIR / f"{dataset}_tokenized.pkl"
258
  use_cached_data = False
259
- if cached_data_path.exists() and not skip_cache:
260
- use_cached_data = click.confirm(f"Found existing tokenized data for '{dataset}'. Use it?", default=True)
 
 
 
 
 
 
 
 
261
 
262
  if use_cached_data:
263
  click.echo("Loading cached data... ", nl=False)
264
- token_data, label_data = joblib.load(cached_data_path)
 
265
  click.echo(DONE_STR)
266
  else:
267
- click.echo("Loading dataset... ", nl=False)
268
- text_data, label_data = load_data(dataset)
269
- click.echo(DONE_STR)
270
-
271
  click.echo("Tokenizing data... ", nl=False)
272
  token_data = tokenize(text_data, batch_size=batch_size, n_jobs=processes, show_progress=True)
273
- joblib.dump((token_data, label_data), cached_data_path, compress=3)
274
  click.echo(DONE_STR)
275
 
276
- del text_data
277
- gc.collect()
 
 
 
 
 
278
 
279
  click.echo("Training model... ")
280
  model, accuracy = train_model(
 
136
  from app.constants import CACHE_DIR
137
  from app.data import load_data, tokenize
138
  from app.model import evaluate_model
139
+ from app.utils import deserialize, serialize
140
 
141
  cached_data_path = CACHE_DIR / f"{dataset}_tokenized.pkl"
142
  use_cached_data = False
143
  if cached_data_path.exists():
144
  use_cached_data = click.confirm(f"Found existing tokenized data for '{dataset}'. Use it?", default=True)
145
 
146
+ click.echo("Loading dataset... ", nl=False)
147
+ text_data, label_data = load_data(dataset)
148
+ click.echo(DONE_STR)
149
+
150
  if use_cached_data:
151
  click.echo("Loading cached data... ", nl=False)
152
+ # token_data = joblib.load(cached_data_path)
153
+ token_data = deserialize(cached_data_path)
154
  click.echo(DONE_STR)
155
  else:
 
 
 
 
156
  click.echo("Tokenizing data... ", nl=False)
157
  token_data = tokenize(text_data, batch_size=batch_size, n_jobs=processes, show_progress=True)
 
158
  click.echo(DONE_STR)
159
 
160
+ click.echo("Caching tokenized data... ", nl=False)
161
+ # joblib.dump(token_data, cached_data_path, compress=3)
162
+ serialize(token_data, cached_data_path)
163
+ click.echo(DONE_STR)
164
+
165
+ del text_data
166
+ gc.collect()
167
 
168
  click.echo("Loading model... ", nl=False)
169
  model = joblib.load(model_path)
 
227
  help="Overwrite the model file if it already exists",
228
  )
229
  @click.option(
230
+ "--force-cache",
231
  is_flag=True,
232
+ help="Always use the cached tokenized data (if available)",
233
  )
234
  @click.option(
235
  "--verbose",
 
244
  processes: int,
245
  seed: int,
246
  overwrite: bool,
247
+ force_cache: bool,
248
  verbose: bool,
249
  ) -> None:
250
  """Train the model on the provided dataset"""
 
255
  from app.constants import CACHE_DIR, MODELS_DIR
256
  from app.data import load_data, tokenize
257
  from app.model import train_model
258
+ from app.utils import deserialize, serialize
259
 
260
  model_path = MODELS_DIR / f"{dataset}_tfidf_ft-{max_features}.pkl"
261
  if model_path.exists() and not overwrite:
 
263
 
264
  cached_data_path = CACHE_DIR / f"{dataset}_tokenized.pkl"
265
  use_cached_data = False
266
+
267
+ if cached_data_path.exists():
268
+ use_cached_data = force_cache or click.confirm(
269
+ f"Found existing tokenized data for '{dataset}'. Use it?",
270
+ default=True,
271
+ )
272
+
273
+ click.echo("Loading dataset... ", nl=False)
274
+ text_data, label_data = load_data(dataset)
275
+ click.echo(DONE_STR)
276
 
277
  if use_cached_data:
278
  click.echo("Loading cached data... ", nl=False)
279
+ # token_data = joblib.load(cached_data_path)
280
+ token_data = deserialize(cached_data_path)
281
  click.echo(DONE_STR)
282
  else:
 
 
 
 
283
  click.echo("Tokenizing data... ", nl=False)
284
  token_data = tokenize(text_data, batch_size=batch_size, n_jobs=processes, show_progress=True)
 
285
  click.echo(DONE_STR)
286
 
287
+ click.echo("Caching tokenized data... ", nl=False)
288
+ # joblib.dump(token_data, cached_data_path, compress=3)
289
+ serialize(token_data, cached_data_path)
290
+ click.echo(DONE_STR)
291
+
292
+ del text_data
293
+ gc.collect()
294
 
295
  click.echo("Training model... ")
296
  model, accuracy = train_model(
app/utils.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import joblib
6
+ from tqdm import tqdm
7
+
8
+ if TYPE_CHECKING:
9
+ from pathlib import Path
10
+
11
+ __all__ = ["serialize", "deserialize"]
12
+
13
+
14
+ def serialize(data: list[list[str]], path: Path, max_size: int = 400) -> None:
15
+ """Serialize data to a file
16
+
17
+ Args:
18
+ data: The data to serialize
19
+ path: The path to save the serialized data
20
+ max_size: The maximum size a chunk can be (in elements)
21
+ """
22
+ # first file is path, next chunks have ".1", ".2", etc. appended
23
+ for i, chunk in enumerate(tqdm([data[i : i + max_size] for i in range(0, len(data), max_size)])):
24
+ fd = path.with_suffix(f".{i}.pkl" if i else ".pkl")
25
+ with fd.open("wb") as f:
26
+ joblib.dump(chunk, f, compress=3)
27
+
28
+
29
+ def deserialize(path: Path) -> list[list[str]]:
30
+ """Deserialize data from a file
31
+
32
+ Args:
33
+ path: The path to the serialized data
34
+
35
+ Returns:
36
+ The deserialized data
37
+ """
38
+ data = []
39
+ i = 0
40
+ while (fd := path.with_suffix(f".{i}.pkl" if i else ".pkl")).exists():
41
+ with fd.open("rb") as f:
42
+ data.extend(joblib.load(f))
43
+ i += 1
44
+ return data