Tymec commited on
Commit
667fe9d
1 Parent(s): 7d4eb47

Restructure project into package structure

Browse files
.vscode/settings.json CHANGED
@@ -2,7 +2,7 @@
2
  "notebook.formatOnSave.enabled": true,
3
  "notebook.codeActionsOnSave": {
4
  "notebook.source.fixAll": "explicit",
5
- "notebook.source.organizeImports": "explicit"
6
  },
7
  "[python]": {
8
  "editor.formatOnSave": true,
 
2
  "notebook.formatOnSave.enabled": true,
3
  "notebook.codeActionsOnSave": {
4
  "notebook.source.fixAll": "explicit",
5
+ "source.organizeImports": "explicit"
6
  },
7
  "[python]": {
8
  "editor.formatOnSave": true,
README.md CHANGED
@@ -6,3 +6,18 @@ Sentiment Analysis
6
  2. `cd` into the repository
7
  3. Run `just install` to install the dependencies
8
  4. Run `just run --help` to see the available commands
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  2. `cd` into the repository
7
  3. Run `just install` to install the dependencies
8
  4. Run `just run --help` to see the available commands
9
+
10
+
11
+ ### TODO
12
+ - [ ] CLI using `click` (commands: predict, train, evaluate) with settings set via flags or environment variables
13
+ - [ ] GUI using `gradio` (tabs: predict, train, evaluate, compare, settings)
14
+ - [ ] For the sklearn model, add more classifiers
15
+ - [ ] Use random search for hyperparameter tuning and grid search for fine-tuning
16
+ - [ ] Finish the text pre-processing transformer
17
+ - [ ] For vectorization, use custom stopwords
18
+ - [ ] Write own tokenizer/vectorizer
19
+ - [ ] Add more datasets
20
+ - [ ] Add more models (e.g. BERT)
21
+ - [ ] Write tests
22
+ - [ ] Use xgboost?
23
+ - [ ] Deploy to huggingface?
app/__init__.py ADDED
File without changes
app/constants.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ DEFAULT_SEED: int = 42
4
+ MAX_TOKENIZER_FEATURES: int = 500000
5
+ CLF_MAX_ITER: int = 1000
6
+
7
+ DATASET_PATH: Path = Path("data/training.1600000.processed.noemoticon.csv")
8
+ STOPWORDS_PATH: Path = Path("data/stopwords-en.txt")
9
+ MODELS_DIR: Path = Path("models")
10
+ CACHE_DIR: Path = Path("cache")
11
+ CHECKPOINT_PATH: Path = CACHE_DIR / "pipeline.pkl"
12
+
13
+
14
+ # Create directories if they don't exist
15
+ MODELS_DIR.mkdir(parents=True, exist_ok=True)
16
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
app/gui.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+
5
+ import gradio as gr
6
+
7
+ from constants import MODELS_DIR
8
+ from model import predict, tokenize
9
+
10
+ CSS_PATH = Path("style.css")
11
+ TOKENIZER_EXT = ".tokenizer.pkl"
12
+ MODEL_EXT = ".model.pkl"
13
+ POSITIVE_LABEL = "Positive 😊"
14
+ NEGATIVE_LABEL = "Negative 😤"
15
+ REFRESH_SYMBOL = "🔄"
16
+
17
+
18
+ def load_style() -> str:
19
+ if not CSS_PATH.is_file():
20
+ return ""
21
+
22
+ with Path.open(CSS_PATH) as f:
23
+ return f.read()
24
+
25
+
26
+ def predict_wrapper(text: str, tokenizer: str, model: str) -> str:
27
+ toks = tokenize(text, MODELS_DIR / f"{tokenizer}{TOKENIZER_EXT}")
28
+ pred = predict(toks, MODELS_DIR / f"{model}{MODEL_EXT}")
29
+ return POSITIVE_LABEL if pred else NEGATIVE_LABEL
30
+
31
+
32
+ def train_wrapper() -> None:
33
+ msg = "Training is not supported in the GUI."
34
+ raise NotImplementedError(msg)
35
+
36
+
37
+ def evaluate_wrapper() -> None:
38
+ msg = "Evaluation is not supported in the GUI."
39
+ raise NotImplementedError(msg)
40
+
41
+
42
+ with gr.Blocks(css=load_style()) as demo:
43
+ gr.Markdown("## Sentiment Analysis")
44
+
45
+ with gr.Row(equal_height=True):
46
+ textbox = gr.Textbox(
47
+ lines=10,
48
+ label="Enter text to analyze",
49
+ placeholder="Enter text here",
50
+ key="input-textbox",
51
+ )
52
+
53
+ with gr.Column():
54
+ output = gr.Label()
55
+
56
+ with gr.Row(elem_classes="justify-between"):
57
+ clear_btn = gr.ClearButton([textbox, output], value="Clear 🧹")
58
+ analyze_btn = gr.Button(
59
+ "Analyze 🔍",
60
+ variant="primary",
61
+ interactive=False,
62
+ )
63
+
64
+ with gr.Row():
65
+ tokenizer_selector = gr.Dropdown(
66
+ choices=[tkn.stem[: -len(".tokenizer")] for tkn in MODELS_DIR.glob(f"*{TOKENIZER_EXT}")],
67
+ label="Tokenizer",
68
+ key="tokenizer-selector",
69
+ )
70
+
71
+ model_selector = gr.Dropdown(
72
+ choices=[mdl.stem[: -len(".model")] for mdl in MODELS_DIR.glob(f"*{MODEL_EXT}")],
73
+ label="Model",
74
+ key="model-selector",
75
+ )
76
+
77
+ # TODO: Refresh button
78
+
79
+ # Event handlers
80
+ textbox.input(
81
+ fn=lambda text: gr.update(interactive=bool(text.strip())),
82
+ inputs=[textbox],
83
+ outputs=[analyze_btn],
84
+ )
85
+ analyze_btn.click(
86
+ fn=predict_wrapper,
87
+ inputs=[textbox, tokenizer_selector, model_selector],
88
+ outputs=[output],
89
+ )
90
+
91
+ demo.queue()
92
+ demo.launch()
app/model.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+ from functools import lru_cache
5
+ from typing import TYPE_CHECKING, Sequence
6
+
7
+ import joblib
8
+ from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
9
+ from sklearn.linear_model import LogisticRegression
10
+ from sklearn.pipeline import Pipeline
11
+
12
+ from constants import CLF_MAX_ITER, MAX_TOKENIZER_FEATURES
13
+ from utils import get_cache_memory, get_random_state
14
+
15
+ if TYPE_CHECKING:
16
+ from pathlib import Path
17
+
18
+ from numpy import ndarray
19
+ from numpy.random import RandomState
20
+
21
+
22
+ __all__ = ["predict", "tokenize"]
23
+
24
+
25
+ @lru_cache(maxsize=1)
26
+ def get_model(model_path: Path) -> Pipeline:
27
+ return joblib.load(model_path)
28
+
29
+
30
+ @lru_cache(maxsize=1)
31
+ def get_tokenizer(tokenizer_path: Path) -> Pipeline:
32
+ return joblib.load(tokenizer_path)
33
+
34
+
35
+ def export_to_file(pipeline: Pipeline, path: Path) -> None:
36
+ joblib.dump(pipeline, path)
37
+
38
+
39
+ def tokenize(text: str, tokenizer_path: Path) -> ndarray:
40
+ tokenizer = get_tokenizer(tokenizer_path)
41
+ return tokenizer.transform([text])[0]
42
+
43
+
44
+ def predict(tokens: ndarray, model_path: Path) -> bool:
45
+ model = get_model(model_path)
46
+ prediction = model.predict([tokens])
47
+ return prediction[0] == 1
48
+
49
+
50
+ def train_and_export(
51
+ steps: Sequence[tuple],
52
+ x: list[str],
53
+ y: list[int],
54
+ export_path: Path,
55
+ cache: joblib.Memory,
56
+ ) -> Pipeline:
57
+ pipeline = Pipeline(steps, memory=cache)
58
+
59
+ with warnings.catch_warnings():
60
+ warnings.simplefilter("ignore")
61
+ pipeline.fit(x, y)
62
+
63
+ export_to_file(pipeline, export_path)
64
+ return pipeline
65
+
66
+
67
+ def train_tokenizer_and_export(x: list[str], y: list[int], export_path: Path, cache: joblib.Memory) -> Pipeline:
68
+ return train_and_export(
69
+ [
70
+ (
71
+ "vectorize",
72
+ CountVectorizer(stop_words="english", ngram_range=(1, 2), max_features=MAX_TOKENIZER_FEATURES),
73
+ ),
74
+ ("tfidf", TfidfTransformer()),
75
+ ],
76
+ x,
77
+ y,
78
+ export_path,
79
+ cache,
80
+ )
81
+
82
+
83
+ def train_model_and_export(
84
+ x: ndarray,
85
+ y: list[int],
86
+ export_path: Path,
87
+ cache: joblib.Memory,
88
+ rs: RandomState,
89
+ ) -> Pipeline:
90
+ return train_and_export(
91
+ [("clf", LogisticRegression(max_iter=CLF_MAX_ITER, random_state=rs))],
92
+ x,
93
+ y,
94
+ export_path,
95
+ cache,
96
+ )
97
+
98
+
99
+ def train(x: list[str], y: list[int]) -> Pipeline:
100
+ cache = get_cache_memory()
101
+ rs = get_random_state()
102
+
103
+ tokenizer = train_tokenizer(x, y, cache)
104
+ x_tr = tokenizer.transform(x)
105
+
106
+ model = train_model(x_tr, y, cache, rs)
107
+
108
+ return Pipeline([("tokenizer", tokenizer), ("model", model)])
109
+
110
+
111
+ def train_tokenizer(x: list[str], y: list[int], cache: joblib.Memory) -> Pipeline:
112
+ # TODO: In the future, allow for different tokenizers
113
+ pipeline = Pipeline(
114
+ [
115
+ (
116
+ "vectorize",
117
+ CountVectorizer(stop_words="english", ngram_range=(1, 2), max_features=MAX_TOKENIZER_FEATURES),
118
+ ),
119
+ ("tfidf", TfidfTransformer()),
120
+ ],
121
+ memory=cache,
122
+ )
123
+
124
+ with warnings.catch_warnings():
125
+ warnings.simplefilter("ignore") # Ignore joblib warnings
126
+ pipeline.fit(x, y)
127
+
128
+ return pipeline
129
+
130
+
131
+ def train_model(x: list[str], y: list[int], cache: joblib.Memory, rs: RandomState) -> Pipeline:
132
+ # TODO: In the future, allow for different classifiers
133
+ pipeline = Pipeline(
134
+ [
135
+ ("clf", LogisticRegression(max_iter=CLF_MAX_ITER, random_state=rs)),
136
+ ],
137
+ memory=cache,
138
+ )
139
+
140
+ with warnings.catch_warnings():
141
+ warnings.simplefilter("ignore") # Ignore joblib warnings
142
+ pipeline.fit(x, y)
143
+
144
+ return pipeline
app/utils.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions"""
2
+
3
+ from __future__ import annotations
4
+
5
+ import itertools
6
+ import re
7
+ import warnings
8
+ from collections import deque
9
+ from enum import Enum
10
+ from functools import lru_cache
11
+ from threading import Event, Lock
12
+ from typing import Any
13
+
14
+ from joblib import Memory
15
+ from numpy.random import RandomState
16
+
17
+ from constants import CACHE_DIR, DEFAULT_SEED
18
+
19
+ __all__ = ["colorize", "wrap_queued_call", "get_random_state", "get_cache_memory"]
20
+
21
+
22
+ ANSI_RESET = 0
23
+
24
+
25
+ class Color(Enum):
26
+ """ANSI color codes."""
27
+
28
+ BLACK = 30
29
+ RED = 31
30
+ GREEN = 32
31
+ YELLOW = 33
32
+ BLUE = 34
33
+ MAGENTA = 35
34
+ CYAN = 36
35
+ WHITE = 37
36
+
37
+
38
+ class Style(Enum):
39
+ """ANSI style codes."""
40
+
41
+ BOLD = 1
42
+ DIM = 2
43
+ ITALIC = 3
44
+ UNDERLINE = 4
45
+ BLINK = 5
46
+ INVERTED = 7
47
+ HIDDEN = 8
48
+
49
+
50
+ # https://gist.github.com/vitaliyp/6d54dd76ca2c3cdfc1149d33007dc34a
51
+ class FIFOLock:
52
+ def __init__(self):
53
+ self._lock = Lock()
54
+ self._inner_lock = Lock()
55
+ self._pending_threads = deque()
56
+
57
+ def acquire(self, blocking: bool = True) -> bool:
58
+ with self._inner_lock:
59
+ lock_acquired = self._lock.acquire(False)
60
+ if lock_acquired:
61
+ return True
62
+ if not blocking:
63
+ return False
64
+
65
+ release_event = Event()
66
+ self._pending_threads.append(release_event)
67
+
68
+ release_event.wait()
69
+ return self._lock.acquire()
70
+
71
+ def release(self) -> None:
72
+ with self._inner_lock:
73
+ if self._pending_threads:
74
+ release_event = self._pending_threads.popleft()
75
+ release_event.set()
76
+
77
+ self._lock.release()
78
+
79
+ __enter__ = acquire
80
+
81
+ def __exit__(self, _t, _v, _tb): # noqa: ANN001
82
+ self.release()
83
+
84
+
85
+ @lru_cache(maxsize=1)
86
+ def get_queue_lock() -> FIFOLock:
87
+ return FIFOLock()
88
+
89
+
90
+ @lru_cache(maxsize=1)
91
+ def get_random_state(seed: int = DEFAULT_SEED) -> RandomState:
92
+ return RandomState(seed)
93
+
94
+
95
+ @lru_cache(maxsize=1)
96
+ def get_cache_memory() -> Memory:
97
+ return Memory(CACHE_DIR, verbose=0)
98
+
99
+
100
+ def to_ansi(code: int) -> str:
101
+ """Convert an integer to an ANSI escape code."""
102
+ return f"\033[{code}m"
103
+
104
+
105
+ @lru_cache(maxsize=None)
106
+ def get_ansi_color(color: Color, bright: bool = False, background: bool = False) -> str:
107
+ """Get ANSI color code for the specified color, brightness and background."""
108
+ code = color.value
109
+ if bright:
110
+ code += 60
111
+ if background:
112
+ code += 10
113
+ return to_ansi(code)
114
+
115
+
116
+ def replace_color_tag(color: Color, text: str) -> None:
117
+ """Replace both dark and light color tags for background and foreground."""
118
+ for bright, bg in itertools.product([False, True], repeat=2):
119
+ tag = f"{'BG_' if bg else ''}{'BRIGHT_' if bright else ''}{color.name}"
120
+ text = text.replace(f"[{tag}]", get_ansi_color(color, bright=bright, background=bg))
121
+ text = text.replace(f"[/{tag}]", to_ansi(ANSI_RESET))
122
+
123
+ return text
124
+
125
+
126
+ @lru_cache(maxsize=256)
127
+ def colorize(text: str, strip: bool = True) -> str:
128
+ """Format text with ANSI color codes using tags [COLOR], [BG_COLOR] and [STYLE].
129
+ Reset color/style with [/TAG].
130
+ Escape with double brackets [[]]. Strip leading and trailing whitespace if strip=True.
131
+ """
132
+
133
+ # replace foreground and background color tags
134
+ for color in Color:
135
+ text = replace_color_tag(color, text)
136
+
137
+ # replace style tags
138
+ for style in Style:
139
+ text = text.replace(f"[{style.name}]", to_ansi(style.value)).replace(f"[/{style.name}]", to_ansi(ANSI_RESET))
140
+
141
+ # if there are any tags left, remove them and throw a warning
142
+ pat1 = re.compile(r"((?<!\[)\[)([^\[\]]*)(\](?!\]))")
143
+ for match in pat1.finditer(text):
144
+ color = match.group(1)
145
+ text = text.replace(match.group(0), "")
146
+ warnings.warn(f"Invalid color tag: {color!r}", UserWarning, stacklevel=2)
147
+
148
+ # escape double brackets
149
+ pat2 = re.compile(r"\[\[[^\[\]\v]+\]\]")
150
+ text = pat2.sub("", text)
151
+
152
+ # reset color/style at the end
153
+ text += to_ansi(ANSI_RESET)
154
+
155
+ return text.strip() if strip else text
156
+
157
+
158
+ # https://github.com/AUTOMATIC1111/stable-diffusion-webui/modules/call_queue.py
159
+ def wrap_queued_call(func: callable) -> callable:
160
+ def f(*args, **kwargs) -> Any: # noqa: ANN003, ANN002
161
+ with get_queue_lock():
162
+ return func(*args, **kwargs)
163
+
164
+ return f
deprecated/__init__.py ADDED
File without changes
deprecated/main.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+
5
+ import click
6
+ import joblib
7
+
8
+ from app.utils import colorize
9
+
10
+
11
+ @click.group()
12
+ def cli() -> None: ...
13
+
14
+
15
+ @cli.command("predict")
16
+ @click.option(
17
+ "-m",
18
+ "--model",
19
+ "model_path",
20
+ default="models/model.pkl",
21
+ help="Path to the model file.",
22
+ type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True, resolve_path=True, path_type=Path),
23
+ )
24
+ @click.argument("text", nargs=-1)
25
+ def predict(model_path: Path, text: list[str]) -> None:
26
+ input_text = " ".join(text).strip()
27
+ if not input_text:
28
+ click.echo("[RED]Error[/RED]: Input text is empty.")
29
+ return
30
+
31
+ # Load the model
32
+ click.echo("Loading model... ", nl=False)
33
+ model = joblib.load(model_path)
34
+ click.echo(colorize("[GREEN]DONE"))
35
+
36
+ # Run the model
37
+ click.echo("Performing sentiment analysis... ", nl=False)
38
+ prediction = model.predict([input_text])
39
+ sentiment = "[GREEN]POSITIVE" if prediction[0] == 1 else "[RED]NEGATIVE"
40
+ click.echo(colorize(sentiment))
41
+
42
+
43
+ if __name__ == "__main__":
44
+ cli()
deprecated/train.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING
6
+
7
+ import click
8
+ import joblib
9
+ import pandas as pd
10
+ from numpy.random import RandomState
11
+ from sklearn.feature_extraction.text import TfidfVectorizer
12
+ from sklearn.linear_model import LogisticRegression
13
+ from sklearn.metrics import accuracy_score, classification_report
14
+ from sklearn.model_selection import train_test_split
15
+ from sklearn.pipeline import Pipeline
16
+
17
+ if TYPE_CHECKING:
18
+ from sklearn.base import BaseEstimator
19
+
20
+ SEED = 42
21
+ DATASET_PATH = Path("data/training.1600000.processed.noemoticon.csv")
22
+ STOPWORDS_PATH = Path("data/stopwords-en.txt")
23
+ CHECKPOINT_PATH = Path("cache/pipeline.pkl")
24
+ MODELS_DIR = Path("models")
25
+ CACHE_DIR = Path("cache")
26
+ MAX_FEATURES = 10000 # 500000
27
+
28
+ # Make sure paths exist
29
+ MODELS_DIR.mkdir(parents=True, exist_ok=True)
30
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
31
+
32
+ # Memory cache for sklearn pipelines
33
+ mem = joblib.Memory(CACHE_DIR, verbose=0)
34
+
35
+ # TODO: use xgboost
36
+
37
+
38
+ def get_random_state(seed: int = SEED) -> RandomState:
39
+ return RandomState(seed)
40
+
41
+
42
+ def load_data() -> tuple[list[str], list[int]]:
43
+ """The model takes in a list of strings and a list of integers where 1 is positive sentiment and 0 is negative sentiment."""
44
+ data = pd.read_csv(
45
+ DATASET_PATH,
46
+ encoding="ISO-8859-1",
47
+ names=[
48
+ "target", # 0 = negative, 2 = neutral, 4 = positive
49
+ "id", # The id of the tweet
50
+ "date", # The date of the tweet
51
+ "flag", # The query, NO_QUERY if not present
52
+ "user", # The user that tweeted
53
+ "text", # The text of the tweet
54
+ ],
55
+ )
56
+
57
+ # Ignore rows with neutral sentiment
58
+ data = data[data["target"] != 2]
59
+
60
+ # Create new column called "sentiment" with 1 for positive and 0 for negative
61
+ data["sentiment"] = data["target"] == 4
62
+
63
+ # Drop the columns we don't need
64
+ # data = data.drop(columns=["target", "id", "date", "flag", "user"]) # NOTE: No need, since we return the columns we need
65
+
66
+ # Return as lists
67
+ return list(data["text"]), list(data["sentiment"])
68
+
69
+
70
+ def create_pipeline(clf: BaseEstimator) -> Pipeline:
71
+ return Pipeline(
72
+ [
73
+ # Preprocess
74
+ # ("vectorize", CountVectorizer(stop_words="english", ngram_range=(1, 2), max_features=MAX_FEATURES)),
75
+ # ("tfidf", TfidfTransformer()),
76
+ ("vectorize", TfidfVectorizer(ngram_range=(1, 2), max_features=MAX_FEATURES)),
77
+ # Classifier
78
+ ("clf", clf),
79
+ ],
80
+ memory=mem,
81
+ )
82
+
83
+
84
+ def evaluate_pipeline(pipeline: Pipeline, x: list[str], y: list[int]) -> float:
85
+ y_pred = pipeline.predict(x)
86
+ report = classification_report(y, y_pred)
87
+ click.echo(report)
88
+
89
+ # TODO: Confusion matrix
90
+
91
+ return accuracy_score(y, y_pred)
92
+
93
+
94
+ def export_pipeline(pipeline: Pipeline, name: str) -> None:
95
+ model_path = MODELS_DIR / f"{name}.pkl"
96
+ joblib.dump(pipeline, model_path)
97
+ click.echo(f"Model exported to {model_path!r}")
98
+
99
+
100
+ @click.command()
101
+ @click.option("--retrain", is_flag=True, help="Train the model even if a checkpoint exists.")
102
+ @click.option("--evaluate", is_flag=True, help="Evaluate the model.")
103
+ @click.option("--flush-cache", is_flag=True, help="Clear sklearn cache.")
104
+ @click.option("--seed", type=int, default=SEED, help="Random seed.")
105
+ def train(retrain: bool, evaluate: bool, flush_cache: bool, seed: int) -> None:
106
+ rng = get_random_state(seed)
107
+
108
+ # Clear sklearn cache
109
+ if flush_cache:
110
+ click.echo("Clearing cache... ", nl=False)
111
+ mem.clear(warn=False)
112
+ click.echo("DONE")
113
+
114
+ # Load and split data
115
+ click.echo("Loading data... ", nl=False)
116
+ x, y = load_data()
117
+ x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=rng)
118
+ click.echo("DONE")
119
+
120
+ # Train model
121
+ if retrain or not CHECKPOINT_PATH.exists():
122
+ click.echo("Training model... ", nl=False)
123
+ clf = LogisticRegression(max_iter=1000, random_state=rng)
124
+ model = create_pipeline(clf)
125
+ with warnings.catch_warnings():
126
+ warnings.simplefilter("ignore") # Ignore joblib warnings
127
+ model.fit(x_train, y_train)
128
+ joblib.dump(model, CHECKPOINT_PATH)
129
+ click.echo("DONE")
130
+ else:
131
+ click.echo("Loading model... ", nl=False)
132
+ model = joblib.load(CHECKPOINT_PATH)
133
+ click.echo("DONE")
134
+
135
+ # Evaluate model
136
+ if evaluate:
137
+ evaluate_pipeline(model, x_test, y_test)
138
+
139
+ # Quick test
140
+ test_text = ["I love this movie", "I hate this movie"]
141
+ click.echo("Quick test:")
142
+ for text in test_text:
143
+ click.echo(f"\t{'positive' if model.predict([text])[0] else 'negative'}: {text}")
144
+
145
+ # Export model
146
+ click.echo("Exporting model... ", nl=False)
147
+ export_pipeline(model, "logistic_regression")
148
+ click.echo("DONE")
149
+
150
+
151
+ if __name__ == "__main__":
152
+ train()
pyproject.toml CHANGED
@@ -108,6 +108,7 @@ ignore = [
108
  "PERF203", # ignore for now; investigate
109
  "T201", # print
110
  "ANN204", # missing-return-type-special-method
 
111
  ]
112
  select = ["ALL"]
113
  # Allow unused variables when underscore-prefixed
 
108
  "PERF203", # ignore for now; investigate
109
  "T201", # print
110
  "ANN204", # missing-return-type-special-method
111
+ "ERA001", # commented-out-code
112
  ]
113
  select = ["ALL"]
114
  # Allow unused variables when underscore-prefixed
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .justify-between {
2
+ justify-content: space-between;
3
+ }