Spaces:
Running
Running
Use stopwords from NLTK and download NLTK data
Browse files- app/cli.py +6 -4
- app/model.py +16 -1
app/cli.py
CHANGED
@@ -117,15 +117,17 @@ def train(
|
|
117 |
click.echo(DONE_STR)
|
118 |
|
119 |
click.echo("Creating model... ", nl=False)
|
120 |
-
model = create_model(max_features, seed=None if seed == -1 else seed)
|
121 |
click.echo(DONE_STR)
|
122 |
|
123 |
-
click.echo("Training model... ", nl=False)
|
|
|
124 |
accuracy = train_model(model, text_data, label_data)
|
125 |
joblib.dump(model, model_path)
|
126 |
-
click.echo(
|
|
|
127 |
|
128 |
-
click.echo("Model accuracy: ")
|
129 |
click.secho(f"{accuracy:.2%}", fg="blue")
|
130 |
|
131 |
# TODO: Add hyperparameter options
|
|
|
117 |
click.echo(DONE_STR)
|
118 |
|
119 |
click.echo("Creating model... ", nl=False)
|
120 |
+
model = create_model(max_features, seed=None if seed == -1 else seed, verbose=True)
|
121 |
click.echo(DONE_STR)
|
122 |
|
123 |
+
# click.echo("Training model... ", nl=False)
|
124 |
+
click.echo("Training model... ")
|
125 |
accuracy = train_model(model, text_data, label_data)
|
126 |
joblib.dump(model, model_path)
|
127 |
+
click.echo("Model saved to: ", nl=False)
|
128 |
+
click.secho(str(model_path), fg="blue")
|
129 |
|
130 |
+
click.echo("Model accuracy: ", nl=False)
|
131 |
click.secho(f"{accuracy:.2%}", fg="blue")
|
132 |
|
133 |
# TODO: Add hyperparameter options
|
app/model.py
CHANGED
@@ -5,8 +5,10 @@ import re
|
|
5 |
import warnings
|
6 |
from typing import Literal
|
7 |
|
|
|
8 |
import pandas as pd
|
9 |
from joblib import Memory
|
|
|
10 |
from nltk.stem import WordNetLemmatizer
|
11 |
from sklearn.base import BaseEstimator, TransformerMixin
|
12 |
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
|
@@ -248,28 +250,41 @@ def load_data(dataset: Literal["sentiment140", "amazonreviews", "imdb50k"]) -> t
|
|
248 |
def create_model(
|
249 |
max_features: int,
|
250 |
seed: int | None = None,
|
|
|
251 |
) -> Pipeline:
|
252 |
"""Create a sentiment analysis model.
|
253 |
|
254 |
Args:
|
255 |
max_features: Maximum number of features
|
256 |
seed: Random seed (None for random seed)
|
|
|
257 |
|
258 |
Returns:
|
259 |
Untrained model
|
260 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
return Pipeline(
|
262 |
[
|
263 |
# Text preprocessing
|
264 |
("clean", TextCleaner()),
|
265 |
("lemma", TextLemmatizer()),
|
266 |
# Preprocess (NOTE: Can be replaced with TfidfVectorizer, but left for clarity)
|
267 |
-
(
|
|
|
|
|
|
|
268 |
("tfidf", TfidfTransformer()),
|
269 |
# Classifier
|
270 |
("clf", LogisticRegression(max_iter=1000, random_state=seed)),
|
271 |
],
|
272 |
memory=Memory(CACHE_DIR, verbose=0),
|
|
|
273 |
)
|
274 |
|
275 |
|
|
|
5 |
import warnings
|
6 |
from typing import Literal
|
7 |
|
8 |
+
import nltk
|
9 |
import pandas as pd
|
10 |
from joblib import Memory
|
11 |
+
from nltk.corpus import stopwords
|
12 |
from nltk.stem import WordNetLemmatizer
|
13 |
from sklearn.base import BaseEstimator, TransformerMixin
|
14 |
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
|
|
|
250 |
def create_model(
|
251 |
max_features: int,
|
252 |
seed: int | None = None,
|
253 |
+
verbose: bool = False,
|
254 |
) -> Pipeline:
|
255 |
"""Create a sentiment analysis model.
|
256 |
|
257 |
Args:
|
258 |
max_features: Maximum number of features
|
259 |
seed: Random seed (None for random seed)
|
260 |
+
verbose: Whether to log progress during training
|
261 |
|
262 |
Returns:
|
263 |
Untrained model
|
264 |
"""
|
265 |
+
# Download NLTK data if not already downloaded
|
266 |
+
nltk.download("wordnet", quiet=True)
|
267 |
+
nltk.download("stopwords", quiet=True)
|
268 |
+
|
269 |
+
# Load English stopwords
|
270 |
+
stopwords_en = set(stopwords.words("english"))
|
271 |
+
|
272 |
return Pipeline(
|
273 |
[
|
274 |
# Text preprocessing
|
275 |
("clean", TextCleaner()),
|
276 |
("lemma", TextLemmatizer()),
|
277 |
# Preprocess (NOTE: Can be replaced with TfidfVectorizer, but left for clarity)
|
278 |
+
(
|
279 |
+
"vectorize",
|
280 |
+
CountVectorizer(stop_words=stopwords_en, ngram_range=(1, 2), max_features=max_features),
|
281 |
+
),
|
282 |
("tfidf", TfidfTransformer()),
|
283 |
# Classifier
|
284 |
("clf", LogisticRegression(max_iter=1000, random_state=seed)),
|
285 |
],
|
286 |
memory=Memory(CACHE_DIR, verbose=0),
|
287 |
+
verbose=verbose,
|
288 |
)
|
289 |
|
290 |
|