Spaces:
Runtime error
Runtime error
Cache label data along with tokenized text data
Browse files- app/cli.py +26 -20
- app/utils.py +2 -3
app/cli.py
CHANGED
|
@@ -146,32 +146,35 @@ def evaluate(
|
|
| 146 |
from app.model import evaluate_model
|
| 147 |
from app.utils import deserialize, serialize
|
| 148 |
|
| 149 |
-
|
|
|
|
| 150 |
use_cached_data = False
|
| 151 |
|
| 152 |
-
if
|
| 153 |
use_cached_data = force_cache or click.confirm(
|
| 154 |
f"Found existing tokenized data for '{dataset}'. Use it?",
|
| 155 |
default=True,
|
| 156 |
)
|
| 157 |
|
| 158 |
-
click.echo("Loading dataset... ", nl=False)
|
| 159 |
-
text_data, label_data = load_data(dataset)
|
| 160 |
-
click.echo(DONE_STR)
|
| 161 |
-
|
| 162 |
if use_cached_data:
|
| 163 |
click.echo("Loading cached data... ", nl=False)
|
| 164 |
-
token_data = pd.Series(deserialize(
|
|
|
|
| 165 |
click.echo(DONE_STR)
|
| 166 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
click.echo("Tokenizing data... ")
|
| 168 |
token_data = tokenize(text_data, batch_size=token_batch_size, n_jobs=token_jobs, show_progress=True)
|
| 169 |
|
| 170 |
click.echo("Caching tokenized data... ")
|
| 171 |
-
serialize(token_data,
|
|
|
|
| 172 |
|
| 173 |
-
|
| 174 |
-
|
| 175 |
|
| 176 |
click.echo("Size of vocabulary: ", nl=False)
|
| 177 |
vocab = token_data.explode().value_counts()
|
|
@@ -281,32 +284,35 @@ def train(
|
|
| 281 |
if model_path.exists() and not overwrite:
|
| 282 |
click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True)
|
| 283 |
|
| 284 |
-
|
|
|
|
| 285 |
use_cached_data = False
|
| 286 |
|
| 287 |
-
if
|
| 288 |
use_cached_data = force_cache or click.confirm(
|
| 289 |
f"Found existing tokenized data for '{dataset}'. Use it?",
|
| 290 |
default=True,
|
| 291 |
)
|
| 292 |
|
| 293 |
-
click.echo("Loading dataset... ", nl=False)
|
| 294 |
-
text_data, label_data = load_data(dataset)
|
| 295 |
-
click.echo(DONE_STR)
|
| 296 |
-
|
| 297 |
if use_cached_data:
|
| 298 |
click.echo("Loading cached data... ", nl=False)
|
| 299 |
-
token_data = pd.Series(deserialize(
|
|
|
|
| 300 |
click.echo(DONE_STR)
|
| 301 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
click.echo("Tokenizing data... ")
|
| 303 |
token_data = tokenize(text_data, batch_size=token_batch_size, n_jobs=token_jobs, show_progress=True)
|
| 304 |
|
| 305 |
click.echo("Caching tokenized data... ")
|
| 306 |
-
serialize(token_data,
|
|
|
|
| 307 |
|
| 308 |
-
|
| 309 |
-
|
| 310 |
|
| 311 |
click.echo("Size of vocabulary: ", nl=False)
|
| 312 |
vocab = token_data.explode().value_counts()
|
|
|
|
| 146 |
from app.model import evaluate_model
|
| 147 |
from app.utils import deserialize, serialize
|
| 148 |
|
| 149 |
+
token_cache_path = TOKENIZER_CACHE_PATH / f"{dataset}_tokenized.pkl"
|
| 150 |
+
label_cache_path = TOKENIZER_CACHE_PATH / f"{dataset}_labels.pkl"
|
| 151 |
use_cached_data = False
|
| 152 |
|
| 153 |
+
if token_cache_path.exists():
|
| 154 |
use_cached_data = force_cache or click.confirm(
|
| 155 |
f"Found existing tokenized data for '{dataset}'. Use it?",
|
| 156 |
default=True,
|
| 157 |
)
|
| 158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
if use_cached_data:
|
| 160 |
click.echo("Loading cached data... ", nl=False)
|
| 161 |
+
token_data = pd.Series(deserialize(token_cache_path))
|
| 162 |
+
label_data = joblib.load(label_cache_path)
|
| 163 |
click.echo(DONE_STR)
|
| 164 |
else:
|
| 165 |
+
click.echo("Loading dataset... ", nl=False)
|
| 166 |
+
text_data, label_data = load_data(dataset)
|
| 167 |
+
click.echo(DONE_STR)
|
| 168 |
+
|
| 169 |
click.echo("Tokenizing data... ")
|
| 170 |
token_data = tokenize(text_data, batch_size=token_batch_size, n_jobs=token_jobs, show_progress=True)
|
| 171 |
|
| 172 |
click.echo("Caching tokenized data... ")
|
| 173 |
+
serialize(token_data, token_cache_path, show_progress=True)
|
| 174 |
+
joblib.dump(label_data, label_cache_path, compress=3)
|
| 175 |
|
| 176 |
+
del text_data
|
| 177 |
+
gc.collect()
|
| 178 |
|
| 179 |
click.echo("Size of vocabulary: ", nl=False)
|
| 180 |
vocab = token_data.explode().value_counts()
|
|
|
|
| 284 |
if model_path.exists() and not overwrite:
|
| 285 |
click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True)
|
| 286 |
|
| 287 |
+
token_cache_path = TOKENIZER_CACHE_PATH / f"{dataset}_tokenized.pkl"
|
| 288 |
+
label_cache_path = TOKENIZER_CACHE_PATH / f"{dataset}_labels.pkl"
|
| 289 |
use_cached_data = False
|
| 290 |
|
| 291 |
+
if token_cache_path.exists():
|
| 292 |
use_cached_data = force_cache or click.confirm(
|
| 293 |
f"Found existing tokenized data for '{dataset}'. Use it?",
|
| 294 |
default=True,
|
| 295 |
)
|
| 296 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
if use_cached_data:
|
| 298 |
click.echo("Loading cached data... ", nl=False)
|
| 299 |
+
token_data = pd.Series(deserialize(token_cache_path))
|
| 300 |
+
label_data = joblib.load(label_cache_path)
|
| 301 |
click.echo(DONE_STR)
|
| 302 |
else:
|
| 303 |
+
click.echo("Loading dataset... ", nl=False)
|
| 304 |
+
text_data, label_data = load_data(dataset)
|
| 305 |
+
click.echo(DONE_STR)
|
| 306 |
+
|
| 307 |
click.echo("Tokenizing data... ")
|
| 308 |
token_data = tokenize(text_data, batch_size=token_batch_size, n_jobs=token_jobs, show_progress=True)
|
| 309 |
|
| 310 |
click.echo("Caching tokenized data... ")
|
| 311 |
+
serialize(token_data, token_cache_path, show_progress=True)
|
| 312 |
+
joblib.dump(label_data, label_cache_path, compress=3)
|
| 313 |
|
| 314 |
+
del text_data
|
| 315 |
+
gc.collect()
|
| 316 |
|
| 317 |
click.echo("Size of vocabulary: ", nl=False)
|
| 318 |
vocab = token_data.explode().value_counts()
|
app/utils.py
CHANGED
|
@@ -11,7 +11,7 @@ if TYPE_CHECKING:
|
|
| 11 |
__all__ = ["serialize", "deserialize"]
|
| 12 |
|
| 13 |
|
| 14 |
-
def serialize(data: Sequence[str], path: Path, max_size: int = 100000, show_progress: bool = False) -> None:
|
| 15 |
"""Serialize data to a file
|
| 16 |
|
| 17 |
Args:
|
|
@@ -20,7 +20,6 @@ def serialize(data: Sequence[str], path: Path, max_size: int = 100000, show_prog
|
|
| 20 |
max_size: The maximum size a chunk can be (in elements)
|
| 21 |
show_progress: Whether to show a progress bar
|
| 22 |
"""
|
| 23 |
-
# first file is path, next chunks have ".1", ".2", etc. appended
|
| 24 |
for i, chunk in enumerate(
|
| 25 |
tqdm(
|
| 26 |
[data[i : i + max_size] for i in range(0, len(data), max_size)],
|
|
@@ -33,7 +32,7 @@ def serialize(data: Sequence[str], path: Path, max_size: int = 100000, show_prog
|
|
| 33 |
joblib.dump(chunk, f, compress=3)
|
| 34 |
|
| 35 |
|
| 36 |
-
def deserialize(path: Path) -> Sequence[str]:
|
| 37 |
"""Deserialize data from a file
|
| 38 |
|
| 39 |
Args:
|
|
|
|
| 11 |
__all__ = ["serialize", "deserialize"]
|
| 12 |
|
| 13 |
|
| 14 |
+
def serialize(data: Sequence[str | int], path: Path, max_size: int = 100000, show_progress: bool = False) -> None:
|
| 15 |
"""Serialize data to a file
|
| 16 |
|
| 17 |
Args:
|
|
|
|
| 20 |
max_size: The maximum size a chunk can be (in elements)
|
| 21 |
show_progress: Whether to show a progress bar
|
| 22 |
"""
|
|
|
|
| 23 |
for i, chunk in enumerate(
|
| 24 |
tqdm(
|
| 25 |
[data[i : i + max_size] for i in range(0, len(data), max_size)],
|
|
|
|
| 32 |
joblib.dump(chunk, f, compress=3)
|
| 33 |
|
| 34 |
|
| 35 |
+
def deserialize(path: Path) -> Sequence[str | int]:
|
| 36 |
"""Deserialize data from a file
|
| 37 |
|
| 38 |
Args:
|