"""Functions for model training, evaluation, and inference.""" from __future__ import annotations import warnings from typing import TYPE_CHECKING, Literal, Sequence import numpy as np from joblib import Memory from sklearn.exceptions import ConvergenceWarning from sklearn.feature_extraction.text import CountVectorizer, HashingVectorizer, TfidfVectorizer from sklearn.linear_model import LogisticRegression from sklearn.model_selection import RandomizedSearchCV, cross_val_score, train_test_split from sklearn.pipeline import Pipeline from app.constants import CACHE_DIR from app.data import tokenize if TYPE_CHECKING: from sklearn.base import BaseEstimator, TransformerMixin __all__ = ["train_model", "evaluate_model", "infer_model"] def _identity(x: list[str]) -> list[str]: """Identity function for use in vectorizers. Args: x: Input data Returns: Unchanged input data """ return x def _get_vectorizer( name: Literal["tfidf", "count", "hashing"], n_features: int, min_df: int = 5, ) -> TransformerMixin: """Get the appropriate vectorizer. Args: name: Type of vectorizer n_features: Maximum number of features min_df: Minimum document frequency (ignored for hashing) Returns: Vectorizer instance Raises: ValueError: If the vectorizer is not recognized """ shared_params = { "ngram_range": (1, 2), # unigrams and bigrams # disable text processing "tokenizer": _identity, "preprocessor": _identity, "lowercase": False, "token_pattern": None, } match name: case "tfidf": return TfidfVectorizer( max_features=n_features, min_df=min_df, **shared_params, ) case "count": return CountVectorizer( max_features=n_features, min_df=min_df, **shared_params, ) case "hashing": if n_features < 2**15: warnings.warn( "HashingVectorizer may perform poorly with small n_features, default is 2^20.", stacklevel=2, ) return HashingVectorizer( n_features=n_features, **shared_params, ) case _: msg = f"Unknown vectorizer: {name}" raise ValueError(msg) def train_model( token_data: Sequence[Sequence[str]], label_data: list[int], vectorizer: Literal["tfidf", "count", "hashing"], max_features: int, min_df: int = 5, cv: int = 5, n_jobs: int = 4, seed: int = 42, ) -> tuple[BaseEstimator, float]: """Train the sentiment analysis model. Args: token_data: Tokenized text data label_data: Label data vectorizer: Which vectorizer to use max_features: Maximum number of features min_df: Minimum document frequency (ignored for hashing) cv: Number of cross-validation folds n_jobs: Number of parallel jobs seed: Random seed (None for random seed) Returns: Trained model and accuracy Raises: ValueError: If the vectorizer is not recognized """ rs = None if seed == -1 else seed # Split the data into training and testing sets text_train, text_test, label_train, label_test = train_test_split( token_data, label_data, test_size=0.2, random_state=rs, ) # Create the model pipeline vectorizer = _get_vectorizer(vectorizer, max_features, min_df) classifier = LogisticRegression(max_iter=1000, random_state=rs) model = Pipeline( [("vectorizer", vectorizer), ("classifier", classifier)], memory=Memory(CACHE_DIR, verbose=0), ) param_dist = {"classifier__C": np.logspace(-4, 4, 20)} # Perform randomized search for hyperparameter tuning search = RandomizedSearchCV( model, param_dist, cv=cv, random_state=rs, n_jobs=n_jobs, scoring="accuracy", n_iter=10, verbose=2, ) with warnings.catch_warnings(): warnings.filterwarnings("once", category=ConvergenceWarning) warnings.filterwarnings("ignore", category=UserWarning, message="Persisting input arguments took") search.fit(text_train, label_train) final_model = search.best_estimator_ return final_model, final_model.score(text_test, label_test) def evaluate_model( model: BaseEstimator, token_data: Sequence[Sequence[str]], label_data: list[int], cv: int = 5, n_jobs: int = 4, ) -> tuple[float, float]: """Evaluate the model using cross-validation. Args: model: Trained model token_data: Tokenized text data label_data: Label data cv: Number of cross-validation folds n_jobs: Number of parallel jobs Returns: Mean accuracy and standard deviation """ with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning, message="Persisting input arguments took") # Perform cross-validation to evaluate the model scores = cross_val_score( model, token_data, label_data, cv=cv, scoring="accuracy", n_jobs=n_jobs, verbose=2, ) return scores.mean(), scores.std() def infer_model( model: BaseEstimator, text_data: list[str], batch_size: int = 32, n_jobs: int = 4, ) -> list[int]: """Predict the sentiment of the provided text documents. Args: model: Trained model text_data: Text data batch_size: Batch size for tokenization n_jobs: Number of parallel jobs Returns: Predicted sentiments """ tokens = tokenize( text_data, batch_size=batch_size, n_jobs=n_jobs, show_progress=False, ) return model.predict(tokens)