USAD-Base / modeling_usad.py
vectominist's picture
upload model and code
b038b10
raw
history blame contribute delete
498 Bytes
# modeling_usad.py
from transformers import PreTrainedModel
from .configuration_usad import USADConfig
from .usad_model import UsadModel as model
class USADModel(PreTrainedModel):
config_class = USADConfig
def __init__(self, config: USADConfig):
super().__init__(config)
self.model = model(config)
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
def load_audio(self, audio_path):
return self.model.load_audio(audio_path)