File size: 4,823 Bytes
146c0f9
288a4c7
 
 
2dda95f
 
 
 
 
146c0f9
288a4c7
cf911a2
145604e
 
 
 
 
 
 
 
 
cf911a2
145604e
 
 
a863382
288a4c7
146c0f9
288a4c7
cf911a2
146c0f9
288a4c7
cf911a2
146c0f9
cf911a2
146c0f9
e6aa8a5
cf911a2
e3c26be
288a4c7
cf911a2
146c0f9
86affba
a08b5b3
288a4c7
 
40e7e0b
53a0ea3
cf911a2
288a4c7
 
146c0f9
d76c439
53a0ea3
86affba
bb60daa
9246022
86affba
d76c439
 
86affba
d76c439
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86affba
146c0f9
501c145
 
d9ae039
 
 
 
 
 
 
 
 
 
 
d76c439
86affba
 
 
 
146c0f9
53a0ea3
86affba
 
 
d76c439
 
 
86affba
 
af47a58
 
860320d
af47a58
4b14570
 
860320d
4b14570
 
86affba
 
 
4b14570
86affba
 
 
4b14570
 
86affba
4b14570
86affba
 
 
4b14570
 
86affba
 
 
 
 
146c0f9
86affba
 
 
 
 
 
 
 
9e7012d
86affba
 
cf911a2
9e7012d
86affba
a08b5b3
cf911a2
86affba
 
cf911a2
9e7012d
86affba
 
 
cf911a2
 
 
 
 
 
 
9e7012d
cf911a2
86affba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
---
license: apache-2.0
language:
- de
tags:
- sign-language
- whisper
- german
- safetensors
library_name: transformers
model-index:
- name: whisper-large-v3-turbo-german
  results:
  - task:
      type: automatic-speech-recognition
      name: Speech Recognition
    dataset:
      name: German ASR Data-Mix
      type: flozi00/asr-german-mixed
    metrics:
    - type: wer
      value: TBD
datasets:
- flozi00/asr-german-mixed
base_model:
- primeline/whisper-large-v3-german
---

### Summary
Whisper is a powerful speech recognition platform developed by OpenAI. This model has been specially optimized for converting sign language input features into german text.

### Applications
The model is based on 'primeline/whisper-large-v3-german' and used (in combination with google mediapipe) to translate a video of german sign language into text. This model decodes a sequence of input features, where each input feature represents keypoints extracted from a video (body hands, upper body and face), into text. 

We keep the decoder frozen, while training the encoder.

## Evaluations - Word error rate
TBD

### Training data
TBD

#### Training process
!!! Make sure to install Transformers 4.46.0 !!!
```python
import torch
from transformers import WhisperForConditionalGeneration, AutoProcessor, AutoTokenizer, AutoConfig, TextStreamer, Trainer
from datasets import load_dataset

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# First load the config using AutoConfig
# See custom config in model.py for configuration options.
config = AutoConfig.from_pretrained(
    "mrprimenotes/sign-whisper-german",
    trust_remote_code=True,
    use_first_embeddings=True,
    #embedding_stride=2,
    #conv_dropout=0.1,
    skip_connections=True,
    conv_preprocessing_layers=[ 
                { # When changing conv_preprocessing_layers make sure their final output has the shape b x 1280 x seq.
                    "in_channels": 128,
                    "out_channels": 1280,
                    "kernel_size": 3,
                    "stride": 1,
                    "padding": 1,
                    "activation": "gelu",
                    "bias": True
                },
                {
                    "in_channels": 1280,
                    "out_channels": 1280,
                    "kernel_size": 3,
                    "stride": 1,
                    "padding": 1,
                    "activation": "gelu",
                    "bias": True
                }
            ]
)

tokenizer = AutoTokenizer.from_pretrained("mrprimenotes/sign-whisper-german")

model = AutoModel.from_pretrained(
    pretrained_model_name_or_path="mrprimenotes/sign-whisper-german",
    config=config,
    use_safetensors=True,
    trust_remote_code=True,
    ignore_mismatched_sizes=True,
    torch_dtype=torch_dtype,
    low_cpu_mem_usage=True,
    device_map='auto'
).to(device)

# You can see raw model outputs as follows:
# output = model(input_features, labels=labels)
# e.g.
# output.loss
# output.shape --> b x sq

# Load your dataset (e.g. mrprimenotes/sign-whisper-german-example)
train_dataset = YourSignDataset(...)
val_dataset = YourSignDataset(...)

# Freeze the decoder for our purpose
model.freeze_decoder()

# Define training arguments
training_args = TrainingArguments(
    hub_model_id="mrprimenotes/sign-whisper-german_trained",
    push_to_hub=True,

    num_train_epochs=2,
    per_device_train_batch_size=256,
    per_device_eval_batch_size=386,

    learning_rate=2e-5
    warmup_steps=200,
    weight_decay=0.01,

    # Logging settings
    logging_steps=500,
    logging_strategy="steps",

    # Evaluation
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    evaluation_strategy="steps",
    eval_steps=1000,

    # Saving
    save_strategy="steps",
    save_steps=2000,
    save_total_limit=4,
    resume_from_checkpoint=True,

    load_best_model_at_end=True,
    fp16=torch.cuda.is_available(),
)

# Initialize trainer with tokenizer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
)

# Train the model
trainer.train()
```

### Use model for inference (with generate)
!!! Make sure to install Transformers 4.46.0 !!!
```python
from transformers import TextStreamer

streamer = TextStreamer(tokenizer, skip_special_tokens=False) #only needed for streaming

# input preprocessing / feature extraction (TBD)
# input_features = ...

# Generate
generated_ids = model.generate(
    input_features,
    max_new_tokens=128,
    return_timestamps=False, #timestamps are not supported
    streamer=streamer #only needed for streaming
)

tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
```