Tymec commited on
Commit
a092d54
1 Parent(s): 9a96b6b

Use spacy instead of nltk and move data functions to separate module

Browse files
Files changed (2) hide show
  1. app/data.py +171 -0
  2. app/model.py +94 -247
app/data.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import bz2
4
+ from typing import Literal
5
+
6
+ import pandas as pd
7
+
8
+ from app.constants import (
9
+ AMAZONREVIEWS_PATH,
10
+ AMAZONREVIEWS_URL,
11
+ IMDB50K_PATH,
12
+ IMDB50K_URL,
13
+ SENTIMENT140_PATH,
14
+ SENTIMENT140_URL,
15
+ )
16
+
17
+ __all__ = ["load_data"]
18
+
19
+
20
+ def load_sentiment140(include_neutral: bool = False) -> tuple[list[str], list[int]]:
21
+ """Load the sentiment140 dataset and make it suitable for use.
22
+
23
+ Args:
24
+ include_neutral: Whether to include neutral sentiment
25
+
26
+ Returns:
27
+ Text and label data
28
+
29
+ Raises:
30
+ FileNotFoundError: If the dataset is not found
31
+ """
32
+ # Check if the dataset exists
33
+ if not SENTIMENT140_PATH.exists():
34
+ msg = (
35
+ f"Sentiment140 dataset not found at: '{SENTIMENT140_PATH}'\n"
36
+ "Please download the dataset from:\n"
37
+ f"{SENTIMENT140_URL}"
38
+ )
39
+ raise FileNotFoundError(msg)
40
+
41
+ # Load the dataset
42
+ data = pd.read_csv(
43
+ SENTIMENT140_PATH,
44
+ encoding="ISO-8859-1",
45
+ names=[
46
+ "target", # 0 = negative, 2 = neutral, 4 = positive
47
+ "id", # The id of the tweet
48
+ "date", # The date of the tweet
49
+ "flag", # The query, NO_QUERY if not present
50
+ "user", # The user that tweeted
51
+ "text", # The text of the tweet
52
+ ],
53
+ )
54
+
55
+ # Ignore rows with neutral sentiment
56
+ if not include_neutral:
57
+ data = data[data["target"] != 2]
58
+
59
+ # Map sentiment values
60
+ data["sentiment"] = data["target"].map(
61
+ {
62
+ 0: 0, # Negative
63
+ 4: 1, # Positive
64
+ 2: 2, # Neutral
65
+ },
66
+ )
67
+
68
+ # Return as lists
69
+ return data["text"].tolist(), data["sentiment"].tolist()
70
+
71
+
72
+ def load_amazonreviews(merge: bool = True) -> tuple[list[str], list[int]]:
73
+ """Load the amazonreviews dataset and make it suitable for use.
74
+
75
+ Args:
76
+ merge: Whether to merge the test and train datasets (otherwise ignore test)
77
+
78
+ Returns:
79
+ Text and label data
80
+
81
+ Raises:
82
+ FileNotFoundError: If the dataset is not found
83
+ """
84
+ # Check if the dataset exists
85
+ test_exists = AMAZONREVIEWS_PATH[0].exists() or not merge
86
+ train_exists = AMAZONREVIEWS_PATH[1].exists()
87
+ if not (test_exists and train_exists):
88
+ msg = (
89
+ f"Amazonreviews dataset not found at: '{AMAZONREVIEWS_PATH[0]}' and '{AMAZONREVIEWS_PATH[1]}'\n"
90
+ "Please download the dataset from:\n"
91
+ f"{AMAZONREVIEWS_URL}"
92
+ )
93
+ raise FileNotFoundError(msg)
94
+
95
+ # Load the datasets
96
+ with bz2.BZ2File(AMAZONREVIEWS_PATH[1]) as train_file:
97
+ train_data = [line.decode("utf-8") for line in train_file]
98
+
99
+ test_data = []
100
+ if merge:
101
+ with bz2.BZ2File(AMAZONREVIEWS_PATH[0]) as test_file:
102
+ test_data = [line.decode("utf-8") for line in test_file]
103
+
104
+ # Merge the datasets
105
+ data = train_data + test_data
106
+
107
+ # Split the data into labels and text
108
+ labels, texts = zip(*(line.split(" ", 1) for line in data))
109
+
110
+ # Map sentiment values
111
+ sentiments = [int(label.split("__label__")[1]) - 1 for label in labels]
112
+
113
+ # Return as lists
114
+ return texts, sentiments
115
+
116
+
117
+ def load_imdb50k() -> tuple[list[str], list[int]]:
118
+ """Load the imdb50k dataset and make it suitable for use.
119
+
120
+ Returns:
121
+ Text and label data
122
+
123
+ Raises:
124
+ FileNotFoundError: If the dataset is not found
125
+ """
126
+ # Check if the dataset exists
127
+ if not IMDB50K_PATH.exists():
128
+ msg = (
129
+ f"IMDB50K dataset not found at: '{IMDB50K_PATH}'\n"
130
+ "Please download the dataset from:\n"
131
+ f"{IMDB50K_URL}"
132
+ ) # fmt: off
133
+ raise FileNotFoundError(msg)
134
+
135
+ # Load the dataset
136
+ data = pd.read_csv(IMDB50K_PATH)
137
+
138
+ # Map sentiment values
139
+ data["sentiment"] = data["sentiment"].map(
140
+ {
141
+ "positive": 1,
142
+ "negative": 0,
143
+ },
144
+ )
145
+
146
+ # Return as lists
147
+ return data["review"].tolist(), data["sentiment"].tolist()
148
+
149
+
150
+ def load_data(dataset: Literal["sentiment140", "amazonreviews", "imdb50k"]) -> tuple[list[str], list[int]]:
151
+ """Load and preprocess the specified dataset.
152
+
153
+ Args:
154
+ dataset: Dataset to load
155
+
156
+ Returns:
157
+ Text and label data
158
+
159
+ Raises:
160
+ ValueError: If the dataset is not recognized
161
+ """
162
+ match dataset:
163
+ case "sentiment140":
164
+ return load_sentiment140(include_neutral=False)
165
+ case "amazonreviews":
166
+ return load_amazonreviews(merge=True)
167
+ case "imdb50k":
168
+ return load_imdb50k()
169
+ case _:
170
+ msg = f"Unknown dataset: {dataset}"
171
+ raise ValueError(msg)
app/model.py CHANGED
@@ -1,250 +1,86 @@
1
  from __future__ import annotations
2
 
3
- import bz2
4
- import re
5
  import warnings
6
- from typing import Literal
7
 
8
- import nltk
9
- import pandas as pd
10
  from joblib import Memory
11
- from nltk.corpus import stopwords
12
- from nltk.stem import WordNetLemmatizer
13
  from sklearn.base import BaseEstimator, TransformerMixin
14
- from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
15
  from sklearn.linear_model import LogisticRegression
16
- from sklearn.model_selection import cross_val_score, train_test_split
17
  from sklearn.pipeline import Pipeline
 
18
 
19
- from app.constants import (
20
- AMAZONREVIEWS_PATH,
21
- AMAZONREVIEWS_URL,
22
- CACHE_DIR,
23
- EMOTICON_MAP,
24
- IMDB50K_PATH,
25
- IMDB50K_URL,
26
- SENTIMENT140_PATH,
27
- SENTIMENT140_URL,
28
- URL_REGEX,
29
- )
30
 
31
- __all__ = ["load_data", "create_model", "train_model", "evaluate_model"]
32
 
 
33
 
34
- class TextCleaner(BaseEstimator, TransformerMixin):
 
35
  def __init__(
36
  self,
37
  *,
38
- replace_url: bool = True,
39
- replace_hashtag: bool = True,
40
- replace_emoticon: bool = True,
41
- replace_emoji: bool = True,
42
- lowercase: bool = True,
43
  character_threshold: int = 2,
44
- remove_special_characters: bool = True,
45
- remove_extra_spaces: bool = True,
46
- ):
47
- self.replace_url = replace_url
48
- self.replace_hashtag = replace_hashtag
49
- self.replace_emoticon = replace_emoticon
50
- self.replace_emoji = replace_emoji
51
- self.lowercase = lowercase
52
  self.character_threshold = character_threshold
53
- self.remove_special_characters = remove_special_characters
54
- self.remove_extra_spaces = remove_extra_spaces
55
-
56
- def fit(self, _data: list[str], _labels: list[int] | None = None) -> TextCleaner:
57
- return self
58
 
59
- def transform(self, data: list[str], _labels: list[int] | None = None) -> list[str]:
60
- # Replace URLs, hashtags, emoticons, and emojis
61
- data = [re.sub(URL_REGEX, "URL", text) for text in data] if self.replace_url else data
62
- data = [re.sub(r"#\w+", "HASHTAG", text) for text in data] if self.replace_hashtag else data
63
-
64
- # Replace emoticons
65
- if self.replace_emoticon:
66
- for word, emoticons in EMOTICON_MAP.items():
67
- for emoticon in emoticons:
68
- data = [text.replace(emoticon, f"EMOTE_{word}") for text in data]
69
-
70
- # Basic text cleaning
71
- data = [text.lower() for text in data] if self.lowercase else data # Lowercase
72
- threshold_pattern = re.compile(rf"\b\w{{1,{self.character_threshold}}}\b")
73
- data = (
74
- [re.sub(threshold_pattern, "", text) for text in data] if self.character_threshold > 0 else data
75
- ) # Remove short words
76
- data = (
77
- [re.sub(r"[^a-zA-Z0-9\s]", "", text) for text in data] if self.remove_special_characters else data
78
- ) # Remove special characters
79
- data = [re.sub(r"\s+", " ", text) for text in data] if self.remove_extra_spaces else data # Remove extra spaces
80
-
81
- # Remove leading and trailing whitespace
82
- return [text.strip() for text in data]
83
-
84
-
85
- class TextLemmatizer(BaseEstimator, TransformerMixin):
86
- def __init__(self):
87
- self.lemmatizer = WordNetLemmatizer()
88
-
89
- def fit(self, _data: list[str], _labels: list[int] | None = None) -> TextLemmatizer:
90
  return self
91
 
92
- def transform(self, data: list[str], _labels: list[int] | None = None) -> list[str]:
93
- return [self.lemmatizer.lemmatize(text) for text in data]
94
-
95
-
96
- def load_sentiment140(include_neutral: bool = False) -> tuple[list[str], list[int]]:
97
- """Load the sentiment140 dataset and make it suitable for use.
98
-
99
- Args:
100
- include_neutral: Whether to include neutral sentiment
101
-
102
- Returns:
103
- Text and label data
104
-
105
- Raises:
106
- FileNotFoundError: If the dataset is not found
107
- """
108
- # Check if the dataset exists
109
- if not SENTIMENT140_PATH.exists():
110
- msg = (
111
- f"Sentiment140 dataset not found at: '{SENTIMENT140_PATH}'\n"
112
- "Please download the dataset from:\n"
113
- f"{SENTIMENT140_URL}"
114
- )
115
- raise FileNotFoundError(msg)
116
-
117
- # Load the dataset
118
- data = pd.read_csv(
119
- SENTIMENT140_PATH,
120
- encoding="ISO-8859-1",
121
- names=[
122
- "target", # 0 = negative, 2 = neutral, 4 = positive
123
- "id", # The id of the tweet
124
- "date", # The date of the tweet
125
- "flag", # The query, NO_QUERY if not present
126
- "user", # The user that tweeted
127
- "text", # The text of the tweet
128
- ],
129
- )
130
-
131
- # Ignore rows with neutral sentiment
132
- if not include_neutral:
133
- data = data[data["target"] != 2]
134
-
135
- # Map sentiment values
136
- data["sentiment"] = data["target"].map(
137
- {
138
- 0: 0, # Negative
139
- 4: 1, # Positive
140
- 2: 2, # Neutral
141
- },
142
- )
143
-
144
- # Return as lists
145
- return data["text"].tolist(), data["sentiment"].tolist()
146
-
147
-
148
- def load_amazonreviews(merge: bool = True) -> tuple[list[str], list[int]]:
149
- """Load the amazonreviews dataset and make it suitable for use.
150
-
151
- Args:
152
- merge: Whether to merge the test and train datasets (otherwise ignore test)
153
-
154
- Returns:
155
- Text and label data
156
-
157
- Raises:
158
- FileNotFoundError: If the dataset is not found
159
- """
160
- # Check if the dataset exists
161
- test_exists = AMAZONREVIEWS_PATH[0].exists() or not merge
162
- train_exists = AMAZONREVIEWS_PATH[1].exists()
163
- if not (test_exists and train_exists):
164
- msg = (
165
- f"Amazonreviews dataset not found at: '{AMAZONREVIEWS_PATH[0]}' and '{AMAZONREVIEWS_PATH[1]}'\n"
166
- "Please download the dataset from:\n"
167
- f"{AMAZONREVIEWS_URL}"
168
- )
169
- raise FileNotFoundError(msg)
170
-
171
- # Load the datasets
172
- with bz2.BZ2File(AMAZONREVIEWS_PATH[1]) as train_file:
173
- train_data = [line.decode("utf-8") for line in train_file]
174
-
175
- test_data = []
176
- if merge:
177
- with bz2.BZ2File(AMAZONREVIEWS_PATH[0]) as test_file:
178
- test_data = [line.decode("utf-8") for line in test_file]
179
 
180
- # Merge the datasets
181
- data = train_data + test_data
182
 
183
- # Split the data into labels and text
184
- labels, texts = zip(*(line.split(" ", 1) for line in data))
 
185
 
186
- # Map sentiment values
187
- sentiments = [int(label.split("__label__")[1]) - 1 for label in labels]
 
188
 
189
- # Return as lists
190
- return texts, sentiments
191
 
 
 
 
192
 
193
- def load_imdb50k() -> tuple[list[str], list[int]]:
194
- """Load the imdb50k dataset and make it suitable for use.
195
 
196
- Returns:
197
- Text and label data
198
-
199
- Raises:
200
- FileNotFoundError: If the dataset is not found
201
- """
202
- # Check if the dataset exists
203
- if not IMDB50K_PATH.exists():
204
- msg = (
205
- f"IMDB50K dataset not found at: '{IMDB50K_PATH}'\n"
206
- "Please download the dataset from:\n"
207
- f"{IMDB50K_URL}"
208
- ) # fmt: off
209
- raise FileNotFoundError(msg)
210
-
211
- # Load the dataset
212
- data = pd.read_csv(IMDB50K_PATH)
213
-
214
- # Map sentiment values
215
- data["sentiment"] = data["sentiment"].map(
216
- {
217
- "positive": 1,
218
- "negative": 0,
219
- },
220
- )
221
-
222
- # Return as lists
223
- return data["review"].tolist(), data["sentiment"].tolist()
224
-
225
-
226
- def load_data(dataset: Literal["sentiment140", "amazonreviews", "imdb50k"]) -> tuple[list[str], list[int]]:
227
- """Load and preprocess the specified dataset.
228
 
229
  Args:
230
- dataset: Dataset to load
231
 
232
  Returns:
233
- Text and label data
234
-
235
- Raises:
236
- ValueError: If the dataset is not recognized
237
  """
238
- match dataset:
239
- case "sentiment140":
240
- return load_sentiment140(include_neutral=False)
241
- case "amazonreviews":
242
- return load_amazonreviews(merge=True)
243
- case "imdb50k":
244
- return load_imdb50k()
245
- case _:
246
- msg = f"Unknown dataset: {dataset}"
247
- raise ValueError(msg)
248
 
249
 
250
  def create_model(
@@ -262,26 +98,22 @@ def create_model(
262
  Returns:
263
  Untrained model
264
  """
265
- # Download NLTK data if not already downloaded
266
- nltk.download("wordnet", quiet=True)
267
- nltk.download("stopwords", quiet=True)
268
-
269
- # Load English stopwords
270
- stopwords_en = set(stopwords.words("english"))
271
-
272
  return Pipeline(
273
  [
274
- # Text preprocessing
275
- ("clean", TextCleaner()),
276
- ("lemma", TextLemmatizer()),
277
- # Preprocess (NOTE: Can be replaced with TfidfVectorizer, but left for clarity)
278
  (
279
- "vectorize",
280
- CountVectorizer(stop_words=stopwords_en, ngram_range=(1, 2), max_features=max_features),
 
 
 
 
 
 
 
 
281
  ),
282
- ("tfidf", TfidfTransformer()),
283
- # Classifier
284
- ("clf", LogisticRegression(max_iter=1000, random_state=seed)),
285
  ],
286
  memory=Memory(CACHE_DIR, verbose=0),
287
  verbose=verbose,
@@ -289,11 +121,11 @@ def create_model(
289
 
290
 
291
  def train_model(
292
- model: Pipeline,
293
  text_data: list[str],
294
  label_data: list[int],
295
  seed: int = 42,
296
- ) -> tuple[float, list[str], list[int]]:
297
  """Train the sentiment analysis model.
298
 
299
  Args:
@@ -303,7 +135,7 @@ def train_model(
303
  seed: Random seed (None for random seed)
304
 
305
  Returns:
306
- Model accuracy and test data
307
  """
308
  text_train, text_test, label_train, label_test = train_test_split(
309
  text_data,
@@ -312,37 +144,52 @@ def train_model(
312
  random_state=seed,
313
  )
314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  with warnings.catch_warnings():
316
  warnings.simplefilter("ignore")
317
- model.fit(text_train, label_train)
 
318
 
319
- return model.score(text_test, label_test), text_test, label_test
 
320
 
321
 
322
  def evaluate_model(
323
  model: Pipeline,
324
- text_test: list[str],
325
- label_test: list[int],
326
- cv: int = 5,
327
  ) -> tuple[float, float]:
328
  """Evaluate the model using cross-validation.
329
 
330
  Args:
331
  model: Trained model
332
- text_test: Text data
333
- label_test: Label data
334
- seed: Random seed (None for random seed)
335
- cv: Number of cross-validation folds
336
 
337
  Returns:
338
  Mean accuracy and standard deviation
339
  """
340
  scores = cross_val_score(
341
  model,
342
- text_test,
343
- label_test,
344
- cv=cv,
345
  scoring="accuracy",
346
- n_jobs=-1,
347
  )
348
  return scores.mean(), scores.std()
 
1
  from __future__ import annotations
2
 
 
 
3
  import warnings
 
4
 
5
+ import numpy as np
6
+ import spacy
7
  from joblib import Memory
 
 
8
  from sklearn.base import BaseEstimator, TransformerMixin
9
+ from sklearn.feature_extraction.text import TfidfVectorizer
10
  from sklearn.linear_model import LogisticRegression
11
+ from sklearn.model_selection import RandomizedSearchCV, cross_val_score, train_test_split
12
  from sklearn.pipeline import Pipeline
13
+ from tqdm import tqdm
14
 
15
+ from app.constants import CACHE_DIR
 
 
 
 
 
 
 
 
 
 
16
 
17
+ __all__ = ["create_model", "train_model", "evaluate_model"]
18
 
19
+ nlp = spacy.load("en_core_web_sm", disable=["tok2vec", "parser", "ner"])
20
 
21
+
22
+ class TextTokenizer(BaseEstimator, TransformerMixin):
23
  def __init__(
24
  self,
25
  *,
 
 
 
 
 
26
  character_threshold: int = 2,
27
+ batch_size: int = 1024,
28
+ n_jobs: int = 8,
29
+ progress: bool = True,
30
+ ) -> None:
 
 
 
 
31
  self.character_threshold = character_threshold
32
+ self.batch_size = batch_size
33
+ self.n_jobs = n_jobs
34
+ self.progress = progress
 
 
35
 
36
+ def fit(self, _data: list[str], _labels: list[int] | None = None) -> TextTokenizer:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  return self
38
 
39
+ def transform(self, data: list[str]) -> list[list[str]]:
40
+ tokenized = []
41
+ for doc in tqdm(
42
+ nlp.pipe(data, batch_size=self.batch_size, n_process=self.n_jobs),
43
+ total=len(data),
44
+ disable=not self.progress,
45
+ ):
46
+ tokens = []
47
+ for token in doc:
48
+ # Ignore stop words and punctuation
49
+ if token.is_stop or token.is_punct:
50
+ continue
51
+ # Ignore emails, URLs and numbers
52
+ if token.like_email or token.like_email or token.like_num:
53
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ # Lemmatize and lowercase
56
+ tok = token.lemma_.lower().strip()
57
 
58
+ # Format hashtags
59
+ if tok.startswith("#"):
60
+ tok = tok[1:]
61
 
62
+ # Ignore short and non-alphanumeric tokens
63
+ if len(tok) < self.character_threshold or not tok.isalnum():
64
+ continue
65
 
66
+ # TODO: Emoticons and emojis
67
+ # TODO: Spelling correction
68
 
69
+ tokens.append(tok)
70
+ tokenized.append(tokens)
71
+ return tokenized
72
 
 
 
73
 
74
+ def identity(x: list[str]) -> list[str]:
75
+ """Identity function for use in TfidfVectorizer.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  Args:
78
+ x: Input data
79
 
80
  Returns:
81
+ Unchanged input data
 
 
 
82
  """
83
+ return x
 
 
 
 
 
 
 
 
 
84
 
85
 
86
  def create_model(
 
98
  Returns:
99
  Untrained model
100
  """
 
 
 
 
 
 
 
101
  return Pipeline(
102
  [
103
+ ("tokenizer", TextTokenizer(progress=True)),
 
 
 
104
  (
105
+ "vectorizer",
106
+ TfidfVectorizer(
107
+ max_features=max_features,
108
+ ngram_range=(1, 2),
109
+ # disable text processing
110
+ tokenizer=identity,
111
+ preprocessor=identity,
112
+ lowercase=False,
113
+ token_pattern=None,
114
+ ),
115
  ),
116
+ ("classifier", LogisticRegression(max_iter=1000, C=1.0, random_state=seed)),
 
 
117
  ],
118
  memory=Memory(CACHE_DIR, verbose=0),
119
  verbose=verbose,
 
121
 
122
 
123
  def train_model(
124
+ model: BaseEstimator,
125
  text_data: list[str],
126
  label_data: list[int],
127
  seed: int = 42,
128
+ ) -> tuple[BaseEstimator, float]:
129
  """Train the sentiment analysis model.
130
 
131
  Args:
 
135
  seed: Random seed (None for random seed)
136
 
137
  Returns:
138
+ Trained model and accuracy
139
  """
140
  text_train, text_test, label_train, label_test = train_test_split(
141
  text_data,
 
144
  random_state=seed,
145
  )
146
 
147
+ param_distributions = {
148
+ "classifier__C": np.logspace(-4, 4, 20),
149
+ "classifier__penalty": ["l1", "l2"],
150
+ }
151
+
152
+ search = RandomizedSearchCV(
153
+ model,
154
+ param_distributions,
155
+ n_iter=10,
156
+ cv=5,
157
+ scoring="accuracy",
158
+ random_state=seed,
159
+ n_jobs=-1,
160
+ )
161
+
162
  with warnings.catch_warnings():
163
  warnings.simplefilter("ignore")
164
+ # model.fit(text_train, label_train)
165
+ search.fit(text_train, label_train)
166
 
167
+ best_model = search.best_estimator_
168
+ return best_model, best_model.score(text_test, label_test)
169
 
170
 
171
  def evaluate_model(
172
  model: Pipeline,
173
+ text_data: list[str],
174
+ label_data: list[int],
175
+ folds: int = 5,
176
  ) -> tuple[float, float]:
177
  """Evaluate the model using cross-validation.
178
 
179
  Args:
180
  model: Trained model
181
+ text_data: Text data
182
+ label_data: Label data
183
+ folds: Number of cross-validation folds
 
184
 
185
  Returns:
186
  Mean accuracy and standard deviation
187
  """
188
  scores = cross_val_score(
189
  model,
190
+ text_data,
191
+ label_data,
192
+ cv=folds,
193
  scoring="accuracy",
 
194
  )
195
  return scores.mean(), scores.std()