Tymec commited on
Commit
af84d9b
1 Parent(s): e3095cd

Cache label data along with tokenized text data

Browse files
Files changed (2) hide show
  1. app/cli.py +26 -20
  2. app/utils.py +2 -3
app/cli.py CHANGED
@@ -146,32 +146,35 @@ def evaluate(
146
  from app.model import evaluate_model
147
  from app.utils import deserialize, serialize
148
 
149
- cached_data_path = TOKENIZER_CACHE_PATH / f"{dataset}_tokenized.pkl"
 
150
  use_cached_data = False
151
 
152
- if cached_data_path.exists():
153
  use_cached_data = force_cache or click.confirm(
154
  f"Found existing tokenized data for '{dataset}'. Use it?",
155
  default=True,
156
  )
157
 
158
- click.echo("Loading dataset... ", nl=False)
159
- text_data, label_data = load_data(dataset)
160
- click.echo(DONE_STR)
161
-
162
  if use_cached_data:
163
  click.echo("Loading cached data... ", nl=False)
164
- token_data = pd.Series(deserialize(cached_data_path))
 
165
  click.echo(DONE_STR)
166
  else:
 
 
 
 
167
  click.echo("Tokenizing data... ")
168
  token_data = tokenize(text_data, batch_size=token_batch_size, n_jobs=token_jobs, show_progress=True)
169
 
170
  click.echo("Caching tokenized data... ")
171
- serialize(token_data, cached_data_path, show_progress=True)
 
172
 
173
- del text_data
174
- gc.collect()
175
 
176
  click.echo("Size of vocabulary: ", nl=False)
177
  vocab = token_data.explode().value_counts()
@@ -281,32 +284,35 @@ def train(
281
  if model_path.exists() and not overwrite:
282
  click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True)
283
 
284
- cached_data_path = TOKENIZER_CACHE_PATH / f"{dataset}_tokenized.pkl"
 
285
  use_cached_data = False
286
 
287
- if cached_data_path.exists():
288
  use_cached_data = force_cache or click.confirm(
289
  f"Found existing tokenized data for '{dataset}'. Use it?",
290
  default=True,
291
  )
292
 
293
- click.echo("Loading dataset... ", nl=False)
294
- text_data, label_data = load_data(dataset)
295
- click.echo(DONE_STR)
296
-
297
  if use_cached_data:
298
  click.echo("Loading cached data... ", nl=False)
299
- token_data = pd.Series(deserialize(cached_data_path))
 
300
  click.echo(DONE_STR)
301
  else:
 
 
 
 
302
  click.echo("Tokenizing data... ")
303
  token_data = tokenize(text_data, batch_size=token_batch_size, n_jobs=token_jobs, show_progress=True)
304
 
305
  click.echo("Caching tokenized data... ")
306
- serialize(token_data, cached_data_path, show_progress=True)
 
307
 
308
- del text_data
309
- gc.collect()
310
 
311
  click.echo("Size of vocabulary: ", nl=False)
312
  vocab = token_data.explode().value_counts()
 
146
  from app.model import evaluate_model
147
  from app.utils import deserialize, serialize
148
 
149
+ token_cache_path = TOKENIZER_CACHE_PATH / f"{dataset}_tokenized.pkl"
150
+ label_cache_path = TOKENIZER_CACHE_PATH / f"{dataset}_labels.pkl"
151
  use_cached_data = False
152
 
153
+ if token_cache_path.exists():
154
  use_cached_data = force_cache or click.confirm(
155
  f"Found existing tokenized data for '{dataset}'. Use it?",
156
  default=True,
157
  )
158
 
 
 
 
 
159
  if use_cached_data:
160
  click.echo("Loading cached data... ", nl=False)
161
+ token_data = pd.Series(deserialize(token_cache_path))
162
+ label_data = joblib.load(label_cache_path)
163
  click.echo(DONE_STR)
164
  else:
165
+ click.echo("Loading dataset... ", nl=False)
166
+ text_data, label_data = load_data(dataset)
167
+ click.echo(DONE_STR)
168
+
169
  click.echo("Tokenizing data... ")
170
  token_data = tokenize(text_data, batch_size=token_batch_size, n_jobs=token_jobs, show_progress=True)
171
 
172
  click.echo("Caching tokenized data... ")
173
+ serialize(token_data, token_cache_path, show_progress=True)
174
+ joblib.dump(label_data, label_cache_path, compress=3)
175
 
176
+ del text_data
177
+ gc.collect()
178
 
179
  click.echo("Size of vocabulary: ", nl=False)
180
  vocab = token_data.explode().value_counts()
 
284
  if model_path.exists() and not overwrite:
285
  click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True)
286
 
287
+ token_cache_path = TOKENIZER_CACHE_PATH / f"{dataset}_tokenized.pkl"
288
+ label_cache_path = TOKENIZER_CACHE_PATH / f"{dataset}_labels.pkl"
289
  use_cached_data = False
290
 
291
+ if token_cache_path.exists():
292
  use_cached_data = force_cache or click.confirm(
293
  f"Found existing tokenized data for '{dataset}'. Use it?",
294
  default=True,
295
  )
296
 
 
 
 
 
297
  if use_cached_data:
298
  click.echo("Loading cached data... ", nl=False)
299
+ token_data = pd.Series(deserialize(token_cache_path))
300
+ label_data = joblib.load(label_cache_path)
301
  click.echo(DONE_STR)
302
  else:
303
+ click.echo("Loading dataset... ", nl=False)
304
+ text_data, label_data = load_data(dataset)
305
+ click.echo(DONE_STR)
306
+
307
  click.echo("Tokenizing data... ")
308
  token_data = tokenize(text_data, batch_size=token_batch_size, n_jobs=token_jobs, show_progress=True)
309
 
310
  click.echo("Caching tokenized data... ")
311
+ serialize(token_data, token_cache_path, show_progress=True)
312
+ joblib.dump(label_data, label_cache_path, compress=3)
313
 
314
+ del text_data
315
+ gc.collect()
316
 
317
  click.echo("Size of vocabulary: ", nl=False)
318
  vocab = token_data.explode().value_counts()
app/utils.py CHANGED
@@ -11,7 +11,7 @@ if TYPE_CHECKING:
11
  __all__ = ["serialize", "deserialize"]
12
 
13
 
14
- def serialize(data: Sequence[str], path: Path, max_size: int = 100000, show_progress: bool = False) -> None:
15
  """Serialize data to a file
16
 
17
  Args:
@@ -20,7 +20,6 @@ def serialize(data: Sequence[str], path: Path, max_size: int = 100000, show_prog
20
  max_size: The maximum size a chunk can be (in elements)
21
  show_progress: Whether to show a progress bar
22
  """
23
- # first file is path, next chunks have ".1", ".2", etc. appended
24
  for i, chunk in enumerate(
25
  tqdm(
26
  [data[i : i + max_size] for i in range(0, len(data), max_size)],
@@ -33,7 +32,7 @@ def serialize(data: Sequence[str], path: Path, max_size: int = 100000, show_prog
33
  joblib.dump(chunk, f, compress=3)
34
 
35
 
36
- def deserialize(path: Path) -> Sequence[str]:
37
  """Deserialize data from a file
38
 
39
  Args:
 
11
  __all__ = ["serialize", "deserialize"]
12
 
13
 
14
+ def serialize(data: Sequence[str | int], path: Path, max_size: int = 100000, show_progress: bool = False) -> None:
15
  """Serialize data to a file
16
 
17
  Args:
 
20
  max_size: The maximum size a chunk can be (in elements)
21
  show_progress: Whether to show a progress bar
22
  """
 
23
  for i, chunk in enumerate(
24
  tqdm(
25
  [data[i : i + max_size] for i in range(0, len(data), max_size)],
 
32
  joblib.dump(chunk, f, compress=3)
33
 
34
 
35
+ def deserialize(path: Path) -> Sequence[str | int]:
36
  """Deserialize data from a file
37
 
38
  Args: