Tymec commited on
Commit
d4ef46b
·
1 Parent(s): 68bf0ed

Update docstrings and comments

Browse files
Files changed (6) hide show
  1. app/__main__.py +2 -0
  2. app/constants.py +2 -0
  3. app/data.py +2 -0
  4. app/gui.py +21 -5
  5. app/model.py +18 -13
  6. app/utils.py +4 -0
app/__main__.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from __future__ import annotations
2
 
3
  from app.cli import cli_wrapper as cli
 
1
+ """Entry point for the application."""
2
+
3
  from __future__ import annotations
4
 
5
  from app.cli import cli_wrapper as cli
app/constants.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from __future__ import annotations
2
 
3
  import os
 
1
+ """Constants used by the application."""
2
+
3
  from __future__ import annotations
4
 
5
  import os
app/data.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from __future__ import annotations
2
 
3
  import bz2
 
1
+ """Functions to load and preprocess text data."""
2
+
3
  from __future__ import annotations
4
 
5
  import bz2
app/gui.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from __future__ import annotations
2
 
3
  import os
@@ -22,7 +24,11 @@ NEGATIVE_LABEL = "Negative 😤"
22
 
23
  @lru_cache(maxsize=1)
24
  def load_model() -> BaseEstimator:
25
- """Load the trained model and cache it."""
 
 
 
 
26
  model_path = os.environ.get("MODEL_PATH", None)
27
  if model_path is None:
28
  msg = "MODEL_PATH environment variable not set"
@@ -31,9 +37,15 @@ def load_model() -> BaseEstimator:
31
 
32
 
33
  def sentiment_analysis(text: str) -> str:
34
- """Perform sentiment analysis on the provided text."""
35
- model = load_model()
36
- prediction = infer_model(model, [text])[0]
 
 
 
 
 
 
37
 
38
  if prediction == 0:
39
  return NEGATIVE_LABEL
@@ -59,7 +71,11 @@ demo = gr.Interface(
59
 
60
 
61
  def launch_gui(share: bool) -> None:
62
- """Launch the Gradio GUI."""
 
 
 
 
63
  demo.launch(share=share)
64
 
65
 
 
1
+ """GUI using Gradio."""
2
+
3
  from __future__ import annotations
4
 
5
  import os
 
24
 
25
  @lru_cache(maxsize=1)
26
  def load_model() -> BaseEstimator:
27
+ """Load the trained model and cache it.
28
+
29
+ Returns:
30
+ Loaded model
31
+ """
32
  model_path = os.environ.get("MODEL_PATH", None)
33
  if model_path is None:
34
  msg = "MODEL_PATH environment variable not set"
 
37
 
38
 
39
  def sentiment_analysis(text: str) -> str:
40
+ """Perform sentiment analysis on the provided text.
41
+
42
+ Args:
43
+ text: Input text
44
+
45
+ Returns:
46
+ Predicted sentiment label
47
+ """
48
+ prediction = infer_model(load_model(), [text])[0]
49
 
50
  if prediction == 0:
51
  return NEGATIVE_LABEL
 
71
 
72
 
73
  def launch_gui(share: bool) -> None:
74
+ """Launch the Gradio GUI.
75
+
76
+ Args:
77
+ share: Whether to create a public link
78
+ """
79
  demo.launch(share=share)
80
 
81
 
app/model.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from __future__ import annotations
2
 
3
  import warnings
@@ -21,7 +23,7 @@ __all__ = ["train_model", "evaluate_model", "infer_model"]
21
 
22
 
23
  def _identity(x: list[str]) -> list[str]:
24
- """Identity function for use in TfidfVectorizer.
25
 
26
  Args:
27
  x: Input data
@@ -36,7 +38,6 @@ def _get_vectorizer(
36
  name: Literal["tfidf", "count", "hashing"],
37
  n_features: int,
38
  min_df: int = 5,
39
- ngram: tuple[int, int] = (1, 2),
40
  ) -> TransformerMixin:
41
  """Get the appropriate vectorizer.
42
 
@@ -44,7 +45,6 @@ def _get_vectorizer(
44
  name: Type of vectorizer
45
  n_features: Maximum number of features
46
  min_df: Minimum document frequency (ignored for hashing)
47
- ngram: N-gram range [min_n, max_n]
48
 
49
  Returns:
50
  Vectorizer instance
@@ -53,7 +53,7 @@ def _get_vectorizer(
53
  ValueError: If the vectorizer is not recognized
54
  """
55
  shared_params = {
56
- "ngram_range": ngram,
57
  # disable text processing
58
  "tokenizer": _identity,
59
  "preprocessor": _identity,
@@ -96,7 +96,7 @@ def train_model(
96
  vectorizer: Literal["tfidf", "count", "hashing"],
97
  max_features: int,
98
  min_df: int = 5,
99
- folds: int = 5,
100
  n_jobs: int = 4,
101
  seed: int = 42,
102
  ) -> tuple[BaseEstimator, float]:
@@ -108,7 +108,7 @@ def train_model(
108
  vectorizer: Which vectorizer to use
109
  max_features: Maximum number of features
110
  min_df: Minimum document frequency (ignored for hashing)
111
- folds: Number of cross-validation folds
112
  n_jobs: Number of parallel jobs
113
  seed: Random seed (None for random seed)
114
 
@@ -120,6 +120,7 @@ def train_model(
120
  """
121
  rs = None if seed == -1 else seed
122
 
 
123
  text_train, text_test, label_train, label_test = train_test_split(
124
  token_data,
125
  label_data,
@@ -127,24 +128,25 @@ def train_model(
127
  random_state=rs,
128
  )
129
 
 
130
  vectorizer = _get_vectorizer(vectorizer, max_features, min_df)
131
  classifier = LogisticRegression(max_iter=1000, random_state=rs)
132
- param_dist = {"classifier__C": np.logspace(-4, 4, 20)}
133
-
134
  model = Pipeline(
135
  [("vectorizer", vectorizer), ("classifier", classifier)],
136
  memory=Memory(CACHE_DIR, verbose=0),
137
  )
 
138
 
 
139
  search = RandomizedSearchCV(
140
  model,
141
  param_dist,
142
- cv=folds,
143
  random_state=rs,
144
  n_jobs=n_jobs,
145
- verbose=2,
146
  scoring="accuracy",
147
  n_iter=10,
 
148
  )
149
 
150
  with warnings.catch_warnings():
@@ -161,7 +163,7 @@ def evaluate_model(
161
  model: BaseEstimator,
162
  token_data: Sequence[Sequence[str]],
163
  label_data: list[int],
164
- folds: int = 5,
165
  n_jobs: int = 4,
166
  ) -> tuple[float, float]:
167
  """Evaluate the model using cross-validation.
@@ -170,7 +172,7 @@ def evaluate_model(
170
  model: Trained model
171
  token_data: Tokenized text data
172
  label_data: Label data
173
- folds: Number of cross-validation folds
174
  n_jobs: Number of parallel jobs
175
 
176
  Returns:
@@ -178,15 +180,18 @@ def evaluate_model(
178
  """
179
  with warnings.catch_warnings():
180
  warnings.filterwarnings("ignore", category=UserWarning, message="Persisting input arguments took")
 
 
181
  scores = cross_val_score(
182
  model,
183
  token_data,
184
  label_data,
185
- cv=folds,
186
  scoring="accuracy",
187
  n_jobs=n_jobs,
188
  verbose=2,
189
  )
 
190
  return scores.mean(), scores.std()
191
 
192
 
 
1
+ """Functions for model training, evaluation, and inference."""
2
+
3
  from __future__ import annotations
4
 
5
  import warnings
 
23
 
24
 
25
  def _identity(x: list[str]) -> list[str]:
26
+ """Identity function for use in vectorizers.
27
 
28
  Args:
29
  x: Input data
 
38
  name: Literal["tfidf", "count", "hashing"],
39
  n_features: int,
40
  min_df: int = 5,
 
41
  ) -> TransformerMixin:
42
  """Get the appropriate vectorizer.
43
 
 
45
  name: Type of vectorizer
46
  n_features: Maximum number of features
47
  min_df: Minimum document frequency (ignored for hashing)
 
48
 
49
  Returns:
50
  Vectorizer instance
 
53
  ValueError: If the vectorizer is not recognized
54
  """
55
  shared_params = {
56
+ "ngram_range": (1, 2), # unigrams and bigrams
57
  # disable text processing
58
  "tokenizer": _identity,
59
  "preprocessor": _identity,
 
96
  vectorizer: Literal["tfidf", "count", "hashing"],
97
  max_features: int,
98
  min_df: int = 5,
99
+ cv: int = 5,
100
  n_jobs: int = 4,
101
  seed: int = 42,
102
  ) -> tuple[BaseEstimator, float]:
 
108
  vectorizer: Which vectorizer to use
109
  max_features: Maximum number of features
110
  min_df: Minimum document frequency (ignored for hashing)
111
+ cv: Number of cross-validation folds
112
  n_jobs: Number of parallel jobs
113
  seed: Random seed (None for random seed)
114
 
 
120
  """
121
  rs = None if seed == -1 else seed
122
 
123
+ # Split the data into training and testing sets
124
  text_train, text_test, label_train, label_test = train_test_split(
125
  token_data,
126
  label_data,
 
128
  random_state=rs,
129
  )
130
 
131
+ # Create the model pipeline
132
  vectorizer = _get_vectorizer(vectorizer, max_features, min_df)
133
  classifier = LogisticRegression(max_iter=1000, random_state=rs)
 
 
134
  model = Pipeline(
135
  [("vectorizer", vectorizer), ("classifier", classifier)],
136
  memory=Memory(CACHE_DIR, verbose=0),
137
  )
138
+ param_dist = {"classifier__C": np.logspace(-4, 4, 20)}
139
 
140
+ # Perform randomized search for hyperparameter tuning
141
  search = RandomizedSearchCV(
142
  model,
143
  param_dist,
144
+ cv=cv,
145
  random_state=rs,
146
  n_jobs=n_jobs,
 
147
  scoring="accuracy",
148
  n_iter=10,
149
+ verbose=2,
150
  )
151
 
152
  with warnings.catch_warnings():
 
163
  model: BaseEstimator,
164
  token_data: Sequence[Sequence[str]],
165
  label_data: list[int],
166
+ cv: int = 5,
167
  n_jobs: int = 4,
168
  ) -> tuple[float, float]:
169
  """Evaluate the model using cross-validation.
 
172
  model: Trained model
173
  token_data: Tokenized text data
174
  label_data: Label data
175
+ cv: Number of cross-validation folds
176
  n_jobs: Number of parallel jobs
177
 
178
  Returns:
 
180
  """
181
  with warnings.catch_warnings():
182
  warnings.filterwarnings("ignore", category=UserWarning, message="Persisting input arguments took")
183
+
184
+ # Perform cross-validation to evaluate the model
185
  scores = cross_val_score(
186
  model,
187
  token_data,
188
  label_data,
189
+ cv=cv,
190
  scoring="accuracy",
191
  n_jobs=n_jobs,
192
  verbose=2,
193
  )
194
+
195
  return scores.mean(), scores.std()
196
 
197
 
app/utils.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from __future__ import annotations
2
 
3
  from typing import TYPE_CHECKING, Sequence
@@ -20,6 +22,7 @@ def serialize(data: Sequence[str | int], path: Path, max_size: int = 100_000, sh
20
  max_size: The maximum size a chunk can be (in elements)
21
  show_progress: Whether to show a progress bar
22
  """
 
23
  for i, chunk in enumerate(
24
  tqdm(
25
  [data[i : i + max_size] for i in range(0, len(data), max_size)],
@@ -28,6 +31,7 @@ def serialize(data: Sequence[str | int], path: Path, max_size: int = 100_000, sh
28
  disable=not show_progress,
29
  ),
30
  ):
 
31
  fd = path.with_suffix(f".{i}.pkl" if i else ".pkl")
32
  with fd.open("wb") as f:
33
  joblib.dump(chunk, f, compress=3)
 
1
+ """Utility functions for the application."""
2
+
3
  from __future__ import annotations
4
 
5
  from typing import TYPE_CHECKING, Sequence
 
22
  max_size: The maximum size a chunk can be (in elements)
23
  show_progress: Whether to show a progress bar
24
  """
25
+ # Split the data into chunks of size `max_size` and serialize each one
26
  for i, chunk in enumerate(
27
  tqdm(
28
  [data[i : i + max_size] for i in range(0, len(data), max_size)],
 
31
  disable=not show_progress,
32
  ),
33
  ):
34
+ # The first chunk is saved as `.pkl` and the rest as `.i.pkl`
35
  fd = path.with_suffix(f".{i}.pkl" if i else ".pkl")
36
  with fd.open("wb") as f:
37
  joblib.dump(chunk, f, compress=3)