Spaces:
Runtime error
Runtime error
John Waidhofer
commited on
Commit
·
4078103
1
Parent(s):
6193575
updated to gtzan
Browse files- models/wav2vec2.py +5 -6
models/wav2vec2.py
CHANGED
@@ -3,7 +3,7 @@ from typing import Any
|
|
3 |
import pytorch_lightning as pl
|
4 |
from torch.utils.data import random_split
|
5 |
from transformers import AutoFeatureExtractor
|
6 |
-
from transformers import AutoModelForAudioClassification, TrainingArguments, Trainer
|
7 |
|
8 |
from preprocessing.dataset import (
|
9 |
HuggingFaceDatasetWrapper,
|
@@ -13,20 +13,18 @@ from preprocessing.pipelines import WaveformTrainingPipeline
|
|
13 |
|
14 |
from .utils import get_id_label_mapping, compute_hf_metrics
|
15 |
|
16 |
-
MODEL_CHECKPOINT = "
|
17 |
|
18 |
|
19 |
class Wav2VecFeatureExtractor:
|
20 |
def __init__(self) -> None:
|
21 |
self.waveform_pipeline = WaveformTrainingPipeline()
|
22 |
-
self.feature_extractor =
|
23 |
-
MODEL_CHECKPOINT,
|
24 |
-
)
|
25 |
|
26 |
def __call__(self, waveform) -> Any:
|
27 |
waveform = self.waveform_pipeline(waveform)
|
28 |
return self.feature_extractor(
|
29 |
-
waveform.squeeze(0), sampling_rate=
|
30 |
)
|
31 |
|
32 |
def __getattr__(self, attr):
|
@@ -64,6 +62,7 @@ def train_huggingface(config: dict):
|
|
64 |
learning_rate=3e-5,
|
65 |
per_device_train_batch_size=batch_size,
|
66 |
gradient_accumulation_steps=5,
|
|
|
67 |
per_device_eval_batch_size=batch_size,
|
68 |
num_train_epochs=epochs,
|
69 |
warmup_ratio=0.1,
|
|
|
3 |
import pytorch_lightning as pl
|
4 |
from torch.utils.data import random_split
|
5 |
from transformers import AutoFeatureExtractor
|
6 |
+
from transformers import AutoModelForAudioClassification, TrainingArguments, Trainer, AutoProcessor
|
7 |
|
8 |
from preprocessing.dataset import (
|
9 |
HuggingFaceDatasetWrapper,
|
|
|
13 |
|
14 |
from .utils import get_id_label_mapping, compute_hf_metrics
|
15 |
|
16 |
+
MODEL_CHECKPOINT = "yuval6967/wav2vec2-base-finetuned-gtzan"
|
17 |
|
18 |
|
19 |
class Wav2VecFeatureExtractor:
|
20 |
def __init__(self) -> None:
|
21 |
self.waveform_pipeline = WaveformTrainingPipeline()
|
22 |
+
self.feature_extractor = AutoProcessor.from_pretrained("facebook/wav2vec2-base")
|
|
|
|
|
23 |
|
24 |
def __call__(self, waveform) -> Any:
|
25 |
waveform = self.waveform_pipeline(waveform)
|
26 |
return self.feature_extractor(
|
27 |
+
waveform.squeeze(0), sampling_rate=16000
|
28 |
)
|
29 |
|
30 |
def __getattr__(self, attr):
|
|
|
62 |
learning_rate=3e-5,
|
63 |
per_device_train_batch_size=batch_size,
|
64 |
gradient_accumulation_steps=5,
|
65 |
+
gradient_checkpointing=True,
|
66 |
per_device_eval_batch_size=batch_size,
|
67 |
num_train_epochs=epochs,
|
68 |
warmup_ratio=0.1,
|