File size: 2,703 Bytes
73c9c96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9323c6
a04aea5
73c9c96
c9323c6
73c9c96
 
 
 
 
 
 
 
a04aea5
 
73c9c96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import gradio as gr
from pathlib import Path
from whistress import WhiStressInferenceClient

CURRENT_DIR = Path(__file__).parent
# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = WhiStressInferenceClient(device=device)


def get_whistress_predictions(audio):
    """
    Get the transcription and emphasis scores for the given audio input.
    Args:
        audio (sr, numpy.ndarray): The audio input as a NumPy array.
    Returns:
        List[Tuple[str, int]]: A list of tuples containing words and their emphasis scores.
    """
    audio = {
        "sampling_rate": audio[0],
        "array": audio[1],
    } 
    return model.predict(audio=audio, transcription=None, return_pairs=True)


# App UI
with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column(scale=2):
            gr.Markdown(
                """
                # ***WhiStress***: Enriching Transcriptions with Sentence Stress Detection
                WhiStress allows you to detect emphasized words in your speech.
                
                Check out our paper: πŸ“š [***WhiStress***](https://arxiv.org/)
                
                ## Architecture
                The model is built on [Whisper](https://arxiv.org/abs/2212.04356) model,
                using `whisper-small.en` [model](https://huggingface.co/openai/whisper-small.en)
                as the backbone.
                WhiStress includes an additional decoder based classifier that predicts the stress label of each transcription token.

                ## Training Data
                WhiStress was trained using [***TinyStress-15K***](https://huggingface.co/datasets/slprl/TinyStress-15K),
                that is derived from the [tinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) dataset.

                ## Inference Demo
                Upload an audio file or record your own voice to transcribe the speech and emphasize the important words.
                
                For maximal performance, please speak clearly.
                """
            )
        with gr.Column(scale=1):
            # Define Gradio interface for displaying image with HTML component
            gr.Image(
                f"{CURRENT_DIR}/assets/whistress_model.svg",
                label="Architecture",
            )

    gr.Interface(
        get_whistress_predictions,
        gr.Audio(
                sources=["microphone", "upload"],
                label="Upload speech or record your own",
                type="numpy",
            ),
        gr.HighlightedText(),
        allow_flagging="never",
    )


def launch():
    demo.launch()


if __name__ == "__main__":
    launch()