File size: 1,384 Bytes
c2e60bb
b389fb6
 
 
 
c2e60bb
 
 
 
 
b389fb6
c2e60bb
 
b389fb6
c2e60bb
 
b389fb6
c2e60bb
 
b389fb6
 
c2e60bb
 
 
 
f718747
 
b389fb6
 
c2e60bb
 
 
 
b389fb6
 
c2e60bb
 
 
 
 
 
 
 
 
 
 
 
 
b389fb6
c2e60bb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# This module handles model inference

import torch
from transformers import AutoProcessor, AutoModelForCTC

DEVICE = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

# set espeak library path for macOS
import sys

if sys.platform == "darwin":
    from phonemizer.backend.espeak.wrapper import EspeakWrapper

    _ESPEAK_LIBRARY = "/opt/homebrew/Cellar/espeak/1.48.04_1/lib/libespeak.1.1.48.dylib"
    EspeakWrapper.set_library(_ESPEAK_LIBRARY)


def clear_cache():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    if torch.backends.mps.is_available():
        torch.mps.empty_cache()


def load_model(model_id, device=DEVICE):
    processor = AutoProcessor.from_pretrained(model_id)
    model = AutoModelForCTC.from_pretrained(model_id).to(device)
    return model, processor


def transcribe(audio, model, processor) -> str:
    input_values = (
        processor(
            [audio],
            sampling_rate=processor.feature_extractor.sampling_rate,
            return_tensors="pt",
            padding=True,
        )
        .input_values.type(torch.float32)
        .to(model.device)
    )
    with torch.no_grad():
        logits = model(input_values).logits

    predicted_ids = torch.argmax(logits, dim=-1)
    return processor.decode(predicted_ids[0])