Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Cache label data along with tokenized text data
Browse files- app/cli.py +26 -20
 - app/utils.py +2 -3
 
    	
        app/cli.py
    CHANGED
    
    | 
         @@ -146,32 +146,35 @@ def evaluate( 
     | 
|
| 146 | 
         
             
                from app.model import evaluate_model
         
     | 
| 147 | 
         
             
                from app.utils import deserialize, serialize
         
     | 
| 148 | 
         | 
| 149 | 
         
            -
                 
     | 
| 
         | 
|
| 150 | 
         
             
                use_cached_data = False
         
     | 
| 151 | 
         | 
| 152 | 
         
            -
                if  
     | 
| 153 | 
         
             
                    use_cached_data = force_cache or click.confirm(
         
     | 
| 154 | 
         
             
                        f"Found existing tokenized data for '{dataset}'. Use it?",
         
     | 
| 155 | 
         
             
                        default=True,
         
     | 
| 156 | 
         
             
                    )
         
     | 
| 157 | 
         | 
| 158 | 
         
            -
                click.echo("Loading dataset... ", nl=False)
         
     | 
| 159 | 
         
            -
                text_data, label_data = load_data(dataset)
         
     | 
| 160 | 
         
            -
                click.echo(DONE_STR)
         
     | 
| 161 | 
         
            -
             
     | 
| 162 | 
         
             
                if use_cached_data:
         
     | 
| 163 | 
         
             
                    click.echo("Loading cached data... ", nl=False)
         
     | 
| 164 | 
         
            -
                    token_data = pd.Series(deserialize( 
     | 
| 
         | 
|
| 165 | 
         
             
                    click.echo(DONE_STR)
         
     | 
| 166 | 
         
             
                else:
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 167 | 
         
             
                    click.echo("Tokenizing data... ")
         
     | 
| 168 | 
         
             
                    token_data = tokenize(text_data, batch_size=token_batch_size, n_jobs=token_jobs, show_progress=True)
         
     | 
| 169 | 
         | 
| 170 | 
         
             
                    click.echo("Caching tokenized data... ")
         
     | 
| 171 | 
         
            -
                    serialize(token_data,  
     | 
| 
         | 
|
| 172 | 
         | 
| 173 | 
         
            -
             
     | 
| 174 | 
         
            -
             
     | 
| 175 | 
         | 
| 176 | 
         
             
                click.echo("Size of vocabulary: ", nl=False)
         
     | 
| 177 | 
         
             
                vocab = token_data.explode().value_counts()
         
     | 
| 
         @@ -281,32 +284,35 @@ def train( 
     | 
|
| 281 | 
         
             
                if model_path.exists() and not overwrite:
         
     | 
| 282 | 
         
             
                    click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True)
         
     | 
| 283 | 
         | 
| 284 | 
         
            -
                 
     | 
| 
         | 
|
| 285 | 
         
             
                use_cached_data = False
         
     | 
| 286 | 
         | 
| 287 | 
         
            -
                if  
     | 
| 288 | 
         
             
                    use_cached_data = force_cache or click.confirm(
         
     | 
| 289 | 
         
             
                        f"Found existing tokenized data for '{dataset}'. Use it?",
         
     | 
| 290 | 
         
             
                        default=True,
         
     | 
| 291 | 
         
             
                    )
         
     | 
| 292 | 
         | 
| 293 | 
         
            -
                click.echo("Loading dataset... ", nl=False)
         
     | 
| 294 | 
         
            -
                text_data, label_data = load_data(dataset)
         
     | 
| 295 | 
         
            -
                click.echo(DONE_STR)
         
     | 
| 296 | 
         
            -
             
     | 
| 297 | 
         
             
                if use_cached_data:
         
     | 
| 298 | 
         
             
                    click.echo("Loading cached data... ", nl=False)
         
     | 
| 299 | 
         
            -
                    token_data = pd.Series(deserialize( 
     | 
| 
         | 
|
| 300 | 
         
             
                    click.echo(DONE_STR)
         
     | 
| 301 | 
         
             
                else:
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 302 | 
         
             
                    click.echo("Tokenizing data... ")
         
     | 
| 303 | 
         
             
                    token_data = tokenize(text_data, batch_size=token_batch_size, n_jobs=token_jobs, show_progress=True)
         
     | 
| 304 | 
         | 
| 305 | 
         
             
                    click.echo("Caching tokenized data... ")
         
     | 
| 306 | 
         
            -
                    serialize(token_data,  
     | 
| 
         | 
|
| 307 | 
         | 
| 308 | 
         
            -
             
     | 
| 309 | 
         
            -
             
     | 
| 310 | 
         | 
| 311 | 
         
             
                click.echo("Size of vocabulary: ", nl=False)
         
     | 
| 312 | 
         
             
                vocab = token_data.explode().value_counts()
         
     | 
| 
         | 
|
| 146 | 
         
             
                from app.model import evaluate_model
         
     | 
| 147 | 
         
             
                from app.utils import deserialize, serialize
         
     | 
| 148 | 
         | 
| 149 | 
         
            +
                token_cache_path = TOKENIZER_CACHE_PATH / f"{dataset}_tokenized.pkl"
         
     | 
| 150 | 
         
            +
                label_cache_path = TOKENIZER_CACHE_PATH / f"{dataset}_labels.pkl"
         
     | 
| 151 | 
         
             
                use_cached_data = False
         
     | 
| 152 | 
         | 
| 153 | 
         
            +
                if token_cache_path.exists():
         
     | 
| 154 | 
         
             
                    use_cached_data = force_cache or click.confirm(
         
     | 
| 155 | 
         
             
                        f"Found existing tokenized data for '{dataset}'. Use it?",
         
     | 
| 156 | 
         
             
                        default=True,
         
     | 
| 157 | 
         
             
                    )
         
     | 
| 158 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 159 | 
         
             
                if use_cached_data:
         
     | 
| 160 | 
         
             
                    click.echo("Loading cached data... ", nl=False)
         
     | 
| 161 | 
         
            +
                    token_data = pd.Series(deserialize(token_cache_path))
         
     | 
| 162 | 
         
            +
                    label_data = joblib.load(label_cache_path)
         
     | 
| 163 | 
         
             
                    click.echo(DONE_STR)
         
     | 
| 164 | 
         
             
                else:
         
     | 
| 165 | 
         
            +
                    click.echo("Loading dataset... ", nl=False)
         
     | 
| 166 | 
         
            +
                    text_data, label_data = load_data(dataset)
         
     | 
| 167 | 
         
            +
                    click.echo(DONE_STR)
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
             
                    click.echo("Tokenizing data... ")
         
     | 
| 170 | 
         
             
                    token_data = tokenize(text_data, batch_size=token_batch_size, n_jobs=token_jobs, show_progress=True)
         
     | 
| 171 | 
         | 
| 172 | 
         
             
                    click.echo("Caching tokenized data... ")
         
     | 
| 173 | 
         
            +
                    serialize(token_data, token_cache_path, show_progress=True)
         
     | 
| 174 | 
         
            +
                    joblib.dump(label_data, label_cache_path, compress=3)
         
     | 
| 175 | 
         | 
| 176 | 
         
            +
                    del text_data
         
     | 
| 177 | 
         
            +
                    gc.collect()
         
     | 
| 178 | 
         | 
| 179 | 
         
             
                click.echo("Size of vocabulary: ", nl=False)
         
     | 
| 180 | 
         
             
                vocab = token_data.explode().value_counts()
         
     | 
| 
         | 
|
| 284 | 
         
             
                if model_path.exists() and not overwrite:
         
     | 
| 285 | 
         
             
                    click.confirm(f"Model file '{model_path}' already exists. Overwrite?", abort=True)
         
     | 
| 286 | 
         | 
| 287 | 
         
            +
                token_cache_path = TOKENIZER_CACHE_PATH / f"{dataset}_tokenized.pkl"
         
     | 
| 288 | 
         
            +
                label_cache_path = TOKENIZER_CACHE_PATH / f"{dataset}_labels.pkl"
         
     | 
| 289 | 
         
             
                use_cached_data = False
         
     | 
| 290 | 
         | 
| 291 | 
         
            +
                if token_cache_path.exists():
         
     | 
| 292 | 
         
             
                    use_cached_data = force_cache or click.confirm(
         
     | 
| 293 | 
         
             
                        f"Found existing tokenized data for '{dataset}'. Use it?",
         
     | 
| 294 | 
         
             
                        default=True,
         
     | 
| 295 | 
         
             
                    )
         
     | 
| 296 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 297 | 
         
             
                if use_cached_data:
         
     | 
| 298 | 
         
             
                    click.echo("Loading cached data... ", nl=False)
         
     | 
| 299 | 
         
            +
                    token_data = pd.Series(deserialize(token_cache_path))
         
     | 
| 300 | 
         
            +
                    label_data = joblib.load(label_cache_path)
         
     | 
| 301 | 
         
             
                    click.echo(DONE_STR)
         
     | 
| 302 | 
         
             
                else:
         
     | 
| 303 | 
         
            +
                    click.echo("Loading dataset... ", nl=False)
         
     | 
| 304 | 
         
            +
                    text_data, label_data = load_data(dataset)
         
     | 
| 305 | 
         
            +
                    click.echo(DONE_STR)
         
     | 
| 306 | 
         
            +
             
     | 
| 307 | 
         
             
                    click.echo("Tokenizing data... ")
         
     | 
| 308 | 
         
             
                    token_data = tokenize(text_data, batch_size=token_batch_size, n_jobs=token_jobs, show_progress=True)
         
     | 
| 309 | 
         | 
| 310 | 
         
             
                    click.echo("Caching tokenized data... ")
         
     | 
| 311 | 
         
            +
                    serialize(token_data, token_cache_path, show_progress=True)
         
     | 
| 312 | 
         
            +
                    joblib.dump(label_data, label_cache_path, compress=3)
         
     | 
| 313 | 
         | 
| 314 | 
         
            +
                    del text_data
         
     | 
| 315 | 
         
            +
                    gc.collect()
         
     | 
| 316 | 
         | 
| 317 | 
         
             
                click.echo("Size of vocabulary: ", nl=False)
         
     | 
| 318 | 
         
             
                vocab = token_data.explode().value_counts()
         
     | 
    	
        app/utils.py
    CHANGED
    
    | 
         @@ -11,7 +11,7 @@ if TYPE_CHECKING: 
     | 
|
| 11 | 
         
             
            __all__ = ["serialize", "deserialize"]
         
     | 
| 12 | 
         | 
| 13 | 
         | 
| 14 | 
         
            -
            def serialize(data: Sequence[str], path: Path, max_size: int = 100000, show_progress: bool = False) -> None:
         
     | 
| 15 | 
         
             
                """Serialize data to a file
         
     | 
| 16 | 
         | 
| 17 | 
         
             
                Args:
         
     | 
| 
         @@ -20,7 +20,6 @@ def serialize(data: Sequence[str], path: Path, max_size: int = 100000, show_prog 
     | 
|
| 20 | 
         
             
                    max_size: The maximum size a chunk can be (in elements)
         
     | 
| 21 | 
         
             
                    show_progress: Whether to show a progress bar
         
     | 
| 22 | 
         
             
                """
         
     | 
| 23 | 
         
            -
                # first file is path, next chunks have ".1", ".2", etc. appended
         
     | 
| 24 | 
         
             
                for i, chunk in enumerate(
         
     | 
| 25 | 
         
             
                    tqdm(
         
     | 
| 26 | 
         
             
                        [data[i : i + max_size] for i in range(0, len(data), max_size)],
         
     | 
| 
         @@ -33,7 +32,7 @@ def serialize(data: Sequence[str], path: Path, max_size: int = 100000, show_prog 
     | 
|
| 33 | 
         
             
                        joblib.dump(chunk, f, compress=3)
         
     | 
| 34 | 
         | 
| 35 | 
         | 
| 36 | 
         
            -
            def deserialize(path: Path) -> Sequence[str]:
         
     | 
| 37 | 
         
             
                """Deserialize data from a file
         
     | 
| 38 | 
         | 
| 39 | 
         
             
                Args:
         
     | 
| 
         | 
|
| 11 | 
         
             
            __all__ = ["serialize", "deserialize"]
         
     | 
| 12 | 
         | 
| 13 | 
         | 
| 14 | 
         
            +
            def serialize(data: Sequence[str | int], path: Path, max_size: int = 100000, show_progress: bool = False) -> None:
         
     | 
| 15 | 
         
             
                """Serialize data to a file
         
     | 
| 16 | 
         | 
| 17 | 
         
             
                Args:
         
     | 
| 
         | 
|
| 20 | 
         
             
                    max_size: The maximum size a chunk can be (in elements)
         
     | 
| 21 | 
         
             
                    show_progress: Whether to show a progress bar
         
     | 
| 22 | 
         
             
                """
         
     | 
| 
         | 
|
| 23 | 
         
             
                for i, chunk in enumerate(
         
     | 
| 24 | 
         
             
                    tqdm(
         
     | 
| 25 | 
         
             
                        [data[i : i + max_size] for i in range(0, len(data), max_size)],
         
     | 
| 
         | 
|
| 32 | 
         
             
                        joblib.dump(chunk, f, compress=3)
         
     | 
| 33 | 
         | 
| 34 | 
         | 
| 35 | 
         
            +
            def deserialize(path: Path) -> Sequence[str | int]:
         
     | 
| 36 | 
         
             
                """Deserialize data from a file
         
     | 
| 37 | 
         | 
| 38 | 
         
             
                Args:
         
     |