zhoukz commited on
Commit
5bf16a7
·
verified ·
1 Parent(s): acdea41

Upload folder using huggingface_hub

Browse files
model-00001-of-00007.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:324d47a69b306b736f9c1ed9c3ac6b2f08dd25f3238e0995ca03d1f628d14d3f
3
- size 4962055488
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1580ce43afa6d023bb89c1d68ed86e155119885fbf2f6dafe98e6696593f7b1
3
+ size 4961987424
model.safetensors.index.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
  "metadata": {
3
- "total_size": 33127027980
4
  },
5
  "weight_map": {
6
  "audio_encoder.blocks.0.attn.proj.bias": "model-00001-of-00007.safetensors",
@@ -388,8 +388,6 @@
388
  "audio_encoder.blocks.9.norm2.bias": "model-00001-of-00007.safetensors",
389
  "audio_encoder.blocks.9.norm2.weight": "model-00001-of-00007.safetensors",
390
  "audio_encoder.freq_pos_embed": "model-00001-of-00007.safetensors",
391
- "audio_encoder.front_end.0.mel_scale.fb": "model-00001-of-00007.safetensors",
392
- "audio_encoder.front_end.0.spectrogram.window": "model-00001-of-00007.safetensors",
393
  "audio_encoder.init_bn.bias": "model-00001-of-00007.safetensors",
394
  "audio_encoder.init_bn.num_batches_tracked": "model-00001-of-00007.safetensors",
395
  "audio_encoder.init_bn.running_mean": "model-00001-of-00007.safetensors",
 
1
  {
2
  "metadata": {
3
+ "total_size": 33126960136
4
  },
5
  "weight_map": {
6
  "audio_encoder.blocks.0.attn.proj.bias": "model-00001-of-00007.safetensors",
 
388
  "audio_encoder.blocks.9.norm2.bias": "model-00001-of-00007.safetensors",
389
  "audio_encoder.blocks.9.norm2.weight": "model-00001-of-00007.safetensors",
390
  "audio_encoder.freq_pos_embed": "model-00001-of-00007.safetensors",
 
 
391
  "audio_encoder.init_bn.bias": "model-00001-of-00007.safetensors",
392
  "audio_encoder.init_bn.num_batches_tracked": "model-00001-of-00007.safetensors",
393
  "audio_encoder.init_bn.running_mean": "model-00001-of-00007.safetensors",
modeling_midashenglm.py CHANGED
@@ -5,7 +5,7 @@ from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Uni
5
 
6
  import torch
7
  import torch.nn as nn
8
- import torchaudio.transforms as audio_transforms
9
  from torch import Tensor
10
  from transformers import GenerationMixin, PreTrainedModel
11
  from transformers.cache_utils import Cache
@@ -217,6 +217,59 @@ class DashengBlock(nn.Module):
217
  return x
218
 
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  class DashengAudioTransformer(PreTrainedModel):
221
  config_class = DashengConfig
222
  supports_gradient_checkpointing = True
@@ -229,19 +282,7 @@ class DashengAudioTransformer(PreTrainedModel):
229
  self.hop_length = config.hop_length
230
  self.gradient_checkpointing = False
231
 
232
- self.front_end = nn.Sequential(
233
- audio_transforms.MelSpectrogram(
234
- f_min=config.f_min,
235
- f_max=config.f_max,
236
- center=config.center,
237
- win_length=config.win_length,
238
- hop_length=config.hop_length,
239
- sample_rate=config.sample_rate,
240
- n_fft=config.n_fft,
241
- n_mels=config.n_mels,
242
- ),
243
- audio_transforms.AmplitudeToDB(top_db=120),
244
- )
245
 
246
  self.init_bn = nn.BatchNorm2d(config.n_mels, momentum=0.01)
247
 
@@ -272,7 +313,7 @@ class DashengAudioTransformer(PreTrainedModel):
272
  drop=config.drop_rate,
273
  attn_drop=config.attn_drop_rate,
274
  )
275
- for i in range(config.depth)
276
  )
277
  self.norm = nn.LayerNorm(config.embed_dim, eps=1e-6)
278
 
 
5
 
6
  import torch
7
  import torch.nn as nn
8
+ import torchaudio.functional as F
9
  from torch import Tensor
10
  from transformers import GenerationMixin, PreTrainedModel
11
  from transformers.cache_utils import Cache
 
217
  return x
218
 
219
 
220
+ class DashengFrontend(nn.Module):
221
+ def __init__(self, config: DashengConfig):
222
+ super().__init__()
223
+ self.config = config
224
+
225
+ spectrogram_window = torch.hann_window(self.config.win_length)
226
+ self.register_buffer(
227
+ "spectrogram_window",
228
+ spectrogram_window,
229
+ persistent=False,
230
+ )
231
+ self.spectrogram_window: torch.Tensor
232
+
233
+ melscale_fbanks = F.melscale_fbanks(
234
+ n_freqs=self.config.n_fft // 2 + 1,
235
+ f_min=self.config.f_min,
236
+ f_max=self.config.f_max,
237
+ n_mels=self.config.n_mels,
238
+ sample_rate=self.config.sample_rate,
239
+ )
240
+ self.register_buffer("melscale_fbanks", melscale_fbanks, persistent=False)
241
+ self.melscale_fbanks: torch.Tensor
242
+
243
+ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
244
+ spectrogram = F.spectrogram(
245
+ waveform=waveform.to(torch.float32),
246
+ pad=0,
247
+ window=self.spectrogram_window,
248
+ n_fft=self.config.n_fft,
249
+ hop_length=self.config.hop_length,
250
+ win_length=self.config.win_length,
251
+ power=2,
252
+ normalized=False,
253
+ center=self.config.center,
254
+ )
255
+ mel_spectrogram = (spectrogram.mT @ self.melscale_fbanks.to(torch.float32)).mT
256
+ # x has shape [batch, freq, time].
257
+ # F.amplitude_to_DB accepts inputs shaped as:
258
+ # - [freq, time]
259
+ # - [channel, freq, time]
260
+ # - [..., channel, freq, time]
261
+ # Here we insert a channel dimension of size 1 before calling it,
262
+ # then remove that extra dimension afterward.
263
+ log_mel_spectrogram = F.amplitude_to_DB(
264
+ mel_spectrogram.unsqueeze(1),
265
+ multiplier=10,
266
+ amin=1e-10,
267
+ db_multiplier=0,
268
+ top_db=120,
269
+ ).squeeze(1)
270
+ return log_mel_spectrogram.to(waveform.dtype)
271
+
272
+
273
  class DashengAudioTransformer(PreTrainedModel):
274
  config_class = DashengConfig
275
  supports_gradient_checkpointing = True
 
282
  self.hop_length = config.hop_length
283
  self.gradient_checkpointing = False
284
 
285
+ self.front_end = DashengFrontend(config)
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
  self.init_bn = nn.BatchNorm2d(config.n_mels, momentum=0.01)
288
 
 
313
  drop=config.drop_rate,
314
  attn_drop=config.attn_drop_rate,
315
  )
316
+ for _ in range(config.depth)
317
  )
318
  self.norm = nn.LayerNorm(config.embed_dim, eps=1e-6)
319