Tymec commited on
Commit
cdf1241
1 Parent(s): a092d54

Add evaluate command

Browse files
Files changed (1) hide show
  1. app/cli.py +50 -5
app/cli.py CHANGED
@@ -76,6 +76,51 @@ def predict(model_path: Path, text: list[str]) -> None:
76
  click.echo(sentiment)
77
 
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  @cli.command()
80
  @click.option(
81
  "--dataset",
@@ -120,13 +165,14 @@ def train(
120
  import joblib
121
 
122
  from app.constants import MODELS_DIR
123
- from app.model import create_model, evaluate_model, load_data, train_model
 
124
 
125
  model_path = MODELS_DIR / f"{dataset}_tfidf_ft-{max_features}.pkl"
126
  if model_path.exists() and not force:
127
  click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True)
128
 
129
- click.echo("Preprocessing dataset... ", nl=False)
130
  text_data, label_data = load_data(dataset)
131
  click.echo(DONE_STR)
132
 
@@ -134,9 +180,8 @@ def train(
134
  model = create_model(max_features, seed=None if seed == -1 else seed, verbose=True)
135
  click.echo(DONE_STR)
136
 
137
- # click.echo("Training model... ", nl=False)
138
  click.echo("Training model... ")
139
- accuracy, text_test, text_label = train_model(model, text_data, label_data)
140
  click.echo("Model accuracy: ", nl=False)
141
  click.secho(f"{accuracy:.2%}", fg="blue")
142
 
@@ -145,7 +190,7 @@ def train(
145
  click.secho(str(model_path), fg="blue")
146
 
147
  click.echo("Evaluating model... ", nl=False)
148
- acc_mean, acc_std = evaluate_model(model, text_test, text_label, cv=cv)
149
  click.secho(f"{acc_mean:.2%} ± {acc_std:.2%}", fg="blue")
150
 
151
 
 
76
  click.echo(sentiment)
77
 
78
 
79
+ @cli.command()
80
+ @click.option(
81
+ "--dataset",
82
+ required=True,
83
+ help="Dataset to train the model on",
84
+ type=click.Choice(["sentiment140", "amazonreviews", "imdb50k"]),
85
+ )
86
+ @click.option(
87
+ "--model",
88
+ "model_path",
89
+ required=True,
90
+ help="Path to the trained model",
91
+ type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True, resolve_path=True, path_type=Path),
92
+ )
93
+ @click.option(
94
+ "--cv",
95
+ default=5,
96
+ help="Number of cross-validation folds",
97
+ show_default=True,
98
+ type=click.IntRange(1, 50),
99
+ )
100
+ def evaluate(
101
+ dataset: Literal["sentiment140", "amazonreviews", "imdb50k"],
102
+ model_path: Path,
103
+ cv: int,
104
+ ) -> None:
105
+ """Evaluate the model on the test dataset"""
106
+ import joblib
107
+
108
+ from app.data import load_data
109
+ from app.model import evaluate_model
110
+
111
+ click.echo("Loading dataset... ", nl=False)
112
+ text_data, label_data = load_data(dataset)
113
+ click.echo(DONE_STR)
114
+
115
+ click.echo("Loading model... ", nl=False)
116
+ model = joblib.load(model_path)
117
+ click.echo(DONE_STR)
118
+
119
+ click.echo("Evaluating model... ", nl=False)
120
+ acc_mean, acc_std = evaluate_model(model, text_data, label_data, folds=cv)
121
+ click.secho(f"{acc_mean:.2%} ± {acc_std:.2%}", fg="blue")
122
+
123
+
124
  @cli.command()
125
  @click.option(
126
  "--dataset",
 
165
  import joblib
166
 
167
  from app.constants import MODELS_DIR
168
+ from app.data import load_data
169
+ from app.model import create_model, evaluate_model, train_model
170
 
171
  model_path = MODELS_DIR / f"{dataset}_tfidf_ft-{max_features}.pkl"
172
  if model_path.exists() and not force:
173
  click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True)
174
 
175
+ click.echo("Loading dataset... ", nl=False)
176
  text_data, label_data = load_data(dataset)
177
  click.echo(DONE_STR)
178
 
 
180
  model = create_model(max_features, seed=None if seed == -1 else seed, verbose=True)
181
  click.echo(DONE_STR)
182
 
 
183
  click.echo("Training model... ")
184
+ accuracy = train_model(model, text_data, label_data)
185
  click.echo("Model accuracy: ", nl=False)
186
  click.secho(f"{accuracy:.2%}", fg="blue")
187
 
 
190
  click.secho(str(model_path), fg="blue")
191
 
192
  click.echo("Evaluating model... ", nl=False)
193
+ acc_mean, acc_std = evaluate_model(model, text_data, label_data, folds=cv)
194
  click.secho(f"{acc_mean:.2%} ± {acc_std:.2%}", fg="blue")
195
 
196