|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import librosa
|
|
import numpy as np
|
|
import torch
|
|
from lightning import LightningDataModule
|
|
from torch.utils.data import DataLoader, Dataset
|
|
|
|
from fish_speech.utils import RankedLogger
|
|
|
|
logger = RankedLogger(__name__, rank_zero_only=False)
|
|
|
|
|
|
class VQGANDataset(Dataset):
|
|
def __init__(
|
|
self,
|
|
filelist: str,
|
|
sample_rate: int = 32000,
|
|
hop_length: int = 640,
|
|
slice_frames: Optional[int] = None,
|
|
):
|
|
super().__init__()
|
|
|
|
filelist = Path(filelist)
|
|
root = filelist.parent
|
|
|
|
self.files = [
|
|
root / line.strip()
|
|
for line in filelist.read_text(encoding="utf-8").splitlines()
|
|
if line.strip()
|
|
]
|
|
self.sample_rate = sample_rate
|
|
self.hop_length = hop_length
|
|
self.slice_frames = slice_frames
|
|
|
|
def __len__(self):
|
|
return len(self.files)
|
|
|
|
def get_item(self, idx):
|
|
file = self.files[idx]
|
|
|
|
audio, _ = librosa.load(file, sr=self.sample_rate, mono=True)
|
|
|
|
|
|
if (
|
|
self.slice_frames is not None
|
|
and audio.shape[0] > self.slice_frames * self.hop_length
|
|
):
|
|
start = np.random.randint(
|
|
0, audio.shape[0] - self.slice_frames * self.hop_length
|
|
)
|
|
audio = audio[start : start + self.slice_frames * self.hop_length]
|
|
|
|
if len(audio) == 0:
|
|
return None
|
|
|
|
max_value = np.abs(audio).max()
|
|
if max_value > 1.0:
|
|
audio = audio / max_value
|
|
|
|
return {
|
|
"audio": torch.from_numpy(audio),
|
|
}
|
|
|
|
def __getitem__(self, idx):
|
|
try:
|
|
return self.get_item(idx)
|
|
except Exception as e:
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
logger.error(f"Error loading {self.files[idx]}: {e}")
|
|
return None
|
|
|
|
|
|
@dataclass
|
|
class VQGANCollator:
|
|
def __call__(self, batch):
|
|
batch = [x for x in batch if x is not None]
|
|
|
|
audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
|
|
audio_maxlen = audio_lengths.max()
|
|
|
|
|
|
audios = []
|
|
for x in batch:
|
|
audios.append(
|
|
torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"])))
|
|
)
|
|
|
|
return {
|
|
"audios": torch.stack(audios),
|
|
"audio_lengths": audio_lengths,
|
|
}
|
|
|
|
|
|
class VQGANDataModule(LightningDataModule):
|
|
def __init__(
|
|
self,
|
|
train_dataset: VQGANDataset,
|
|
val_dataset: VQGANDataset,
|
|
batch_size: int = 32,
|
|
num_workers: int = 4,
|
|
val_batch_size: Optional[int] = None,
|
|
):
|
|
super().__init__()
|
|
|
|
self.train_dataset = train_dataset
|
|
self.val_dataset = val_dataset
|
|
self.batch_size = batch_size
|
|
self.val_batch_size = val_batch_size or batch_size
|
|
self.num_workers = num_workers
|
|
|
|
def train_dataloader(self):
|
|
return DataLoader(
|
|
self.train_dataset,
|
|
batch_size=self.batch_size,
|
|
collate_fn=VQGANCollator(),
|
|
num_workers=self.num_workers,
|
|
shuffle=True,
|
|
persistent_workers=True,
|
|
)
|
|
|
|
def val_dataloader(self):
|
|
return DataLoader(
|
|
self.val_dataset,
|
|
batch_size=self.val_batch_size,
|
|
collate_fn=VQGANCollator(),
|
|
num_workers=self.num_workers,
|
|
persistent_workers=True,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt")
|
|
dataloader = DataLoader(
|
|
dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator()
|
|
)
|
|
|
|
for batch in dataloader:
|
|
print(batch["audios"].shape)
|
|
print(batch["features"].shape)
|
|
print(batch["audio_lengths"])
|
|
print(batch["feature_lengths"])
|
|
break
|
|
|