Spaces:
Runtime error
Runtime error
Add cross validation
Browse files- app/cli.py +24 -9
- app/model.py +34 -5
- notebook.ipynb +11 -3
app/cli.py
CHANGED
|
@@ -90,6 +90,13 @@ def predict(model_path: Path, text: list[str]) -> None:
|
|
| 90 |
show_default=True,
|
| 91 |
type=click.IntRange(1, None),
|
| 92 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
@click.option(
|
| 94 |
"--seed",
|
| 95 |
default=42,
|
|
@@ -97,19 +104,26 @@ def predict(model_path: Path, text: list[str]) -> None:
|
|
| 97 |
show_default=True,
|
| 98 |
type=click.IntRange(-1, None),
|
| 99 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
def train(
|
| 101 |
dataset: Literal["sentiment140", "amazonreviews", "imdb50k"],
|
| 102 |
max_features: int,
|
|
|
|
| 103 |
seed: int,
|
|
|
|
| 104 |
) -> None:
|
| 105 |
"""Train the model on the provided dataset"""
|
| 106 |
import joblib
|
| 107 |
|
| 108 |
from app.constants import MODELS_DIR
|
| 109 |
-
from app.model import create_model, load_data, train_model
|
| 110 |
|
| 111 |
model_path = MODELS_DIR / f"{dataset}_tfidf_ft-{max_features}.pkl"
|
| 112 |
-
if model_path.exists():
|
| 113 |
click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True)
|
| 114 |
|
| 115 |
click.echo("Preprocessing dataset... ", nl=False)
|
|
@@ -122,16 +136,17 @@ def train(
|
|
| 122 |
|
| 123 |
# click.echo("Training model... ", nl=False)
|
| 124 |
click.echo("Training model... ")
|
| 125 |
-
accuracy = train_model(model, text_data, label_data)
|
| 126 |
-
joblib.dump(model, model_path)
|
| 127 |
-
click.echo("Model saved to: ", nl=False)
|
| 128 |
-
click.secho(str(model_path), fg="blue")
|
| 129 |
-
|
| 130 |
click.echo("Model accuracy: ", nl=False)
|
| 131 |
click.secho(f"{accuracy:.2%}", fg="blue")
|
| 132 |
|
| 133 |
-
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
|
| 137 |
def cli_wrapper() -> None:
|
|
|
|
| 90 |
show_default=True,
|
| 91 |
type=click.IntRange(1, None),
|
| 92 |
)
|
| 93 |
+
@click.option(
|
| 94 |
+
"--cv",
|
| 95 |
+
default=5,
|
| 96 |
+
help="Number of cross-validation folds",
|
| 97 |
+
show_default=True,
|
| 98 |
+
type=click.IntRange(1, 50),
|
| 99 |
+
)
|
| 100 |
@click.option(
|
| 101 |
"--seed",
|
| 102 |
default=42,
|
|
|
|
| 104 |
show_default=True,
|
| 105 |
type=click.IntRange(-1, None),
|
| 106 |
)
|
| 107 |
+
@click.option(
|
| 108 |
+
"--force",
|
| 109 |
+
is_flag=True,
|
| 110 |
+
help="Overwrite the model file if it already exists",
|
| 111 |
+
)
|
| 112 |
def train(
|
| 113 |
dataset: Literal["sentiment140", "amazonreviews", "imdb50k"],
|
| 114 |
max_features: int,
|
| 115 |
+
cv: int,
|
| 116 |
seed: int,
|
| 117 |
+
force: bool,
|
| 118 |
) -> None:
|
| 119 |
"""Train the model on the provided dataset"""
|
| 120 |
import joblib
|
| 121 |
|
| 122 |
from app.constants import MODELS_DIR
|
| 123 |
+
from app.model import create_model, evaluate_model, load_data, train_model
|
| 124 |
|
| 125 |
model_path = MODELS_DIR / f"{dataset}_tfidf_ft-{max_features}.pkl"
|
| 126 |
+
if model_path.exists() and not force:
|
| 127 |
click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True)
|
| 128 |
|
| 129 |
click.echo("Preprocessing dataset... ", nl=False)
|
|
|
|
| 136 |
|
| 137 |
# click.echo("Training model... ", nl=False)
|
| 138 |
click.echo("Training model... ")
|
| 139 |
+
accuracy, text_test, text_label = train_model(model, text_data, label_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
click.echo("Model accuracy: ", nl=False)
|
| 141 |
click.secho(f"{accuracy:.2%}", fg="blue")
|
| 142 |
|
| 143 |
+
click.echo("Model saved to: ", nl=False)
|
| 144 |
+
joblib.dump(model, model_path)
|
| 145 |
+
click.secho(str(model_path), fg="blue")
|
| 146 |
+
|
| 147 |
+
click.echo("Evaluating model... ", nl=False)
|
| 148 |
+
acc_mean, acc_std = evaluate_model(model, text_test, text_label, cv=cv)
|
| 149 |
+
click.secho(f"{acc_mean:.2%} ± {acc_std:.2%}", fg="blue")
|
| 150 |
|
| 151 |
|
| 152 |
def cli_wrapper() -> None:
|
app/model.py
CHANGED
|
@@ -13,7 +13,7 @@ from nltk.stem import WordNetLemmatizer
|
|
| 13 |
from sklearn.base import BaseEstimator, TransformerMixin
|
| 14 |
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
|
| 15 |
from sklearn.linear_model import LogisticRegression
|
| 16 |
-
from sklearn.model_selection import train_test_split
|
| 17 |
from sklearn.pipeline import Pipeline
|
| 18 |
|
| 19 |
from app.constants import (
|
|
@@ -28,7 +28,7 @@ from app.constants import (
|
|
| 28 |
URL_REGEX,
|
| 29 |
)
|
| 30 |
|
| 31 |
-
__all__ = ["load_data", "create_model", "train_model"]
|
| 32 |
|
| 33 |
|
| 34 |
class TextCleaner(BaseEstimator, TransformerMixin):
|
|
@@ -293,7 +293,7 @@ def train_model(
|
|
| 293 |
text_data: list[str],
|
| 294 |
label_data: list[int],
|
| 295 |
seed: int = 42,
|
| 296 |
-
) -> float:
|
| 297 |
"""Train the sentiment analysis model.
|
| 298 |
|
| 299 |
Args:
|
|
@@ -303,7 +303,7 @@ def train_model(
|
|
| 303 |
seed: Random seed (None for random seed)
|
| 304 |
|
| 305 |
Returns:
|
| 306 |
-
|
| 307 |
"""
|
| 308 |
text_train, text_test, label_train, label_test = train_test_split(
|
| 309 |
text_data,
|
|
@@ -316,4 +316,33 @@ def train_model(
|
|
| 316 |
warnings.simplefilter("ignore")
|
| 317 |
model.fit(text_train, label_train)
|
| 318 |
|
| 319 |
-
return model.score(text_test, label_test)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
from sklearn.base import BaseEstimator, TransformerMixin
|
| 14 |
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
|
| 15 |
from sklearn.linear_model import LogisticRegression
|
| 16 |
+
from sklearn.model_selection import cross_val_score, train_test_split
|
| 17 |
from sklearn.pipeline import Pipeline
|
| 18 |
|
| 19 |
from app.constants import (
|
|
|
|
| 28 |
URL_REGEX,
|
| 29 |
)
|
| 30 |
|
| 31 |
+
__all__ = ["load_data", "create_model", "train_model", "evaluate_model"]
|
| 32 |
|
| 33 |
|
| 34 |
class TextCleaner(BaseEstimator, TransformerMixin):
|
|
|
|
| 293 |
text_data: list[str],
|
| 294 |
label_data: list[int],
|
| 295 |
seed: int = 42,
|
| 296 |
+
) -> tuple[float, list[str], list[int]]:
|
| 297 |
"""Train the sentiment analysis model.
|
| 298 |
|
| 299 |
Args:
|
|
|
|
| 303 |
seed: Random seed (None for random seed)
|
| 304 |
|
| 305 |
Returns:
|
| 306 |
+
Model accuracy and test data
|
| 307 |
"""
|
| 308 |
text_train, text_test, label_train, label_test = train_test_split(
|
| 309 |
text_data,
|
|
|
|
| 316 |
warnings.simplefilter("ignore")
|
| 317 |
model.fit(text_train, label_train)
|
| 318 |
|
| 319 |
+
return model.score(text_test, label_test), text_test, label_test
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def evaluate_model(
|
| 323 |
+
model: Pipeline,
|
| 324 |
+
text_test: list[str],
|
| 325 |
+
label_test: list[int],
|
| 326 |
+
cv: int = 5,
|
| 327 |
+
) -> tuple[float, float]:
|
| 328 |
+
"""Evaluate the model using cross-validation.
|
| 329 |
+
|
| 330 |
+
Args:
|
| 331 |
+
model: Trained model
|
| 332 |
+
text_test: Text data
|
| 333 |
+
label_test: Label data
|
| 334 |
+
seed: Random seed (None for random seed)
|
| 335 |
+
cv: Number of cross-validation folds
|
| 336 |
+
|
| 337 |
+
Returns:
|
| 338 |
+
Mean accuracy and standard deviation
|
| 339 |
+
"""
|
| 340 |
+
scores = cross_val_score(
|
| 341 |
+
model,
|
| 342 |
+
text_test,
|
| 343 |
+
label_test,
|
| 344 |
+
cv=cv,
|
| 345 |
+
scoring="accuracy",
|
| 346 |
+
n_jobs=-1,
|
| 347 |
+
)
|
| 348 |
+
return scores.mean(), scores.std()
|
notebook.ipynb
CHANGED
|
@@ -668,9 +668,17 @@
|
|
| 668 |
},
|
| 669 |
{
|
| 670 |
"cell_type": "code",
|
| 671 |
-
"execution_count":
|
| 672 |
"metadata": {},
|
| 673 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 674 |
"source": [
|
| 675 |
"# SVM\n",
|
| 676 |
"svm_clf = SVC(random_state=SEED)\n",
|
|
@@ -680,7 +688,7 @@
|
|
| 680 |
" svm_clf,\n",
|
| 681 |
" {\n",
|
| 682 |
" \"C\": np.logspace(-4, 4, 20),\n",
|
| 683 |
-
" \"kernel\": [\"linear\", \"poly\", \"rbf\"
|
| 684 |
" \"degree\": [2, 3, 4],\n",
|
| 685 |
" },\n",
|
| 686 |
")\n",
|
|
|
|
| 668 |
},
|
| 669 |
{
|
| 670 |
"cell_type": "code",
|
| 671 |
+
"execution_count": 24,
|
| 672 |
"metadata": {},
|
| 673 |
+
"outputs": [
|
| 674 |
+
{
|
| 675 |
+
"name": "stdout",
|
| 676 |
+
"output_type": "stream",
|
| 677 |
+
"text": [
|
| 678 |
+
"Fitting 3 folds for each of 10 candidates, totalling 30 fits\n"
|
| 679 |
+
]
|
| 680 |
+
}
|
| 681 |
+
],
|
| 682 |
"source": [
|
| 683 |
"# SVM\n",
|
| 684 |
"svm_clf = SVC(random_state=SEED)\n",
|
|
|
|
| 688 |
" svm_clf,\n",
|
| 689 |
" {\n",
|
| 690 |
" \"C\": np.logspace(-4, 4, 20),\n",
|
| 691 |
+
" \"kernel\": [\"linear\", \"poly\", \"rbf\"],\n",
|
| 692 |
" \"degree\": [2, 3, 4],\n",
|
| 693 |
" },\n",
|
| 694 |
")\n",
|