Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Use stopwords from NLTK and download NLTK data
Browse files- app/cli.py +6 -4
- app/model.py +16 -1
    	
        app/cli.py
    CHANGED
    
    | @@ -117,15 +117,17 @@ def train( | |
| 117 | 
             
                click.echo(DONE_STR)
         | 
| 118 |  | 
| 119 | 
             
                click.echo("Creating model... ", nl=False)
         | 
| 120 | 
            -
                model = create_model(max_features, seed=None if seed == -1 else seed)
         | 
| 121 | 
             
                click.echo(DONE_STR)
         | 
| 122 |  | 
| 123 | 
            -
                click.echo("Training model... ", nl=False)
         | 
|  | |
| 124 | 
             
                accuracy = train_model(model, text_data, label_data)
         | 
| 125 | 
             
                joblib.dump(model, model_path)
         | 
| 126 | 
            -
                click.echo( | 
|  | |
| 127 |  | 
| 128 | 
            -
                click.echo("Model accuracy: ")
         | 
| 129 | 
             
                click.secho(f"{accuracy:.2%}", fg="blue")
         | 
| 130 |  | 
| 131 | 
             
                # TODO: Add hyperparameter options
         | 
|  | |
| 117 | 
             
                click.echo(DONE_STR)
         | 
| 118 |  | 
| 119 | 
             
                click.echo("Creating model... ", nl=False)
         | 
| 120 | 
            +
                model = create_model(max_features, seed=None if seed == -1 else seed, verbose=True)
         | 
| 121 | 
             
                click.echo(DONE_STR)
         | 
| 122 |  | 
| 123 | 
            +
                # click.echo("Training model... ", nl=False)
         | 
| 124 | 
            +
                click.echo("Training model... ")
         | 
| 125 | 
             
                accuracy = train_model(model, text_data, label_data)
         | 
| 126 | 
             
                joblib.dump(model, model_path)
         | 
| 127 | 
            +
                click.echo("Model saved to: ", nl=False)
         | 
| 128 | 
            +
                click.secho(str(model_path), fg="blue")
         | 
| 129 |  | 
| 130 | 
            +
                click.echo("Model accuracy: ", nl=False)
         | 
| 131 | 
             
                click.secho(f"{accuracy:.2%}", fg="blue")
         | 
| 132 |  | 
| 133 | 
             
                # TODO: Add hyperparameter options
         | 
    	
        app/model.py
    CHANGED
    
    | @@ -5,8 +5,10 @@ import re | |
| 5 | 
             
            import warnings
         | 
| 6 | 
             
            from typing import Literal
         | 
| 7 |  | 
|  | |
| 8 | 
             
            import pandas as pd
         | 
| 9 | 
             
            from joblib import Memory
         | 
|  | |
| 10 | 
             
            from nltk.stem import WordNetLemmatizer
         | 
| 11 | 
             
            from sklearn.base import BaseEstimator, TransformerMixin
         | 
| 12 | 
             
            from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
         | 
| @@ -248,28 +250,41 @@ def load_data(dataset: Literal["sentiment140", "amazonreviews", "imdb50k"]) -> t | |
| 248 | 
             
            def create_model(
         | 
| 249 | 
             
                max_features: int,
         | 
| 250 | 
             
                seed: int | None = None,
         | 
|  | |
| 251 | 
             
            ) -> Pipeline:
         | 
| 252 | 
             
                """Create a sentiment analysis model.
         | 
| 253 |  | 
| 254 | 
             
                Args:
         | 
| 255 | 
             
                    max_features: Maximum number of features
         | 
| 256 | 
             
                    seed: Random seed (None for random seed)
         | 
|  | |
| 257 |  | 
| 258 | 
             
                Returns:
         | 
| 259 | 
             
                    Untrained model
         | 
| 260 | 
             
                """
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 261 | 
             
                return Pipeline(
         | 
| 262 | 
             
                    [
         | 
| 263 | 
             
                        # Text preprocessing
         | 
| 264 | 
             
                        ("clean", TextCleaner()),
         | 
| 265 | 
             
                        ("lemma", TextLemmatizer()),
         | 
| 266 | 
             
                        # Preprocess (NOTE: Can be replaced with TfidfVectorizer, but left for clarity)
         | 
| 267 | 
            -
                        ( | 
|  | |
|  | |
|  | |
| 268 | 
             
                        ("tfidf", TfidfTransformer()),
         | 
| 269 | 
             
                        # Classifier
         | 
| 270 | 
             
                        ("clf", LogisticRegression(max_iter=1000, random_state=seed)),
         | 
| 271 | 
             
                    ],
         | 
| 272 | 
             
                    memory=Memory(CACHE_DIR, verbose=0),
         | 
|  | |
| 273 | 
             
                )
         | 
| 274 |  | 
| 275 |  | 
|  | |
| 5 | 
             
            import warnings
         | 
| 6 | 
             
            from typing import Literal
         | 
| 7 |  | 
| 8 | 
            +
            import nltk
         | 
| 9 | 
             
            import pandas as pd
         | 
| 10 | 
             
            from joblib import Memory
         | 
| 11 | 
            +
            from nltk.corpus import stopwords
         | 
| 12 | 
             
            from nltk.stem import WordNetLemmatizer
         | 
| 13 | 
             
            from sklearn.base import BaseEstimator, TransformerMixin
         | 
| 14 | 
             
            from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
         | 
|  | |
| 250 | 
             
            def create_model(
         | 
| 251 | 
             
                max_features: int,
         | 
| 252 | 
             
                seed: int | None = None,
         | 
| 253 | 
            +
                verbose: bool = False,
         | 
| 254 | 
             
            ) -> Pipeline:
         | 
| 255 | 
             
                """Create a sentiment analysis model.
         | 
| 256 |  | 
| 257 | 
             
                Args:
         | 
| 258 | 
             
                    max_features: Maximum number of features
         | 
| 259 | 
             
                    seed: Random seed (None for random seed)
         | 
| 260 | 
            +
                    verbose: Whether to log progress during training
         | 
| 261 |  | 
| 262 | 
             
                Returns:
         | 
| 263 | 
             
                    Untrained model
         | 
| 264 | 
             
                """
         | 
| 265 | 
            +
                # Download NLTK data if not already downloaded
         | 
| 266 | 
            +
                nltk.download("wordnet", quiet=True)
         | 
| 267 | 
            +
                nltk.download("stopwords", quiet=True)
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                # Load English stopwords
         | 
| 270 | 
            +
                stopwords_en = set(stopwords.words("english"))
         | 
| 271 | 
            +
             | 
| 272 | 
             
                return Pipeline(
         | 
| 273 | 
             
                    [
         | 
| 274 | 
             
                        # Text preprocessing
         | 
| 275 | 
             
                        ("clean", TextCleaner()),
         | 
| 276 | 
             
                        ("lemma", TextLemmatizer()),
         | 
| 277 | 
             
                        # Preprocess (NOTE: Can be replaced with TfidfVectorizer, but left for clarity)
         | 
| 278 | 
            +
                        (
         | 
| 279 | 
            +
                            "vectorize",
         | 
| 280 | 
            +
                            CountVectorizer(stop_words=stopwords_en, ngram_range=(1, 2), max_features=max_features),
         | 
| 281 | 
            +
                        ),
         | 
| 282 | 
             
                        ("tfidf", TfidfTransformer()),
         | 
| 283 | 
             
                        # Classifier
         | 
| 284 | 
             
                        ("clf", LogisticRegression(max_iter=1000, random_state=seed)),
         | 
| 285 | 
             
                    ],
         | 
| 286 | 
             
                    memory=Memory(CACHE_DIR, verbose=0),
         | 
| 287 | 
            +
                    verbose=verbose,
         | 
| 288 | 
             
                )
         | 
| 289 |  | 
| 290 |  |