Spaces:
Runtime error
Runtime error
sanchit-gandhi
commited on
Commit
·
fc4914b
1
Parent(s):
ff9897a
Update asr_diarizer.py
Browse files- asr_diarizer.py +35 -16
asr_diarizer.py
CHANGED
@@ -16,14 +16,15 @@ class ASRDiarizationPipeline:
|
|
16 |
diarization_pipeline,
|
17 |
):
|
18 |
self.asr_pipeline = asr_pipeline
|
19 |
-
self.
|
20 |
|
21 |
-
self.
|
22 |
|
23 |
@classmethod
|
24 |
def from_pretrained(
|
25 |
cls,
|
26 |
-
asr_model: Optional[str] = "openai/whisper-
|
|
|
27 |
diarizer_model: Optional[str] = "pyannote/speaker-diarization",
|
28 |
chunk_length_s: Optional[int] = 30,
|
29 |
use_auth_token: Optional[Union[str, bool]] = True,
|
@@ -37,7 +38,7 @@ class ASRDiarizationPipeline:
|
|
37 |
**kwargs,
|
38 |
)
|
39 |
diarization_pipeline = Pipeline.from_pretrained(diarizer_model, use_auth_token=use_auth_token)
|
40 |
-
cls(asr_pipeline, diarization_pipeline)
|
41 |
|
42 |
def __call__(
|
43 |
self,
|
@@ -46,7 +47,13 @@ class ASRDiarizationPipeline:
|
|
46 |
**kwargs,
|
47 |
):
|
48 |
"""
|
49 |
-
Transcribe the audio sequence(s) given as inputs to text.
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
Args:
|
52 |
inputs (`np.ndarray` or `bytes` or `str` or `dict`):
|
@@ -62,15 +69,16 @@ class ASRDiarizationPipeline:
|
|
62 |
np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to
|
63 |
treat the first `left` samples and last `right` samples to be ignored in decoding (but used at
|
64 |
inference to provide more context to the model). Only use `stride` with CTC models.
|
|
|
|
|
|
|
65 |
|
66 |
Return:
|
67 |
-
|
|
|
68 |
- **text** (`str` ) -- The recognized text.
|
69 |
-
- **
|
70 |
-
|
71 |
-
chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamps": (0.5,0.9), {"text":
|
72 |
-
"there", "timestamps": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing
|
73 |
-
`"".join(chunk["text"] for chunk in output["chunks"])`.
|
74 |
"""
|
75 |
inputs, diarizer_inputs = self.preprocess(inputs)
|
76 |
|
@@ -81,13 +89,17 @@ class ASRDiarizationPipeline:
|
|
81 |
|
82 |
segments = diarization.for_json()["content"]
|
83 |
|
|
|
|
|
84 |
new_segments = []
|
85 |
prev_segment = cur_segment = segments[0]
|
86 |
|
87 |
for i in range(1, len(segments)):
|
88 |
cur_segment = segments[i]
|
89 |
|
|
|
90 |
if cur_segment["label"] != prev_segment["label"] and i < len(segments):
|
|
|
91 |
new_segments.append(
|
92 |
{
|
93 |
"segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["start"]},
|
@@ -96,6 +108,7 @@ class ASRDiarizationPipeline:
|
|
96 |
)
|
97 |
prev_segment = segments[i]
|
98 |
|
|
|
99 |
new_segments.append(
|
100 |
{
|
101 |
"segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["end"]},
|
@@ -110,11 +123,15 @@ class ASRDiarizationPipeline:
|
|
110 |
)
|
111 |
transcript = asr_out["chunks"]
|
112 |
|
|
|
113 |
end_timestamps = np.array([chunk["timestamp"][-1] for chunk in transcript])
|
114 |
segmented_preds = []
|
115 |
|
|
|
116 |
for segment in new_segments:
|
|
|
117 |
end_time = segment["segment"]["end"]
|
|
|
118 |
upto_idx = np.argmin(np.abs(end_timestamps - end_time))
|
119 |
|
120 |
if group_by_speaker:
|
@@ -122,21 +139,21 @@ class ASRDiarizationPipeline:
|
|
122 |
{
|
123 |
"speaker": segment["speaker"],
|
124 |
"text": "".join([chunk["text"] for chunk in transcript[: upto_idx + 1]]),
|
125 |
-
"timestamp":
|
126 |
-
"start": transcript[0]["timestamp"][0],
|
127 |
-
"end": transcript[upto_idx]["timestamp"][1],
|
128 |
-
},
|
129 |
}
|
130 |
)
|
131 |
else:
|
132 |
for i in range(upto_idx + 1):
|
133 |
segmented_preds.append({"speaker": segment["speaker"], **transcript[i]})
|
134 |
|
|
|
135 |
transcript = transcript[upto_idx + 1 :]
|
136 |
end_timestamps = end_timestamps[upto_idx + 1 :]
|
137 |
|
138 |
return segmented_preds
|
139 |
|
|
|
|
|
140 |
def preprocess(self, inputs):
|
141 |
if isinstance(inputs, str):
|
142 |
if inputs.startswith("http://") or inputs.startswith("https://"):
|
@@ -174,6 +191,8 @@ class ASRDiarizationPipeline:
|
|
174 |
if len(inputs.shape) != 1:
|
175 |
raise ValueError("We expect a single channel audio input for ASRDiarizePipeline")
|
176 |
|
177 |
-
|
|
|
|
|
178 |
|
179 |
return inputs, diarizer_inputs
|
|
|
16 |
diarization_pipeline,
|
17 |
):
|
18 |
self.asr_pipeline = asr_pipeline
|
19 |
+
self.sampling_rate = asr_pipeline.feature_extractor.sampling_rate
|
20 |
|
21 |
+
self.diarization_pipeline = diarization_pipeline
|
22 |
|
23 |
@classmethod
|
24 |
def from_pretrained(
|
25 |
cls,
|
26 |
+
asr_model: Optional[str] = "openai/whisper-medium",
|
27 |
+
*,
|
28 |
diarizer_model: Optional[str] = "pyannote/speaker-diarization",
|
29 |
chunk_length_s: Optional[int] = 30,
|
30 |
use_auth_token: Optional[Union[str, bool]] = True,
|
|
|
38 |
**kwargs,
|
39 |
)
|
40 |
diarization_pipeline = Pipeline.from_pretrained(diarizer_model, use_auth_token=use_auth_token)
|
41 |
+
return cls(asr_pipeline, diarization_pipeline)
|
42 |
|
43 |
def __call__(
|
44 |
self,
|
|
|
47 |
**kwargs,
|
48 |
):
|
49 |
"""
|
50 |
+
Transcribe the audio sequence(s) given as inputs to text and label with speaker information. The input audio
|
51 |
+
is first passed to the speaker diarization pipeline, which returns timestamps for 'who spoke when'. The audio
|
52 |
+
is then passed to the ASR pipeline, which returns utterance-level transcriptions and their corresponding
|
53 |
+
timestamps. The speaker diarizer timestamps are aligned with the ASR transcription timestamps to give
|
54 |
+
speaker-labelled transcriptions. We cannot use the speaker diarization timestamps alone to partition the
|
55 |
+
transcriptions, as these timestamps may straddle across transcribed utterances from the ASR output. Thus, we
|
56 |
+
find the diarizer timestamps that are closest to the ASR timestamps and partition here.
|
57 |
|
58 |
Args:
|
59 |
inputs (`np.ndarray` or `bytes` or `str` or `dict`):
|
|
|
69 |
np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to
|
70 |
treat the first `left` samples and last `right` samples to be ignored in decoding (but used at
|
71 |
inference to provide more context to the model). Only use `stride` with CTC models.
|
72 |
+
group_by_speaker (`bool`):
|
73 |
+
Whether to group consecutive utterances by one speaker into a single segment. If False, will return
|
74 |
+
transcriptions on a chunk-by-chunk basis.
|
75 |
|
76 |
Return:
|
77 |
+
A list of transcriptions. Each list item corresponds to one chunk / segment of transcription, and is a
|
78 |
+
dictionary with the following keys:
|
79 |
- **text** (`str` ) -- The recognized text.
|
80 |
+
- **speaker** (`str`) -- The associated speaker.
|
81 |
+
- **timestamps** (`tuple`) -- The start and end time for the chunk / segment.
|
|
|
|
|
|
|
82 |
"""
|
83 |
inputs, diarizer_inputs = self.preprocess(inputs)
|
84 |
|
|
|
89 |
|
90 |
segments = diarization.for_json()["content"]
|
91 |
|
92 |
+
# diarizer output may contain consecutive segments from the same speaker (e.g. {(0 -> 1, speaker_1), (1 -> 1.5, speaker_1), ...})
|
93 |
+
# we combine these segments to give overall timestamps for each speaker's turn (e.g. {(0 -> 1.5, speaker_1), ...})
|
94 |
new_segments = []
|
95 |
prev_segment = cur_segment = segments[0]
|
96 |
|
97 |
for i in range(1, len(segments)):
|
98 |
cur_segment = segments[i]
|
99 |
|
100 |
+
# check if we have changed speaker ("label")
|
101 |
if cur_segment["label"] != prev_segment["label"] and i < len(segments):
|
102 |
+
# add the start/end times for the super-segment to the new list
|
103 |
new_segments.append(
|
104 |
{
|
105 |
"segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["start"]},
|
|
|
108 |
)
|
109 |
prev_segment = segments[i]
|
110 |
|
111 |
+
# add the last segment(s) if there was no speaker change
|
112 |
new_segments.append(
|
113 |
{
|
114 |
"segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["end"]},
|
|
|
123 |
)
|
124 |
transcript = asr_out["chunks"]
|
125 |
|
126 |
+
# get the end timestamps for each chunk from the ASR output
|
127 |
end_timestamps = np.array([chunk["timestamp"][-1] for chunk in transcript])
|
128 |
segmented_preds = []
|
129 |
|
130 |
+
# align the diarizer timestamps and the ASR timestamps
|
131 |
for segment in new_segments:
|
132 |
+
# get the diarizer end timestamp
|
133 |
end_time = segment["segment"]["end"]
|
134 |
+
# find the ASR end timestamp that is closest to the diarizer's end timestamp and cut the transcript to here
|
135 |
upto_idx = np.argmin(np.abs(end_timestamps - end_time))
|
136 |
|
137 |
if group_by_speaker:
|
|
|
139 |
{
|
140 |
"speaker": segment["speaker"],
|
141 |
"text": "".join([chunk["text"] for chunk in transcript[: upto_idx + 1]]),
|
142 |
+
"timestamp": (transcript[0]["timestamp"][0], transcript[upto_idx]["timestamp"][1]),
|
|
|
|
|
|
|
143 |
}
|
144 |
)
|
145 |
else:
|
146 |
for i in range(upto_idx + 1):
|
147 |
segmented_preds.append({"speaker": segment["speaker"], **transcript[i]})
|
148 |
|
149 |
+
# crop the transcripts and timestamp lists according to the latest timestamp (for faster argmin)
|
150 |
transcript = transcript[upto_idx + 1 :]
|
151 |
end_timestamps = end_timestamps[upto_idx + 1 :]
|
152 |
|
153 |
return segmented_preds
|
154 |
|
155 |
+
# Adapted from transformers.pipelines.automatic_speech_recognition.AutomaticSpeechRecognitionPipeline.preprocess
|
156 |
+
# (see https://github.com/huggingface/transformers/blob/238449414f88d94ded35e80459bb6412d8ab42cf/src/transformers/pipelines/automatic_speech_recognition.py#L417)
|
157 |
def preprocess(self, inputs):
|
158 |
if isinstance(inputs, str):
|
159 |
if inputs.startswith("http://") or inputs.startswith("https://"):
|
|
|
191 |
if len(inputs.shape) != 1:
|
192 |
raise ValueError("We expect a single channel audio input for ASRDiarizePipeline")
|
193 |
|
194 |
+
# diarization model expects float32 torch tensor of shape `(channels, seq_len)`
|
195 |
+
diarizer_inputs = torch.from_numpy(inputs).float()
|
196 |
+
diarizer_inputs = diarizer_inputs.unsqueeze(0)
|
197 |
|
198 |
return inputs, diarizer_inputs
|