matt HOFFNER
add voice chat
88cc829
raw
history blame
No virus
4.99 kB
"use client"
/* eslint-disable camelcase */
import { pipeline, env } from "@xenova/transformers";
// Disable local models
env.allowLocalModels = false;
// Define model factories
// Ensures only one model is created of each type
class PipelineFactory {
static task = null;
static model = null;
static quantized = null;
static instance = null;
constructor(tokenizer, model, quantized) {
this.tokenizer = tokenizer;
this.model = model;
this.quantized = quantized;
}
static async getInstance(progress_callback = null) {
if (this.instance === null) {
this.instance = pipeline(this.task, this.model, {
quantized: this.quantized,
progress_callback,
// For medium models, we need to load the `no_attentions` revision to avoid running out of memory
revision: this.model.includes("/whisper-medium") ? "no_attentions" : "main"
});
}
return this.instance;
}
}
self.addEventListener("message", async (event) => {
const message = event.data;
// Do some work...
// TODO use message data
let transcript = await transcribe(
message.audio,
message.model,
message.multilingual,
message.quantized,
message.subtask,
message.language,
);
if (transcript === null) return;
// Send the result back to the main thread
self.postMessage({
status: "complete",
task: "automatic-speech-recognition",
data: transcript,
});
});
class AutomaticSpeechRecognitionPipelineFactory extends PipelineFactory {
static task = "automatic-speech-recognition";
static model = null;
static quantized = null;
}
const transcribe = async (
audio,
model,
multilingual,
quantized,
subtask,
language,
) => {
const isDistilWhisper = model.startsWith("distil-whisper/");
let modelName = model;
if (!isDistilWhisper && !multilingual) {
modelName += ".en"
}
const p = AutomaticSpeechRecognitionPipelineFactory;
if (p.model !== modelName || p.quantized !== quantized) {
// Invalidate model if different
p.model = modelName;
p.quantized = quantized;
if (p.instance !== null) {
(await p.getInstance()).dispose();
p.instance = null;
}
}
// Load transcriber model
let transcriber = await p.getInstance((data) => {
self.postMessage(data);
});
const time_precision =
transcriber.processor.feature_extractor.config.chunk_length /
transcriber.model.config.max_source_positions;
// Storage for chunks to be processed. Initialise with an empty chunk.
let chunks_to_process = [
{
tokens: [],
finalised: false,
},
];
// TODO: Storage for fully-processed and merged chunks
// let decoded_chunks = [];
function chunk_callback(chunk) {
let last = chunks_to_process[chunks_to_process.length - 1];
// Overwrite last chunk with new info
Object.assign(last, chunk);
last.finalised = true;
// Create an empty chunk after, if it not the last chunk
if (!chunk.is_last) {
chunks_to_process.push({
tokens: [],
finalised: false,
});
}
}
// Inject custom callback function to handle merging of chunks
function callback_function(item) {
let last = chunks_to_process[chunks_to_process.length - 1];
// Update tokens of last chunk
last.tokens = [...item[0].output_token_ids];
// Merge text chunks
// TODO optimise so we don't have to decode all chunks every time
let data = transcriber.tokenizer._decode_asr(chunks_to_process, {
time_precision: time_precision,
return_timestamps: true,
force_full_sequences: false,
});
self.postMessage({
status: "update",
task: "automatic-speech-recognition",
data: data,
});
}
// Actually run transcription
let output = await transcriber(audio, {
// Greedy
top_k: 0,
do_sample: false,
// Sliding window
chunk_length_s: isDistilWhisper ? 20 : 30,
stride_length_s: isDistilWhisper ? 3 : 5,
// Language and task
language: language,
task: subtask,
// Return timestamps
return_timestamps: true,
force_full_sequences: false,
// Callback functions
callback_function: callback_function, // after each generation step
chunk_callback: chunk_callback, // after each chunk is processed
}).catch((error) => {
self.postMessage({
status: "error",
task: "automatic-speech-recognition",
data: error,
});
return null;
});
return output;
};