sanchit-gandhi commited on
Commit
fc4914b
·
1 Parent(s): ff9897a

Update asr_diarizer.py

Browse files
Files changed (1) hide show
  1. asr_diarizer.py +35 -16
asr_diarizer.py CHANGED
@@ -16,14 +16,15 @@ class ASRDiarizationPipeline:
16
  diarization_pipeline,
17
  ):
18
  self.asr_pipeline = asr_pipeline
19
- self.diarization_pipeline = diarization_pipeline
20
 
21
- self.sampling_rate = self.asr_pipeline.feature_extractor.sampling_rate
22
 
23
  @classmethod
24
  def from_pretrained(
25
  cls,
26
- asr_model: Optional[str] = "openai/whisper-small",
 
27
  diarizer_model: Optional[str] = "pyannote/speaker-diarization",
28
  chunk_length_s: Optional[int] = 30,
29
  use_auth_token: Optional[Union[str, bool]] = True,
@@ -37,7 +38,7 @@ class ASRDiarizationPipeline:
37
  **kwargs,
38
  )
39
  diarization_pipeline = Pipeline.from_pretrained(diarizer_model, use_auth_token=use_auth_token)
40
- cls(asr_pipeline, diarization_pipeline)
41
 
42
  def __call__(
43
  self,
@@ -46,7 +47,13 @@ class ASRDiarizationPipeline:
46
  **kwargs,
47
  ):
48
  """
49
- Transcribe the audio sequence(s) given as inputs to text.
 
 
 
 
 
 
50
 
51
  Args:
52
  inputs (`np.ndarray` or `bytes` or `str` or `dict`):
@@ -62,15 +69,16 @@ class ASRDiarizationPipeline:
62
  np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to
63
  treat the first `left` samples and last `right` samples to be ignored in decoding (but used at
64
  inference to provide more context to the model). Only use `stride` with CTC models.
 
 
 
65
 
66
  Return:
67
- `Dict`: A dictionary with the following keys:
 
68
  - **text** (`str` ) -- The recognized text.
69
- - **chunks** (*optional(, `List[Dict]`)
70
- When using `return_timestamps`, the `chunks` will become a list containing all the various text
71
- chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamps": (0.5,0.9), {"text":
72
- "there", "timestamps": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing
73
- `"".join(chunk["text"] for chunk in output["chunks"])`.
74
  """
75
  inputs, diarizer_inputs = self.preprocess(inputs)
76
 
@@ -81,13 +89,17 @@ class ASRDiarizationPipeline:
81
 
82
  segments = diarization.for_json()["content"]
83
 
 
 
84
  new_segments = []
85
  prev_segment = cur_segment = segments[0]
86
 
87
  for i in range(1, len(segments)):
88
  cur_segment = segments[i]
89
 
 
90
  if cur_segment["label"] != prev_segment["label"] and i < len(segments):
 
91
  new_segments.append(
92
  {
93
  "segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["start"]},
@@ -96,6 +108,7 @@ class ASRDiarizationPipeline:
96
  )
97
  prev_segment = segments[i]
98
 
 
99
  new_segments.append(
100
  {
101
  "segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["end"]},
@@ -110,11 +123,15 @@ class ASRDiarizationPipeline:
110
  )
111
  transcript = asr_out["chunks"]
112
 
 
113
  end_timestamps = np.array([chunk["timestamp"][-1] for chunk in transcript])
114
  segmented_preds = []
115
 
 
116
  for segment in new_segments:
 
117
  end_time = segment["segment"]["end"]
 
118
  upto_idx = np.argmin(np.abs(end_timestamps - end_time))
119
 
120
  if group_by_speaker:
@@ -122,21 +139,21 @@ class ASRDiarizationPipeline:
122
  {
123
  "speaker": segment["speaker"],
124
  "text": "".join([chunk["text"] for chunk in transcript[: upto_idx + 1]]),
125
- "timestamp": {
126
- "start": transcript[0]["timestamp"][0],
127
- "end": transcript[upto_idx]["timestamp"][1],
128
- },
129
  }
130
  )
131
  else:
132
  for i in range(upto_idx + 1):
133
  segmented_preds.append({"speaker": segment["speaker"], **transcript[i]})
134
 
 
135
  transcript = transcript[upto_idx + 1 :]
136
  end_timestamps = end_timestamps[upto_idx + 1 :]
137
 
138
  return segmented_preds
139
 
 
 
140
  def preprocess(self, inputs):
141
  if isinstance(inputs, str):
142
  if inputs.startswith("http://") or inputs.startswith("https://"):
@@ -174,6 +191,8 @@ class ASRDiarizationPipeline:
174
  if len(inputs.shape) != 1:
175
  raise ValueError("We expect a single channel audio input for ASRDiarizePipeline")
176
 
177
- diarizer_inputs = torch.from_numpy(inputs).float().unsqueeze(0)
 
 
178
 
179
  return inputs, diarizer_inputs
 
16
  diarization_pipeline,
17
  ):
18
  self.asr_pipeline = asr_pipeline
19
+ self.sampling_rate = asr_pipeline.feature_extractor.sampling_rate
20
 
21
+ self.diarization_pipeline = diarization_pipeline
22
 
23
  @classmethod
24
  def from_pretrained(
25
  cls,
26
+ asr_model: Optional[str] = "openai/whisper-medium",
27
+ *,
28
  diarizer_model: Optional[str] = "pyannote/speaker-diarization",
29
  chunk_length_s: Optional[int] = 30,
30
  use_auth_token: Optional[Union[str, bool]] = True,
 
38
  **kwargs,
39
  )
40
  diarization_pipeline = Pipeline.from_pretrained(diarizer_model, use_auth_token=use_auth_token)
41
+ return cls(asr_pipeline, diarization_pipeline)
42
 
43
  def __call__(
44
  self,
 
47
  **kwargs,
48
  ):
49
  """
50
+ Transcribe the audio sequence(s) given as inputs to text and label with speaker information. The input audio
51
+ is first passed to the speaker diarization pipeline, which returns timestamps for 'who spoke when'. The audio
52
+ is then passed to the ASR pipeline, which returns utterance-level transcriptions and their corresponding
53
+ timestamps. The speaker diarizer timestamps are aligned with the ASR transcription timestamps to give
54
+ speaker-labelled transcriptions. We cannot use the speaker diarization timestamps alone to partition the
55
+ transcriptions, as these timestamps may straddle across transcribed utterances from the ASR output. Thus, we
56
+ find the diarizer timestamps that are closest to the ASR timestamps and partition here.
57
 
58
  Args:
59
  inputs (`np.ndarray` or `bytes` or `str` or `dict`):
 
69
  np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to
70
  treat the first `left` samples and last `right` samples to be ignored in decoding (but used at
71
  inference to provide more context to the model). Only use `stride` with CTC models.
72
+ group_by_speaker (`bool`):
73
+ Whether to group consecutive utterances by one speaker into a single segment. If False, will return
74
+ transcriptions on a chunk-by-chunk basis.
75
 
76
  Return:
77
+ A list of transcriptions. Each list item corresponds to one chunk / segment of transcription, and is a
78
+ dictionary with the following keys:
79
  - **text** (`str` ) -- The recognized text.
80
+ - **speaker** (`str`) -- The associated speaker.
81
+ - **timestamps** (`tuple`) -- The start and end time for the chunk / segment.
 
 
 
82
  """
83
  inputs, diarizer_inputs = self.preprocess(inputs)
84
 
 
89
 
90
  segments = diarization.for_json()["content"]
91
 
92
+ # diarizer output may contain consecutive segments from the same speaker (e.g. {(0 -> 1, speaker_1), (1 -> 1.5, speaker_1), ...})
93
+ # we combine these segments to give overall timestamps for each speaker's turn (e.g. {(0 -> 1.5, speaker_1), ...})
94
  new_segments = []
95
  prev_segment = cur_segment = segments[0]
96
 
97
  for i in range(1, len(segments)):
98
  cur_segment = segments[i]
99
 
100
+ # check if we have changed speaker ("label")
101
  if cur_segment["label"] != prev_segment["label"] and i < len(segments):
102
+ # add the start/end times for the super-segment to the new list
103
  new_segments.append(
104
  {
105
  "segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["start"]},
 
108
  )
109
  prev_segment = segments[i]
110
 
111
+ # add the last segment(s) if there was no speaker change
112
  new_segments.append(
113
  {
114
  "segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["end"]},
 
123
  )
124
  transcript = asr_out["chunks"]
125
 
126
+ # get the end timestamps for each chunk from the ASR output
127
  end_timestamps = np.array([chunk["timestamp"][-1] for chunk in transcript])
128
  segmented_preds = []
129
 
130
+ # align the diarizer timestamps and the ASR timestamps
131
  for segment in new_segments:
132
+ # get the diarizer end timestamp
133
  end_time = segment["segment"]["end"]
134
+ # find the ASR end timestamp that is closest to the diarizer's end timestamp and cut the transcript to here
135
  upto_idx = np.argmin(np.abs(end_timestamps - end_time))
136
 
137
  if group_by_speaker:
 
139
  {
140
  "speaker": segment["speaker"],
141
  "text": "".join([chunk["text"] for chunk in transcript[: upto_idx + 1]]),
142
+ "timestamp": (transcript[0]["timestamp"][0], transcript[upto_idx]["timestamp"][1]),
 
 
 
143
  }
144
  )
145
  else:
146
  for i in range(upto_idx + 1):
147
  segmented_preds.append({"speaker": segment["speaker"], **transcript[i]})
148
 
149
+ # crop the transcripts and timestamp lists according to the latest timestamp (for faster argmin)
150
  transcript = transcript[upto_idx + 1 :]
151
  end_timestamps = end_timestamps[upto_idx + 1 :]
152
 
153
  return segmented_preds
154
 
155
+ # Adapted from transformers.pipelines.automatic_speech_recognition.AutomaticSpeechRecognitionPipeline.preprocess
156
+ # (see https://github.com/huggingface/transformers/blob/238449414f88d94ded35e80459bb6412d8ab42cf/src/transformers/pipelines/automatic_speech_recognition.py#L417)
157
  def preprocess(self, inputs):
158
  if isinstance(inputs, str):
159
  if inputs.startswith("http://") or inputs.startswith("https://"):
 
191
  if len(inputs.shape) != 1:
192
  raise ValueError("We expect a single channel audio input for ASRDiarizePipeline")
193
 
194
+ # diarization model expects float32 torch tensor of shape `(channels, seq_len)`
195
+ diarizer_inputs = torch.from_numpy(inputs).float()
196
+ diarizer_inputs = diarizer_inputs.unsqueeze(0)
197
 
198
  return inputs, diarizer_inputs