Spaces:
Running
Running
Tokenization rework
Browse files- app/cli.py +97 -26
- app/data.py +115 -6
- app/gui.py +4 -1
- app/model.py +62 -87
app/cli.py
CHANGED
@@ -55,6 +55,8 @@ def predict(model_path: Path, text: list[str]) -> None:
|
|
55 |
|
56 |
import joblib
|
57 |
|
|
|
|
|
58 |
text = " ".join(text).strip()
|
59 |
if not sys.stdin.isatty():
|
60 |
piped_text = sys.stdin.read().strip()
|
@@ -69,7 +71,8 @@ def predict(model_path: Path, text: list[str]) -> None:
|
|
69 |
click.echo(DONE_STR)
|
70 |
|
71 |
click.echo("Performing sentiment analysis... ", nl=False)
|
72 |
-
prediction = model
|
|
|
73 |
if prediction == 0:
|
74 |
sentiment = click.style("NEGATIVE", fg="red")
|
75 |
elif prediction == 1:
|
@@ -82,9 +85,9 @@ def predict(model_path: Path, text: list[str]) -> None:
|
|
82 |
@cli.command()
|
83 |
@click.option(
|
84 |
"--dataset",
|
85 |
-
|
86 |
-
help="Dataset to
|
87 |
-
type=click.Choice(["sentiment140", "amazonreviews", "imdb50k"]),
|
88 |
)
|
89 |
@click.option(
|
90 |
"--model",
|
@@ -100,27 +103,65 @@ def predict(model_path: Path, text: list[str]) -> None:
|
|
100 |
show_default=True,
|
101 |
type=click.IntRange(1, 50),
|
102 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
def evaluate(
|
104 |
-
dataset: Literal["sentiment140", "amazonreviews", "imdb50k"],
|
105 |
model_path: Path,
|
106 |
cv: int,
|
|
|
|
|
|
|
107 |
) -> None:
|
108 |
-
"""Evaluate the model on the
|
109 |
import joblib
|
110 |
|
111 |
-
from app.
|
|
|
112 |
from app.model import evaluate_model
|
113 |
|
114 |
-
|
115 |
-
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
click.echo("Loading model... ", nl=False)
|
119 |
model = joblib.load(model_path)
|
120 |
click.echo(DONE_STR)
|
121 |
|
122 |
click.echo("Evaluating model... ", nl=False)
|
123 |
-
acc_mean, acc_std = evaluate_model(model,
|
124 |
click.secho(f"{acc_mean:.2%} ± {acc_std:.2%}", fg="blue")
|
125 |
|
126 |
|
@@ -145,6 +186,18 @@ def evaluate(
|
|
145 |
show_default=True,
|
146 |
type=click.IntRange(1, 50),
|
147 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
@click.option(
|
149 |
"--seed",
|
150 |
default=42,
|
@@ -157,45 +210,63 @@ def evaluate(
|
|
157 |
is_flag=True,
|
158 |
help="Overwrite the model file if it already exists",
|
159 |
)
|
|
|
|
|
|
|
|
|
|
|
160 |
def train(
|
161 |
dataset: Literal["sentiment140", "amazonreviews", "imdb50k"],
|
162 |
max_features: int,
|
163 |
cv: int,
|
|
|
|
|
164 |
seed: int,
|
165 |
force: bool,
|
|
|
166 |
) -> None:
|
167 |
"""Train the model on the provided dataset"""
|
168 |
import joblib
|
169 |
|
170 |
-
from app.constants import MODELS_DIR
|
171 |
-
from app.data import load_data
|
172 |
-
from app.model import create_model,
|
173 |
|
174 |
model_path = MODELS_DIR / f"{dataset}_tfidf_ft-{max_features}.pkl"
|
175 |
if model_path.exists() and not force:
|
176 |
click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True)
|
177 |
|
178 |
-
|
179 |
-
|
180 |
-
|
|
|
181 |
|
182 |
-
|
183 |
-
|
184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
|
186 |
click.echo("Training model... ")
|
187 |
-
|
|
|
188 |
click.echo("Model accuracy: ", nl=False)
|
189 |
click.secho(f"{accuracy:.2%}", fg="blue")
|
190 |
|
191 |
click.echo("Model saved to: ", nl=False)
|
192 |
-
joblib.dump(
|
193 |
click.secho(str(model_path), fg="blue")
|
194 |
|
195 |
-
click.echo("Evaluating model... ", nl=False)
|
196 |
-
acc_mean, acc_std = evaluate_model(model, text_data, label_data, folds=cv)
|
197 |
-
click.secho(f"{acc_mean:.2%} ± {acc_std:.2%}", fg="blue")
|
198 |
-
|
199 |
|
200 |
def cli_wrapper() -> None:
|
201 |
cli(max_content_width=120)
|
|
|
55 |
|
56 |
import joblib
|
57 |
|
58 |
+
from app.model import infer_model
|
59 |
+
|
60 |
text = " ".join(text).strip()
|
61 |
if not sys.stdin.isatty():
|
62 |
piped_text = sys.stdin.read().strip()
|
|
|
71 |
click.echo(DONE_STR)
|
72 |
|
73 |
click.echo("Performing sentiment analysis... ", nl=False)
|
74 |
+
prediction = infer_model(model, [text])[0]
|
75 |
+
# prediction = model.predict([text])[0]
|
76 |
if prediction == 0:
|
77 |
sentiment = click.style("NEGATIVE", fg="red")
|
78 |
elif prediction == 1:
|
|
|
85 |
@cli.command()
|
86 |
@click.option(
|
87 |
"--dataset",
|
88 |
+
default="test",
|
89 |
+
help="Dataset to evaluate the model on",
|
90 |
+
type=click.Choice(["test", "sentiment140", "amazonreviews", "imdb50k"]),
|
91 |
)
|
92 |
@click.option(
|
93 |
"--model",
|
|
|
103 |
show_default=True,
|
104 |
type=click.IntRange(1, 50),
|
105 |
)
|
106 |
+
@click.option(
|
107 |
+
"--batch-size",
|
108 |
+
default=512,
|
109 |
+
help="Size of the batches used in tokenization",
|
110 |
+
show_default=True,
|
111 |
+
)
|
112 |
+
@click.option(
|
113 |
+
"--processes",
|
114 |
+
default=8,
|
115 |
+
help="Number of parallel jobs during tokenization",
|
116 |
+
show_default=True,
|
117 |
+
)
|
118 |
+
@click.option(
|
119 |
+
"--verbose",
|
120 |
+
is_flag=True,
|
121 |
+
help="Show verbose output",
|
122 |
+
)
|
123 |
def evaluate(
|
124 |
+
dataset: Literal["test", "sentiment140", "amazonreviews", "imdb50k"],
|
125 |
model_path: Path,
|
126 |
cv: int,
|
127 |
+
batch_size: int,
|
128 |
+
processes: int,
|
129 |
+
verbose: bool,
|
130 |
) -> None:
|
131 |
+
"""Evaluate the model on the the specified dataset"""
|
132 |
import joblib
|
133 |
|
134 |
+
from app.constants import CACHE_DIR
|
135 |
+
from app.data import load_data, tokenize
|
136 |
from app.model import evaluate_model
|
137 |
|
138 |
+
cached_data_path = CACHE_DIR / f"{dataset}_tokenized.pkl"
|
139 |
+
use_cached_data = False
|
140 |
+
if cached_data_path.exists():
|
141 |
+
use_cached_data = click.confirm(f"Found existing tokenized data for '{dataset}'. Use it?", default=True)
|
142 |
+
|
143 |
+
if use_cached_data:
|
144 |
+
click.echo("Loading cached data... ", nl=False)
|
145 |
+
token_data, label_data = joblib.load(cached_data_path)
|
146 |
+
click.echo(DONE_STR)
|
147 |
+
else:
|
148 |
+
click.echo("Loading dataset... ", nl=False)
|
149 |
+
text_data, label_data = load_data(dataset)
|
150 |
+
click.echo(DONE_STR)
|
151 |
+
|
152 |
+
click.echo("Tokenizing data... ", nl=False)
|
153 |
+
token_data = tokenize(text_data, batch_size=batch_size, n_jobs=processes, show_progress=True)
|
154 |
+
joblib.dump((token_data, label_data), cached_data_path, compress=3)
|
155 |
+
click.echo(DONE_STR)
|
156 |
+
|
157 |
+
del text_data
|
158 |
|
159 |
click.echo("Loading model... ", nl=False)
|
160 |
model = joblib.load(model_path)
|
161 |
click.echo(DONE_STR)
|
162 |
|
163 |
click.echo("Evaluating model... ", nl=False)
|
164 |
+
acc_mean, acc_std = evaluate_model(model, token_data, label_data, folds=cv, verbose=verbose)
|
165 |
click.secho(f"{acc_mean:.2%} ± {acc_std:.2%}", fg="blue")
|
166 |
|
167 |
|
|
|
186 |
show_default=True,
|
187 |
type=click.IntRange(1, 50),
|
188 |
)
|
189 |
+
@click.option(
|
190 |
+
"--batch-size",
|
191 |
+
default=512,
|
192 |
+
help="Size of the batches used in tokenization",
|
193 |
+
show_default=True,
|
194 |
+
)
|
195 |
+
@click.option(
|
196 |
+
"--processes",
|
197 |
+
default=8,
|
198 |
+
help="Number of parallel jobs during tokenization",
|
199 |
+
show_default=True,
|
200 |
+
)
|
201 |
@click.option(
|
202 |
"--seed",
|
203 |
default=42,
|
|
|
210 |
is_flag=True,
|
211 |
help="Overwrite the model file if it already exists",
|
212 |
)
|
213 |
+
@click.option(
|
214 |
+
"--verbose",
|
215 |
+
is_flag=True,
|
216 |
+
help="Show verbose output",
|
217 |
+
)
|
218 |
def train(
|
219 |
dataset: Literal["sentiment140", "amazonreviews", "imdb50k"],
|
220 |
max_features: int,
|
221 |
cv: int,
|
222 |
+
batch_size: int,
|
223 |
+
processes: int,
|
224 |
seed: int,
|
225 |
force: bool,
|
226 |
+
verbose: bool,
|
227 |
) -> None:
|
228 |
"""Train the model on the provided dataset"""
|
229 |
import joblib
|
230 |
|
231 |
+
from app.constants import CACHE_DIR, MODELS_DIR
|
232 |
+
from app.data import load_data, tokenize
|
233 |
+
from app.model import create_model, train_model
|
234 |
|
235 |
model_path = MODELS_DIR / f"{dataset}_tfidf_ft-{max_features}.pkl"
|
236 |
if model_path.exists() and not force:
|
237 |
click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True)
|
238 |
|
239 |
+
cached_data_path = CACHE_DIR / f"{dataset}_tokenized.pkl"
|
240 |
+
use_cached_data = False
|
241 |
+
if cached_data_path.exists():
|
242 |
+
use_cached_data = click.confirm(f"Found existing tokenized data for '{dataset}'. Use it?", default=True)
|
243 |
|
244 |
+
if use_cached_data:
|
245 |
+
click.echo("Loading cached data... ", nl=False)
|
246 |
+
token_data, label_data = joblib.load(cached_data_path)
|
247 |
+
click.echo(DONE_STR)
|
248 |
+
else:
|
249 |
+
click.echo("Loading dataset... ", nl=False)
|
250 |
+
text_data, label_data = load_data(dataset)
|
251 |
+
click.echo(DONE_STR)
|
252 |
+
|
253 |
+
click.echo("Tokenizing data... ", nl=False)
|
254 |
+
token_data = tokenize(text_data, batch_size=batch_size, n_jobs=processes, show_progress=True)
|
255 |
+
joblib.dump((token_data, label_data), cached_data_path, compress=3)
|
256 |
+
click.echo(DONE_STR)
|
257 |
+
|
258 |
+
del text_data
|
259 |
|
260 |
click.echo("Training model... ")
|
261 |
+
model = create_model(max_features, seed=None if seed == -1 else seed, verbose=verbose)
|
262 |
+
trained_model, accuracy = train_model(model, token_data, label_data, folds=cv, seed=seed, verbose=verbose)
|
263 |
click.echo("Model accuracy: ", nl=False)
|
264 |
click.secho(f"{accuracy:.2%}", fg="blue")
|
265 |
|
266 |
click.echo("Model saved to: ", nl=False)
|
267 |
+
joblib.dump(trained_model, model_path, compress=3)
|
268 |
click.secho(str(model_path), fg="blue")
|
269 |
|
|
|
|
|
|
|
|
|
270 |
|
271 |
def cli_wrapper() -> None:
|
272 |
cli(max_content_width=120)
|
app/data.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
import bz2
|
4 |
-
from typing import Literal
|
5 |
|
6 |
import pandas as pd
|
|
|
|
|
7 |
|
8 |
from app.constants import (
|
9 |
AMAZONREVIEWS_PATH,
|
@@ -12,9 +14,76 @@ from app.constants import (
|
|
12 |
IMDB50K_URL,
|
13 |
SENTIMENT140_PATH,
|
14 |
SENTIMENT140_URL,
|
|
|
|
|
15 |
)
|
16 |
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
|
20 |
def load_sentiment140(include_neutral: bool = False) -> tuple[list[str], list[int]]:
|
@@ -104,9 +173,6 @@ def load_amazonreviews(merge: bool = True) -> tuple[list[str], list[int]]:
|
|
104 |
# Split the data into labels and text
|
105 |
labels, texts = zip(*(line.split(" ", 1) for line in dataset)) # NOTE: Occasionally OOM
|
106 |
|
107 |
-
# Free up memory
|
108 |
-
del dataset
|
109 |
-
|
110 |
# Map sentiment values
|
111 |
sentiments = [int(label.split("__label__")[1]) - 1 for label in labels]
|
112 |
|
@@ -147,7 +213,48 @@ def load_imdb50k() -> tuple[list[str], list[int]]:
|
|
147 |
return data["review"].tolist(), data["sentiment"].tolist()
|
148 |
|
149 |
|
150 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
"""Load and preprocess the specified dataset.
|
152 |
|
153 |
Args:
|
@@ -166,6 +273,8 @@ def load_data(dataset: Literal["sentiment140", "amazonreviews", "imdb50k"]) -> t
|
|
166 |
return load_amazonreviews(merge=True)
|
167 |
case "imdb50k":
|
168 |
return load_imdb50k()
|
|
|
|
|
169 |
case _:
|
170 |
msg = f"Unknown dataset: {dataset}"
|
171 |
raise ValueError(msg)
|
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
import bz2
|
4 |
+
from typing import TYPE_CHECKING, Literal
|
5 |
|
6 |
import pandas as pd
|
7 |
+
import spacy
|
8 |
+
from tqdm import tqdm
|
9 |
|
10 |
from app.constants import (
|
11 |
AMAZONREVIEWS_PATH,
|
|
|
14 |
IMDB50K_URL,
|
15 |
SENTIMENT140_PATH,
|
16 |
SENTIMENT140_URL,
|
17 |
+
TEST_DATASET_PATH,
|
18 |
+
TEST_DATASET_URL,
|
19 |
)
|
20 |
|
21 |
+
if TYPE_CHECKING:
|
22 |
+
from spacy.tokens import Doc
|
23 |
+
|
24 |
+
__all__ = ["load_data", "tokenize"]
|
25 |
+
|
26 |
+
|
27 |
+
try:
|
28 |
+
nlp = spacy.load("en_core_web_sm", disable=["tok2vec", "parser", "ner"])
|
29 |
+
except OSError:
|
30 |
+
print("Downloading spaCy model...")
|
31 |
+
|
32 |
+
from spacy.cli import download as spacy_download
|
33 |
+
|
34 |
+
spacy_download("en_core_web_sm")
|
35 |
+
nlp = spacy.load("en_core_web_sm", disable=["tok2vec", "parser", "ner"])
|
36 |
+
|
37 |
+
|
38 |
+
def _lemmatize(doc: Doc, threshold: int = 2) -> list[str]:
|
39 |
+
"""Lemmatize the provided text using spaCy.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
doc: spaCy document
|
43 |
+
threshold: Minimum character length of tokens
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
Lemmatized text
|
47 |
+
"""
|
48 |
+
return [
|
49 |
+
token.lemma_.lower().strip()
|
50 |
+
for token in doc
|
51 |
+
if not token.is_stop
|
52 |
+
and not token.is_punct
|
53 |
+
and not token.like_email
|
54 |
+
and not token.like_url
|
55 |
+
and not token.like_num
|
56 |
+
and not (len(token.lemma_) < threshold)
|
57 |
+
]
|
58 |
+
|
59 |
+
|
60 |
+
def tokenize(
|
61 |
+
text_data: list[str],
|
62 |
+
batch_size: int = 512,
|
63 |
+
n_jobs: int = 4,
|
64 |
+
character_threshold: int = 2,
|
65 |
+
show_progress: bool = True,
|
66 |
+
) -> list[list[str]]:
|
67 |
+
"""Tokenize the provided text using spaCy.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
text_data: Text data to tokenize
|
71 |
+
batch_size: Batch size for tokenization
|
72 |
+
n_jobs: Number of parallel jobs
|
73 |
+
character_threshold: Minimum character length of tokens
|
74 |
+
show_progress: Whether to show a progress bar
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
Tokenized text data
|
78 |
+
"""
|
79 |
+
return [
|
80 |
+
_lemmatize(doc, character_threshold)
|
81 |
+
for doc in tqdm(
|
82 |
+
nlp.pipe(text_data, batch_size=batch_size, n_process=n_jobs),
|
83 |
+
total=len(text_data),
|
84 |
+
disable=not show_progress,
|
85 |
+
)
|
86 |
+
]
|
87 |
|
88 |
|
89 |
def load_sentiment140(include_neutral: bool = False) -> tuple[list[str], list[int]]:
|
|
|
173 |
# Split the data into labels and text
|
174 |
labels, texts = zip(*(line.split(" ", 1) for line in dataset)) # NOTE: Occasionally OOM
|
175 |
|
|
|
|
|
|
|
176 |
# Map sentiment values
|
177 |
sentiments = [int(label.split("__label__")[1]) - 1 for label in labels]
|
178 |
|
|
|
213 |
return data["review"].tolist(), data["sentiment"].tolist()
|
214 |
|
215 |
|
216 |
+
def load_test(include_neutral: bool = False) -> tuple[list[str], list[int]]:
|
217 |
+
"""Load the test dataset and make it suitable for use.
|
218 |
+
|
219 |
+
Args:
|
220 |
+
include_neutral: Whether to include neutral sentiment
|
221 |
+
|
222 |
+
Returns:
|
223 |
+
Text and label data
|
224 |
+
|
225 |
+
Raises:
|
226 |
+
FileNotFoundError: If the dataset is not found
|
227 |
+
"""
|
228 |
+
# Check if the dataset exists
|
229 |
+
if not TEST_DATASET_PATH.exists():
|
230 |
+
msg = (
|
231 |
+
f"Test dataset not found at: '{TEST_DATASET_PATH}'\n"
|
232 |
+
"Please download the dataset from:\n"
|
233 |
+
f"{TEST_DATASET_URL}"
|
234 |
+
)
|
235 |
+
raise FileNotFoundError(msg)
|
236 |
+
|
237 |
+
# Load the dataset
|
238 |
+
data = pd.read_csv(TEST_DATASET_PATH)
|
239 |
+
|
240 |
+
# Ignore rows with neutral sentiment
|
241 |
+
if not include_neutral:
|
242 |
+
data = data[data["label"] != 1]
|
243 |
+
|
244 |
+
# Map sentiment values
|
245 |
+
data["label"] = data["label"].map(
|
246 |
+
{
|
247 |
+
0: 0, # Negative
|
248 |
+
1: 1, # Neutral
|
249 |
+
2: 2, # Positive
|
250 |
+
},
|
251 |
+
)
|
252 |
+
|
253 |
+
# Return as lists
|
254 |
+
return data["text"].tolist(), data["label"].tolist()
|
255 |
+
|
256 |
+
|
257 |
+
def load_data(dataset: Literal["sentiment140", "amazonreviews", "imdb50k", "test"]) -> tuple[list[str], list[int]]:
|
258 |
"""Load and preprocess the specified dataset.
|
259 |
|
260 |
Args:
|
|
|
273 |
return load_amazonreviews(merge=True)
|
274 |
case "imdb50k":
|
275 |
return load_imdb50k()
|
276 |
+
case "test":
|
277 |
+
return load_test(include_neutral=False)
|
278 |
case _:
|
279 |
msg = f"Unknown dataset: {dataset}"
|
280 |
raise ValueError(msg)
|
app/gui.py
CHANGED
@@ -7,6 +7,8 @@ from typing import TYPE_CHECKING
|
|
7 |
import gradio as gr
|
8 |
import joblib
|
9 |
|
|
|
|
|
10 |
if TYPE_CHECKING:
|
11 |
from sklearn.base import BaseEstimator
|
12 |
|
@@ -31,7 +33,7 @@ def load_model() -> BaseEstimator:
|
|
31 |
def sentiment_analysis(text: str) -> str:
|
32 |
"""Perform sentiment analysis on the provided text."""
|
33 |
model = load_model()
|
34 |
-
prediction = model
|
35 |
|
36 |
if prediction == 0:
|
37 |
return NEGATIVE_LABEL
|
@@ -52,6 +54,7 @@ demo = gr.Interface(
|
|
52 |
["The movie we watched was boring."],
|
53 |
["This website is amazing!"],
|
54 |
],
|
|
|
55 |
)
|
56 |
|
57 |
|
|
|
7 |
import gradio as gr
|
8 |
import joblib
|
9 |
|
10 |
+
from app.model import infer_model
|
11 |
+
|
12 |
if TYPE_CHECKING:
|
13 |
from sklearn.base import BaseEstimator
|
14 |
|
|
|
33 |
def sentiment_analysis(text: str) -> str:
|
34 |
"""Perform sentiment analysis on the provided text."""
|
35 |
model = load_model()
|
36 |
+
prediction = infer_model(model, [text])[0]
|
37 |
|
38 |
if prediction == 0:
|
39 |
return NEGATIVE_LABEL
|
|
|
54 |
["The movie we watched was boring."],
|
55 |
["This website is amazing!"],
|
56 |
],
|
57 |
+
allow_flagging=False,
|
58 |
)
|
59 |
|
60 |
|
app/model.py
CHANGED
@@ -1,85 +1,25 @@
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
-
import
|
|
|
4 |
|
5 |
import numpy as np
|
6 |
-
import spacy
|
7 |
from joblib import Memory
|
8 |
-
from sklearn.base import BaseEstimator, TransformerMixin
|
9 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
10 |
from sklearn.linear_model import LogisticRegression
|
11 |
from sklearn.model_selection import RandomizedSearchCV, cross_val_score, train_test_split
|
12 |
from sklearn.pipeline import Pipeline
|
13 |
-
from tqdm import tqdm
|
14 |
|
15 |
from app.constants import CACHE_DIR
|
|
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
from spacy.cli import download as spacy_download
|
25 |
-
|
26 |
-
spacy_download("en_core_web_sm")
|
27 |
-
nlp = spacy.load("en_core_web_sm", disable=["tok2vec", "parser", "ner"])
|
28 |
-
|
29 |
-
|
30 |
-
class TextTokenizer(BaseEstimator, TransformerMixin):
|
31 |
-
def __init__(
|
32 |
-
self,
|
33 |
-
*,
|
34 |
-
character_threshold: int = 2,
|
35 |
-
batch_size: int = 1024,
|
36 |
-
n_jobs: int = 8,
|
37 |
-
progress: bool = True,
|
38 |
-
) -> None:
|
39 |
-
self.character_threshold = character_threshold
|
40 |
-
self.batch_size = batch_size
|
41 |
-
self.n_jobs = n_jobs
|
42 |
-
self.progress = progress
|
43 |
-
|
44 |
-
def fit(self, _data: list[str], _labels: list[int] | None = None) -> TextTokenizer:
|
45 |
-
return self
|
46 |
-
|
47 |
-
def transform(self, data: list[str]) -> list[list[str]]:
|
48 |
-
tokenized = []
|
49 |
-
for doc in tqdm(
|
50 |
-
nlp.pipe(data, batch_size=self.batch_size, n_process=self.n_jobs),
|
51 |
-
total=len(data),
|
52 |
-
disable=not self.progress,
|
53 |
-
):
|
54 |
-
tokens = []
|
55 |
-
for token in doc:
|
56 |
-
# Ignore stop words and punctuation
|
57 |
-
if token.is_stop or token.is_punct:
|
58 |
-
continue
|
59 |
-
# Ignore emails, URLs and numbers
|
60 |
-
if token.like_email or token.like_email or token.like_num:
|
61 |
-
continue
|
62 |
-
|
63 |
-
# Lemmatize and lowercase
|
64 |
-
tok = token.lemma_.lower().strip()
|
65 |
-
|
66 |
-
# Format hashtags
|
67 |
-
if tok.startswith("#"):
|
68 |
-
tok = tok[1:]
|
69 |
-
|
70 |
-
# Ignore short and non-alphanumeric tokens
|
71 |
-
if len(tok) < self.character_threshold or not tok.isalnum():
|
72 |
-
continue
|
73 |
-
|
74 |
-
# TODO: Emoticons and emojis
|
75 |
-
# TODO: Spelling correction
|
76 |
-
|
77 |
-
tokens.append(tok)
|
78 |
-
tokenized.append(tokens)
|
79 |
-
return tokenized
|
80 |
-
|
81 |
-
|
82 |
-
def identity(x: list[str]) -> list[str]:
|
83 |
"""Identity function for use in TfidfVectorizer.
|
84 |
|
85 |
Args:
|
@@ -101,22 +41,21 @@ def create_model(
|
|
101 |
Args:
|
102 |
max_features: Maximum number of features
|
103 |
seed: Random seed (None for random seed)
|
104 |
-
verbose: Whether to
|
105 |
|
106 |
Returns:
|
107 |
Untrained model
|
108 |
"""
|
109 |
return Pipeline(
|
110 |
[
|
111 |
-
("tokenizer", TextTokenizer(progress=True)),
|
112 |
(
|
113 |
"vectorizer",
|
114 |
TfidfVectorizer(
|
115 |
max_features=max_features,
|
116 |
ngram_range=(1, 2),
|
117 |
# disable text processing
|
118 |
-
tokenizer=
|
119 |
-
preprocessor=
|
120 |
lowercase=False,
|
121 |
token_pattern=None,
|
122 |
),
|
@@ -130,23 +69,27 @@ def create_model(
|
|
130 |
|
131 |
def train_model(
|
132 |
model: BaseEstimator,
|
133 |
-
|
134 |
label_data: list[int],
|
|
|
135 |
seed: int = 42,
|
|
|
136 |
) -> tuple[BaseEstimator, float]:
|
137 |
"""Train the sentiment analysis model.
|
138 |
|
139 |
Args:
|
140 |
model: Untrained model
|
141 |
-
|
142 |
label_data: Label data
|
|
|
143 |
seed: Random seed (None for random seed)
|
|
|
144 |
|
145 |
Returns:
|
146 |
Trained model and accuracy
|
147 |
"""
|
148 |
text_train, text_test, label_train, label_test = train_test_split(
|
149 |
-
|
150 |
label_data,
|
151 |
test_size=0.2,
|
152 |
random_state=seed,
|
@@ -154,50 +97,82 @@ def train_model(
|
|
154 |
|
155 |
param_distributions = {
|
156 |
"classifier__C": np.logspace(-4, 4, 20),
|
157 |
-
"
|
158 |
}
|
159 |
|
160 |
search = RandomizedSearchCV(
|
161 |
model,
|
162 |
param_distributions,
|
163 |
n_iter=10,
|
164 |
-
cv=
|
165 |
scoring="accuracy",
|
166 |
random_state=seed,
|
167 |
n_jobs=-1,
|
|
|
168 |
)
|
169 |
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
search.fit(text_train, label_train)
|
174 |
|
175 |
best_model = search.best_estimator_
|
176 |
return best_model, best_model.score(text_test, label_test)
|
177 |
|
178 |
|
179 |
def evaluate_model(
|
180 |
-
model:
|
181 |
-
|
182 |
label_data: list[int],
|
183 |
folds: int = 5,
|
|
|
184 |
) -> tuple[float, float]:
|
185 |
"""Evaluate the model using cross-validation.
|
186 |
|
187 |
Args:
|
188 |
model: Trained model
|
189 |
-
|
190 |
label_data: Label data
|
191 |
folds: Number of cross-validation folds
|
|
|
192 |
|
193 |
Returns:
|
194 |
Mean accuracy and standard deviation
|
195 |
"""
|
|
|
196 |
scores = cross_val_score(
|
197 |
model,
|
198 |
-
|
199 |
label_data,
|
200 |
cv=folds,
|
201 |
scoring="accuracy",
|
|
|
|
|
202 |
)
|
|
|
203 |
return scores.mean(), scores.std()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
+
import os
|
4 |
+
from typing import TYPE_CHECKING
|
5 |
|
6 |
import numpy as np
|
|
|
7 |
from joblib import Memory
|
|
|
8 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
9 |
from sklearn.linear_model import LogisticRegression
|
10 |
from sklearn.model_selection import RandomizedSearchCV, cross_val_score, train_test_split
|
11 |
from sklearn.pipeline import Pipeline
|
|
|
12 |
|
13 |
from app.constants import CACHE_DIR
|
14 |
+
from app.data import tokenize
|
15 |
|
16 |
+
if TYPE_CHECKING:
|
17 |
+
from sklearn.base import BaseEstimator
|
18 |
+
|
19 |
+
__all__ = ["create_model", "train_model", "evaluate_model", "infer_model"]
|
20 |
+
|
21 |
+
|
22 |
+
def _identity(x: list[str]) -> list[str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
"""Identity function for use in TfidfVectorizer.
|
24 |
|
25 |
Args:
|
|
|
41 |
Args:
|
42 |
max_features: Maximum number of features
|
43 |
seed: Random seed (None for random seed)
|
44 |
+
verbose: Whether to output additional information
|
45 |
|
46 |
Returns:
|
47 |
Untrained model
|
48 |
"""
|
49 |
return Pipeline(
|
50 |
[
|
|
|
51 |
(
|
52 |
"vectorizer",
|
53 |
TfidfVectorizer(
|
54 |
max_features=max_features,
|
55 |
ngram_range=(1, 2),
|
56 |
# disable text processing
|
57 |
+
tokenizer=_identity,
|
58 |
+
preprocessor=_identity,
|
59 |
lowercase=False,
|
60 |
token_pattern=None,
|
61 |
),
|
|
|
69 |
|
70 |
def train_model(
|
71 |
model: BaseEstimator,
|
72 |
+
token_data: list[str],
|
73 |
label_data: list[int],
|
74 |
+
folds: int = 5,
|
75 |
seed: int = 42,
|
76 |
+
verbose: bool = False,
|
77 |
) -> tuple[BaseEstimator, float]:
|
78 |
"""Train the sentiment analysis model.
|
79 |
|
80 |
Args:
|
81 |
model: Untrained model
|
82 |
+
token_data: Tokenized text data
|
83 |
label_data: Label data
|
84 |
+
folds: Number of cross-validation folds
|
85 |
seed: Random seed (None for random seed)
|
86 |
+
verbose: Whether to output additional information
|
87 |
|
88 |
Returns:
|
89 |
Trained model and accuracy
|
90 |
"""
|
91 |
text_train, text_test, label_train, label_test = train_test_split(
|
92 |
+
token_data,
|
93 |
label_data,
|
94 |
test_size=0.2,
|
95 |
random_state=seed,
|
|
|
97 |
|
98 |
param_distributions = {
|
99 |
"classifier__C": np.logspace(-4, 4, 20),
|
100 |
+
"classifier__solver": ["liblinear", "saga"],
|
101 |
}
|
102 |
|
103 |
search = RandomizedSearchCV(
|
104 |
model,
|
105 |
param_distributions,
|
106 |
n_iter=10,
|
107 |
+
cv=folds,
|
108 |
scoring="accuracy",
|
109 |
random_state=seed,
|
110 |
n_jobs=-1,
|
111 |
+
verbose=verbose,
|
112 |
)
|
113 |
|
114 |
+
os.environ["PYTHONWARNINGS"] = "ignore"
|
115 |
+
search.fit(text_train, label_train)
|
116 |
+
del os.environ["PYTHONWARNINGS"]
|
|
|
117 |
|
118 |
best_model = search.best_estimator_
|
119 |
return best_model, best_model.score(text_test, label_test)
|
120 |
|
121 |
|
122 |
def evaluate_model(
|
123 |
+
model: BaseEstimator,
|
124 |
+
token_data: list[str],
|
125 |
label_data: list[int],
|
126 |
folds: int = 5,
|
127 |
+
verbose: bool = False,
|
128 |
) -> tuple[float, float]:
|
129 |
"""Evaluate the model using cross-validation.
|
130 |
|
131 |
Args:
|
132 |
model: Trained model
|
133 |
+
token_data: Tokenized text data
|
134 |
label_data: Label data
|
135 |
folds: Number of cross-validation folds
|
136 |
+
verbose: Whether to output additional information
|
137 |
|
138 |
Returns:
|
139 |
Mean accuracy and standard deviation
|
140 |
"""
|
141 |
+
os.environ["PYTHONWARNINGS"] = "ignore"
|
142 |
scores = cross_val_score(
|
143 |
model,
|
144 |
+
token_data,
|
145 |
label_data,
|
146 |
cv=folds,
|
147 |
scoring="accuracy",
|
148 |
+
n_jobs=-1,
|
149 |
+
verbose=verbose,
|
150 |
)
|
151 |
+
del os.environ["PYTHONWARNINGS"]
|
152 |
return scores.mean(), scores.std()
|
153 |
+
|
154 |
+
|
155 |
+
def infer_model(
|
156 |
+
model: BaseEstimator,
|
157 |
+
text_data: list[str],
|
158 |
+
batch_size: int = 32,
|
159 |
+
n_jobs: int = 4,
|
160 |
+
) -> list[int]:
|
161 |
+
"""Predict the sentiment of the provided text documents.
|
162 |
+
|
163 |
+
Args:
|
164 |
+
model: Trained model
|
165 |
+
text_data: Text data
|
166 |
+
batch_size: Batch size for tokenization
|
167 |
+
n_jobs: Number of parallel jobs
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
Predicted sentiments
|
171 |
+
"""
|
172 |
+
tokens = tokenize(
|
173 |
+
text_data,
|
174 |
+
batch_size=batch_size,
|
175 |
+
n_jobs=n_jobs,
|
176 |
+
show_progress=False,
|
177 |
+
)
|
178 |
+
return model.predict(tokens)
|