sanchit-gandhi commited on
Commit
51c1a74
·
1 Parent(s): b263b21

Delete asr_diarizer.py

Browse files
Files changed (1) hide show
  1. asr_diarizer.py +0 -198
asr_diarizer.py DELETED
@@ -1,198 +0,0 @@
1
- from typing import List, Optional, Union
2
-
3
- import numpy as np
4
- import requests
5
- import torch
6
- from pyannote.audio import Pipeline
7
- from torchaudio import functional as F
8
- from transformers import pipeline
9
- from transformers.pipelines.audio_utils import ffmpeg_read
10
-
11
-
12
- class ASRDiarizationPipeline:
13
- def __init__(
14
- self,
15
- asr_pipeline,
16
- diarization_pipeline,
17
- ):
18
- self.asr_pipeline = asr_pipeline
19
- self.sampling_rate = asr_pipeline.feature_extractor.sampling_rate
20
-
21
- self.diarization_pipeline = diarization_pipeline
22
-
23
- @classmethod
24
- def from_pretrained(
25
- cls,
26
- asr_model: Optional[str] = "openai/whisper-medium",
27
- *,
28
- diarizer_model: Optional[str] = "pyannote/speaker-diarization",
29
- chunk_length_s: Optional[int] = 30,
30
- use_auth_token: Optional[Union[str, bool]] = True,
31
- **kwargs,
32
- ):
33
- asr_pipeline = pipeline(
34
- "automatic-speech-recognition",
35
- model=asr_model,
36
- chunk_length_s=chunk_length_s,
37
- use_auth_token=use_auth_token,
38
- **kwargs,
39
- )
40
- diarization_pipeline = Pipeline.from_pretrained(diarizer_model, use_auth_token=use_auth_token)
41
- return cls(asr_pipeline, diarization_pipeline)
42
-
43
- def __call__(
44
- self,
45
- inputs: Union[np.ndarray, List[np.ndarray]],
46
- group_by_speaker: bool = True,
47
- **kwargs,
48
- ):
49
- """
50
- Transcribe the audio sequence(s) given as inputs to text and label with speaker information. The input audio
51
- is first passed to the speaker diarization pipeline, which returns timestamps for 'who spoke when'. The audio
52
- is then passed to the ASR pipeline, which returns utterance-level transcriptions and their corresponding
53
- timestamps. The speaker diarizer timestamps are aligned with the ASR transcription timestamps to give
54
- speaker-labelled transcriptions. We cannot use the speaker diarization timestamps alone to partition the
55
- transcriptions, as these timestamps may straddle across transcribed utterances from the ASR output. Thus, we
56
- find the diarizer timestamps that are closest to the ASR timestamps and partition here.
57
-
58
- Args:
59
- inputs (`np.ndarray` or `bytes` or `str` or `dict`):
60
- The inputs is either :
61
- - `str` that is the filename of the audio file, the file will be read at the correct sampling rate
62
- to get the waveform using *ffmpeg*. This requires *ffmpeg* to be installed on the system.
63
- - `bytes` it is supposed to be the content of an audio file and is interpreted by *ffmpeg* in the
64
- same way.
65
- - (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`)
66
- Raw audio at the correct sampling rate (no further check will be done)
67
- - `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this
68
- pipeline do the resampling. The dict must be in the format `{"sampling_rate": int, "raw":
69
- np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to
70
- treat the first `left` samples and last `right` samples to be ignored in decoding (but used at
71
- inference to provide more context to the model). Only use `stride` with CTC models.
72
- group_by_speaker (`bool`):
73
- Whether to group consecutive utterances by one speaker into a single segment. If False, will return
74
- transcriptions on a chunk-by-chunk basis.
75
-
76
- Return:
77
- A list of transcriptions. Each list item corresponds to one chunk / segment of transcription, and is a
78
- dictionary with the following keys:
79
- - **text** (`str` ) -- The recognized text.
80
- - **speaker** (`str`) -- The associated speaker.
81
- - **timestamps** (`tuple`) -- The start and end time for the chunk / segment.
82
- """
83
- inputs, diarizer_inputs = self.preprocess(inputs)
84
-
85
- diarization = self.diarization_pipeline(
86
- {"waveform": diarizer_inputs, "sample_rate": self.sampling_rate},
87
- **kwargs,
88
- )
89
-
90
- segments = diarization.for_json()["content"]
91
-
92
- # diarizer output may contain consecutive segments from the same speaker (e.g. {(0 -> 1, speaker_1), (1 -> 1.5, speaker_1), ...})
93
- # we combine these segments to give overall timestamps for each speaker's turn (e.g. {(0 -> 1.5, speaker_1), ...})
94
- new_segments = []
95
- prev_segment = cur_segment = segments[0]
96
-
97
- for i in range(1, len(segments)):
98
- cur_segment = segments[i]
99
-
100
- # check if we have changed speaker ("label")
101
- if cur_segment["label"] != prev_segment["label"] and i < len(segments):
102
- # add the start/end times for the super-segment to the new list
103
- new_segments.append(
104
- {
105
- "segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["start"]},
106
- "speaker": prev_segment["label"],
107
- }
108
- )
109
- prev_segment = segments[i]
110
-
111
- # add the last segment(s) if there was no speaker change
112
- new_segments.append(
113
- {
114
- "segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["end"]},
115
- "speaker": prev_segment["label"],
116
- }
117
- )
118
-
119
- asr_out = self.asr_pipeline(
120
- {"array": inputs, "sampling_rate": self.sampling_rate},
121
- return_timestamps=True,
122
- **kwargs,
123
- )
124
- transcript = asr_out["chunks"]
125
-
126
- # get the end timestamps for each chunk from the ASR output
127
- end_timestamps = np.array([chunk["timestamp"][-1] for chunk in transcript])
128
- segmented_preds = []
129
-
130
- # align the diarizer timestamps and the ASR timestamps
131
- for segment in new_segments:
132
- # get the diarizer end timestamp
133
- end_time = segment["segment"]["end"]
134
- # find the ASR end timestamp that is closest to the diarizer's end timestamp and cut the transcript to here
135
- upto_idx = np.argmin(np.abs(end_timestamps - end_time))
136
-
137
- if group_by_speaker:
138
- segmented_preds.append(
139
- {
140
- "speaker": segment["speaker"],
141
- "text": "".join([chunk["text"] for chunk in transcript[: upto_idx + 1]]),
142
- "timestamp": (transcript[0]["timestamp"][0], transcript[upto_idx]["timestamp"][1]),
143
- }
144
- )
145
- else:
146
- for i in range(upto_idx + 1):
147
- segmented_preds.append({"speaker": segment["speaker"], **transcript[i]})
148
-
149
- # crop the transcripts and timestamp lists according to the latest timestamp (for faster argmin)
150
- transcript = transcript[upto_idx + 1 :]
151
- end_timestamps = end_timestamps[upto_idx + 1 :]
152
-
153
- return segmented_preds
154
-
155
- # Adapted from transformers.pipelines.automatic_speech_recognition.AutomaticSpeechRecognitionPipeline.preprocess
156
- # (see https://github.com/huggingface/transformers/blob/238449414f88d94ded35e80459bb6412d8ab42cf/src/transformers/pipelines/automatic_speech_recognition.py#L417)
157
- def preprocess(self, inputs):
158
- if isinstance(inputs, str):
159
- if inputs.startswith("http://") or inputs.startswith("https://"):
160
- # We need to actually check for a real protocol, otherwise it's impossible to use a local file
161
- # like http_huggingface_co.png
162
- inputs = requests.get(inputs).content
163
- else:
164
- with open(inputs, "rb") as f:
165
- inputs = f.read()
166
-
167
- if isinstance(inputs, bytes):
168
- inputs = ffmpeg_read(inputs, self.sampling_rate)
169
-
170
- if isinstance(inputs, dict):
171
- # Accepting `"array"` which is the key defined in `datasets` for better integration
172
- if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)):
173
- raise ValueError(
174
- "When passing a dictionary to ASRDiarizePipeline, the dict needs to contain a "
175
- '"raw" key containing the numpy array representing the audio and a "sampling_rate" key, '
176
- "containing the sampling_rate associated with that array"
177
- )
178
-
179
- _inputs = inputs.pop("raw", None)
180
- if _inputs is None:
181
- # Remove path which will not be used from `datasets`.
182
- inputs.pop("path", None)
183
- _inputs = inputs.pop("array", None)
184
- in_sampling_rate = inputs.pop("sampling_rate")
185
- inputs = _inputs
186
- if in_sampling_rate != self.sampling_rate:
187
- inputs = F.resample(torch.from_numpy(inputs), in_sampling_rate, self.sampling_rate).numpy()
188
-
189
- if not isinstance(inputs, np.ndarray):
190
- raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
191
- if len(inputs.shape) != 1:
192
- raise ValueError("We expect a single channel audio input for ASRDiarizePipeline")
193
-
194
- # diarization model expects float32 torch tensor of shape `(channels, seq_len)`
195
- diarizer_inputs = torch.from_numpy(inputs).float()
196
- diarizer_inputs = diarizer_inputs.unsqueeze(0)
197
-
198
- return inputs, diarizer_inputs