Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import warnings | |
| import numpy as np | |
| import spacy | |
| from joblib import Memory | |
| from sklearn.base import BaseEstimator, TransformerMixin | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.model_selection import RandomizedSearchCV, cross_val_score, train_test_split | |
| from sklearn.pipeline import Pipeline | |
| from tqdm import tqdm | |
| from app.constants import CACHE_DIR | |
| __all__ = ["create_model", "train_model", "evaluate_model"] | |
| try: | |
| nlp = spacy.load("en_core_web_sm", disable=["tok2vec", "parser", "ner"]) | |
| except OSError: | |
| print("Downloading spaCy model...") | |
| from spacy.cli import download as spacy_download | |
| spacy_download("en_core_web_sm") | |
| nlp = spacy.load("en_core_web_sm", disable=["tok2vec", "parser", "ner"]) | |
| class TextTokenizer(BaseEstimator, TransformerMixin): | |
| def __init__( | |
| self, | |
| *, | |
| character_threshold: int = 2, | |
| batch_size: int = 1024, | |
| n_jobs: int = 8, | |
| progress: bool = True, | |
| ) -> None: | |
| self.character_threshold = character_threshold | |
| self.batch_size = batch_size | |
| self.n_jobs = n_jobs | |
| self.progress = progress | |
| def fit(self, _data: list[str], _labels: list[int] | None = None) -> TextTokenizer: | |
| return self | |
| def transform(self, data: list[str]) -> list[list[str]]: | |
| tokenized = [] | |
| for doc in tqdm( | |
| nlp.pipe(data, batch_size=self.batch_size, n_process=self.n_jobs), | |
| total=len(data), | |
| disable=not self.progress, | |
| ): | |
| tokens = [] | |
| for token in doc: | |
| # Ignore stop words and punctuation | |
| if token.is_stop or token.is_punct: | |
| continue | |
| # Ignore emails, URLs and numbers | |
| if token.like_email or token.like_email or token.like_num: | |
| continue | |
| # Lemmatize and lowercase | |
| tok = token.lemma_.lower().strip() | |
| # Format hashtags | |
| if tok.startswith("#"): | |
| tok = tok[1:] | |
| # Ignore short and non-alphanumeric tokens | |
| if len(tok) < self.character_threshold or not tok.isalnum(): | |
| continue | |
| # TODO: Emoticons and emojis | |
| # TODO: Spelling correction | |
| tokens.append(tok) | |
| tokenized.append(tokens) | |
| return tokenized | |
| def identity(x: list[str]) -> list[str]: | |
| """Identity function for use in TfidfVectorizer. | |
| Args: | |
| x: Input data | |
| Returns: | |
| Unchanged input data | |
| """ | |
| return x | |
| def create_model( | |
| max_features: int, | |
| seed: int | None = None, | |
| verbose: bool = False, | |
| ) -> Pipeline: | |
| """Create a sentiment analysis model. | |
| Args: | |
| max_features: Maximum number of features | |
| seed: Random seed (None for random seed) | |
| verbose: Whether to log progress during training | |
| Returns: | |
| Untrained model | |
| """ | |
| return Pipeline( | |
| [ | |
| ("tokenizer", TextTokenizer(progress=True)), | |
| ( | |
| "vectorizer", | |
| TfidfVectorizer( | |
| max_features=max_features, | |
| ngram_range=(1, 2), | |
| # disable text processing | |
| tokenizer=identity, | |
| preprocessor=identity, | |
| lowercase=False, | |
| token_pattern=None, | |
| ), | |
| ), | |
| ("classifier", LogisticRegression(max_iter=1000, random_state=seed)), | |
| ], | |
| memory=Memory(CACHE_DIR, verbose=0), | |
| verbose=verbose, | |
| ) | |
| def train_model( | |
| model: BaseEstimator, | |
| text_data: list[str], | |
| label_data: list[int], | |
| seed: int = 42, | |
| ) -> tuple[BaseEstimator, float]: | |
| """Train the sentiment analysis model. | |
| Args: | |
| model: Untrained model | |
| text_data: Text data | |
| label_data: Label data | |
| seed: Random seed (None for random seed) | |
| Returns: | |
| Trained model and accuracy | |
| """ | |
| text_train, text_test, label_train, label_test = train_test_split( | |
| text_data, | |
| label_data, | |
| test_size=0.2, | |
| random_state=seed, | |
| ) | |
| param_distributions = { | |
| "classifier__C": np.logspace(-4, 4, 20), | |
| "classifier__penalty": ["l1", "l2"], | |
| } | |
| search = RandomizedSearchCV( | |
| model, | |
| param_distributions, | |
| n_iter=10, | |
| cv=5, | |
| scoring="accuracy", | |
| random_state=seed, | |
| n_jobs=-1, | |
| ) | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| # model.fit(text_train, label_train) | |
| search.fit(text_train, label_train) | |
| best_model = search.best_estimator_ | |
| return best_model, best_model.score(text_test, label_test) | |
| def evaluate_model( | |
| model: Pipeline, | |
| text_data: list[str], | |
| label_data: list[int], | |
| folds: int = 5, | |
| ) -> tuple[float, float]: | |
| """Evaluate the model using cross-validation. | |
| Args: | |
| model: Trained model | |
| text_data: Text data | |
| label_data: Label data | |
| folds: Number of cross-validation folds | |
| Returns: | |
| Mean accuracy and standard deviation | |
| """ | |
| scores = cross_val_score( | |
| model, | |
| text_data, | |
| label_data, | |
| cv=folds, | |
| scoring="accuracy", | |
| ) | |
| return scores.mean(), scores.std() | |