Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		Jean Garcia-Gathright
		
	commited on
		
		
					Commit 
							
							·
						
						a02c788
	
1
								Parent(s):
							
							4150cb0
								
added ernie files
Browse files- app.py +2 -2
- app.py~ +7 -0
- ernie/__init__.py +47 -0
- ernie/aggregation_strategies.py +70 -0
- ernie/ernie.py +397 -0
- ernie/helper.py +121 -0
- ernie/models.py +51 -0
- ernie/split_strategies.py +125 -0
    	
        app.py
    CHANGED
    
    | @@ -1,6 +1,6 @@ | |
| 1 | 
             
            import gradio as gr
         | 
| 2 | 
            -
            import  | 
| 3 | 
            -
            import  | 
| 4 |  | 
| 5 | 
             
            def greet(name):
         | 
| 6 | 
             
                return "Hello " + name + "!!"
         | 
|  | |
| 1 | 
             
            import gradio as gr
         | 
| 2 | 
            +
            from ernie.ernie import SentenceClassifier
         | 
| 3 | 
            +
            from ernie import helper
         | 
| 4 |  | 
| 5 | 
             
            def greet(name):
         | 
| 6 | 
             
                return "Hello " + name + "!!"
         | 
    	
        app.py~
    ADDED
    
    | @@ -0,0 +1,7 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            def greet(name):
         | 
| 4 | 
            +
                return "Hello " + name + "!!"
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            iface = gr.Interface(fn=greet, inputs="text", outputs="text")
         | 
| 7 | 
            +
            iface.launch()
         | 
    	
        ernie/__init__.py
    ADDED
    
    | @@ -0,0 +1,47 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env python
         | 
| 2 | 
            +
            # -*- coding: utf-8 -*-
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from .ernie import *  # noqa: F401, F403
         | 
| 5 | 
            +
            from tensorflow.python.client import device_lib
         | 
| 6 | 
            +
            import logging
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            __version__ = '1.0.1'
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            logging.getLogger().setLevel(logging.WARNING)
         | 
| 11 | 
            +
            logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR)
         | 
| 12 | 
            +
            logging.basicConfig(
         | 
| 13 | 
            +
                format='%(asctime)-15s [%(levelname)s] %(message)s',
         | 
| 14 | 
            +
                datefmt='%Y-%m-%d %H:%M:%S'
         | 
| 15 | 
            +
            )
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def _get_cpu_name():
         | 
| 19 | 
            +
                import cpuinfo
         | 
| 20 | 
            +
                cpu_info = cpuinfo.get_cpu_info()
         | 
| 21 | 
            +
                cpu_name = f"{cpu_info['brand_raw']}, {cpu_info['count']} vCores"
         | 
| 22 | 
            +
                return cpu_name
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            def _get_gpu_name():
         | 
| 26 | 
            +
                gpu_name = \
         | 
| 27 | 
            +
                    device_lib\
         | 
| 28 | 
            +
                    .list_local_devices()[3]\
         | 
| 29 | 
            +
                    .physical_device_desc\
         | 
| 30 | 
            +
                    .split(',')[1]\
         | 
| 31 | 
            +
                    .split('name:')[1]\
         | 
| 32 | 
            +
                    .strip()
         | 
| 33 | 
            +
                return gpu_name
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            device_name = _get_cpu_name()
         | 
| 37 | 
            +
            device_type = 'CPU'
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            try:
         | 
| 40 | 
            +
                device_name = _get_gpu_name()
         | 
| 41 | 
            +
                device_type = 'GPU'
         | 
| 42 | 
            +
            except IndexError:
         | 
| 43 | 
            +
                # Detect TPU
         | 
| 44 | 
            +
                pass
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            logging.info(f'ernie v{__version__}')
         | 
| 47 | 
            +
            logging.info(f'target device: [{device_type}] {device_name}\n')
         | 
    	
        ernie/aggregation_strategies.py
    ADDED
    
    | @@ -0,0 +1,70 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env python
         | 
| 2 | 
            +
            # -*- coding: utf-8 -*-
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from statistics import mean
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            class AggregationStrategy:
         | 
| 8 | 
            +
                def __init__(
         | 
| 9 | 
            +
                    self,
         | 
| 10 | 
            +
                    method,
         | 
| 11 | 
            +
                    max_items=None,
         | 
| 12 | 
            +
                    top_items=True,
         | 
| 13 | 
            +
                    sorting_class_index=1
         | 
| 14 | 
            +
                ):
         | 
| 15 | 
            +
                    self.method = method
         | 
| 16 | 
            +
                    self.max_items = max_items
         | 
| 17 | 
            +
                    self.top_items = top_items
         | 
| 18 | 
            +
                    self.sorting_class_index = sorting_class_index
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                def aggregate(self, softmax_tuples):
         | 
| 21 | 
            +
                    softmax_dicts = []
         | 
| 22 | 
            +
                    for softmax_tuple in softmax_tuples:
         | 
| 23 | 
            +
                        softmax_dict = {}
         | 
| 24 | 
            +
                        for i, probability in enumerate(softmax_tuple):
         | 
| 25 | 
            +
                            softmax_dict[i] = probability
         | 
| 26 | 
            +
                        softmax_dicts.append(softmax_dict)
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                    if self.max_items is not None:
         | 
| 29 | 
            +
                        softmax_dicts = sorted(
         | 
| 30 | 
            +
                            softmax_dicts,
         | 
| 31 | 
            +
                            key=lambda x: x[self.sorting_class_index],
         | 
| 32 | 
            +
                            reverse=self.top_items
         | 
| 33 | 
            +
                        )
         | 
| 34 | 
            +
                        if self.max_items < len(softmax_dicts):
         | 
| 35 | 
            +
                            softmax_dicts = softmax_dicts[:self.max_items]
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    softmax_list = []
         | 
| 38 | 
            +
                    for key in softmax_dicts[0].keys():
         | 
| 39 | 
            +
                        softmax_list.append(self.method(
         | 
| 40 | 
            +
                            [probabilities[key] for probabilities in softmax_dicts]))
         | 
| 41 | 
            +
                    softmax_tuple = tuple(softmax_list)
         | 
| 42 | 
            +
                    return softmax_tuple
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            class AggregationStrategies:
         | 
| 46 | 
            +
                Mean = AggregationStrategy(method=mean)
         | 
| 47 | 
            +
                MeanTopFiveBinaryClassification = AggregationStrategy(
         | 
| 48 | 
            +
                    method=mean,
         | 
| 49 | 
            +
                    max_items=5,
         | 
| 50 | 
            +
                    top_items=True,
         | 
| 51 | 
            +
                    sorting_class_index=1
         | 
| 52 | 
            +
                )
         | 
| 53 | 
            +
                MeanTopTenBinaryClassification = AggregationStrategy(
         | 
| 54 | 
            +
                    method=mean,
         | 
| 55 | 
            +
                    max_items=10,
         | 
| 56 | 
            +
                    top_items=True,
         | 
| 57 | 
            +
                    sorting_class_index=1
         | 
| 58 | 
            +
                )
         | 
| 59 | 
            +
                MeanTopFifteenBinaryClassification = AggregationStrategy(
         | 
| 60 | 
            +
                    method=mean,
         | 
| 61 | 
            +
                    max_items=15,
         | 
| 62 | 
            +
                    top_items=True,
         | 
| 63 | 
            +
                    sorting_class_index=1
         | 
| 64 | 
            +
                )
         | 
| 65 | 
            +
                MeanTopTwentyBinaryClassification = AggregationStrategy(
         | 
| 66 | 
            +
                    method=mean,
         | 
| 67 | 
            +
                    max_items=20,
         | 
| 68 | 
            +
                    top_items=True,
         | 
| 69 | 
            +
                    sorting_class_index=1
         | 
| 70 | 
            +
                )
         | 
    	
        ernie/ernie.py
    ADDED
    
    | @@ -0,0 +1,397 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env python
         | 
| 2 | 
            +
            # -*- coding: utf-8 -*-
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import pandas as pd
         | 
| 6 | 
            +
            from transformers import (
         | 
| 7 | 
            +
                AutoTokenizer,
         | 
| 8 | 
            +
                AutoModel,
         | 
| 9 | 
            +
                AutoConfig,
         | 
| 10 | 
            +
                TFAutoModelForSequenceClassification,
         | 
| 11 | 
            +
            )
         | 
| 12 | 
            +
            from tensorflow import keras
         | 
| 13 | 
            +
            from sklearn.model_selection import train_test_split
         | 
| 14 | 
            +
            import logging
         | 
| 15 | 
            +
            import time
         | 
| 16 | 
            +
            from .models import Models, ModelsByFamily  # noqa: F401
         | 
| 17 | 
            +
            from .split_strategies import (  # noqa: F401
         | 
| 18 | 
            +
                SplitStrategy,
         | 
| 19 | 
            +
                SplitStrategies,
         | 
| 20 | 
            +
                RegexExpressions
         | 
| 21 | 
            +
            )
         | 
| 22 | 
            +
            from .aggregation_strategies import (  # noqa: F401
         | 
| 23 | 
            +
                AggregationStrategy,
         | 
| 24 | 
            +
                AggregationStrategies
         | 
| 25 | 
            +
            )
         | 
| 26 | 
            +
            from .helper import (
         | 
| 27 | 
            +
                get_features,
         | 
| 28 | 
            +
                softmax,
         | 
| 29 | 
            +
                remove_dir,
         | 
| 30 | 
            +
                make_dir,
         | 
| 31 | 
            +
                copy_dir
         | 
| 32 | 
            +
            )
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            AUTOSAVE_PATH = './ernie-autosave/'
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            def clean_autosave():
         | 
| 38 | 
            +
                remove_dir(AUTOSAVE_PATH)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            class SentenceClassifier:
         | 
| 42 | 
            +
                def __init__(self,
         | 
| 43 | 
            +
                             model_name=Models.BertBaseUncased,
         | 
| 44 | 
            +
                             model_path=None,
         | 
| 45 | 
            +
                             max_length=64,
         | 
| 46 | 
            +
                             labels_no=2,
         | 
| 47 | 
            +
                             tokenizer_kwargs=None,
         | 
| 48 | 
            +
                             model_kwargs=None):
         | 
| 49 | 
            +
                    self._loaded_data = False
         | 
| 50 | 
            +
                    self._model_path = None
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    if model_kwargs is None:
         | 
| 53 | 
            +
                        model_kwargs = {}
         | 
| 54 | 
            +
                    model_kwargs['num_labels'] = labels_no
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    if tokenizer_kwargs is None:
         | 
| 57 | 
            +
                        tokenizer_kwargs = {}
         | 
| 58 | 
            +
                    tokenizer_kwargs['max_len'] = max_length
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    if model_path is not None:
         | 
| 61 | 
            +
                        self._load_local_model(model_path)
         | 
| 62 | 
            +
                    else:
         | 
| 63 | 
            +
                        self._load_remote_model(model_name, tokenizer_kwargs, model_kwargs)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                @property
         | 
| 66 | 
            +
                def model(self):
         | 
| 67 | 
            +
                    return self._model
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                @property
         | 
| 70 | 
            +
                def tokenizer(self):
         | 
| 71 | 
            +
                    return self._tokenizer
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def load_dataset(self,
         | 
| 74 | 
            +
                                 dataframe=None,
         | 
| 75 | 
            +
                                 validation_split=0.1,
         | 
| 76 | 
            +
                                 random_state=None,
         | 
| 77 | 
            +
                                 stratify=None,
         | 
| 78 | 
            +
                                 csv_path=None,
         | 
| 79 | 
            +
                                 read_csv_kwargs=None):
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    if dataframe is None and csv_path is None:
         | 
| 82 | 
            +
                        raise ValueError
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    if csv_path is not None:
         | 
| 85 | 
            +
                        dataframe = pd.read_csv(csv_path, **read_csv_kwargs)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    sentences = list(dataframe[dataframe.columns[0]])
         | 
| 88 | 
            +
                    labels = dataframe[dataframe.columns[1]].values
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    (
         | 
| 91 | 
            +
                        training_sentences,
         | 
| 92 | 
            +
                        validation_sentences,
         | 
| 93 | 
            +
                        training_labels,
         | 
| 94 | 
            +
                        validation_labels
         | 
| 95 | 
            +
                    ) = train_test_split(
         | 
| 96 | 
            +
                        sentences,
         | 
| 97 | 
            +
                        labels,
         | 
| 98 | 
            +
                        test_size=validation_split,
         | 
| 99 | 
            +
                        shuffle=True,
         | 
| 100 | 
            +
                        random_state=random_state,
         | 
| 101 | 
            +
                        stratify=stratify
         | 
| 102 | 
            +
                    )
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    self._training_features = get_features(
         | 
| 105 | 
            +
                        self._tokenizer, training_sentences, training_labels)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    self._training_size = len(training_sentences)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    self._validation_features = get_features(
         | 
| 110 | 
            +
                        self._tokenizer,
         | 
| 111 | 
            +
                        validation_sentences,
         | 
| 112 | 
            +
                        validation_labels
         | 
| 113 | 
            +
                    )
         | 
| 114 | 
            +
                    self._validation_split = len(validation_sentences)
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    logging.info(f'training_size: {self._training_size}')
         | 
| 117 | 
            +
                    logging.info(f'validation_split: {self._validation_split}')
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    self._loaded_data = True
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                def fine_tune(self,
         | 
| 122 | 
            +
                              epochs=4,
         | 
| 123 | 
            +
                              learning_rate=2e-5,
         | 
| 124 | 
            +
                              epsilon=1e-8,
         | 
| 125 | 
            +
                              clipnorm=1.0,
         | 
| 126 | 
            +
                              optimizer_function=keras.optimizers.Adam,
         | 
| 127 | 
            +
                              optimizer_kwargs=None,
         | 
| 128 | 
            +
                              loss_function=keras.losses.SparseCategoricalCrossentropy,
         | 
| 129 | 
            +
                              loss_kwargs=None,
         | 
| 130 | 
            +
                              accuracy_function=keras.metrics.SparseCategoricalAccuracy,
         | 
| 131 | 
            +
                              accuracy_kwargs=None,
         | 
| 132 | 
            +
                              training_batch_size=32,
         | 
| 133 | 
            +
                              validation_batch_size=64,
         | 
| 134 | 
            +
                              **kwargs):
         | 
| 135 | 
            +
                    if not self._loaded_data:
         | 
| 136 | 
            +
                        raise Exception('Data has not been loaded.')
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    if optimizer_kwargs is None:
         | 
| 139 | 
            +
                        optimizer_kwargs = {
         | 
| 140 | 
            +
                            'learning_rate': learning_rate,
         | 
| 141 | 
            +
                            'epsilon': epsilon,
         | 
| 142 | 
            +
                            'clipnorm': clipnorm
         | 
| 143 | 
            +
                        }
         | 
| 144 | 
            +
                    optimizer = optimizer_function(**optimizer_kwargs)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    if loss_kwargs is None:
         | 
| 147 | 
            +
                        loss_kwargs = {'from_logits': True}
         | 
| 148 | 
            +
                    loss = loss_function(**loss_kwargs)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    if accuracy_kwargs is None:
         | 
| 151 | 
            +
                        accuracy_kwargs = {'name': 'accuracy'}
         | 
| 152 | 
            +
                    accuracy = accuracy_function(**accuracy_kwargs)
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    self._model.compile(optimizer=optimizer, loss=loss, metrics=[accuracy])
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    training_features = self._training_features.shuffle(
         | 
| 157 | 
            +
                        self._training_size).batch(training_batch_size).repeat(-1)
         | 
| 158 | 
            +
                    validation_features = self._validation_features.batch(
         | 
| 159 | 
            +
                        validation_batch_size)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    training_steps = self._training_size // training_batch_size
         | 
| 162 | 
            +
                    if training_steps == 0:
         | 
| 163 | 
            +
                        training_steps = self._training_size
         | 
| 164 | 
            +
                    logging.info(f'training_steps: {training_steps}')
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    validation_steps = self._validation_split // validation_batch_size
         | 
| 167 | 
            +
                    if validation_steps == 0:
         | 
| 168 | 
            +
                        validation_steps = self._validation_split
         | 
| 169 | 
            +
                    logging.info(f'validation_steps: {validation_steps}')
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                    for i in range(epochs):
         | 
| 172 | 
            +
                        self._model.fit(training_features,
         | 
| 173 | 
            +
                                        epochs=1,
         | 
| 174 | 
            +
                                        validation_data=validation_features,
         | 
| 175 | 
            +
                                        steps_per_epoch=training_steps,
         | 
| 176 | 
            +
                                        validation_steps=validation_steps,
         | 
| 177 | 
            +
                                        **kwargs)
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                    # The fine-tuned model does not have the same input interface
         | 
| 180 | 
            +
                    # after being exported and loaded again.
         | 
| 181 | 
            +
                    self._reload_model()
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                def predict_one(
         | 
| 184 | 
            +
                    self,
         | 
| 185 | 
            +
                    text,
         | 
| 186 | 
            +
                    split_strategy=None,
         | 
| 187 | 
            +
                    aggregation_strategy=None
         | 
| 188 | 
            +
                ):
         | 
| 189 | 
            +
                    return next(
         | 
| 190 | 
            +
                        self.predict([text],
         | 
| 191 | 
            +
                                     batch_size=1,
         | 
| 192 | 
            +
                                     split_strategy=split_strategy,
         | 
| 193 | 
            +
                                     aggregation_strategy=aggregation_strategy))
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                def predict(
         | 
| 196 | 
            +
                    self,
         | 
| 197 | 
            +
                    texts,
         | 
| 198 | 
            +
                    batch_size=32,
         | 
| 199 | 
            +
                    split_strategy=None,
         | 
| 200 | 
            +
                    aggregation_strategy=None
         | 
| 201 | 
            +
                ):
         | 
| 202 | 
            +
                    if split_strategy is None:
         | 
| 203 | 
            +
                        yield from self._predict_batch(texts, batch_size)
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                    else:
         | 
| 206 | 
            +
                        if aggregation_strategy is None:
         | 
| 207 | 
            +
                            aggregation_strategy = AggregationStrategies.Mean
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                        split_indexes = [0]
         | 
| 210 | 
            +
                        sentences = []
         | 
| 211 | 
            +
                        for text in texts:
         | 
| 212 | 
            +
                            new_sentences = split_strategy.split(text, self.tokenizer)
         | 
| 213 | 
            +
                            if not new_sentences:
         | 
| 214 | 
            +
                                continue
         | 
| 215 | 
            +
                            split_indexes.append(split_indexes[-1] + len(new_sentences))
         | 
| 216 | 
            +
                            sentences.extend(new_sentences)
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                        predictions = list(self._predict_batch(sentences, batch_size))
         | 
| 219 | 
            +
                        for i, split_index in enumerate(split_indexes[:-1]):
         | 
| 220 | 
            +
                            stop_index = split_indexes[i + 1]
         | 
| 221 | 
            +
                            yield aggregation_strategy.aggregate(
         | 
| 222 | 
            +
                                predictions[split_index:stop_index]
         | 
| 223 | 
            +
                            )
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                def dump(self, path):
         | 
| 226 | 
            +
                    if self._model_path:
         | 
| 227 | 
            +
                        copy_dir(self._model_path, path)
         | 
| 228 | 
            +
                    else:
         | 
| 229 | 
            +
                        self._dump(path)
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                def _dump(self, path):
         | 
| 232 | 
            +
                    make_dir(path)
         | 
| 233 | 
            +
                    make_dir(path + '/tokenizer')
         | 
| 234 | 
            +
                    self._model.save_pretrained(path)
         | 
| 235 | 
            +
                    self._tokenizer.save_pretrained(path + '/tokenizer')
         | 
| 236 | 
            +
                    self._config.save_pretrained(path + '/tokenizer')
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                def _predict_batch(self, sentences: list, batch_size: int):
         | 
| 239 | 
            +
                    sentences_number = len(sentences)
         | 
| 240 | 
            +
                    if batch_size > sentences_number:
         | 
| 241 | 
            +
                        batch_size = sentences_number
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                    for i in range(0, sentences_number, batch_size):
         | 
| 244 | 
            +
                        input_ids_list = []
         | 
| 245 | 
            +
                        attention_mask_list = []
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                        stop_index = i + batch_size
         | 
| 248 | 
            +
                        stop_index = stop_index if stop_index < sentences_number \
         | 
| 249 | 
            +
                            else sentences_number
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                        for j in range(i, stop_index):
         | 
| 252 | 
            +
                            features = self._tokenizer.encode_plus(
         | 
| 253 | 
            +
                                sentences[j],
         | 
| 254 | 
            +
                                add_special_tokens=True,
         | 
| 255 | 
            +
                                max_length=self._tokenizer.model_max_length
         | 
| 256 | 
            +
                            )
         | 
| 257 | 
            +
                            input_ids, _, attention_mask = (
         | 
| 258 | 
            +
                                features['input_ids'],
         | 
| 259 | 
            +
                                features['token_type_ids'],
         | 
| 260 | 
            +
                                features['attention_mask']
         | 
| 261 | 
            +
                            )
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                            input_ids = self._list_to_padded_array(features['input_ids'])
         | 
| 264 | 
            +
                            attention_mask = self._list_to_padded_array(
         | 
| 265 | 
            +
                                features['attention_mask'])
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                            input_ids_list.append(input_ids)
         | 
| 268 | 
            +
                            attention_mask_list.append(attention_mask)
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                        input_dict = {
         | 
| 271 | 
            +
                            'input_ids': np.array(input_ids_list),
         | 
| 272 | 
            +
                            'attention_mask': np.array(attention_mask_list)
         | 
| 273 | 
            +
                        }
         | 
| 274 | 
            +
                        logit_predictions = self._model.predict_on_batch(input_dict)
         | 
| 275 | 
            +
                        yield from (
         | 
| 276 | 
            +
                            [softmax(logit_prediction)
         | 
| 277 | 
            +
                             for logit_prediction in logit_predictions[0]]
         | 
| 278 | 
            +
                        )
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                def _list_to_padded_array(self, items):
         | 
| 281 | 
            +
                    array = np.array(items)
         | 
| 282 | 
            +
                    padded_array = np.zeros(self._tokenizer.model_max_length, dtype=np.int)
         | 
| 283 | 
            +
                    padded_array[:array.shape[0]] = array
         | 
| 284 | 
            +
                    return padded_array
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                def _get_temporary_path(self, name=''):
         | 
| 287 | 
            +
                    return f'{AUTOSAVE_PATH}{name}/{int(round(time.time() * 1000))}'
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                def _reload_model(self):
         | 
| 290 | 
            +
                    self._model_path = self._get_temporary_path(
         | 
| 291 | 
            +
                        name=self._get_model_family())
         | 
| 292 | 
            +
                    self._dump(self._model_path)
         | 
| 293 | 
            +
                    self._load_local_model(self._model_path)
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                def _load_local_model(self, model_path):
         | 
| 296 | 
            +
                    try:
         | 
| 297 | 
            +
                        self._tokenizer = AutoTokenizer.from_pretrained(
         | 
| 298 | 
            +
                            model_path + '/tokenizer')
         | 
| 299 | 
            +
                        self._config = AutoConfig.from_pretrained(
         | 
| 300 | 
            +
                            model_path + '/tokenizer')
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                    # Old models didn't use to have a tokenizer folder
         | 
| 303 | 
            +
                    except OSError:
         | 
| 304 | 
            +
                        self._tokenizer = AutoTokenizer.from_pretrained(model_path)
         | 
| 305 | 
            +
                        self._config = AutoConfig.from_pretrained(model_path)
         | 
| 306 | 
            +
                    self._model = TFAutoModelForSequenceClassification.from_pretrained(
         | 
| 307 | 
            +
                        model_path,
         | 
| 308 | 
            +
                        from_pt=False
         | 
| 309 | 
            +
                    )
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                def _get_model_family(self):
         | 
| 312 | 
            +
                    model_family = ''.join(self._model.name[2:].split('_')[:2])
         | 
| 313 | 
            +
                    return model_family
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                def _load_remote_model(self, model_name, tokenizer_kwargs, model_kwargs):
         | 
| 316 | 
            +
                    do_lower_case = False
         | 
| 317 | 
            +
                    if 'uncased' in model_name.lower():
         | 
| 318 | 
            +
                        do_lower_case = True
         | 
| 319 | 
            +
                    tokenizer_kwargs.update({'do_lower_case': do_lower_case})
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                    self._tokenizer = AutoTokenizer.from_pretrained(
         | 
| 322 | 
            +
                        model_name, **tokenizer_kwargs)
         | 
| 323 | 
            +
                    self._config = AutoConfig.from_pretrained(model_name)
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                    temporary_path = self._get_temporary_path()
         | 
| 326 | 
            +
                    make_dir(temporary_path)
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                    # TensorFlow model
         | 
| 329 | 
            +
                    try:
         | 
| 330 | 
            +
                        self._model = TFAutoModelForSequenceClassification.from_pretrained(
         | 
| 331 | 
            +
                            model_name,
         | 
| 332 | 
            +
                            from_pt=False
         | 
| 333 | 
            +
                        )
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                    # PyTorch model
         | 
| 336 | 
            +
                    except TypeError:
         | 
| 337 | 
            +
                        try:
         | 
| 338 | 
            +
                            self._model = \
         | 
| 339 | 
            +
                                TFAutoModelForSequenceClassification.from_pretrained(
         | 
| 340 | 
            +
                                    model_name,
         | 
| 341 | 
            +
                                    from_pt=True
         | 
| 342 | 
            +
                                )
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                        # Loading a TF model from a PyTorch checkpoint is not supported
         | 
| 345 | 
            +
                        # when using a model identifier name
         | 
| 346 | 
            +
                        except OSError:
         | 
| 347 | 
            +
                            model = AutoModel.from_pretrained(model_name)
         | 
| 348 | 
            +
                            model.save_pretrained(temporary_path)
         | 
| 349 | 
            +
                            self._model = \
         | 
| 350 | 
            +
                                TFAutoModelForSequenceClassification.from_pretrained(
         | 
| 351 | 
            +
                                    temporary_path,
         | 
| 352 | 
            +
                                    from_pt=True
         | 
| 353 | 
            +
                                )
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                    # Clean the model's last layer if the provided properties are different
         | 
| 356 | 
            +
                    clean_last_layer = False
         | 
| 357 | 
            +
                    for key, value in model_kwargs.items():
         | 
| 358 | 
            +
                        if not hasattr(self._model.config, key):
         | 
| 359 | 
            +
                            clean_last_layer = True
         | 
| 360 | 
            +
                            break
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                        if getattr(self._model.config, key) != value:
         | 
| 363 | 
            +
                            clean_last_layer = True
         | 
| 364 | 
            +
                            break
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                    if clean_last_layer:
         | 
| 367 | 
            +
                        try:
         | 
| 368 | 
            +
                            getattr(self._model, self._get_model_family()
         | 
| 369 | 
            +
                                    ).save_pretrained(temporary_path)
         | 
| 370 | 
            +
                            self._model = self._model.__class__.from_pretrained(
         | 
| 371 | 
            +
                                temporary_path,
         | 
| 372 | 
            +
                                from_pt=False,
         | 
| 373 | 
            +
                                **model_kwargs
         | 
| 374 | 
            +
                            )
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                        # The model is itself the main layer
         | 
| 377 | 
            +
                        except AttributeError:
         | 
| 378 | 
            +
                            # TensorFlow model
         | 
| 379 | 
            +
                            try:
         | 
| 380 | 
            +
                                self._model = self._model.__class__.from_pretrained(
         | 
| 381 | 
            +
                                    model_name,
         | 
| 382 | 
            +
                                    from_pt=False,
         | 
| 383 | 
            +
                                    **model_kwargs
         | 
| 384 | 
            +
                                )
         | 
| 385 | 
            +
             | 
| 386 | 
            +
                            # PyTorch Model
         | 
| 387 | 
            +
                            except (OSError, TypeError):
         | 
| 388 | 
            +
                                model = AutoModel.from_pretrained(model_name)
         | 
| 389 | 
            +
                                model.save_pretrained(temporary_path)
         | 
| 390 | 
            +
                                self._model = self._model.__class__.from_pretrained(
         | 
| 391 | 
            +
                                    temporary_path,
         | 
| 392 | 
            +
                                    from_pt=True,
         | 
| 393 | 
            +
                                    **model_kwargs
         | 
| 394 | 
            +
                                )
         | 
| 395 | 
            +
             | 
| 396 | 
            +
                    remove_dir(temporary_path)
         | 
| 397 | 
            +
                    assert self._tokenizer and self._model
         | 
    	
        ernie/helper.py
    ADDED
    
    | @@ -0,0 +1,121 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env python
         | 
| 2 | 
            +
            # -*- coding: utf-8 -*-
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from tensorflow import data, TensorShape, int64, int32
         | 
| 5 | 
            +
            from math import exp
         | 
| 6 | 
            +
            from os import makedirs
         | 
| 7 | 
            +
            from shutil import rmtree, move, copytree
         | 
| 8 | 
            +
            from huggingface_hub import hf_hub_download
         | 
| 9 | 
            +
            import os
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            def get_features(tokenizer, sentences, labels):
         | 
| 13 | 
            +
                features = []
         | 
| 14 | 
            +
                for i, sentence in enumerate(sentences):
         | 
| 15 | 
            +
                    inputs = tokenizer.encode_plus(
         | 
| 16 | 
            +
                        sentence,
         | 
| 17 | 
            +
                        add_special_tokens=True,
         | 
| 18 | 
            +
                        max_length=tokenizer.model_max_length
         | 
| 19 | 
            +
                    )
         | 
| 20 | 
            +
                    input_ids, token_type_ids = \
         | 
| 21 | 
            +
                        inputs['input_ids'], inputs['token_type_ids']
         | 
| 22 | 
            +
                    padding_length = tokenizer.model_max_length - len(input_ids)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                    if tokenizer.padding_side == 'right':
         | 
| 25 | 
            +
                        attention_mask = [1] * len(input_ids) + [0] * padding_length
         | 
| 26 | 
            +
                        input_ids = input_ids + [tokenizer.pad_token_id] * padding_length
         | 
| 27 | 
            +
                        token_type_ids = token_type_ids + \
         | 
| 28 | 
            +
                            [tokenizer.pad_token_type_id] * padding_length
         | 
| 29 | 
            +
                    else:
         | 
| 30 | 
            +
                        attention_mask = [0] * padding_length + [1] * len(input_ids)
         | 
| 31 | 
            +
                        input_ids = [tokenizer.pad_token_id] * padding_length + input_ids
         | 
| 32 | 
            +
                        token_type_ids = \
         | 
| 33 | 
            +
                            [tokenizer.pad_token_type_id] * padding_length + token_type_ids
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    assert tokenizer.model_max_length \
         | 
| 36 | 
            +
                        == len(attention_mask) \
         | 
| 37 | 
            +
                        == len(input_ids) \
         | 
| 38 | 
            +
                        == len(token_type_ids)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    feature = {
         | 
| 41 | 
            +
                        'input_ids': input_ids,
         | 
| 42 | 
            +
                        'attention_mask': attention_mask,
         | 
| 43 | 
            +
                        'token_type_ids': token_type_ids,
         | 
| 44 | 
            +
                        'label': int(labels[i])
         | 
| 45 | 
            +
                    }
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    features.append(feature)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                def gen():
         | 
| 50 | 
            +
                    for feature in features:
         | 
| 51 | 
            +
                        yield (
         | 
| 52 | 
            +
                            {
         | 
| 53 | 
            +
                                'input_ids': feature['input_ids'],
         | 
| 54 | 
            +
                                'attention_mask': feature['attention_mask'],
         | 
| 55 | 
            +
                                'token_type_ids': feature['token_type_ids'],
         | 
| 56 | 
            +
                            },
         | 
| 57 | 
            +
                            feature['label'],
         | 
| 58 | 
            +
                        )
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                dataset = data.Dataset.from_generator(
         | 
| 61 | 
            +
                    gen,
         | 
| 62 | 
            +
                    ({
         | 
| 63 | 
            +
                        'input_ids': int32,
         | 
| 64 | 
            +
                        'attention_mask': int32,
         | 
| 65 | 
            +
                        'token_type_ids': int32
         | 
| 66 | 
            +
                    }, int64),
         | 
| 67 | 
            +
                    (
         | 
| 68 | 
            +
                        {
         | 
| 69 | 
            +
                            'input_ids': TensorShape([None]),
         | 
| 70 | 
            +
                            'attention_mask': TensorShape([None]),
         | 
| 71 | 
            +
                            'token_type_ids': TensorShape([None]),
         | 
| 72 | 
            +
                        },
         | 
| 73 | 
            +
                        TensorShape([]),
         | 
| 74 | 
            +
                    ),
         | 
| 75 | 
            +
                )
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                return dataset
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
            def softmax(values):
         | 
| 81 | 
            +
                exps = [exp(value) for value in values]
         | 
| 82 | 
            +
                exps_sum = sum(exp_value for exp_value in exps)
         | 
| 83 | 
            +
                return tuple(map(lambda x: x / exps_sum, exps))
         | 
| 84 | 
            +
             | 
| 85 | 
            +
             | 
| 86 | 
            +
            def make_dir(path):
         | 
| 87 | 
            +
                try:
         | 
| 88 | 
            +
                    makedirs(path)
         | 
| 89 | 
            +
                except FileExistsError:
         | 
| 90 | 
            +
                    pass
         | 
| 91 | 
            +
             | 
| 92 | 
            +
             | 
| 93 | 
            +
            def remove_dir(path):
         | 
| 94 | 
            +
                rmtree(path)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
             | 
| 97 | 
            +
            def copy_dir(source_path, target_path):
         | 
| 98 | 
            +
                copytree(source_path, target_path)
         | 
| 99 | 
            +
             | 
| 100 | 
            +
             | 
| 101 | 
            +
            def move_dir(source_path, target_path):
         | 
| 102 | 
            +
                move(source_path, target_path)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
            def download_from_hub(repo_id, filename, revision=None, cache_dir=None):
         | 
| 105 | 
            +
                try:
         | 
| 106 | 
            +
                    hf_hub_download(repo_id=repo_id, filename=filename, revision=revision, cache_dir=cache_dir)
         | 
| 107 | 
            +
                except Exception as exp:
         | 
| 108 | 
            +
                    raise exp
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                        
         | 
| 111 | 
            +
                if cache_dir is not None:
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    files = os.listdir(cache_dir)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    for f in files:
         | 
| 116 | 
            +
                        if '.lock' in f:
         | 
| 117 | 
            +
                            name = f[0:-5]
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                            os.rename(cache_dir+name, cache_dir+filename)
         | 
| 120 | 
            +
                            os.remove(cache_dir+name+'.lock')
         | 
| 121 | 
            +
                            os.remove(cache_dir+name+'.json')
         | 
    	
        ernie/models.py
    ADDED
    
    | @@ -0,0 +1,51 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env python
         | 
| 2 | 
            +
            # -*- coding: utf-8 -*-
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            class Models:
         | 
| 6 | 
            +
                BertBaseUncased = 'bert-base-uncased'
         | 
| 7 | 
            +
                BertBaseCased = 'bert-base-cased'
         | 
| 8 | 
            +
                BertLargeUncased = 'bert-large-uncased'
         | 
| 9 | 
            +
                BertLargeCased = 'bert-large-cased'
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                RobertaBaseCased = 'roberta-base'
         | 
| 12 | 
            +
                RobertaLargeCased = 'roberta-large'
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                XLNetBaseCased = 'xlnet-base-cased'
         | 
| 15 | 
            +
                XLNetLargeCased = 'xlnet-large-cased'
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                DistilBertBaseUncased = 'distilbert-base-uncased'
         | 
| 18 | 
            +
                DistilBertBaseMultilingualCased = 'distilbert-base-multilingual-cased'
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                AlbertBaseCased = 'albert-base-v1'
         | 
| 21 | 
            +
                AlbertLargeCased = 'albert-large-v1'
         | 
| 22 | 
            +
                AlbertXLargeCased = 'albert-xlarge-v1'
         | 
| 23 | 
            +
                AlbertXXLargeCased = 'albert-xxlarge-v1'
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                AlbertBaseCased2 = 'albert-base-v2'
         | 
| 26 | 
            +
                AlbertLargeCased2 = 'albert-large-v2'
         | 
| 27 | 
            +
                AlbertXLargeCased2 = 'albert-xlarge-v2'
         | 
| 28 | 
            +
                AlbertXXLargeCased2 = 'albert-xxlarge-v2'
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            class ModelsByFamily:
         | 
| 32 | 
            +
                Bert = set([Models.BertBaseUncased, Models.BertBaseCased,
         | 
| 33 | 
            +
                           Models.BertLargeUncased, Models.BertLargeCased])
         | 
| 34 | 
            +
                Roberta = set([Models.RobertaBaseCased, Models.RobertaLargeCased])
         | 
| 35 | 
            +
                XLNet = set([Models.XLNetBaseCased, Models.XLNetLargeCased])
         | 
| 36 | 
            +
                DistilBert = set([Models.DistilBertBaseUncased,
         | 
| 37 | 
            +
                                 Models.DistilBertBaseMultilingualCased])
         | 
| 38 | 
            +
                Albert = set([
         | 
| 39 | 
            +
                    Models.AlbertBaseCased,
         | 
| 40 | 
            +
                    Models.AlbertLargeCased,
         | 
| 41 | 
            +
                    Models.AlbertXLargeCased,
         | 
| 42 | 
            +
                    Models.AlbertXXLargeCased,
         | 
| 43 | 
            +
                    Models.AlbertBaseCased2,
         | 
| 44 | 
            +
                    Models.AlbertLargeCased2,
         | 
| 45 | 
            +
                    Models.AlbertXLargeCased2,
         | 
| 46 | 
            +
                    Models.AlbertXXLargeCased2
         | 
| 47 | 
            +
                ])
         | 
| 48 | 
            +
                Supported = set([
         | 
| 49 | 
            +
                    getattr(Models, model_type) for model_type
         | 
| 50 | 
            +
                    in filter(lambda x: x[:2] != '__', Models.__dict__.keys())
         | 
| 51 | 
            +
                ])
         | 
    	
        ernie/split_strategies.py
    ADDED
    
    | @@ -0,0 +1,125 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env python
         | 
| 2 | 
            +
            # -*- coding: utf-8 -*-
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import re
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            class RegexExpressions:
         | 
| 8 | 
            +
                split_by_dot = re.compile(r'[^.]+(?:\.\s*)?')
         | 
| 9 | 
            +
                split_by_semicolon = re.compile(r'[^;]+(?:\;\s*)?')
         | 
| 10 | 
            +
                split_by_colon = re.compile(r'[^:]+(?:\:\s*)?')
         | 
| 11 | 
            +
                split_by_comma = re.compile(r'[^,]+(?:\,\s*)?')
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                url = re.compile(
         | 
| 14 | 
            +
                    r'https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}'
         | 
| 15 | 
            +
                    r'\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)'
         | 
| 16 | 
            +
                )
         | 
| 17 | 
            +
                domain = re.compile(r'\w+\.\w+')
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            class SplitStrategy:
         | 
| 21 | 
            +
                def __init__(
         | 
| 22 | 
            +
                    self,
         | 
| 23 | 
            +
                    split_patterns,
         | 
| 24 | 
            +
                    remove_patterns=None,
         | 
| 25 | 
            +
                    group_splits=True,
         | 
| 26 | 
            +
                    remove_too_short_groups=True
         | 
| 27 | 
            +
                ):
         | 
| 28 | 
            +
                    if not isinstance(split_patterns, list):
         | 
| 29 | 
            +
                        self.split_patterns = [split_patterns]
         | 
| 30 | 
            +
                    else:
         | 
| 31 | 
            +
                        self.split_patterns = split_patterns
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    if remove_patterns is not None \
         | 
| 34 | 
            +
                            and not isinstance(remove_patterns, list):
         | 
| 35 | 
            +
                        self.remove_patterns = [remove_patterns]
         | 
| 36 | 
            +
                    else:
         | 
| 37 | 
            +
                        self.remove_patterns = remove_patterns
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                    self.group_splits = group_splits
         | 
| 40 | 
            +
                    self.remove_too_short_groups = remove_too_short_groups
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                def split(self, text, tokenizer, split_patterns=None):
         | 
| 43 | 
            +
                    if split_patterns is None:
         | 
| 44 | 
            +
                        if self.split_patterns is None:
         | 
| 45 | 
            +
                            return [text]
         | 
| 46 | 
            +
                        split_patterns = self.split_patterns
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    def len_in_tokens(text_):
         | 
| 49 | 
            +
                        no_tokens = len(tokenizer.encode(text_, add_special_tokens=False))
         | 
| 50 | 
            +
                        return no_tokens
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    no_special_tokens = len(tokenizer.encode('', add_special_tokens=True))
         | 
| 53 | 
            +
                    max_tokens = tokenizer.max_len - no_special_tokens
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    if self.remove_patterns is not None:
         | 
| 56 | 
            +
                        for remove_pattern in self.remove_patterns:
         | 
| 57 | 
            +
                            text = re.sub(remove_pattern, '', text).strip()
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    if len_in_tokens(text) <= max_tokens:
         | 
| 60 | 
            +
                        return [text]
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    selected_splits = []
         | 
| 63 | 
            +
                    splits = map(lambda x: x.strip(), re.findall(split_patterns[0], text))
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    aggregated_splits = ''
         | 
| 66 | 
            +
                    for split in splits:
         | 
| 67 | 
            +
                        if len_in_tokens(split) > max_tokens:
         | 
| 68 | 
            +
                            if len(split_patterns) > 1:
         | 
| 69 | 
            +
                                sub_splits = self.split(
         | 
| 70 | 
            +
                                    split, tokenizer, split_patterns[1:])
         | 
| 71 | 
            +
                                selected_splits.extend(sub_splits)
         | 
| 72 | 
            +
                            else:
         | 
| 73 | 
            +
                                selected_splits.append(split)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                        else:
         | 
| 76 | 
            +
                            if not self.group_splits:
         | 
| 77 | 
            +
                                selected_splits.append(split)
         | 
| 78 | 
            +
                            else:
         | 
| 79 | 
            +
                                new_aggregated_splits = \
         | 
| 80 | 
            +
                                    f'{aggregated_splits} {split}'.strip()
         | 
| 81 | 
            +
                                if len_in_tokens(new_aggregated_splits) <= max_tokens:
         | 
| 82 | 
            +
                                    aggregated_splits = new_aggregated_splits
         | 
| 83 | 
            +
                                else:
         | 
| 84 | 
            +
                                    selected_splits.append(aggregated_splits)
         | 
| 85 | 
            +
                                    aggregated_splits = split
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    if aggregated_splits:
         | 
| 88 | 
            +
                        selected_splits.append(aggregated_splits)
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    remove_too_short_groups = len(selected_splits) > 1 \
         | 
| 91 | 
            +
                        and self.group_splits \
         | 
| 92 | 
            +
                        and self.remove_too_short_groups
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    if not remove_too_short_groups:
         | 
| 95 | 
            +
                        final_splits = selected_splits
         | 
| 96 | 
            +
                    else:
         | 
| 97 | 
            +
                        final_splits = []
         | 
| 98 | 
            +
                        min_length = tokenizer.max_len / 2
         | 
| 99 | 
            +
                        for split in selected_splits:
         | 
| 100 | 
            +
                            if len_in_tokens(split) >= min_length:
         | 
| 101 | 
            +
                                final_splits.append(split)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    return final_splits
         | 
| 104 | 
            +
             | 
| 105 | 
            +
             | 
| 106 | 
            +
            class SplitStrategies:
         | 
| 107 | 
            +
                SentencesWithoutUrls = SplitStrategy(split_patterns=[
         | 
| 108 | 
            +
                    RegexExpressions.split_by_dot,
         | 
| 109 | 
            +
                    RegexExpressions.split_by_semicolon,
         | 
| 110 | 
            +
                    RegexExpressions.split_by_colon,
         | 
| 111 | 
            +
                    RegexExpressions.split_by_comma
         | 
| 112 | 
            +
                ],
         | 
| 113 | 
            +
                    remove_patterns=[RegexExpressions.url, RegexExpressions.domain],
         | 
| 114 | 
            +
                    remove_too_short_groups=False,
         | 
| 115 | 
            +
                    group_splits=False)
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                GroupedSentencesWithoutUrls = SplitStrategy(split_patterns=[
         | 
| 118 | 
            +
                    RegexExpressions.split_by_dot,
         | 
| 119 | 
            +
                    RegexExpressions.split_by_semicolon,
         | 
| 120 | 
            +
                    RegexExpressions.split_by_colon,
         | 
| 121 | 
            +
                    RegexExpressions.split_by_comma
         | 
| 122 | 
            +
                ],
         | 
| 123 | 
            +
                    remove_patterns=[RegexExpressions.url, RegexExpressions.domain],
         | 
| 124 | 
            +
                    remove_too_short_groups=True,
         | 
| 125 | 
            +
                    group_splits=True)
         |