File size: 4,989 Bytes
88cc829 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
"use client"
/* eslint-disable camelcase */
import { pipeline, env } from "@xenova/transformers";
// Disable local models
env.allowLocalModels = false;
// Define model factories
// Ensures only one model is created of each type
class PipelineFactory {
static task = null;
static model = null;
static quantized = null;
static instance = null;
constructor(tokenizer, model, quantized) {
this.tokenizer = tokenizer;
this.model = model;
this.quantized = quantized;
}
static async getInstance(progress_callback = null) {
if (this.instance === null) {
this.instance = pipeline(this.task, this.model, {
quantized: this.quantized,
progress_callback,
// For medium models, we need to load the `no_attentions` revision to avoid running out of memory
revision: this.model.includes("/whisper-medium") ? "no_attentions" : "main"
});
}
return this.instance;
}
}
self.addEventListener("message", async (event) => {
const message = event.data;
// Do some work...
// TODO use message data
let transcript = await transcribe(
message.audio,
message.model,
message.multilingual,
message.quantized,
message.subtask,
message.language,
);
if (transcript === null) return;
// Send the result back to the main thread
self.postMessage({
status: "complete",
task: "automatic-speech-recognition",
data: transcript,
});
});
class AutomaticSpeechRecognitionPipelineFactory extends PipelineFactory {
static task = "automatic-speech-recognition";
static model = null;
static quantized = null;
}
const transcribe = async (
audio,
model,
multilingual,
quantized,
subtask,
language,
) => {
const isDistilWhisper = model.startsWith("distil-whisper/");
let modelName = model;
if (!isDistilWhisper && !multilingual) {
modelName += ".en"
}
const p = AutomaticSpeechRecognitionPipelineFactory;
if (p.model !== modelName || p.quantized !== quantized) {
// Invalidate model if different
p.model = modelName;
p.quantized = quantized;
if (p.instance !== null) {
(await p.getInstance()).dispose();
p.instance = null;
}
}
// Load transcriber model
let transcriber = await p.getInstance((data) => {
self.postMessage(data);
});
const time_precision =
transcriber.processor.feature_extractor.config.chunk_length /
transcriber.model.config.max_source_positions;
// Storage for chunks to be processed. Initialise with an empty chunk.
let chunks_to_process = [
{
tokens: [],
finalised: false,
},
];
// TODO: Storage for fully-processed and merged chunks
// let decoded_chunks = [];
function chunk_callback(chunk) {
let last = chunks_to_process[chunks_to_process.length - 1];
// Overwrite last chunk with new info
Object.assign(last, chunk);
last.finalised = true;
// Create an empty chunk after, if it not the last chunk
if (!chunk.is_last) {
chunks_to_process.push({
tokens: [],
finalised: false,
});
}
}
// Inject custom callback function to handle merging of chunks
function callback_function(item) {
let last = chunks_to_process[chunks_to_process.length - 1];
// Update tokens of last chunk
last.tokens = [...item[0].output_token_ids];
// Merge text chunks
// TODO optimise so we don't have to decode all chunks every time
let data = transcriber.tokenizer._decode_asr(chunks_to_process, {
time_precision: time_precision,
return_timestamps: true,
force_full_sequences: false,
});
self.postMessage({
status: "update",
task: "automatic-speech-recognition",
data: data,
});
}
// Actually run transcription
let output = await transcriber(audio, {
// Greedy
top_k: 0,
do_sample: false,
// Sliding window
chunk_length_s: isDistilWhisper ? 20 : 30,
stride_length_s: isDistilWhisper ? 3 : 5,
// Language and task
language: language,
task: subtask,
// Return timestamps
return_timestamps: true,
force_full_sequences: false,
// Callback functions
callback_function: callback_function, // after each generation step
chunk_callback: chunk_callback, // after each chunk is processed
}).catch((error) => {
self.postMessage({
status: "error",
task: "automatic-speech-recognition",
data: error,
});
return null;
});
return output;
}; |