Support CLI into faster-whisper
Browse files- app.py +4 -1
- cli.py +6 -2
- config.json5 +3 -1
- src/config.py +2 -3
- src/whisper/abstractWhisperContainer.py +12 -3
- src/whisper/fasterWhisperContainer.py +41 -8
- src/whisper/whisperContainer.py +14 -4
- src/whisper/whisperFactory.py +4 -3
    	
        app.py
    CHANGED
    
    | @@ -126,7 +126,8 @@ class WhisperTranscriber: | |
| 126 | 
             
                            selectedModel = modelName if modelName is not None else "base"
         | 
| 127 |  | 
| 128 | 
             
                            model = create_whisper_container(whisper_implementation=self.app_config.whisper_implementation, 
         | 
| 129 | 
            -
                                                             model_name=selectedModel,  | 
|  | |
| 130 |  | 
| 131 | 
             
                            # Result
         | 
| 132 | 
             
                            download = []
         | 
| @@ -518,6 +519,8 @@ if __name__ == '__main__': | |
| 518 | 
             
                                    help="directory to save the outputs")
         | 
| 519 | 
             
                parser.add_argument("--whisper_implementation", type=str, default=default_whisper_implementation, choices=["whisper", "faster-whisper"],\
         | 
| 520 | 
             
                                    help="the Whisper implementation to use")
         | 
|  | |
|  | |
| 521 |  | 
| 522 | 
             
                args = parser.parse_args().__dict__
         | 
| 523 |  | 
|  | |
| 126 | 
             
                            selectedModel = modelName if modelName is not None else "base"
         | 
| 127 |  | 
| 128 | 
             
                            model = create_whisper_container(whisper_implementation=self.app_config.whisper_implementation, 
         | 
| 129 | 
            +
                                                             model_name=selectedModel, compute_type=self.app_config.compute_type, 
         | 
| 130 | 
            +
                                                             cache=self.model_cache, models=self.app_config.models)
         | 
| 131 |  | 
| 132 | 
             
                            # Result
         | 
| 133 | 
             
                            download = []
         | 
|  | |
| 519 | 
             
                                    help="directory to save the outputs")
         | 
| 520 | 
             
                parser.add_argument("--whisper_implementation", type=str, default=default_whisper_implementation, choices=["whisper", "faster-whisper"],\
         | 
| 521 | 
             
                                    help="the Whisper implementation to use")
         | 
| 522 | 
            +
                parser.add_argument("--compute_type", type=str, default=default_app_config.compute_type, choices=["int8", "int8_float16", "int16", "float16"], \
         | 
| 523 | 
            +
                                    help="the compute type to use for inference")
         | 
| 524 |  | 
| 525 | 
             
                args = parser.parse_args().__dict__
         | 
| 526 |  | 
    	
        cli.py
    CHANGED
    
    | @@ -80,6 +80,8 @@ def cli(): | |
| 80 | 
             
                                    help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
         | 
| 81 | 
             
                parser.add_argument("--fp16", type=str2bool, default=app_config.fp16, \
         | 
| 82 | 
             
                                    help="whether to perform inference in fp16; True by default")
         | 
|  | |
|  | |
| 83 |  | 
| 84 | 
             
                parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=app_config.temperature_increment_on_fallback, \
         | 
| 85 | 
             
                                    help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
         | 
| @@ -119,12 +121,14 @@ def cli(): | |
| 119 | 
             
                vad_cpu_cores = args.pop("vad_cpu_cores")
         | 
| 120 | 
             
                auto_parallel = args.pop("auto_parallel")
         | 
| 121 |  | 
|  | |
|  | |
| 122 | 
             
                transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
         | 
| 123 | 
             
                transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
         | 
| 124 | 
             
                transcriber.set_auto_parallel(auto_parallel)
         | 
| 125 |  | 
| 126 | 
            -
                model = create_whisper_container(whisper_implementation=whisper_implementation, 
         | 
| 127 | 
            -
                                                 device=device, download_root=model_dir, models=app_config.models)
         | 
| 128 |  | 
| 129 | 
             
                if (transcriber._has_parallel_devices()):
         | 
| 130 | 
             
                    print("Using parallel devices:", transcriber.parallel_device_list)
         | 
|  | |
| 80 | 
             
                                    help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
         | 
| 81 | 
             
                parser.add_argument("--fp16", type=str2bool, default=app_config.fp16, \
         | 
| 82 | 
             
                                    help="whether to perform inference in fp16; True by default")
         | 
| 83 | 
            +
                parser.add_argument("--compute_type", type=str, default=app_config.compute_type, choices=["int8", "int8_float16", "int16", "float16"], \
         | 
| 84 | 
            +
                                    help="the compute type to use for inference")
         | 
| 85 |  | 
| 86 | 
             
                parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=app_config.temperature_increment_on_fallback, \
         | 
| 87 | 
             
                                    help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
         | 
|  | |
| 121 | 
             
                vad_cpu_cores = args.pop("vad_cpu_cores")
         | 
| 122 | 
             
                auto_parallel = args.pop("auto_parallel")
         | 
| 123 |  | 
| 124 | 
            +
                compute_type = args.pop("compute_type")
         | 
| 125 | 
            +
             | 
| 126 | 
             
                transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
         | 
| 127 | 
             
                transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
         | 
| 128 | 
             
                transcriber.set_auto_parallel(auto_parallel)
         | 
| 129 |  | 
| 130 | 
            +
                model = create_whisper_container(whisper_implementation=whisper_implementation, model_name=model_name, 
         | 
| 131 | 
            +
                                                 device=device, compute_type=compute_type, download_root=model_dir, models=app_config.models)
         | 
| 132 |  | 
| 133 | 
             
                if (transcriber._has_parallel_devices()):
         | 
| 134 | 
             
                    print("Using parallel devices:", transcriber.parallel_device_list)
         | 
    	
        config.json5
    CHANGED
    
    | @@ -104,7 +104,7 @@ | |
| 104 | 
             
                // Number of beams in beam search, only applicable when temperature is zero
         | 
| 105 | 
             
                "beam_size": 5,
         | 
| 106 | 
             
                // Optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search
         | 
| 107 | 
            -
                "patience":  | 
| 108 | 
             
                // Optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default
         | 
| 109 | 
             
                "length_penalty": null,
         | 
| 110 | 
             
                // Comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations
         | 
| @@ -115,6 +115,8 @@ | |
| 115 | 
             
                "condition_on_previous_text": true,
         | 
| 116 | 
             
                // Whether to perform inference in fp16; True by default
         | 
| 117 | 
             
                "fp16": true,
         | 
|  | |
|  | |
| 118 | 
             
                // Temperature to increase when falling back when the decoding fails to meet either of the thresholds below
         | 
| 119 | 
             
                "temperature_increment_on_fallback": 0.2,
         | 
| 120 | 
             
                // If the gzip compression ratio is higher than this value, treat the decoding as failed
         | 
|  | |
| 104 | 
             
                // Number of beams in beam search, only applicable when temperature is zero
         | 
| 105 | 
             
                "beam_size": 5,
         | 
| 106 | 
             
                // Optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search
         | 
| 107 | 
            +
                "patience": 1,
         | 
| 108 | 
             
                // Optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default
         | 
| 109 | 
             
                "length_penalty": null,
         | 
| 110 | 
             
                // Comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations
         | 
|  | |
| 115 | 
             
                "condition_on_previous_text": true,
         | 
| 116 | 
             
                // Whether to perform inference in fp16; True by default
         | 
| 117 | 
             
                "fp16": true,
         | 
| 118 | 
            +
                // The compute type used by faster-whisper. Can be "int8". "int16" or "float16".
         | 
| 119 | 
            +
                "compute_type": "float16",
         | 
| 120 | 
             
                // Temperature to increase when falling back when the decoding fails to meet either of the thresholds below
         | 
| 121 | 
             
                "temperature_increment_on_fallback": 0.2,
         | 
| 122 | 
             
                // If the gzip compression ratio is higher than this value, treat the decoding as failed
         | 
    	
        src/config.py
    CHANGED
    
    | @@ -39,12 +39,10 @@ class ApplicationConfig: | |
| 39 | 
             
                             patience: float = None, length_penalty: float = None,
         | 
| 40 | 
             
                             suppress_tokens: str = "-1", initial_prompt: str = None,
         | 
| 41 | 
             
                             condition_on_previous_text: bool = True, fp16: bool = True,
         | 
|  | |
| 42 | 
             
                             temperature_increment_on_fallback: float = 0.2, compression_ratio_threshold: float = 2.4,
         | 
| 43 | 
             
                             logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6):
         | 
| 44 |  | 
| 45 | 
            -
                    if device is None:
         | 
| 46 | 
            -
                        device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 47 | 
            -
             | 
| 48 | 
             
                    self.models = models
         | 
| 49 |  | 
| 50 | 
             
                    # WebUI settings
         | 
| @@ -82,6 +80,7 @@ class ApplicationConfig: | |
| 82 | 
             
                    self.initial_prompt = initial_prompt
         | 
| 83 | 
             
                    self.condition_on_previous_text = condition_on_previous_text
         | 
| 84 | 
             
                    self.fp16 = fp16
         | 
|  | |
| 85 | 
             
                    self.temperature_increment_on_fallback = temperature_increment_on_fallback
         | 
| 86 | 
             
                    self.compression_ratio_threshold = compression_ratio_threshold
         | 
| 87 | 
             
                    self.logprob_threshold = logprob_threshold
         | 
|  | |
| 39 | 
             
                             patience: float = None, length_penalty: float = None,
         | 
| 40 | 
             
                             suppress_tokens: str = "-1", initial_prompt: str = None,
         | 
| 41 | 
             
                             condition_on_previous_text: bool = True, fp16: bool = True,
         | 
| 42 | 
            +
                             compute_type: str = "float16", 
         | 
| 43 | 
             
                             temperature_increment_on_fallback: float = 0.2, compression_ratio_threshold: float = 2.4,
         | 
| 44 | 
             
                             logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6):
         | 
| 45 |  | 
|  | |
|  | |
|  | |
| 46 | 
             
                    self.models = models
         | 
| 47 |  | 
| 48 | 
             
                    # WebUI settings
         | 
|  | |
| 80 | 
             
                    self.initial_prompt = initial_prompt
         | 
| 81 | 
             
                    self.condition_on_previous_text = condition_on_previous_text
         | 
| 82 | 
             
                    self.fp16 = fp16
         | 
| 83 | 
            +
                    self.compute_type = compute_type
         | 
| 84 | 
             
                    self.temperature_increment_on_fallback = temperature_increment_on_fallback
         | 
| 85 | 
             
                    self.compression_ratio_threshold = compression_ratio_threshold
         | 
| 86 | 
             
                    self.logprob_threshold = logprob_threshold
         | 
    	
        src/whisper/abstractWhisperContainer.py
    CHANGED
    
    | @@ -33,10 +33,12 @@ class AbstractWhisperCallback: | |
| 33 | 
             
                        return prompt1 + " " + prompt2
         | 
| 34 |  | 
| 35 | 
             
            class AbstractWhisperContainer:
         | 
| 36 | 
            -
                def __init__(self, model_name: str, device: str = None,  | 
| 37 | 
            -
             | 
|  | |
| 38 | 
             
                    self.model_name = model_name
         | 
| 39 | 
             
                    self.device = device
         | 
|  | |
| 40 | 
             
                    self.download_root = download_root
         | 
| 41 | 
             
                    self.cache = cache
         | 
| 42 |  | 
| @@ -87,13 +89,20 @@ class AbstractWhisperContainer: | |
| 87 |  | 
| 88 | 
             
                # This is required for multiprocessing
         | 
| 89 | 
             
                def __getstate__(self):
         | 
| 90 | 
            -
                    return {  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 91 |  | 
| 92 | 
             
                def __setstate__(self, state):
         | 
| 93 | 
             
                    self.model_name = state["model_name"]
         | 
| 94 | 
             
                    self.device = state["device"]
         | 
| 95 | 
             
                    self.download_root = state["download_root"]
         | 
| 96 | 
             
                    self.models = state["models"]
         | 
|  | |
| 97 | 
             
                    self.model = None
         | 
| 98 | 
             
                    # Depickled objects must use the global cache
         | 
| 99 | 
             
                    self.cache = GLOBAL_MODEL_CACHE
         | 
|  | |
| 33 | 
             
                        return prompt1 + " " + prompt2
         | 
| 34 |  | 
| 35 | 
             
            class AbstractWhisperContainer:
         | 
| 36 | 
            +
                def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
         | 
| 37 | 
            +
                             download_root: str = None,
         | 
| 38 | 
            +
                             cache: ModelCache = None, models: List[ModelConfig] = []):
         | 
| 39 | 
             
                    self.model_name = model_name
         | 
| 40 | 
             
                    self.device = device
         | 
| 41 | 
            +
                    self.compute_type = compute_type
         | 
| 42 | 
             
                    self.download_root = download_root
         | 
| 43 | 
             
                    self.cache = cache
         | 
| 44 |  | 
|  | |
| 89 |  | 
| 90 | 
             
                # This is required for multiprocessing
         | 
| 91 | 
             
                def __getstate__(self):
         | 
| 92 | 
            +
                    return { 
         | 
| 93 | 
            +
                        "model_name": self.model_name, 
         | 
| 94 | 
            +
                        "device": self.device, 
         | 
| 95 | 
            +
                        "download_root": self.download_root, 
         | 
| 96 | 
            +
                        "models": self.models, 
         | 
| 97 | 
            +
                        "compute_type": self.compute_type 
         | 
| 98 | 
            +
                    }
         | 
| 99 |  | 
| 100 | 
             
                def __setstate__(self, state):
         | 
| 101 | 
             
                    self.model_name = state["model_name"]
         | 
| 102 | 
             
                    self.device = state["device"]
         | 
| 103 | 
             
                    self.download_root = state["download_root"]
         | 
| 104 | 
             
                    self.models = state["models"]
         | 
| 105 | 
            +
                    self.compute_type = state["compute_type"]
         | 
| 106 | 
             
                    self.model = None
         | 
| 107 | 
             
                    # Depickled objects must use the global cache
         | 
| 108 | 
             
                    self.cache = GLOBAL_MODEL_CACHE
         | 
    	
        src/whisper/fasterWhisperContainer.py
    CHANGED
    
    | @@ -1,5 +1,5 @@ | |
| 1 | 
             
            import os
         | 
| 2 | 
            -
            from typing import List
         | 
| 3 |  | 
| 4 | 
             
            from faster_whisper import WhisperModel, download_model
         | 
| 5 | 
             
            from src.config import ModelConfig
         | 
| @@ -8,10 +8,10 @@ from src.modelCache import ModelCache | |
| 8 | 
             
            from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
         | 
| 9 |  | 
| 10 | 
             
            class FasterWhisperContainer(AbstractWhisperContainer):
         | 
| 11 | 
            -
                def __init__(self, model_name: str, device: str = None,  | 
| 12 | 
            -
                                    | 
| 13 | 
            -
                                   models: List[ModelConfig] = []):
         | 
| 14 | 
            -
                    super().__init__(model_name, device, download_root, cache, models)
         | 
| 15 |  | 
| 16 | 
             
                def ensure_downloaded(self):
         | 
| 17 | 
             
                    """
         | 
| @@ -35,7 +35,7 @@ class FasterWhisperContainer(AbstractWhisperContainer): | |
| 35 | 
             
                    return None
         | 
| 36 |  | 
| 37 | 
             
                def _create_model(self):
         | 
| 38 | 
            -
                    print("Loading faster whisper model " + self.model_name)
         | 
| 39 | 
             
                    model_config = self._get_model_config()
         | 
| 40 |  | 
| 41 | 
             
                    if model_config.type == "whisper" and model_config.url not in ["tiny", "base", "small", "medium", "large", "large-v2"]:
         | 
| @@ -46,7 +46,7 @@ class FasterWhisperContainer(AbstractWhisperContainer): | |
| 46 | 
             
                    if (device is None):
         | 
| 47 | 
             
                        device = "auto"
         | 
| 48 |  | 
| 49 | 
            -
                    model = WhisperModel(model_config.url, device=device, compute_type= | 
| 50 | 
             
                    return model
         | 
| 51 |  | 
| 52 | 
             
                def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
         | 
| @@ -96,10 +96,33 @@ class FasterWhisperCallback(AbstractWhisperCallback): | |
| 96 | 
             
                    model: WhisperModel = self.model_container.get_model()
         | 
| 97 | 
             
                    language_code = self._lookup_language_code(self.language) if self.language else None
         | 
| 98 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 99 | 
             
                    segments_generator, info = model.transcribe(audio, \
         | 
| 100 | 
             
                        language=language_code if language_code else detected_language, task=self.task, \
         | 
| 101 | 
             
                        initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
         | 
| 102 | 
            -
                        ** | 
| 103 | 
             
                    )
         | 
| 104 |  | 
| 105 | 
             
                    segments = []
         | 
| @@ -109,6 +132,8 @@ class FasterWhisperCallback(AbstractWhisperCallback): | |
| 109 |  | 
| 110 | 
             
                        if progress_listener is not None:
         | 
| 111 | 
             
                            progress_listener.on_progress(segment.end, info.duration)
         | 
|  | |
|  | |
| 112 |  | 
| 113 | 
             
                    text = " ".join([segment.text for segment in segments])
         | 
| 114 |  | 
| @@ -141,6 +166,14 @@ class FasterWhisperCallback(AbstractWhisperCallback): | |
| 141 | 
             
                        progress_listener.on_finished()
         | 
| 142 | 
             
                    return result
         | 
| 143 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 144 | 
             
                def _lookup_language_code(self, language: str):
         | 
| 145 | 
             
                    lookup = {
         | 
| 146 | 
             
                        "english": "en", "chinese": "zh-cn", "german": "de", "spanish": "es", "russian": "ru", "korean": "ko",
         | 
|  | |
| 1 | 
             
            import os
         | 
| 2 | 
            +
            from typing import List, Union
         | 
| 3 |  | 
| 4 | 
             
            from faster_whisper import WhisperModel, download_model
         | 
| 5 | 
             
            from src.config import ModelConfig
         | 
|  | |
| 8 | 
             
            from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
         | 
| 9 |  | 
| 10 | 
             
            class FasterWhisperContainer(AbstractWhisperContainer):
         | 
| 11 | 
            +
                def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
         | 
| 12 | 
            +
                                   download_root: str = None,
         | 
| 13 | 
            +
                                   cache: ModelCache = None, models: List[ModelConfig] = []):
         | 
| 14 | 
            +
                    super().__init__(model_name, device, compute_type, download_root, cache, models)
         | 
| 15 |  | 
| 16 | 
             
                def ensure_downloaded(self):
         | 
| 17 | 
             
                    """
         | 
|  | |
| 35 | 
             
                    return None
         | 
| 36 |  | 
| 37 | 
             
                def _create_model(self):
         | 
| 38 | 
            +
                    print("Loading faster whisper model " + self.model_name + " for device " + str(self.device))
         | 
| 39 | 
             
                    model_config = self._get_model_config()
         | 
| 40 |  | 
| 41 | 
             
                    if model_config.type == "whisper" and model_config.url not in ["tiny", "base", "small", "medium", "large", "large-v2"]:
         | 
|  | |
| 46 | 
             
                    if (device is None):
         | 
| 47 | 
             
                        device = "auto"
         | 
| 48 |  | 
| 49 | 
            +
                    model = WhisperModel(model_config.url, device=device, compute_type=self.compute_type)
         | 
| 50 | 
             
                    return model
         | 
| 51 |  | 
| 52 | 
             
                def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
         | 
|  | |
| 96 | 
             
                    model: WhisperModel = self.model_container.get_model()
         | 
| 97 | 
             
                    language_code = self._lookup_language_code(self.language) if self.language else None
         | 
| 98 |  | 
| 99 | 
            +
                    # Copy decode options and remove options that are not supported by faster-whisper
         | 
| 100 | 
            +
                    decodeOptions = self.decodeOptions.copy()
         | 
| 101 | 
            +
                    verbose = decodeOptions.pop("verbose", None)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    logprob_threshold = decodeOptions.pop("logprob_threshold", None)
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    patience = decodeOptions.pop("patience", None)
         | 
| 106 | 
            +
                    length_penalty = decodeOptions.pop("length_penalty", None)
         | 
| 107 | 
            +
                    suppress_tokens = decodeOptions.pop("suppress_tokens", None)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    if (decodeOptions.pop("fp16", None) is not None):
         | 
| 110 | 
            +
                        print("WARNING: fp16 option is ignored by faster-whisper - use compute_type instead.")
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    # Fix up decode options
         | 
| 113 | 
            +
                    if (logprob_threshold is not None):
         | 
| 114 | 
            +
                        decodeOptions["log_prob_threshold"] = logprob_threshold
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    decodeOptions["patience"] = float(patience) if patience is not None else 1.0
         | 
| 117 | 
            +
                    decodeOptions["length_penalty"] = float(length_penalty) if length_penalty is not None else 1.0
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    # See if supress_tokens is a string - if so, convert it to a list of ints
         | 
| 120 | 
            +
                    decodeOptions["suppress_tokens"] = self._split_suppress_tokens(suppress_tokens)
         | 
| 121 | 
            +
             | 
| 122 | 
             
                    segments_generator, info = model.transcribe(audio, \
         | 
| 123 | 
             
                        language=language_code if language_code else detected_language, task=self.task, \
         | 
| 124 | 
             
                        initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
         | 
| 125 | 
            +
                        **decodeOptions
         | 
| 126 | 
             
                    )
         | 
| 127 |  | 
| 128 | 
             
                    segments = []
         | 
|  | |
| 132 |  | 
| 133 | 
             
                        if progress_listener is not None:
         | 
| 134 | 
             
                            progress_listener.on_progress(segment.end, info.duration)
         | 
| 135 | 
            +
                        if verbose:
         | 
| 136 | 
            +
                            print(segment.text)
         | 
| 137 |  | 
| 138 | 
             
                    text = " ".join([segment.text for segment in segments])
         | 
| 139 |  | 
|  | |
| 166 | 
             
                        progress_listener.on_finished()
         | 
| 167 | 
             
                    return result
         | 
| 168 |  | 
| 169 | 
            +
                def _split_suppress_tokens(self, suppress_tokens: Union[str, List[int]]):
         | 
| 170 | 
            +
                    if (suppress_tokens is None):
         | 
| 171 | 
            +
                        return None
         | 
| 172 | 
            +
                    if (isinstance(suppress_tokens, list)):
         | 
| 173 | 
            +
                        return suppress_tokens
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    return [int(token) for token in suppress_tokens.split(",")]
         | 
| 176 | 
            +
             | 
| 177 | 
             
                def _lookup_language_code(self, language: str):
         | 
| 178 | 
             
                    lookup = {
         | 
| 179 | 
             
                        "english": "en", "chinese": "zh-cn", "german": "de", "spanish": "es", "russian": "ru", "korean": "ko",
         | 
    	
        src/whisper/whisperContainer.py
    CHANGED
    
    | @@ -4,6 +4,7 @@ import os | |
| 4 | 
             
            import sys
         | 
| 5 | 
             
            from typing import List
         | 
| 6 | 
             
            from urllib.parse import urlparse
         | 
|  | |
| 7 | 
             
            import urllib3
         | 
| 8 | 
             
            from src.hooks.progressListener import ProgressListener
         | 
| 9 |  | 
| @@ -18,9 +19,12 @@ from src.utils import download_file | |
| 18 | 
             
            from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
         | 
| 19 |  | 
| 20 | 
             
            class WhisperContainer(AbstractWhisperContainer):
         | 
| 21 | 
            -
                def __init__(self, model_name: str, device: str = None,  | 
| 22 | 
            -
             | 
| 23 | 
            -
             | 
|  | |
|  | |
|  | |
| 24 |  | 
| 25 | 
             
                def ensure_downloaded(self):
         | 
| 26 | 
             
                    """
         | 
| @@ -184,8 +188,14 @@ class WhisperCallback(AbstractWhisperCallback): | |
| 184 | 
             
                        return self._transcribe(model, audio, segment_index, prompt, detected_language)
         | 
| 185 |  | 
| 186 | 
             
                def _transcribe(self, model: Whisper, audio, segment_index: int, prompt: str, detected_language: str):
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 187 | 
             
                    return model.transcribe(audio, \
         | 
| 188 | 
             
                        language=self.language if self.language else detected_language, task=self.task, \
         | 
| 189 | 
             
                        initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
         | 
| 190 | 
            -
                        ** | 
| 191 | 
             
                    )
         | 
|  | |
| 4 | 
             
            import sys
         | 
| 5 | 
             
            from typing import List
         | 
| 6 | 
             
            from urllib.parse import urlparse
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
             
            import urllib3
         | 
| 9 | 
             
            from src.hooks.progressListener import ProgressListener
         | 
| 10 |  | 
|  | |
| 19 | 
             
            from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
         | 
| 20 |  | 
| 21 | 
             
            class WhisperContainer(AbstractWhisperContainer):
         | 
| 22 | 
            +
                def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
         | 
| 23 | 
            +
                             download_root: str = None,
         | 
| 24 | 
            +
                             cache: ModelCache = None, models: List[ModelConfig] = []):
         | 
| 25 | 
            +
                    if device is None:
         | 
| 26 | 
            +
                        device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 27 | 
            +
                    super().__init__(model_name, device, compute_type, download_root, cache, models)
         | 
| 28 |  | 
| 29 | 
             
                def ensure_downloaded(self):
         | 
| 30 | 
             
                    """
         | 
|  | |
| 188 | 
             
                        return self._transcribe(model, audio, segment_index, prompt, detected_language)
         | 
| 189 |  | 
| 190 | 
             
                def _transcribe(self, model: Whisper, audio, segment_index: int, prompt: str, detected_language: str):
         | 
| 191 | 
            +
                    decodeOptions = self.decodeOptions.copy()
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                    # Add fp16
         | 
| 194 | 
            +
                    if self.model_container.compute_type in ["fp16", "float16"]:
         | 
| 195 | 
            +
                        decodeOptions["fp16"] = True
         | 
| 196 | 
            +
             | 
| 197 | 
             
                    return model.transcribe(audio, \
         | 
| 198 | 
             
                        language=self.language if self.language else detected_language, task=self.task, \
         | 
| 199 | 
             
                        initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
         | 
| 200 | 
            +
                        **decodeOptions
         | 
| 201 | 
             
                    )
         | 
    	
        src/whisper/whisperFactory.py
    CHANGED
    
    | @@ -4,15 +4,16 @@ from src.config import ModelConfig | |
| 4 | 
             
            from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
         | 
| 5 |  | 
| 6 | 
             
            def create_whisper_container(whisper_implementation: str, 
         | 
| 7 | 
            -
                                         model_name: str, device: str = None,  | 
|  | |
| 8 | 
             
                                         cache: modelCache = None, models: List[ModelConfig] = []) -> AbstractWhisperContainer:
         | 
| 9 | 
             
                print("Creating whisper container for " + whisper_implementation)
         | 
| 10 |  | 
| 11 | 
             
                if (whisper_implementation == "whisper"):
         | 
| 12 | 
             
                    from src.whisper.whisperContainer import WhisperContainer
         | 
| 13 | 
            -
                    return WhisperContainer(model_name, device, download_root, cache, models)
         | 
| 14 | 
             
                elif (whisper_implementation == "faster-whisper" or whisper_implementation == "faster_whisper"):
         | 
| 15 | 
             
                    from src.whisper.fasterWhisperContainer import FasterWhisperContainer
         | 
| 16 | 
            -
                    return FasterWhisperContainer(model_name, device, download_root, cache, models)
         | 
| 17 | 
             
                else:
         | 
| 18 | 
             
                    raise ValueError("Unknown Whisper implementation: " + whisper_implementation)
         | 
|  | |
| 4 | 
             
            from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
         | 
| 5 |  | 
| 6 | 
             
            def create_whisper_container(whisper_implementation: str, 
         | 
| 7 | 
            +
                                         model_name: str, device: str = None, compute_type: str = "float16",
         | 
| 8 | 
            +
                                         download_root: str = None,
         | 
| 9 | 
             
                                         cache: modelCache = None, models: List[ModelConfig] = []) -> AbstractWhisperContainer:
         | 
| 10 | 
             
                print("Creating whisper container for " + whisper_implementation)
         | 
| 11 |  | 
| 12 | 
             
                if (whisper_implementation == "whisper"):
         | 
| 13 | 
             
                    from src.whisper.whisperContainer import WhisperContainer
         | 
| 14 | 
            +
                    return WhisperContainer(model_name=model_name, device=device, compute_type=compute_type, download_root=download_root, cache=cache, models=models)
         | 
| 15 | 
             
                elif (whisper_implementation == "faster-whisper" or whisper_implementation == "faster_whisper"):
         | 
| 16 | 
             
                    from src.whisper.fasterWhisperContainer import FasterWhisperContainer
         | 
| 17 | 
            +
                    return FasterWhisperContainer(model_name=model_name, device=device, compute_type=compute_type, download_root=download_root, cache=cache, models=models)
         | 
| 18 | 
             
                else:
         | 
| 19 | 
             
                    raise ValueError("Unknown Whisper implementation: " + whisper_implementation)
         | 
