Spaces:
Running
Running
from __future__ import annotations | |
from pathlib import Path | |
from typing import Literal | |
import click | |
__all__ = ["cli_wrapper"] | |
DONE_STR = click.style("DONE", fg="green") | |
def cli() -> None: ... | |
def gui(model_path: Path, share: bool) -> None: | |
"""Launch the Gradio GUI""" | |
from app.gui import launch_gui | |
launch_gui(model_path, share) | |
def predict(model_path: Path, text: list[str]) -> None: | |
"""Perform sentiment analysis on the provided text. | |
Note: Piped input takes precedence over the text argument | |
""" | |
import sys | |
import joblib | |
text = " ".join(text).strip() | |
if not sys.stdin.isatty(): | |
piped_text = sys.stdin.read().strip() | |
text = piped_text or text | |
if not text: | |
msg = "No text provided" | |
raise click.UsageError(msg) | |
click.echo("Loading model... ", nl=False) | |
model = joblib.load(model_path) | |
click.echo(DONE_STR) | |
click.echo("Performing sentiment analysis... ", nl=False) | |
prediction = model.predict([text])[0] | |
if prediction == 0: | |
sentiment = click.style("NEGATIVE", fg="red") | |
elif prediction == 1: | |
sentiment = click.style("POSITIVE", fg="green") | |
else: | |
sentiment = click.style("NEUTRAL", fg="yellow") | |
click.echo(sentiment) | |
def train( | |
dataset: Literal["sentiment140", "amazonreviews", "imdb50k"], | |
max_features: int, | |
cv: int, | |
seed: int, | |
force: bool, | |
) -> None: | |
"""Train the model on the provided dataset""" | |
import joblib | |
from app.constants import MODELS_DIR | |
from app.model import create_model, evaluate_model, load_data, train_model | |
model_path = MODELS_DIR / f"{dataset}_tfidf_ft-{max_features}.pkl" | |
if model_path.exists() and not force: | |
click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True) | |
click.echo("Preprocessing dataset... ", nl=False) | |
text_data, label_data = load_data(dataset) | |
click.echo(DONE_STR) | |
click.echo("Creating model... ", nl=False) | |
model = create_model(max_features, seed=None if seed == -1 else seed, verbose=True) | |
click.echo(DONE_STR) | |
# click.echo("Training model... ", nl=False) | |
click.echo("Training model... ") | |
accuracy, text_test, text_label = train_model(model, text_data, label_data) | |
click.echo("Model accuracy: ", nl=False) | |
click.secho(f"{accuracy:.2%}", fg="blue") | |
click.echo("Model saved to: ", nl=False) | |
joblib.dump(model, model_path) | |
click.secho(str(model_path), fg="blue") | |
click.echo("Evaluating model... ", nl=False) | |
acc_mean, acc_std = evaluate_model(model, text_test, text_label, cv=cv) | |
click.secho(f"{acc_mean:.2%} ± {acc_std:.2%}", fg="blue") | |
def cli_wrapper() -> None: | |
cli(max_content_width=120) | |
if __name__ == "__main__": | |
cli_wrapper() | |