bwshen-mi commited on
Commit
2ab7cb1
·
verified ·
1 Parent(s): 1c580f2

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. modeling_mimo.py +75 -0
README.md CHANGED
@@ -53,7 +53,7 @@ library_name: transformers
53
  <tr>
54
  <td colspan="3"><strong>Mathematics</strong></td>
55
  <p align="center">
56
- <td rowspan="11"><img width="80%" src="https://github.com/XiaomiMiMo/MiMo/raw/main/figures/length.jpg?raw=true"></td>
57
  </p>
58
  </tr>
59
  <tr><td>MATH500<br/>(Pass@1)</td><td>95.8</td><td>97.2</td></tr>
 
53
  <tr>
54
  <td colspan="3"><strong>Mathematics</strong></td>
55
  <p align="center">
56
+ <td rowspan="11"><img width="80%" src="https://github.com/XiaomiMiMo/MiMo-test/raw/main/figures/length.jpg?raw=true"></td>
57
  </p>
58
  </tr>
59
  <tr><td>MATH500<br/>(Pass@1)</td><td>95.8</td><td>97.2</td></tr>
modeling_mimo.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ from transformers.cache_utils import Cache
6
+ from transformers.models.qwen2.modeling_qwen2 import (Qwen2Attention,
7
+ Qwen2ForCausalLM,
8
+ Qwen2MLP, Qwen2Model,
9
+ Qwen2RMSNorm)
10
+
11
+ from .configuration_mimo import MiMoConfig
12
+
13
+
14
+ class MiMoMTPLayers(nn.Module):
15
+ def __init__(self, config):
16
+ super().__init__()
17
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
18
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
19
+ self.token_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
20
+ self.hidden_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
21
+ self.input_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
22
+ self.final_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
23
+ self.self_attn = Qwen2Attention(config, layer_idx=0)
24
+ self.mlp = Qwen2MLP(config)
25
+
26
+ def forward(self, input_embeds,
27
+ hidden_states,
28
+ attention_mask,
29
+ position_ids,
30
+ past_key_values: Optional[Cache]=None,
31
+ output_attentions: Optional[bool]=False,
32
+ use_cache: Optional[bool]=False,
33
+ position_embedding: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
34
+ cache_position=None,
35
+ **kwargs):
36
+ input_embeds = self.token_layernorm(input_embeds)
37
+ previous_hidden_states = self.hidden_layernorm(hidden_states)
38
+ hidden_states = self.input_proj(torch.cat([previous_hidden_states, input_embeds], dim=-1))
39
+ residual = hidden_states
40
+ hidden_states = self.input_layernorm(hidden_states)
41
+ hidden_states, _ = self.self_attn(hidden_states,
42
+ attention_mask=attention_mask,
43
+ position_ids=position_ids,
44
+ past_key_values=past_key_values,
45
+ output_attentions=output_attentions,
46
+ use_cache=use_cache,
47
+ cache_position=cache_position,
48
+ position_embedding=position_embedding,
49
+ **kwargs)
50
+ hidden_states = residual + hidden_states
51
+ residual = hidden_states
52
+ hidden_states = self.post_attention_layernorm(hidden_states)
53
+ hidden_states = self.mlp(hidden_states)
54
+ hidden_states = residual + hidden_states
55
+ hidden_states = self.final_layernorm(hidden_states)
56
+ return hidden_states
57
+
58
+
59
+ class MiMoModel(Qwen2Model):
60
+ config_class = MiMoConfig
61
+
62
+ def __init__(self, config: MiMoConfig):
63
+ super().__init__(config)
64
+ self.mtp_layers = nn.ModuleList([MiMoMTPLayers(config) for _ in range(config.num_nextn_predict_layers)])
65
+
66
+
67
+ class MiMoForCausalLM(Qwen2ForCausalLM):
68
+ config_class = MiMoConfig
69
+ def __init__(self, config: MiMoConfig):
70
+ super(Qwen2ForCausalLM, self).__init__(config)
71
+ self.model = MiMoModel(config)
72
+ self.vocab_size = config.vocab_size
73
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
74
+
75
+ self.post_init()