Spaces:
Running
Running
Add more vectorizers, classifiers and CLI options
Browse files- app/cli.py +46 -33
- app/constants.py +3 -3
- app/model.py +149 -73
app/cli.py
CHANGED
@@ -104,29 +104,36 @@ def predict(model_path: Path, text: list[str]) -> None:
|
|
104 |
type=click.IntRange(1, 50),
|
105 |
)
|
106 |
@click.option(
|
107 |
-
"--batch-size",
|
108 |
default=512,
|
109 |
help="Size of the batches used in tokenization",
|
110 |
show_default=True,
|
111 |
)
|
112 |
@click.option(
|
113 |
-
"--
|
114 |
default=4,
|
115 |
-
help="Number of parallel jobs to run",
|
116 |
show_default=True,
|
117 |
)
|
118 |
@click.option(
|
119 |
-
"--
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
is_flag=True,
|
121 |
-
help="
|
122 |
)
|
123 |
def evaluate(
|
124 |
dataset: Literal["test", "sentiment140", "amazonreviews", "imdb50k"],
|
125 |
model_path: Path,
|
126 |
cv: int,
|
127 |
-
|
128 |
-
|
129 |
-
|
|
|
130 |
) -> None:
|
131 |
"""Evaluate the model on the the specified dataset"""
|
132 |
import gc
|
@@ -141,7 +148,10 @@ def evaluate(
|
|
141 |
cached_data_path = CACHE_DIR / f"{dataset}_tokenized.pkl"
|
142 |
use_cached_data = False
|
143 |
if cached_data_path.exists():
|
144 |
-
use_cached_data = click.confirm(
|
|
|
|
|
|
|
145 |
|
146 |
click.echo("Loading dataset... ", nl=False)
|
147 |
text_data, label_data = load_data(dataset)
|
@@ -149,16 +159,14 @@ def evaluate(
|
|
149 |
|
150 |
if use_cached_data:
|
151 |
click.echo("Loading cached data... ", nl=False)
|
152 |
-
# token_data = joblib.load(cached_data_path)
|
153 |
token_data = deserialize(cached_data_path)
|
154 |
click.echo(DONE_STR)
|
155 |
else:
|
156 |
click.echo("Tokenizing data... ", nl=False)
|
157 |
-
token_data = tokenize(text_data, batch_size=
|
158 |
click.echo(DONE_STR)
|
159 |
|
160 |
click.echo("Caching tokenized data... ", nl=False)
|
161 |
-
# joblib.dump(token_data, cached_data_path, compress=3)
|
162 |
serialize(token_data, cached_data_path)
|
163 |
click.echo(DONE_STR)
|
164 |
|
@@ -175,8 +183,7 @@ def evaluate(
|
|
175 |
token_data,
|
176 |
label_data,
|
177 |
folds=cv,
|
178 |
-
n_jobs=
|
179 |
-
verbose=verbose,
|
180 |
)
|
181 |
click.secho(f"{acc_mean:.2%} ± {acc_std:.2%}", fg="blue")
|
182 |
|
@@ -188,10 +195,16 @@ def evaluate(
|
|
188 |
help="Dataset to train the model on",
|
189 |
type=click.Choice(["sentiment140", "amazonreviews", "imdb50k"]),
|
190 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
@click.option(
|
192 |
"--max-features",
|
193 |
default=20000,
|
194 |
-
help="Maximum number of features",
|
195 |
show_default=True,
|
196 |
type=click.IntRange(1, None),
|
197 |
)
|
@@ -203,15 +216,21 @@ def evaluate(
|
|
203 |
type=click.IntRange(1, 50),
|
204 |
)
|
205 |
@click.option(
|
206 |
-
"--batch-size",
|
207 |
default=512,
|
208 |
help="Size of the batches used in tokenization",
|
209 |
show_default=True,
|
210 |
)
|
211 |
@click.option(
|
212 |
-
"--
|
213 |
default=4,
|
214 |
-
help="Number of parallel jobs to run",
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
show_default=True,
|
216 |
)
|
217 |
@click.option(
|
@@ -231,33 +250,29 @@ def evaluate(
|
|
231 |
is_flag=True,
|
232 |
help="Always use the cached tokenized data (if available)",
|
233 |
)
|
234 |
-
@click.option(
|
235 |
-
"--verbose",
|
236 |
-
is_flag=True,
|
237 |
-
help="Show verbose output",
|
238 |
-
)
|
239 |
def train(
|
240 |
dataset: Literal["sentiment140", "amazonreviews", "imdb50k"],
|
|
|
241 |
max_features: int,
|
242 |
cv: int,
|
243 |
-
|
244 |
-
|
|
|
245 |
seed: int,
|
246 |
overwrite: bool,
|
247 |
force_cache: bool,
|
248 |
-
verbose: bool,
|
249 |
) -> None:
|
250 |
"""Train the model on the provided dataset"""
|
251 |
import gc
|
252 |
|
253 |
import joblib
|
254 |
|
255 |
-
from app.constants import CACHE_DIR,
|
256 |
from app.data import load_data, tokenize
|
257 |
from app.model import train_model
|
258 |
from app.utils import deserialize, serialize
|
259 |
|
260 |
-
model_path =
|
261 |
if model_path.exists() and not overwrite:
|
262 |
click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True)
|
263 |
|
@@ -276,16 +291,14 @@ def train(
|
|
276 |
|
277 |
if use_cached_data:
|
278 |
click.echo("Loading cached data... ", nl=False)
|
279 |
-
# token_data = joblib.load(cached_data_path)
|
280 |
token_data = deserialize(cached_data_path)
|
281 |
click.echo(DONE_STR)
|
282 |
else:
|
283 |
click.echo("Tokenizing data... ", nl=False)
|
284 |
-
token_data = tokenize(text_data, batch_size=
|
285 |
click.echo(DONE_STR)
|
286 |
|
287 |
click.echo("Caching tokenized data... ", nl=False)
|
288 |
-
# joblib.dump(token_data, cached_data_path, compress=3)
|
289 |
serialize(token_data, cached_data_path)
|
290 |
click.echo(DONE_STR)
|
291 |
|
@@ -296,11 +309,11 @@ def train(
|
|
296 |
model, accuracy = train_model(
|
297 |
token_data,
|
298 |
label_data,
|
|
|
299 |
max_features=max_features,
|
300 |
folds=cv,
|
301 |
-
n_jobs=
|
302 |
seed=seed,
|
303 |
-
verbose=verbose,
|
304 |
)
|
305 |
click.echo("Model accuracy: ", nl=False)
|
306 |
click.secho(f"{accuracy:.2%}", fg="blue")
|
|
|
104 |
type=click.IntRange(1, 50),
|
105 |
)
|
106 |
@click.option(
|
107 |
+
"--token-batch-size",
|
108 |
default=512,
|
109 |
help="Size of the batches used in tokenization",
|
110 |
show_default=True,
|
111 |
)
|
112 |
@click.option(
|
113 |
+
"--token-jobs",
|
114 |
default=4,
|
115 |
+
help="Number of parallel jobs to run for tokenization",
|
116 |
show_default=True,
|
117 |
)
|
118 |
@click.option(
|
119 |
+
"--eval-jobs",
|
120 |
+
default=1,
|
121 |
+
help="Number of parallel jobs to run for evaluation",
|
122 |
+
show_default=True,
|
123 |
+
)
|
124 |
+
@click.option(
|
125 |
+
"--force-cache",
|
126 |
is_flag=True,
|
127 |
+
help="Always use the cached tokenized data (if available)",
|
128 |
)
|
129 |
def evaluate(
|
130 |
dataset: Literal["test", "sentiment140", "amazonreviews", "imdb50k"],
|
131 |
model_path: Path,
|
132 |
cv: int,
|
133 |
+
token_batch_size: int,
|
134 |
+
token_jobs: int,
|
135 |
+
eval_jobs: int,
|
136 |
+
force_cache: bool,
|
137 |
) -> None:
|
138 |
"""Evaluate the model on the the specified dataset"""
|
139 |
import gc
|
|
|
148 |
cached_data_path = CACHE_DIR / f"{dataset}_tokenized.pkl"
|
149 |
use_cached_data = False
|
150 |
if cached_data_path.exists():
|
151 |
+
use_cached_data = force_cache or click.confirm(
|
152 |
+
f"Found existing tokenized data for '{dataset}'. Use it?",
|
153 |
+
default=True,
|
154 |
+
)
|
155 |
|
156 |
click.echo("Loading dataset... ", nl=False)
|
157 |
text_data, label_data = load_data(dataset)
|
|
|
159 |
|
160 |
if use_cached_data:
|
161 |
click.echo("Loading cached data... ", nl=False)
|
|
|
162 |
token_data = deserialize(cached_data_path)
|
163 |
click.echo(DONE_STR)
|
164 |
else:
|
165 |
click.echo("Tokenizing data... ", nl=False)
|
166 |
+
token_data = tokenize(text_data, batch_size=token_batch_size, n_jobs=token_jobs, show_progress=True)
|
167 |
click.echo(DONE_STR)
|
168 |
|
169 |
click.echo("Caching tokenized data... ", nl=False)
|
|
|
170 |
serialize(token_data, cached_data_path)
|
171 |
click.echo(DONE_STR)
|
172 |
|
|
|
183 |
token_data,
|
184 |
label_data,
|
185 |
folds=cv,
|
186 |
+
n_jobs=eval_jobs,
|
|
|
187 |
)
|
188 |
click.secho(f"{acc_mean:.2%} ± {acc_std:.2%}", fg="blue")
|
189 |
|
|
|
195 |
help="Dataset to train the model on",
|
196 |
type=click.Choice(["sentiment140", "amazonreviews", "imdb50k"]),
|
197 |
)
|
198 |
+
@click.option(
|
199 |
+
"--vectorizer",
|
200 |
+
default="tfidf",
|
201 |
+
help="Vectorizer to use",
|
202 |
+
type=click.Choice(["tfidf", "count", "hashing"]),
|
203 |
+
)
|
204 |
@click.option(
|
205 |
"--max-features",
|
206 |
default=20000,
|
207 |
+
help="Maximum number of features (should be greater than 2^15 when using hashing vectorizer)",
|
208 |
show_default=True,
|
209 |
type=click.IntRange(1, None),
|
210 |
)
|
|
|
216 |
type=click.IntRange(1, 50),
|
217 |
)
|
218 |
@click.option(
|
219 |
+
"--token-batch-size",
|
220 |
default=512,
|
221 |
help="Size of the batches used in tokenization",
|
222 |
show_default=True,
|
223 |
)
|
224 |
@click.option(
|
225 |
+
"--token-jobs",
|
226 |
default=4,
|
227 |
+
help="Number of parallel jobs to run for tokenization",
|
228 |
+
show_default=True,
|
229 |
+
)
|
230 |
+
@click.option(
|
231 |
+
"--train-jobs",
|
232 |
+
default=1,
|
233 |
+
help="Number of parallel jobs to run for training",
|
234 |
show_default=True,
|
235 |
)
|
236 |
@click.option(
|
|
|
250 |
is_flag=True,
|
251 |
help="Always use the cached tokenized data (if available)",
|
252 |
)
|
|
|
|
|
|
|
|
|
|
|
253 |
def train(
|
254 |
dataset: Literal["sentiment140", "amazonreviews", "imdb50k"],
|
255 |
+
vectorizer: Literal["tfidf", "count", "hashing"],
|
256 |
max_features: int,
|
257 |
cv: int,
|
258 |
+
token_batch_size: int,
|
259 |
+
token_jobs: int,
|
260 |
+
train_jobs: int,
|
261 |
seed: int,
|
262 |
overwrite: bool,
|
263 |
force_cache: bool,
|
|
|
264 |
) -> None:
|
265 |
"""Train the model on the provided dataset"""
|
266 |
import gc
|
267 |
|
268 |
import joblib
|
269 |
|
270 |
+
from app.constants import CACHE_DIR, MODEL_DIR
|
271 |
from app.data import load_data, tokenize
|
272 |
from app.model import train_model
|
273 |
from app.utils import deserialize, serialize
|
274 |
|
275 |
+
model_path = MODEL_DIR / f"{dataset}_{vectorizer}_ft{max_features}.pkl"
|
276 |
if model_path.exists() and not overwrite:
|
277 |
click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True)
|
278 |
|
|
|
291 |
|
292 |
if use_cached_data:
|
293 |
click.echo("Loading cached data... ", nl=False)
|
|
|
294 |
token_data = deserialize(cached_data_path)
|
295 |
click.echo(DONE_STR)
|
296 |
else:
|
297 |
click.echo("Tokenizing data... ", nl=False)
|
298 |
+
token_data = tokenize(text_data, batch_size=token_batch_size, n_jobs=token_jobs, show_progress=True)
|
299 |
click.echo(DONE_STR)
|
300 |
|
301 |
click.echo("Caching tokenized data... ", nl=False)
|
|
|
302 |
serialize(token_data, cached_data_path)
|
303 |
click.echo(DONE_STR)
|
304 |
|
|
|
309 |
model, accuracy = train_model(
|
310 |
token_data,
|
311 |
label_data,
|
312 |
+
vectorizer=vectorizer,
|
313 |
max_features=max_features,
|
314 |
folds=cv,
|
315 |
+
n_jobs=train_jobs,
|
316 |
seed=seed,
|
|
|
317 |
)
|
318 |
click.echo("Model accuracy: ", nl=False)
|
319 |
click.secho(f"{accuracy:.2%}", fg="blue")
|
app/constants.py
CHANGED
@@ -5,12 +5,12 @@ from pathlib import Path
|
|
5 |
|
6 |
CACHE_DIR = Path(os.getenv("CACHE_DIR", ".cache"))
|
7 |
DATA_DIR = Path(os.getenv("DATA_DIR", "data"))
|
8 |
-
|
9 |
|
10 |
SENTIMENT140_PATH = DATA_DIR / "sentiment140.csv"
|
11 |
SENTIMENT140_URL = "https://www.kaggle.com/datasets/kazanova/sentiment140"
|
12 |
|
13 |
-
AMAZONREVIEWS_PATH = DATA_DIR / "amazonreviews.
|
14 |
AMAZONREVIEWS_URL = "https://www.kaggle.com/datasets/bittlingmayer/amazonreviews"
|
15 |
|
16 |
IMDB50K_PATH = DATA_DIR / "imdb50k.csv"
|
@@ -21,4 +21,4 @@ TEST_DATASET_URL = "https://huggingface.co/datasets/Sp1786/multiclass-sentiment-
|
|
21 |
|
22 |
CACHE_DIR.mkdir(exist_ok=True, parents=True)
|
23 |
DATA_DIR.mkdir(exist_ok=True, parents=True)
|
24 |
-
|
|
|
5 |
|
6 |
CACHE_DIR = Path(os.getenv("CACHE_DIR", ".cache"))
|
7 |
DATA_DIR = Path(os.getenv("DATA_DIR", "data"))
|
8 |
+
MODEL_DIR = Path(os.getenv("MODEL_DIR", "models"))
|
9 |
|
10 |
SENTIMENT140_PATH = DATA_DIR / "sentiment140.csv"
|
11 |
SENTIMENT140_URL = "https://www.kaggle.com/datasets/kazanova/sentiment140"
|
12 |
|
13 |
+
AMAZONREVIEWS_PATH = DATA_DIR / "amazonreviews.txt.bz2"
|
14 |
AMAZONREVIEWS_URL = "https://www.kaggle.com/datasets/bittlingmayer/amazonreviews"
|
15 |
|
16 |
IMDB50K_PATH = DATA_DIR / "imdb50k.csv"
|
|
|
21 |
|
22 |
CACHE_DIR.mkdir(exist_ok=True, parents=True)
|
23 |
DATA_DIR.mkdir(exist_ok=True, parents=True)
|
24 |
+
MODEL_DIR.mkdir(exist_ok=True, parents=True)
|
app/model.py
CHANGED
@@ -1,20 +1,23 @@
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
-
import
|
4 |
-
from typing import TYPE_CHECKING
|
5 |
|
6 |
import numpy as np
|
7 |
from joblib import Memory
|
8 |
-
from sklearn.
|
|
|
9 |
from sklearn.linear_model import LogisticRegression
|
10 |
from sklearn.model_selection import RandomizedSearchCV, cross_val_score, train_test_split
|
11 |
from sklearn.pipeline import Pipeline
|
|
|
|
|
12 |
|
13 |
from app.constants import CACHE_DIR
|
14 |
from app.data import tokenize
|
15 |
|
16 |
if TYPE_CHECKING:
|
17 |
-
from sklearn.base import BaseEstimator
|
18 |
|
19 |
__all__ = ["train_model", "evaluate_model", "infer_model"]
|
20 |
|
@@ -31,96 +34,170 @@ def _identity(x: list[str]) -> list[str]:
|
|
31 |
return x
|
32 |
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
def train_model(
|
35 |
-
token_data:
|
36 |
label_data: list[int],
|
|
|
37 |
max_features: int,
|
38 |
folds: int = 5,
|
39 |
n_jobs: int = 4,
|
40 |
seed: int = 42,
|
41 |
-
verbose: bool = False,
|
42 |
) -> tuple[BaseEstimator, float]:
|
43 |
"""Train the sentiment analysis model.
|
44 |
|
45 |
Args:
|
46 |
-
model: Untrained model
|
47 |
token_data: Tokenized text data
|
48 |
label_data: Label data
|
|
|
49 |
max_features: Maximum number of features
|
50 |
folds: Number of cross-validation folds
|
51 |
n_jobs: Number of parallel jobs
|
52 |
seed: Random seed (None for random seed)
|
53 |
-
verbose: Whether to output additional information
|
54 |
|
55 |
Returns:
|
56 |
Trained model and accuracy
|
|
|
|
|
|
|
57 |
"""
|
|
|
|
|
58 |
text_train, text_test, label_train, label_test = train_test_split(
|
59 |
token_data,
|
60 |
label_data,
|
61 |
test_size=0.2,
|
62 |
-
random_state=
|
63 |
)
|
64 |
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
|
117 |
def evaluate_model(
|
118 |
model: BaseEstimator,
|
119 |
-
token_data:
|
120 |
label_data: list[int],
|
121 |
folds: int = 5,
|
122 |
n_jobs: int = 4,
|
123 |
-
verbose: bool = False,
|
124 |
) -> tuple[float, float]:
|
125 |
"""Evaluate the model using cross-validation.
|
126 |
|
@@ -130,22 +207,21 @@ def evaluate_model(
|
|
130 |
label_data: Label data
|
131 |
folds: Number of cross-validation folds
|
132 |
n_jobs: Number of parallel jobs
|
133 |
-
verbose: Whether to output additional information
|
134 |
|
135 |
Returns:
|
136 |
Mean accuracy and standard deviation
|
137 |
"""
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
return scores.mean(), scores.std()
|
150 |
|
151 |
|
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
+
import warnings
|
4 |
+
from typing import TYPE_CHECKING, Literal, Sequence
|
5 |
|
6 |
import numpy as np
|
7 |
from joblib import Memory
|
8 |
+
from sklearn.exceptions import ConvergenceWarning
|
9 |
+
from sklearn.feature_extraction.text import CountVectorizer, HashingVectorizer, TfidfVectorizer
|
10 |
from sklearn.linear_model import LogisticRegression
|
11 |
from sklearn.model_selection import RandomizedSearchCV, cross_val_score, train_test_split
|
12 |
from sklearn.pipeline import Pipeline
|
13 |
+
from sklearn.svm import LinearSVC
|
14 |
+
from tqdm import tqdm
|
15 |
|
16 |
from app.constants import CACHE_DIR
|
17 |
from app.data import tokenize
|
18 |
|
19 |
if TYPE_CHECKING:
|
20 |
+
from sklearn.base import BaseEstimator, TransformerMixin
|
21 |
|
22 |
__all__ = ["train_model", "evaluate_model", "infer_model"]
|
23 |
|
|
|
34 |
return x
|
35 |
|
36 |
|
37 |
+
def _get_vectorizer(
|
38 |
+
name: Literal["tfidf", "count", "hashing"],
|
39 |
+
n_features: int,
|
40 |
+
df: tuple[float, float] = (0.1, 0.9),
|
41 |
+
ngram: tuple[int, int] = (1, 2),
|
42 |
+
) -> TransformerMixin:
|
43 |
+
"""Get the appropriate vectorizer.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
name: Type of vectorizer
|
47 |
+
n_features: Maximum number of features
|
48 |
+
df: Document frequency range [min_df, max_df] (ignored for HashingVectorizer)
|
49 |
+
ngram: N-gram range [min_n, max_n]
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
Vectorizer instance
|
53 |
+
|
54 |
+
Raises:
|
55 |
+
ValueError: If the vectorizer is not recognized
|
56 |
+
"""
|
57 |
+
shared_params = {
|
58 |
+
"ngram_range": ngram,
|
59 |
+
# disable text processing
|
60 |
+
"tokenizer": _identity,
|
61 |
+
"preprocessor": _identity,
|
62 |
+
"lowercase": False,
|
63 |
+
"token_pattern": None,
|
64 |
+
}
|
65 |
+
|
66 |
+
match name:
|
67 |
+
case "tfidf":
|
68 |
+
return TfidfVectorizer(
|
69 |
+
max_features=n_features,
|
70 |
+
min_df=df[0],
|
71 |
+
max_df=df[1],
|
72 |
+
**shared_params,
|
73 |
+
)
|
74 |
+
case "count":
|
75 |
+
return CountVectorizer(
|
76 |
+
max_features=n_features,
|
77 |
+
min_df=df[0],
|
78 |
+
max_df=df[1],
|
79 |
+
**shared_params,
|
80 |
+
)
|
81 |
+
case "hashing":
|
82 |
+
if n_features < 2**15:
|
83 |
+
warnings.warn(
|
84 |
+
"HashingVectorizer may perform poorly with small n_features, default is 2^20.",
|
85 |
+
stacklevel=2,
|
86 |
+
)
|
87 |
+
|
88 |
+
return HashingVectorizer(
|
89 |
+
n_features=n_features,
|
90 |
+
**shared_params,
|
91 |
+
)
|
92 |
+
case _:
|
93 |
+
msg = f"Unknown vectorizer: {name}"
|
94 |
+
raise ValueError(msg)
|
95 |
+
|
96 |
+
|
97 |
def train_model(
|
98 |
+
token_data: Sequence[Sequence[str]],
|
99 |
label_data: list[int],
|
100 |
+
vectorizer: Literal["tfidf", "count", "hashing"],
|
101 |
max_features: int,
|
102 |
folds: int = 5,
|
103 |
n_jobs: int = 4,
|
104 |
seed: int = 42,
|
|
|
105 |
) -> tuple[BaseEstimator, float]:
|
106 |
"""Train the sentiment analysis model.
|
107 |
|
108 |
Args:
|
|
|
109 |
token_data: Tokenized text data
|
110 |
label_data: Label data
|
111 |
+
vectorizer: Which vectorizer to use
|
112 |
max_features: Maximum number of features
|
113 |
folds: Number of cross-validation folds
|
114 |
n_jobs: Number of parallel jobs
|
115 |
seed: Random seed (None for random seed)
|
|
|
116 |
|
117 |
Returns:
|
118 |
Trained model and accuracy
|
119 |
+
|
120 |
+
Raises:
|
121 |
+
ValueError: If the vectorizer is not recognized
|
122 |
"""
|
123 |
+
rs = None if seed == -1 else seed
|
124 |
+
|
125 |
text_train, text_test, label_train, label_test = train_test_split(
|
126 |
token_data,
|
127 |
label_data,
|
128 |
test_size=0.2,
|
129 |
+
random_state=rs,
|
130 |
)
|
131 |
|
132 |
+
vectorizer = _get_vectorizer(vectorizer, max_features)
|
133 |
+
classifiers = [
|
134 |
+
(LogisticRegression(max_iter=1000, random_state=rs), {"C": np.logspace(-4, 4, 20)}),
|
135 |
+
(LinearSVC(max_iter=10000, random_state=rs), {"C": np.logspace(-4, 4, 20)}),
|
136 |
+
# (KNeighborsClassifier(), {"n_neighbors": np.arange(1, 10)}),
|
137 |
+
# (RandomForestClassifier(random_state=rs), {"n_estimators": np.arange(50, 500, 50)}),
|
138 |
+
# (
|
139 |
+
# VotingClassifier(
|
140 |
+
# estimators=[
|
141 |
+
# ("lr", LogisticRegression(max_iter=1000, random_state=rs)),
|
142 |
+
# ("knn", KNeighborsClassifier()),
|
143 |
+
# ("rf", RandomForestClassifier(random_state=rs)),
|
144 |
+
# ],
|
145 |
+
# ),
|
146 |
+
# {
|
147 |
+
# "lr__C": np.logspace(-4, 4, 20),
|
148 |
+
# "knn__n_neighbors": np.arange(1, 10),
|
149 |
+
# "rf__n_estimators": np.arange(50, 500, 50),
|
150 |
+
# },
|
151 |
+
# ),
|
152 |
+
]
|
153 |
+
|
154 |
+
models = []
|
155 |
+
for clf, param_dist in (pbar := tqdm(classifiers, unit="clf")):
|
156 |
+
param_dist = {f"classifier__{k}": v for k, v in param_dist.items()}
|
157 |
+
|
158 |
+
model = Pipeline(
|
159 |
+
[("vectorizer", vectorizer), ("classifier", clf)],
|
160 |
+
memory=Memory(CACHE_DIR, verbose=0),
|
161 |
+
)
|
162 |
+
|
163 |
+
search = RandomizedSearchCV(
|
164 |
+
model,
|
165 |
+
param_dist,
|
166 |
+
cv=folds,
|
167 |
+
random_state=rs,
|
168 |
+
n_jobs=n_jobs,
|
169 |
+
# verbose=2,
|
170 |
+
scoring="accuracy",
|
171 |
+
n_iter=7,
|
172 |
+
)
|
173 |
+
|
174 |
+
pbar.set_description(f"Searching for {clf.__class__.__name__}")
|
175 |
+
|
176 |
+
with warnings.catch_warnings():
|
177 |
+
warnings.filterwarnings("once", category=ConvergenceWarning)
|
178 |
+
warnings.filterwarnings("ignore", category=UserWarning, message="Persisting input arguments took")
|
179 |
+
|
180 |
+
search.fit(text_train, label_train)
|
181 |
+
|
182 |
+
best_model = search.best_estimator_
|
183 |
+
acc = best_model.score(text_test, label_test)
|
184 |
+
models.append((best_model, acc))
|
185 |
+
|
186 |
+
print("Final results:")
|
187 |
+
print("--------------")
|
188 |
+
print("\n".join(f"{model.named_steps['classifier'].__class__.__name__}: {acc:.2%}" for model, acc in models))
|
189 |
+
|
190 |
+
best_model, best_acc = max(models, key=lambda x: x[1])
|
191 |
+
print(f"Settled on {best_model.named_steps['classifier'].__class__.__name__}")
|
192 |
+
return best_model, best_acc
|
193 |
|
194 |
|
195 |
def evaluate_model(
|
196 |
model: BaseEstimator,
|
197 |
+
token_data: Sequence[Sequence[str]],
|
198 |
label_data: list[int],
|
199 |
folds: int = 5,
|
200 |
n_jobs: int = 4,
|
|
|
201 |
) -> tuple[float, float]:
|
202 |
"""Evaluate the model using cross-validation.
|
203 |
|
|
|
207 |
label_data: Label data
|
208 |
folds: Number of cross-validation folds
|
209 |
n_jobs: Number of parallel jobs
|
|
|
210 |
|
211 |
Returns:
|
212 |
Mean accuracy and standard deviation
|
213 |
"""
|
214 |
+
with warnings.catch_warnings():
|
215 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
216 |
+
scores = cross_val_score(
|
217 |
+
model,
|
218 |
+
token_data,
|
219 |
+
label_data,
|
220 |
+
cv=folds,
|
221 |
+
scoring="accuracy",
|
222 |
+
n_jobs=n_jobs,
|
223 |
+
verbose=2,
|
224 |
+
)
|
225 |
return scores.mean(), scores.std()
|
226 |
|
227 |
|