Tymec's picture
Swap test dataset
183f8cd
raw
history blame
9.65 kB
from __future__ import annotations
from pathlib import Path
from typing import Literal
import click
__all__ = ["cli_wrapper"]
DONE_STR = click.style("DONE", fg="green")
@click.group()
def cli() -> None: ...
@cli.command()
@click.option(
"--model",
"model_path",
required=True,
help="Path to the trained model",
type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True, resolve_path=True, path_type=Path),
)
@click.option(
"--share/--no-share",
default=False,
help="Whether to create a shareable link",
)
def gui(model_path: Path, share: bool) -> None:
"""Launch the Gradio GUI"""
import os
from app.gui import launch_gui
os.environ["MODEL_PATH"] = model_path.as_posix()
launch_gui(share)
@cli.command()
@click.option(
"--model",
"model_path",
required=True,
help="Path to the trained model",
type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True, resolve_path=True, path_type=Path),
)
@click.argument("text", nargs=-1)
def predict(model_path: Path, text: list[str]) -> None:
"""Perform sentiment analysis on the provided text.
Note: Piped input takes precedence over the text argument
"""
import sys
import joblib
from app.model import infer_model
text = " ".join(text).strip()
if not sys.stdin.isatty():
piped_text = sys.stdin.read().strip()
text = piped_text or text
if not text:
msg = "No text provided"
raise click.UsageError(msg)
click.echo("Loading model... ", nl=False)
model = joblib.load(model_path)
click.echo(DONE_STR)
click.echo("Performing sentiment analysis... ", nl=False)
prediction = infer_model(model, [text])[0]
# prediction = model.predict([text])[0]
if prediction == 0:
sentiment = click.style("NEGATIVE", fg="red")
elif prediction == 1:
sentiment = click.style("POSITIVE", fg="green")
else:
sentiment = click.style("NEUTRAL", fg="yellow")
click.echo(sentiment)
@cli.command()
@click.option(
"--dataset",
default="test",
help="Dataset to evaluate the model on",
type=click.Choice(["test", "sentiment140", "amazonreviews", "imdb50k"]),
)
@click.option(
"--model",
"model_path",
required=True,
help="Path to the trained model",
type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True, resolve_path=True, path_type=Path),
)
@click.option(
"--cv",
default=5,
help="Number of cross-validation folds",
show_default=True,
type=click.IntRange(1, 50),
)
@click.option(
"--token-batch-size",
default=512,
help="Size of the batches used in tokenization",
show_default=True,
)
@click.option(
"--token-jobs",
default=4,
help="Number of parallel jobs to run for tokenization",
show_default=True,
)
@click.option(
"--eval-jobs",
default=1,
help="Number of parallel jobs to run for evaluation",
show_default=True,
)
@click.option(
"--force-cache",
is_flag=True,
help="Always use the cached tokenized data (if available)",
)
def evaluate(
dataset: Literal["test", "sentiment140", "amazonreviews", "imdb50k"],
model_path: Path,
cv: int,
token_batch_size: int,
token_jobs: int,
eval_jobs: int,
force_cache: bool,
) -> None:
"""Evaluate the model on the the specified dataset"""
import gc
import joblib
import pandas as pd
from app.constants import TOKENIZER_CACHE_DIR
from app.data import load_data, tokenize
from app.model import evaluate_model
from app.utils import deserialize, serialize
token_cache_path = TOKENIZER_CACHE_DIR / f"{dataset}_tokenized.pkl"
label_cache_path = TOKENIZER_CACHE_DIR / f"{dataset}_labels.pkl"
use_cached_data = False
if token_cache_path.exists():
use_cached_data = force_cache or click.confirm(
f"Found existing tokenized data for '{dataset}'. Use it?",
default=True,
)
if use_cached_data:
click.echo("Loading cached data... ", nl=False)
token_data = pd.Series(deserialize(token_cache_path))
label_data = joblib.load(label_cache_path)
click.echo(DONE_STR)
else:
click.echo("Loading dataset... ", nl=False)
text_data, label_data = load_data(dataset)
click.echo(DONE_STR)
click.echo("Tokenizing data... ")
token_data = tokenize(text_data, batch_size=token_batch_size, n_jobs=token_jobs, show_progress=True)
serialize(token_data, token_cache_path, show_progress=True)
joblib.dump(label_data, label_cache_path, compress=3)
del text_data
gc.collect()
click.echo("Size of vocabulary: ", nl=False)
vocab = token_data.explode().value_counts()
click.secho(str(len(vocab)), fg="blue")
click.echo("Loading model... ", nl=False)
model = joblib.load(model_path)
click.echo(DONE_STR)
if cv == 1:
click.echo("Evaluating model... ", nl=False)
acc = model.score(token_data, label_data)
click.secho(f"{acc:.2%}", fg="blue")
return
click.echo("Evaluating model... ")
acc_mean, acc_std = evaluate_model(
model,
token_data,
label_data,
folds=cv,
n_jobs=eval_jobs,
)
click.secho(f"{acc_mean:.2%} ± {acc_std:.2%}", fg="blue")
@cli.command()
@click.option(
"--dataset",
required=True,
help="Dataset to train the model on",
type=click.Choice(["sentiment140", "amazonreviews", "imdb50k"]),
)
@click.option(
"--vectorizer",
default="tfidf",
help="Vectorizer to use",
type=click.Choice(["tfidf", "count", "hashing"]),
)
@click.option(
"--max-features",
default=20000,
help="Maximum number of features (should be greater than 2^15 when using hashing vectorizer)",
show_default=True,
type=click.IntRange(1, None),
)
@click.option(
"--min-df",
default=5,
help="Minimum document frequency for the features (ignored for hashing)",
show_default=True,
)
@click.option(
"--cv",
default=5,
help="Number of cross-validation folds",
show_default=True,
type=click.IntRange(1, 50),
)
@click.option(
"--token-batch-size",
default=512,
help="Size of the batches used in tokenization",
show_default=True,
)
@click.option(
"--token-jobs",
default=4,
help="Number of parallel jobs to run for tokenization",
show_default=True,
)
@click.option(
"--train-jobs",
default=1,
help="Number of parallel jobs to run for training",
show_default=True,
)
@click.option(
"--seed",
default=42,
help="Random seed (-1 for random seed)",
show_default=True,
type=click.IntRange(-1, None),
)
@click.option(
"--overwrite",
is_flag=True,
help="Overwrite the model file if it already exists",
)
@click.option(
"--force-cache",
is_flag=True,
help="Always use the cached tokenized data (if available)",
)
def train(
dataset: Literal["sentiment140", "amazonreviews", "imdb50k"],
vectorizer: Literal["tfidf", "count", "hashing"],
max_features: int,
min_df: int,
cv: int,
token_batch_size: int,
token_jobs: int,
train_jobs: int,
seed: int,
overwrite: bool,
force_cache: bool,
) -> None:
"""Train the model on the provided dataset"""
import gc
import joblib
import pandas as pd
from app.constants import MODEL_DIR, TOKENIZER_CACHE_DIR
from app.data import load_data, tokenize
from app.model import train_model
from app.utils import deserialize, serialize
model_path = MODEL_DIR / f"{dataset}_{vectorizer}_ft{max_features}.pkl"
if model_path.exists() and not overwrite:
click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True)
token_cache_path = TOKENIZER_CACHE_DIR / f"{dataset}_tokenized.pkl"
label_cache_path = TOKENIZER_CACHE_DIR / f"{dataset}_labels.pkl"
use_cached_data = False
if token_cache_path.exists():
use_cached_data = force_cache or click.confirm(
f"Found existing tokenized data for '{dataset}'. Use it?",
default=True,
)
if use_cached_data:
click.echo("Loading cached data... ", nl=False)
token_data = pd.Series(deserialize(token_cache_path))
label_data = joblib.load(label_cache_path)
click.echo(DONE_STR)
else:
click.echo("Loading dataset... ", nl=False)
text_data, label_data = load_data(dataset)
click.echo(DONE_STR)
click.echo("Tokenizing data... ")
token_data = tokenize(text_data, batch_size=token_batch_size, n_jobs=token_jobs, show_progress=True)
serialize(token_data, token_cache_path, show_progress=True)
joblib.dump(label_data, label_cache_path, compress=3)
del text_data
gc.collect()
click.echo("Size of vocabulary: ", nl=False)
vocab = token_data.explode().value_counts()
click.secho(str(len(vocab)), fg="blue")
click.echo("Training model... ")
model, accuracy = train_model(
token_data,
label_data,
vectorizer=vectorizer,
max_features=max_features,
min_df=min_df,
folds=cv,
n_jobs=train_jobs,
seed=seed,
)
click.echo("Model accuracy: ", nl=False)
click.secho(f"{accuracy:.2%}", fg="blue")
click.echo("Model saved to: ", nl=False)
joblib.dump(model, model_path, compress=3)
click.secho(str(model_path), fg="blue")
def cli_wrapper() -> None:
cli(max_content_width=120)
if __name__ == "__main__":
cli_wrapper()