shenyunhang commited on
Commit
5569e06
·
1 Parent(s): cd526de
__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM
2
+
3
+ from .modeling_qwen2 import Qwen2MTPSenseVoiceForCausalLM
4
+ from .configuration_qwen2 import Qwen2MTPSenseVoiceConfig
5
+
6
+ AutoConfig.register("qwen2_mtp_sensevoice", Qwen2MTPSenseVoiceConfig)
7
+ AutoModelForCausalLM.register(Qwen2MTPSenseVoiceConfig, Qwen2MTPSenseVoiceForCausalLM)
8
+ # AutoTokenizer.register(Qwen2MTPSenseVoiceConfig, Qwen2MTPSenseVoiceTokenizer)
9
+
10
+ Qwen2MTPSenseVoiceConfig.register_for_auto_class()
11
+ # Qwen2MTPSenseVoiceModel.register_for_auto_class("AutoModel")
12
+ Qwen2MTPSenseVoiceForCausalLM.register_for_auto_class("AutoModelForCausalLM")
added_tokens.json ADDED
The diff for this file is too large to render. See raw diff
 
am.mvn ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <Nnet>
2
+ <Splice> 560 560
3
+ [ 0 ]
4
+ <AddShift> 560 560
5
+ <LearnRateCoef> 0 [ -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 ]
6
+ <Rescale> 560 560
7
+ <LearnRateCoef> 0 [ 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 ]
8
+ </Nnet>
chn_jpn_yue_eng_ko_spectok.bpe.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa87f86064c3730d799ddf7af3c04659151102cba548bce325cf06ba4da4e6a8
3
+ size 377341
config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen2MTPSenseVoiceForCausalLM"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_qwen2.Qwen2MTPSenseVoiceConfig",
8
+ "AutoModelForCausalLM": "modeling_qwen2.Qwen2MTPSenseVoiceForCausalLM"
9
+ },
10
+ "bos_token_id": 151643,
11
+ "eos_token_id": 151645,
12
+ "hidden_act": "silu",
13
+ "hidden_size": 3584,
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": 18944,
16
+ "max_position_embeddings": 32768,
17
+ "max_window_layers": 28,
18
+ "model_type": "qwen2_mtp_sensevoice",
19
+ "mtp_loss_weight": 1.0,
20
+ "num_attention_heads": 28,
21
+ "num_hidden_layers": 28,
22
+ "num_key_value_heads": 4,
23
+ "num_nextn_predict_layers": 0,
24
+ "rms_norm_eps": 1e-06,
25
+ "rope_scaling": null,
26
+ "rope_theta": 1000000.0,
27
+ "sliding_window": null,
28
+ "tie_word_embeddings": false,
29
+ "torch_dtype": "bfloat16",
30
+ "transformers_version": "4.48.3",
31
+ "use_cache": false,
32
+ "use_sliding_window": false,
33
+ "vocab_size": 168072
34
+ }
config.yaml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ encoder: SenseVoiceEncoderSmall
2
+ encoder_conf:
3
+ output_size: 512
4
+ attention_heads: 4
5
+ linear_units: 2048
6
+ num_blocks: 50
7
+ tp_blocks: 20
8
+ dropout_rate: 0.1
9
+ positional_dropout_rate: 0.1
10
+ attention_dropout_rate: 0.1
11
+ input_layer: pe
12
+ pos_enc_class: SinusoidalPositionEncoder
13
+ normalize_before: true
14
+ kernel_size: 11
15
+ sanm_shfit: 0
16
+ selfattention_layer_type: sanm
17
+
18
+
19
+ model: SenseVoiceSmall
20
+ model_conf:
21
+ length_normalized_loss: true
22
+ sos: 1
23
+ eos: 2
24
+ ignore_id: -1
25
+
26
+ tokenizer: SentencepiecesTokenizer
27
+ tokenizer_conf:
28
+ bpemodel: null
29
+ unk_symbol: <unk>
30
+ split_with_space: true
31
+
32
+ frontend: WavFrontend
33
+ frontend_conf:
34
+ fs: 16000
35
+ window: hamming
36
+ n_mels: 80
37
+ frame_length: 25
38
+ frame_shift: 10
39
+ lfr_m: 7
40
+ lfr_n: 6
41
+ cmvn_file: null
42
+
43
+
44
+ dataset: SenseVoiceCTCDataset
45
+ dataset_conf:
46
+ index_ds: IndexDSJsonl
47
+ batch_sampler: EspnetStyleBatchSampler
48
+ data_split_num: 32
49
+ batch_type: token
50
+ batch_size: 14000
51
+ max_token_length: 2000
52
+ min_token_length: 60
53
+ max_source_length: 2000
54
+ min_source_length: 60
55
+ max_target_length: 200
56
+ min_target_length: 0
57
+ shuffle: true
58
+ num_workers: 4
59
+ sos: ${model_conf.sos}
60
+ eos: ${model_conf.eos}
61
+ IndexDSJsonl: IndexDSJsonl
62
+ retry: 20
63
+
64
+ train_conf:
65
+ accum_grad: 1
66
+ grad_clip: 5
67
+ max_epoch: 20
68
+ keep_nbest_models: 10
69
+ avg_nbest_model: 10
70
+ log_interval: 100
71
+ resume: true
72
+ validate_interval: 10000
73
+ save_checkpoint_interval: 10000
74
+
75
+ optim: adamw
76
+ optim_conf:
77
+ lr: 0.00002
78
+ scheduler: warmuplr
79
+ scheduler_conf:
80
+ warmup_steps: 25000
81
+
82
+ specaug: SpecAugLFR
83
+ specaug_conf:
84
+ apply_time_warp: false
85
+ time_warp_window: 5
86
+ time_warp_mode: bicubic
87
+ apply_freq_mask: true
88
+ freq_mask_width_range:
89
+ - 0
90
+ - 30
91
+ lfr_rate: 6
92
+ num_freq_mask: 1
93
+ apply_time_mask: true
94
+ time_mask_width_range:
95
+ - 0
96
+ - 12
97
+ num_time_mask: 1
98
+
configuration.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "framework": "pytorch",
3
+ "task" : "auto-speech-recognition",
4
+ "model": {"type" : "funasr"},
5
+ "pipeline": {"type":"funasr-pipeline"},
6
+ "model_name_in_hub": {
7
+ "ms":"",
8
+ "hf":""},
9
+ "file_path_metas": {
10
+ "config":"config.yaml",
11
+ "tokenizer_conf": {"bpemodel": "chn_jpn_yue_eng_ko_spectok.bpe.model"},
12
+ "frontend_conf":{"cmvn_file": "am.mvn"}}
13
+ }
14
+
configuration_qwen2.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Qwen2 model configuration"""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.modeling_rope_utils import rope_config_validation
19
+ from transformers.utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class Qwen2MTPSenseVoiceConfig(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
28
+ Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
29
+ with the defaults will yield a similar configuration to that of
30
+ Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
31
+
32
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
+ documentation from [`PretrainedConfig`] for more information.
34
+
35
+
36
+ Args:
37
+ vocab_size (`int`, *optional*, defaults to 151936):
38
+ Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
39
+ `inputs_ids` passed when calling [`Qwen2Model`]
40
+ hidden_size (`int`, *optional*, defaults to 4096):
41
+ Dimension of the hidden representations.
42
+ intermediate_size (`int`, *optional*, defaults to 22016):
43
+ Dimension of the MLP representations.
44
+ num_hidden_layers (`int`, *optional*, defaults to 32):
45
+ Number of hidden layers in the Transformer encoder.
46
+ num_attention_heads (`int`, *optional*, defaults to 32):
47
+ Number of attention heads for each attention layer in the Transformer encoder.
48
+ num_key_value_heads (`int`, *optional*, defaults to 32):
49
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
50
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
51
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
52
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
53
+ by meanpooling all the original heads within that group. For more details checkout [this
54
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
55
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
56
+ The non-linear activation function (function or string) in the decoder.
57
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
58
+ The maximum sequence length that this model might ever be used with.
59
+ initializer_range (`float`, *optional*, defaults to 0.02):
60
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
61
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
62
+ The epsilon used by the rms normalization layers.
63
+ use_cache (`bool`, *optional*, defaults to `True`):
64
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
65
+ relevant if `config.is_decoder=True`.
66
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
67
+ Whether the model's input and output word embeddings should be tied.
68
+ rope_theta (`float`, *optional*, defaults to 10000.0):
69
+ The base period of the RoPE embeddings.
70
+ rope_scaling (`Dict`, *optional*):
71
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
72
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
73
+ accordingly.
74
+ Expected contents:
75
+ `rope_type` (`str`):
76
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
77
+ 'llama3'], with 'default' being the original RoPE implementation.
78
+ `factor` (`float`, *optional*):
79
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
80
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
81
+ original maximum pre-trained length.
82
+ `original_max_position_embeddings` (`int`, *optional*):
83
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
84
+ pretraining.
85
+ `attention_factor` (`float`, *optional*):
86
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
87
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
88
+ `factor` field to infer the suggested value.
89
+ `beta_fast` (`float`, *optional*):
90
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
91
+ ramp function. If unspecified, it defaults to 32.
92
+ `beta_slow` (`float`, *optional*):
93
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
94
+ ramp function. If unspecified, it defaults to 1.
95
+ `short_factor` (`List[float]`, *optional*):
96
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
97
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
98
+ size divided by the number of attention heads divided by 2
99
+ `long_factor` (`List[float]`, *optional*):
100
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
101
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
102
+ size divided by the number of attention heads divided by 2
103
+ `low_freq_factor` (`float`, *optional*):
104
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
105
+ `high_freq_factor` (`float`, *optional*):
106
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
107
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
108
+ Whether to use sliding window attention.
109
+ sliding_window (`int`, *optional*, defaults to 4096):
110
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
111
+ max_window_layers (`int`, *optional*, defaults to 28):
112
+ The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
113
+ attention_dropout (`float`, *optional*, defaults to 0.0):
114
+ The dropout ratio for the attention probabilities.
115
+
116
+ ```python
117
+ >>> from transformers import Qwen2Model, Qwen2Config
118
+
119
+ >>> # Initializing a Qwen2 style configuration
120
+ >>> configuration = Qwen2Config()
121
+
122
+ >>> # Initializing a model from the Qwen2-7B style configuration
123
+ >>> model = Qwen2Model(configuration)
124
+
125
+ >>> # Accessing the model configuration
126
+ >>> configuration = model.config
127
+ ```"""
128
+
129
+ model_type = "qwen2_mtp_sensevoice"
130
+ keys_to_ignore_at_inference = ["past_key_values"]
131
+
132
+ # Default tensor parallel plan for base model `Qwen2`
133
+ base_model_tp_plan = {
134
+ "layers.*.self_attn.q_proj": "colwise",
135
+ "layers.*.self_attn.k_proj": "colwise",
136
+ "layers.*.self_attn.v_proj": "colwise",
137
+ "layers.*.self_attn.o_proj": "rowwise",
138
+ "layers.*.mlp.gate_proj": "colwise",
139
+ "layers.*.mlp.up_proj": "colwise",
140
+ "layers.*.mlp.down_proj": "rowwise",
141
+ }
142
+
143
+ def __init__(
144
+ self,
145
+ vocab_size=151936,
146
+ hidden_size=4096,
147
+ intermediate_size=22016,
148
+ num_hidden_layers=32,
149
+ num_attention_heads=32,
150
+ num_key_value_heads=32,
151
+ hidden_act="silu",
152
+ max_position_embeddings=32768,
153
+ initializer_range=0.02,
154
+ rms_norm_eps=1e-6,
155
+ use_cache=True,
156
+ tie_word_embeddings=False,
157
+ rope_theta=10000.0,
158
+ rope_scaling=None,
159
+ use_sliding_window=False,
160
+ sliding_window=4096,
161
+ max_window_layers=28,
162
+ attention_dropout=0.0,
163
+ num_nextn_predict_layers=1,
164
+ mtp_loss_weight=1.0,
165
+ **kwargs,
166
+ ):
167
+ self.vocab_size = vocab_size
168
+ self.max_position_embeddings = max_position_embeddings
169
+ self.hidden_size = hidden_size
170
+ self.intermediate_size = intermediate_size
171
+ self.num_hidden_layers = num_hidden_layers
172
+ self.num_attention_heads = num_attention_heads
173
+ self.use_sliding_window = use_sliding_window
174
+ self.sliding_window = sliding_window if use_sliding_window else None
175
+ self.max_window_layers = max_window_layers
176
+
177
+ # for backward compatibility
178
+ if num_key_value_heads is None:
179
+ num_key_value_heads = num_attention_heads
180
+
181
+ self.num_key_value_heads = num_key_value_heads
182
+ self.hidden_act = hidden_act
183
+ self.initializer_range = initializer_range
184
+ self.rms_norm_eps = rms_norm_eps
185
+ self.use_cache = use_cache
186
+ self.rope_theta = rope_theta
187
+ self.rope_scaling = rope_scaling
188
+ self.attention_dropout = attention_dropout
189
+ # Validate the correctness of rotary position embeddings parameters
190
+ # BC: if there is a 'type' field, move it to 'rope_type'.
191
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
192
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
193
+ rope_config_validation(self)
194
+
195
+ self.num_nextn_predict_layers = num_nextn_predict_layers
196
+ self.mtp_loss_weight = mtp_loss_weight
197
+
198
+ super().__init__(
199
+ tie_word_embeddings=tie_word_embeddings,
200
+ **kwargs,
201
+ )
generation_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 151645,
6
+ 151643
7
+ ],
8
+ "mtp_inference_mode": [
9
+ 1,
10
+ 10,
11
+ 4,
12
+ 10
13
+ ],
14
+ "pad_token_id": 151643,
15
+ "repetition_penalty": 1.05,
16
+ "temperature": 0.7,
17
+ "top_k": 20,
18
+ "top_p": 0.8,
19
+ "transformers_version": "4.48.3"
20
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7695f8f60a254f9a7912228332531428b43217bdc508331a260aa6e958c4ea33
3
+ size 4992406120
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71c8ab2e41a2c3546d08422eb62e9064229b3f89fd2cba3f25eb25d58223d6cd
3
+ size 4932751008
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3deb727ce1ba3d52006eb3602b971d27718423db80f087861c8df64ebcb7d183
3
+ size 4828352366
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:622ce871af5c52644ec74e208aa767acd0ef98955ddafd42b919bf7d96ca24bc
3
+ size 1204740224
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_qwen2.py ADDED
@@ -0,0 +1,1641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/qwen2/modular_qwen2.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_qwen2.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ from typing import Callable, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+ from transformers.activations import ACT2FN
13
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
14
+ from transformers.generation import GenerationMixin
15
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
16
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
17
+ from transformers.modeling_outputs import (
18
+ BaseModelOutputWithPast,
19
+ CausalLMOutputWithPast,
20
+ QuestionAnsweringModelOutput,
21
+ SequenceClassifierOutputWithPast,
22
+ TokenClassifierOutput,
23
+ )
24
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
25
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
26
+ from transformers.processing_utils import Unpack
27
+ from transformers.utils import (
28
+ LossKwargs,
29
+ add_code_sample_docstrings,
30
+ add_start_docstrings,
31
+ add_start_docstrings_to_model_forward,
32
+ logging,
33
+ replace_return_docstrings,
34
+ )
35
+ from .configuration_qwen2 import Qwen2MTPSenseVoiceConfig as Qwen2Config
36
+
37
+ from .modeling_sensevoice import AudioEncoder
38
+ from .resampler_projector import ResamplerProjector
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+ logger.setLevel(logging.INFO)
43
+
44
+ _CHECKPOINT_FOR_DOC = "meta-qwen2/Qwen2-2-7b-hf"
45
+ _CONFIG_FOR_DOC = "Qwen2Config"
46
+
47
+
48
+ def fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs):
49
+ reduction = "sum" if num_items_in_batch is not None else "mean"
50
+ loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
51
+ if reduction == "sum":
52
+ loss = loss / num_items_in_batch
53
+ return loss
54
+
55
+
56
+ def ForCausalLMLoss(
57
+ logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs
58
+ ):
59
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
60
+ # logits = logits.float()
61
+ labels = labels.to(logits.device)
62
+ # Shift so that tokens < n predict n
63
+ shift_logits = logits[..., :-1, :].contiguous()
64
+ shift_labels = labels[..., 1:].contiguous()
65
+
66
+ # Flatten the tokens
67
+ shift_logits = shift_logits.view(-1, vocab_size)
68
+ shift_labels = shift_labels.view(-1)
69
+ # Enable model parallelism
70
+ shift_labels = shift_labels.to(shift_logits.device)
71
+ loss = fixed_cross_entropy(shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
72
+ return loss
73
+
74
+
75
+ def compute_kl_loss(logits, labels):
76
+ # import pdb;pdb.set_trace()
77
+ *_, vocab_size = logits.shape
78
+ # Convert logits to log probabilities
79
+ log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
80
+ # Convert labels to probabilities
81
+ target_probs = torch.nn.functional.softmax(labels, dim=-1)
82
+ # Define the KL Divergence loss function
83
+ loss_fct = nn.KLDivLoss(reduction='batchmean')
84
+ # Compute the loss
85
+ loss = loss_fct(log_probs.view(-1, vocab_size), target_probs.view(-1, vocab_size))
86
+ return loss
87
+
88
+
89
+ class Qwen2MLP(nn.Module):
90
+ def __init__(self, config):
91
+ super().__init__()
92
+ self.config = config
93
+ self.hidden_size = config.hidden_size
94
+ self.intermediate_size = config.intermediate_size
95
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
96
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
97
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
98
+ self.act_fn = ACT2FN[config.hidden_act]
99
+
100
+ def forward(self, x):
101
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
102
+ return down_proj
103
+
104
+
105
+ def rotate_half(x):
106
+ """Rotates half the hidden dims of the input."""
107
+ x1 = x[..., : x.shape[-1] // 2]
108
+ x2 = x[..., x.shape[-1] // 2 :]
109
+ return torch.cat((-x2, x1), dim=-1)
110
+
111
+
112
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
113
+ """Applies Rotary Position Embedding to the query and key tensors.
114
+
115
+ Args:
116
+ q (`torch.Tensor`): The query tensor.
117
+ k (`torch.Tensor`): The key tensor.
118
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
119
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
120
+ position_ids (`torch.Tensor`, *optional*):
121
+ Deprecated and unused.
122
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
123
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
124
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
125
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
126
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
127
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
128
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
129
+ Returns:
130
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
131
+ """
132
+ cos = cos.unsqueeze(unsqueeze_dim)
133
+ sin = sin.unsqueeze(unsqueeze_dim)
134
+ q_embed = (q * cos) + (rotate_half(q) * sin)
135
+ k_embed = (k * cos) + (rotate_half(k) * sin)
136
+ return q_embed, k_embed
137
+
138
+
139
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
140
+ """
141
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
142
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
143
+ """
144
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
145
+ if n_rep == 1:
146
+ return hidden_states
147
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
148
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
149
+
150
+
151
+ def eager_attention_forward(
152
+ module: nn.Module,
153
+ query: torch.Tensor,
154
+ key: torch.Tensor,
155
+ value: torch.Tensor,
156
+ attention_mask: Optional[torch.Tensor],
157
+ scaling: float,
158
+ dropout: float = 0.0,
159
+ **kwargs,
160
+ ):
161
+ key_states = repeat_kv(key, module.num_key_value_groups)
162
+ value_states = repeat_kv(value, module.num_key_value_groups)
163
+
164
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
165
+ if attention_mask is not None:
166
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
167
+ attn_weights = attn_weights + causal_mask
168
+
169
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
170
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
171
+ attn_output = torch.matmul(attn_weights, value_states)
172
+ attn_output = attn_output.transpose(1, 2).contiguous()
173
+
174
+ return attn_output, attn_weights
175
+
176
+
177
+ class Qwen2Attention(nn.Module):
178
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
179
+
180
+ def __init__(self, config: Qwen2Config, layer_idx: int):
181
+ super().__init__()
182
+ self.config = config
183
+ self.layer_idx = layer_idx
184
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
185
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
186
+ self.scaling = self.head_dim**-0.5
187
+ self.attention_dropout = config.attention_dropout
188
+ self.is_causal = True
189
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True)
190
+ self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
191
+ self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
192
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
193
+
194
+ def forward(
195
+ self,
196
+ hidden_states: torch.Tensor,
197
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
198
+ attention_mask: Optional[torch.Tensor],
199
+ past_key_value: Optional[Cache] = None,
200
+ cache_position: Optional[torch.LongTensor] = None,
201
+ **kwargs: Unpack[FlashAttentionKwargs],
202
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
203
+ input_shape = hidden_states.shape[:-1]
204
+ hidden_shape = (*input_shape, -1, self.head_dim)
205
+
206
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
207
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
208
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
209
+
210
+ cos, sin = position_embeddings
211
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
212
+
213
+ if past_key_value is not None:
214
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
215
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
216
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
217
+
218
+ sliding_window = None
219
+ if (
220
+ self.config.use_sliding_window
221
+ and getattr(self.config, "sliding_window", None) is not None
222
+ and self.layer_idx >= self.config.max_window_layers
223
+ ):
224
+ sliding_window = self.config.sliding_window
225
+
226
+ attention_interface: Callable = eager_attention_forward
227
+ if self.config._attn_implementation != "eager":
228
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
229
+ logger.warning_once(
230
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
231
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
232
+ )
233
+ else:
234
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
235
+
236
+ attn_output, attn_weights = attention_interface(
237
+ self,
238
+ query_states,
239
+ key_states,
240
+ value_states,
241
+ attention_mask,
242
+ dropout=0.0 if not self.training else self.attention_dropout,
243
+ scaling=self.scaling,
244
+ sliding_window=sliding_window, # main diff with Llama
245
+ **kwargs,
246
+ )
247
+
248
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
249
+ attn_output = self.o_proj(attn_output)
250
+ return attn_output, attn_weights
251
+
252
+
253
+ class Qwen2RMSNorm(nn.Module):
254
+ def __init__(self, hidden_size, eps=1e-6):
255
+ """
256
+ Qwen2RMSNorm is equivalent to T5LayerNorm
257
+ """
258
+ super().__init__()
259
+ self.weight = nn.Parameter(torch.ones(hidden_size))
260
+ self.variance_epsilon = eps
261
+
262
+ def forward(self, hidden_states):
263
+ input_dtype = hidden_states.dtype
264
+ hidden_states = hidden_states.to(torch.float32)
265
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
266
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
267
+ return self.weight * hidden_states.to(input_dtype)
268
+
269
+ def extra_repr(self):
270
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
271
+
272
+
273
+ class Qwen2DecoderLayer(nn.Module):
274
+ def __init__(self, config: Qwen2Config, layer_idx: int):
275
+ super().__init__()
276
+ self.hidden_size = config.hidden_size
277
+ self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx)
278
+ self.mlp = Qwen2MLP(config)
279
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
280
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
281
+ if config.sliding_window and config._attn_implementation != "flash_attention_2":
282
+ logger.warning_once(
283
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
284
+ "unexpected results may be encountered."
285
+ )
286
+
287
+ def forward(
288
+ self,
289
+ hidden_states: torch.Tensor,
290
+ attention_mask: Optional[torch.Tensor] = None,
291
+ position_ids: Optional[torch.LongTensor] = None,
292
+ past_key_value: Optional[Cache] = None,
293
+ output_attentions: Optional[bool] = False,
294
+ use_cache: Optional[bool] = False,
295
+ cache_position: Optional[torch.LongTensor] = None,
296
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
297
+ **kwargs: Unpack[FlashAttentionKwargs],
298
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
299
+ residual = hidden_states
300
+
301
+ hidden_states = self.input_layernorm(hidden_states)
302
+
303
+ # Self Attention
304
+ hidden_states, self_attn_weights = self.self_attn(
305
+ hidden_states=hidden_states,
306
+ attention_mask=attention_mask,
307
+ position_ids=position_ids,
308
+ past_key_value=past_key_value,
309
+ output_attentions=output_attentions,
310
+ use_cache=use_cache,
311
+ cache_position=cache_position,
312
+ position_embeddings=position_embeddings,
313
+ **kwargs,
314
+ )
315
+ hidden_states = residual + hidden_states
316
+
317
+ # Fully Connected
318
+ residual = hidden_states
319
+ hidden_states = self.post_attention_layernorm(hidden_states)
320
+ hidden_states = self.mlp(hidden_states)
321
+ hidden_states = residual + hidden_states
322
+
323
+ outputs = (hidden_states,)
324
+ if output_attentions:
325
+ outputs += (self_attn_weights,)
326
+
327
+ return outputs
328
+
329
+
330
+ class Qwen2RotaryEmbedding(nn.Module):
331
+ def __init__(self, config: Qwen2Config, device=None):
332
+ super().__init__()
333
+ # BC: "rope_type" was originally "type"
334
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
335
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
336
+ else:
337
+ self.rope_type = "default"
338
+ self.max_seq_len_cached = config.max_position_embeddings
339
+ self.original_max_seq_len = config.max_position_embeddings
340
+
341
+ self.config = config
342
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
343
+
344
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
345
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
346
+ self.original_inv_freq = self.inv_freq
347
+
348
+ def _dynamic_frequency_update(self, position_ids, device):
349
+ """
350
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
351
+ 1 - growing beyond the cached sequence length (allow scaling)
352
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
353
+ """
354
+ seq_len = torch.max(position_ids) + 1
355
+ if seq_len > self.max_seq_len_cached: # growth
356
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
357
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
358
+ self.max_seq_len_cached = seq_len
359
+
360
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
361
+ # This .to() is needed if the model has been moved to a device after being initialized (because
362
+ # the buffer is automatically moved, but not the original copy)
363
+ self.original_inv_freq = self.original_inv_freq.to(device)
364
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
365
+ self.max_seq_len_cached = self.original_max_seq_len
366
+
367
+ @torch.no_grad()
368
+ def forward(self, x, position_ids):
369
+ if "dynamic" in self.rope_type:
370
+ self._dynamic_frequency_update(position_ids, device=x.device)
371
+
372
+ # Core RoPE block
373
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
374
+ position_ids_expanded = position_ids[:, None, :].float()
375
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
376
+ device_type = x.device.type
377
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
378
+ with torch.autocast(device_type=device_type, enabled=False):
379
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
380
+ emb = torch.cat((freqs, freqs), dim=-1)
381
+ cos = emb.cos()
382
+ sin = emb.sin()
383
+
384
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
385
+ cos = cos * self.attention_scaling
386
+ sin = sin * self.attention_scaling
387
+
388
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
389
+
390
+
391
+ QWEN2_START_DOCSTRING = r"""
392
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
393
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
394
+ etc.)
395
+
396
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
397
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
398
+ and behavior.
399
+
400
+ Parameters:
401
+ config ([`Qwen2Config`]):
402
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
403
+ load the weights associated with the model, only the configuration. Check out the
404
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
405
+ """
406
+
407
+
408
+ @add_start_docstrings(
409
+ "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
410
+ QWEN2_START_DOCSTRING,
411
+ )
412
+ class Qwen2PreTrainedModel(PreTrainedModel):
413
+ config_class = Qwen2Config
414
+ base_model_prefix = "model"
415
+ supports_gradient_checkpointing = True
416
+ _no_split_modules = ["Qwen2DecoderLayer"]
417
+ _skip_keys_device_placement = ["past_key_values"]
418
+ _supports_flash_attn_2 = True
419
+ _supports_sdpa = True
420
+ _supports_flex_attn = True
421
+ _supports_cache_class = True
422
+ _supports_quantized_cache = True
423
+ _supports_static_cache = True
424
+
425
+ def _init_weights(self, module):
426
+ std = self.config.initializer_range
427
+ if isinstance(module, nn.Linear):
428
+ module.weight.data.normal_(mean=0.0, std=std)
429
+ if module.bias is not None:
430
+ module.bias.data.zero_()
431
+ elif isinstance(module, nn.Embedding):
432
+ module.weight.data.normal_(mean=0.0, std=std)
433
+ if module.padding_idx is not None:
434
+ module.weight.data[module.padding_idx].zero_()
435
+
436
+
437
+ QWEN2_INPUTS_DOCSTRING = r"""
438
+ Args:
439
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
440
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
441
+ it.
442
+
443
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
444
+ [`PreTrainedTokenizer.__call__`] for details.
445
+
446
+ [What are input IDs?](../glossary#input-ids)
447
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
448
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
449
+
450
+ - 1 for tokens that are **not masked**,
451
+ - 0 for tokens that are **masked**.
452
+
453
+ [What are attention masks?](../glossary#attention-mask)
454
+
455
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
456
+ [`PreTrainedTokenizer.__call__`] for details.
457
+
458
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
459
+ `past_key_values`).
460
+
461
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
462
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
463
+ information on the default strategy.
464
+
465
+ - 1 indicates the head is **not masked**,
466
+ - 0 indicates the head is **masked**.
467
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
468
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
469
+ config.n_positions - 1]`.
470
+
471
+ [What are position IDs?](../glossary#position-ids)
472
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
473
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
474
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
475
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
476
+
477
+ Two formats are allowed:
478
+ - a [`~cache_utils.Cache`] instance, see our
479
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
480
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
481
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
482
+ cache format.
483
+
484
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
485
+ legacy cache format will be returned.
486
+
487
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
488
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
489
+ of shape `(batch_size, sequence_length)`.
490
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
491
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
492
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
493
+ model's internal embedding lookup matrix.
494
+ use_cache (`bool`, *optional*):
495
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
496
+ `past_key_values`).
497
+ output_attentions (`bool`, *optional*):
498
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
499
+ tensors for more detail.
500
+ output_hidden_states (`bool`, *optional*):
501
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
502
+ more detail.
503
+ return_dict (`bool`, *optional*):
504
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
505
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
506
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
507
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
508
+ the complete sequence length.
509
+ """
510
+
511
+
512
+ @add_start_docstrings(
513
+ "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
514
+ QWEN2_START_DOCSTRING,
515
+ )
516
+ class Qwen2Model(Qwen2PreTrainedModel):
517
+ """
518
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
519
+
520
+ Args:
521
+ config: Qwen2Config
522
+ """
523
+
524
+ def __init__(self, config: Qwen2Config):
525
+ super().__init__(config)
526
+ self.padding_idx = config.pad_token_id
527
+ self.vocab_size = config.vocab_size
528
+
529
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
530
+ self.layers = nn.ModuleList(
531
+ [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
532
+ )
533
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
534
+ self.rotary_emb = Qwen2RotaryEmbedding(config=config)
535
+ self.gradient_checkpointing = False
536
+
537
+ self.audio_model = AudioEncoder(config)
538
+ self.audio_projection = ResamplerProjector(512, config.hidden_size)
539
+
540
+ # Initialize weights and apply final processing
541
+ self.post_init()
542
+
543
+ def get_input_embeddings(self):
544
+ return self.embed_tokens
545
+
546
+ def set_input_embeddings(self, value):
547
+ self.embed_tokens = value
548
+
549
+ @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
550
+ def forward(
551
+ self,
552
+ input_ids: torch.LongTensor = None,
553
+ attention_mask: Optional[torch.Tensor] = None,
554
+ audios: Optional[torch.FloatTensor] = None,
555
+ audio_indices: Optional[torch.LongTensor] = None,
556
+ position_ids: Optional[torch.LongTensor] = None,
557
+ past_key_values: Optional[Cache] = None,
558
+ inputs_embeds: Optional[torch.FloatTensor] = None,
559
+ use_cache: Optional[bool] = None,
560
+ output_attentions: Optional[bool] = None,
561
+ output_hidden_states: Optional[bool] = None,
562
+ return_dict: Optional[bool] = None,
563
+ cache_position: Optional[torch.LongTensor] = None,
564
+ layer_idxs = None,
565
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
566
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
567
+ if (past_key_values is None or len(past_key_values) == 0) and audios is not None:
568
+ audio_embeds, audio_lengths = self.audio_model(audios)
569
+ # if torch.distributed.get_rank() == 0:
570
+ # print(f"audio_embeds {audio_embeds.size()}")
571
+ assert audio_embeds.shape[0] == len(audios)
572
+ fake_audios = None
573
+
574
+ audio_embeds = self.audio_projection(audio_embeds)
575
+
576
+ # torch.set_printoptions(threshold=100_000)
577
+ # if torch.distributed.get_rank() == 0:
578
+ # print(f"audio_embeds {audio_embeds.size()}")
579
+ # print(f"audio_embeds {audio_embeds.sum()}")
580
+ # print(f"audios {[x.size() for x in audios]}")
581
+ # print(f"audios {[x.sum() for x in audios]}")
582
+ # print(f"input_ids {input_ids.size()}")
583
+ # print(f"input_ids {input_ids.sum()}")
584
+ # # print(f"input_ids {input_ids}")
585
+ # print(f"audio_indices {[x.size() for x in audio_indices]}")
586
+ # print(f"audio_indices {[x.sum() for x in audio_indices]}")
587
+ # # print(f"audio_indices {audio_indices}")
588
+
589
+ elif self.training:
590
+ device = self.get_input_embeddings().weight.data.device
591
+ dtype = self.get_input_embeddings().weight.data.dtype
592
+ fake_audios = torch.ones((1, 1, 560), dtype=dtype, device=device)
593
+ audio_embeds, audio_lengths = self.audio_model(fake_audios)
594
+ audio_embeds = self.audio_projection(audio_embeds)
595
+
596
+ else:
597
+ fake_audios = None
598
+ audio_embeds = None
599
+
600
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
601
+ output_hidden_states = (
602
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
603
+ )
604
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
605
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
606
+
607
+ if (input_ids is None) ^ (inputs_embeds is not None):
608
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
609
+
610
+ if self.gradient_checkpointing and self.training and use_cache:
611
+ logger.warning_once(
612
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
613
+ )
614
+ use_cache = False
615
+
616
+ if inputs_embeds is None:
617
+ inputs_embeds = self.embed_tokens(input_ids)
618
+
619
+ if fake_audios is not None:
620
+ inputs_embeds = inputs_embeds + audio_embeds.mean() * 0.0
621
+ elif audio_embeds is not None:
622
+ inputs_embeds = inputs_embeds.clone()
623
+ for audio_embeds_, audio_lengths_, audio_indices_ in zip(audio_embeds, audio_lengths, audio_indices,):
624
+ # print(f"{audio_embeds_.size()=} {audio_lengths_=} {audio_indices_.size()=}")
625
+ audio_embeds_ = audio_embeds_[:audio_lengths_, ...]
626
+ audio_embeds_ = audio_embeds_.to(inputs_embeds.device)
627
+ indices_b, indices_s = audio_indices_.to(inputs_embeds.device).unbind(dim=0)
628
+ inputs_embeds[indices_b.view(-1), indices_s.view(-1)] = audio_embeds_.view(-1, audio_embeds_.shape[-1])
629
+ # inputs_embeds = inputs_embeds + audio_embeds.mean() * 0.0
630
+
631
+ if use_cache and past_key_values is None:
632
+ past_key_values = DynamicCache()
633
+
634
+ if cache_position is None:
635
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
636
+ cache_position = torch.arange(
637
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
638
+ )
639
+
640
+ if position_ids is None:
641
+ position_ids = cache_position.unsqueeze(0)
642
+
643
+ causal_mask = self._update_causal_mask(
644
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
645
+ )
646
+
647
+ hidden_states = inputs_embeds
648
+
649
+ # create position embeddings to be shared across the decoder layers
650
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
651
+
652
+ # decoder layers
653
+ all_hidden_states = () if output_hidden_states else None
654
+ all_self_attns = () if output_attentions else None
655
+
656
+ if layer_idxs is None:
657
+ layer_idxs = list(range(self.config.num_hidden_layers))
658
+ layers = [self.layers[layer_idx] for layer_idx in layer_idxs]
659
+
660
+ for decoder_layer in layers:
661
+ if output_hidden_states:
662
+ all_hidden_states += (hidden_states,)
663
+
664
+ if self.gradient_checkpointing and self.training:
665
+ layer_outputs = self._gradient_checkpointing_func(
666
+ decoder_layer.__call__,
667
+ hidden_states,
668
+ causal_mask,
669
+ position_ids,
670
+ past_key_values,
671
+ output_attentions,
672
+ use_cache,
673
+ cache_position,
674
+ position_embeddings,
675
+ **flash_attn_kwargs,
676
+ )
677
+ else:
678
+ layer_outputs = decoder_layer(
679
+ hidden_states,
680
+ attention_mask=causal_mask,
681
+ position_ids=position_ids,
682
+ past_key_value=past_key_values,
683
+ output_attentions=output_attentions,
684
+ use_cache=use_cache,
685
+ cache_position=cache_position,
686
+ position_embeddings=position_embeddings,
687
+ **flash_attn_kwargs,
688
+ )
689
+
690
+ hidden_states = layer_outputs[0]
691
+
692
+ if output_attentions:
693
+ all_self_attns += (layer_outputs[1],)
694
+
695
+ hidden_states = self.norm(hidden_states)
696
+
697
+ # add hidden states from the last decoder layer
698
+ if output_hidden_states:
699
+ all_hidden_states += (hidden_states,)
700
+
701
+ output = BaseModelOutputWithPast(
702
+ last_hidden_state=hidden_states,
703
+ past_key_values=past_key_values if use_cache else None,
704
+ hidden_states=all_hidden_states,
705
+ attentions=all_self_attns,
706
+ )
707
+ return output if return_dict else output.to_tuple()
708
+
709
+ def _update_causal_mask(
710
+ self,
711
+ attention_mask: torch.Tensor,
712
+ input_tensor: torch.Tensor,
713
+ cache_position: torch.Tensor,
714
+ past_key_values: Cache,
715
+ output_attentions: bool,
716
+ ):
717
+ if self.config._attn_implementation == "flash_attention_2":
718
+ if attention_mask is not None and (attention_mask == 0.0).any():
719
+ return attention_mask
720
+ return None
721
+
722
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
723
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
724
+ # to infer the attention mask.
725
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
726
+ using_static_cache = isinstance(past_key_values, StaticCache)
727
+
728
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
729
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
730
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
731
+ attention_mask,
732
+ inputs_embeds=input_tensor,
733
+ past_key_values_length=past_seen_tokens,
734
+ is_training=self.training,
735
+ ):
736
+ return None
737
+
738
+ dtype, device = input_tensor.dtype, input_tensor.device
739
+ sequence_length = input_tensor.shape[1]
740
+ if using_static_cache:
741
+ target_length = past_key_values.get_max_cache_shape()
742
+ else:
743
+ target_length = (
744
+ attention_mask.shape[-1]
745
+ if isinstance(attention_mask, torch.Tensor)
746
+ else past_seen_tokens + sequence_length + 1
747
+ )
748
+
749
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
750
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
751
+ attention_mask,
752
+ sequence_length=sequence_length,
753
+ target_length=target_length,
754
+ dtype=dtype,
755
+ device=device,
756
+ cache_position=cache_position,
757
+ batch_size=input_tensor.shape[0],
758
+ )
759
+
760
+ if (
761
+ self.config._attn_implementation == "sdpa"
762
+ and attention_mask is not None
763
+ and attention_mask.device.type == "cuda"
764
+ and not output_attentions
765
+ ):
766
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
767
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
768
+ # Details: https://github.com/pytorch/pytorch/issues/110213
769
+ min_dtype = torch.finfo(dtype).min
770
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
771
+
772
+ return causal_mask
773
+
774
+ @staticmethod
775
+ def _prepare_4d_causal_attention_mask_with_cache_position(
776
+ attention_mask: torch.Tensor,
777
+ sequence_length: int,
778
+ target_length: int,
779
+ dtype: torch.dtype,
780
+ device: torch.device,
781
+ cache_position: torch.Tensor,
782
+ batch_size: int,
783
+ **kwargs,
784
+ ):
785
+ """
786
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
787
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
788
+
789
+ Args:
790
+ attention_mask (`torch.Tensor`):
791
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
792
+ `(batch_size, 1, query_length, key_value_length)`.
793
+ sequence_length (`int`):
794
+ The sequence length being processed.
795
+ target_length (`int`):
796
+ The target length: when generating with static cache, the mask should be as long as the static cache,
797
+ to account for the 0 padding, the part of the cache that is not filled yet.
798
+ dtype (`torch.dtype`):
799
+ The dtype to use for the 4D attention mask.
800
+ device (`torch.device`):
801
+ The device to plcae the 4D attention mask on.
802
+ cache_position (`torch.Tensor`):
803
+ Indices depicting the position of the input sequence tokens in the sequence.
804
+ batch_size (`torch.Tensor`):
805
+ Batch size.
806
+ """
807
+ if attention_mask is not None and attention_mask.dim() == 4:
808
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
809
+ causal_mask = attention_mask
810
+ else:
811
+ min_dtype = torch.finfo(dtype).min
812
+ causal_mask = torch.full(
813
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
814
+ )
815
+ if sequence_length != 1:
816
+ causal_mask = torch.triu(causal_mask, diagonal=1)
817
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
818
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
819
+ if attention_mask is not None:
820
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
821
+ mask_length = attention_mask.shape[-1]
822
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
823
+ padding_mask = padding_mask == 0
824
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
825
+ padding_mask, min_dtype
826
+ )
827
+
828
+ return causal_mask
829
+
830
+
831
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
832
+
833
+
834
+ class Qwen2MTPSenseVoiceForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
835
+ _tied_weights_keys = ["lm_head.weight"]
836
+ _tp_plan = {"lm_head": "colwise_rep"}
837
+
838
+ def __init__(self, config):
839
+ super().__init__(config)
840
+ self.model = Qwen2Model(config)
841
+ self.vocab_size = config.vocab_size
842
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
843
+
844
+ self.mtp_projs = nn.ModuleList(
845
+ [nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False) for _ in range(self.config.num_nextn_predict_layers)]
846
+ )
847
+
848
+ self.mtp_embed_norms = nn.ModuleList([Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) for _ in range(self.config.num_nextn_predict_layers)])
849
+ self.mtp_hidden_norms = nn.ModuleList([Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) for _ in range(self.config.num_nextn_predict_layers)])
850
+
851
+ # Initialize weights and apply final processing
852
+ self.post_init()
853
+
854
+ def get_input_embeddings(self):
855
+ return self.model.embed_tokens
856
+
857
+ def set_input_embeddings(self, value):
858
+ self.model.embed_tokens = value
859
+
860
+ def get_output_embeddings(self):
861
+ return self.lm_head
862
+
863
+ def set_output_embeddings(self, new_embeddings):
864
+ self.lm_head = new_embeddings
865
+
866
+ def set_decoder(self, decoder):
867
+ self.model = decoder
868
+
869
+ def get_decoder(self):
870
+ return self.model
871
+
872
+ def mtp_forward(
873
+ self,
874
+ mtp_idx,
875
+ input_ids: torch.LongTensor = None,
876
+ hidden_states: torch.Tensor = None,
877
+ attention_mask: Optional[torch.Tensor] = None,
878
+ position_ids: Optional[torch.LongTensor] = None,
879
+ past_key_values: Optional[Cache] = None,
880
+ inputs_embeds: Optional[torch.FloatTensor] = None,
881
+ labels: Optional[torch.LongTensor] = None,
882
+ kl_labels: Optional[torch.Tensor] = None,
883
+ use_cache: Optional[bool] = None,
884
+ output_attentions: Optional[bool] = None,
885
+ output_hidden_states: Optional[bool] = None,
886
+ return_dict: Optional[bool] = None,
887
+ cache_position: Optional[torch.LongTensor] = None,
888
+ num_logits_to_keep: int = 0,
889
+ **kwargs: Unpack[KwargsForCausalLM],
890
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
891
+
892
+ if (input_ids is None) ^ (inputs_embeds is not None):
893
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
894
+
895
+ if inputs_embeds is None:
896
+ inputs_embeds = self.model.embed_tokens(input_ids)
897
+ # inputs_embeds = inputs_embeds.to(hidden_states.device)
898
+
899
+ inputs_embeds = torch.cat(
900
+ (
901
+ self.mtp_embed_norms[mtp_idx](inputs_embeds),
902
+ self.mtp_hidden_norms[mtp_idx](hidden_states),
903
+ ),
904
+ dim=-1,
905
+ )
906
+
907
+ inputs_embeds = self.mtp_projs[mtp_idx](inputs_embeds)
908
+
909
+ outputs = self.model(
910
+ input_ids=None,
911
+ attention_mask=attention_mask,
912
+ position_ids=position_ids,
913
+ past_key_values=past_key_values,
914
+ inputs_embeds=inputs_embeds,
915
+ use_cache=use_cache,
916
+ output_attentions=output_attentions,
917
+ output_hidden_states=output_hidden_states,
918
+ return_dict=return_dict,
919
+ cache_position=cache_position,
920
+ layer_idxs=[self.config.num_hidden_layers - self.config.num_nextn_predict_layers + mtp_idx],
921
+ **kwargs,
922
+ )
923
+
924
+ hidden_states = outputs[0]
925
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
926
+
927
+ if labels is not None:
928
+ loss = []
929
+ # ce_loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
930
+ ce_loss = ForCausalLMLoss(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
931
+
932
+ loss += [ce_loss]
933
+
934
+ if False:
935
+ kl_logits = logits.contiguous()
936
+ kl_labels = kl_labels.contiguous()
937
+ kl_loss = compute_kl_loss(kl_logits, kl_labels)
938
+
939
+ kl_loss_weight = 1
940
+ loss += [kl_loss_weight * kl_loss]
941
+
942
+ # if self.training and torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
943
+ # with torch.no_grad():
944
+ # logger.info(f"\tMTP {mtp_idx=} {loss=}")
945
+ else:
946
+ loss = None
947
+
948
+ return outputs, logits, loss
949
+
950
+ @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
951
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
952
+ def forward(
953
+ self,
954
+ input_ids: torch.LongTensor = None,
955
+ attention_mask: Optional[torch.Tensor] = None,
956
+ audios: Optional[torch.FloatTensor] = None,
957
+ audio_indices: Optional[torch.LongTensor] = None,
958
+ position_ids: Optional[torch.LongTensor] = None,
959
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
960
+ inputs_embeds: Optional[torch.FloatTensor] = None,
961
+ labels: Optional[torch.LongTensor] = None,
962
+ use_cache: Optional[bool] = None,
963
+ output_attentions: Optional[bool] = None,
964
+ output_hidden_states: Optional[bool] = None,
965
+ return_dict: Optional[bool] = None,
966
+ cache_position: Optional[torch.LongTensor] = None,
967
+ num_logits_to_keep: int = 0,
968
+ **kwargs: Unpack[KwargsForCausalLM],
969
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
970
+ r"""
971
+ Args:
972
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
973
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
974
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
975
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
976
+
977
+ num_logits_to_keep (`int`, *optional*):
978
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
979
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
980
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
981
+
982
+ Returns:
983
+
984
+ Example:
985
+
986
+ ```python
987
+ >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
988
+
989
+ >>> model = Qwen2ForCausalLM.from_pretrained("meta-qwen2/Qwen2-2-7b-hf")
990
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-qwen2/Qwen2-2-7b-hf")
991
+
992
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
993
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
994
+
995
+ >>> # Generate
996
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
997
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
998
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
999
+ ```"""
1000
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1001
+ output_hidden_states = (
1002
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1003
+ )
1004
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1005
+
1006
+ # ===============================================================================================
1007
+ if not self.training:
1008
+ if input_ids is not None:
1009
+ num_input_tokens = input_ids.size(1)
1010
+ if inputs_embeds is not None:
1011
+ num_input_tokens = inputs_embeds.size(1)
1012
+
1013
+ if use_cache:
1014
+ if self.input_ids is None and self.inputs_embeds is None:
1015
+ if input_ids is not None:
1016
+ self.input_ids = input_ids
1017
+ if inputs_embeds is not None:
1018
+ self.inputs_embeds = inputs_embeds
1019
+ if position_ids is not None:
1020
+ self.position_ids = position_ids
1021
+
1022
+ else:
1023
+ if input_ids is not None:
1024
+ self.input_ids = torch.cat([self.input_ids, input_ids], dim=1)
1025
+ if inputs_embeds is not None:
1026
+ self.inputs_embeds = torch.cat([self.inputs_embeds, inputs_embeds], dim=1)
1027
+ if position_ids is not None:
1028
+ self.position_ids = torch.cat([self.position_ids, position_ids], dim=1)
1029
+
1030
+ else:
1031
+ self.input_ids = input_ids
1032
+ self.inputs_embeds = inputs_embeds
1033
+ self.position_ids = position_ids
1034
+
1035
+ self.attention_mask = attention_mask
1036
+
1037
+ if self.num_prefill_tokens < 0:
1038
+ self.num_prefill_tokens = self.input_ids.size(1)
1039
+ num_decode_tokens = self.input_ids.size(1) - self.num_prefill_tokens
1040
+
1041
+ if self.mtp_inference_mode[num_decode_tokens] == "M":
1042
+ self.mtp_idx = -1
1043
+ elif self.mtp_inference_mode[num_decode_tokens] == "m":
1044
+ if self.mtp_inference_mode[num_decode_tokens - 1] == "M":
1045
+ self.mtp_idx = 0
1046
+ else:
1047
+ pass
1048
+
1049
+ # if True:
1050
+ if False:
1051
+ print("=" * 100)
1052
+ print(f"{self.mtp_idx=}")
1053
+ print(f"{self.num_prefill_tokens=}")
1054
+ print(f"{num_decode_tokens=}")
1055
+ print(f"{self.mtp_inference_mode=}")
1056
+ if self.input_ids is not None:
1057
+ print(f"{self.input_ids.size()=}")
1058
+ if self.inputs_embeds is not None:
1059
+ print(f"{self.inputs_embeds.size()=}")
1060
+ if self.hidden_states[self.mtp_idx] is not None:
1061
+ print(f"{self.hidden_states[self.mtp_idx].size()=}")
1062
+
1063
+
1064
+ if self.mtp_idx > -1 and self.mtp_idx < self.config.num_nextn_predict_layers and num_input_tokens == 1:
1065
+ layer_idx = self.config.num_hidden_layers - self.config.num_nextn_predict_layers + self.mtp_idx
1066
+
1067
+ if use_cache:
1068
+ if len(past_key_values.key_cache) > layer_idx:
1069
+ num_seen_tokens = past_key_values.key_cache[layer_idx].size(2)
1070
+ else:
1071
+ num_seen_tokens = 0
1072
+ else:
1073
+ num_seen_tokens = 0
1074
+
1075
+ hidden_states = self.hidden_states[self.mtp_idx][:, num_seen_tokens:, :]
1076
+
1077
+ if self.input_ids is not None:
1078
+ input_ids = self.input_ids[:, num_seen_tokens + self.mtp_idx + 1:]
1079
+ if self.inputs_embeds is not None:
1080
+ inputs_embeds = self.inputs_embeds[:, num_seen_tokens + self.mtp_idx + 1:, :]
1081
+ if self.position_ids is not None:
1082
+ position_ids = self.position_ids[:, num_seen_tokens + self.mtp_idx + 1:]
1083
+ attention_mask = self.attention_mask[:, num_seen_tokens + self.mtp_idx + 1:]
1084
+
1085
+ if False:
1086
+ # if True:
1087
+ print("=" * 100)
1088
+ print(f"{self.mtp_idx=}")
1089
+ print(f"{layer_idx=}")
1090
+ if input_ids is not None:
1091
+ print(f"{input_ids.size()=} {input_ids=}")
1092
+ if inputs_embeds is not None:
1093
+ print(f"{inputs_embeds.size()=} {inputs_embeds=}")
1094
+ print(f"{hidden_states.size()=} {hidden_states=}")
1095
+ if attention_mask is not None:
1096
+ print(f"{attention_mask.size()=} {attention_mask=}")
1097
+ if position_ids is not None:
1098
+ print(f"{position_ids.size()=} {position_ids=}")
1099
+ if use_cache and len(past_key_values.key_cache) > layer_idx:
1100
+ print(f"{past_key_values.key_cache[layer_idx].size()=}")
1101
+ print(f"{use_cache=}")
1102
+ print(f"{num_logits_to_keep=}")
1103
+ print(f"{output_attentions=}")
1104
+ print(f"{output_hidden_states=}")
1105
+ print(f"{cache_position=}")
1106
+
1107
+ mtp_outputs, logits, _ = self.mtp_forward(
1108
+ self.mtp_idx,
1109
+ input_ids=input_ids,
1110
+ hidden_states=hidden_states,
1111
+ attention_mask=attention_mask,
1112
+ position_ids=position_ids,
1113
+ past_key_values=past_key_values,
1114
+ inputs_embeds=inputs_embeds,
1115
+ labels=None,
1116
+ kl_labels=None,
1117
+ use_cache=use_cache,
1118
+ output_attentions=output_attentions,
1119
+ output_hidden_states=output_hidden_states,
1120
+ return_dict=return_dict,
1121
+ cache_position=cache_position,
1122
+ num_logits_to_keep=num_logits_to_keep,
1123
+ **kwargs,
1124
+ )
1125
+ hidden_states = mtp_outputs.last_hidden_state
1126
+
1127
+ self.mtp_idx += 1
1128
+ if use_cache:
1129
+ if self.hidden_states[self.mtp_idx] is None:
1130
+ self.hidden_states[self.mtp_idx] = hidden_states
1131
+ else:
1132
+ self.hidden_states[self.mtp_idx] = torch.cat([self.hidden_states[self.mtp_idx], hidden_states], dim=1)
1133
+
1134
+ else:
1135
+ self.hidden_states[self.mtp_idx] = hidden_states
1136
+
1137
+ return CausalLMOutputWithPast(
1138
+ loss=None,
1139
+ logits=logits,
1140
+ past_key_values=past_key_values,
1141
+ hidden_states=mtp_outputs.hidden_states,
1142
+ attentions=mtp_outputs.attentions,
1143
+ )
1144
+
1145
+ if use_cache and past_key_values is not None:
1146
+ if len(past_key_values.key_cache) > 0:
1147
+ # print(f"{past_key_values.key_cache[0].size()=}")
1148
+ num_seen_tokens = past_key_values.key_cache[0].size(2)
1149
+ else:
1150
+ num_seen_tokens = 0
1151
+ else:
1152
+ num_seen_tokens = 0
1153
+
1154
+ if self.input_ids is not None:
1155
+ input_ids = self.input_ids[:, num_seen_tokens:]
1156
+ if self.inputs_embeds is not None:
1157
+ inputs_embeds = self.inputs_embeds[:, num_seen_tokens:, :]
1158
+ if self.position_ids is not None:
1159
+ position_ids = self.position_ids[:, num_seen_tokens:]
1160
+ attention_mask = attention_mask
1161
+
1162
+ # ===============================================================================================
1163
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1164
+ outputs = self.model(
1165
+ input_ids=input_ids,
1166
+ attention_mask=attention_mask,
1167
+ audios=audios,
1168
+ audio_indices=audio_indices,
1169
+ position_ids=position_ids,
1170
+ past_key_values=past_key_values,
1171
+ inputs_embeds=inputs_embeds,
1172
+ use_cache=use_cache,
1173
+ output_attentions=output_attentions,
1174
+ output_hidden_states=output_hidden_states,
1175
+ return_dict=return_dict,
1176
+ cache_position=cache_position,
1177
+ layer_idxs=list(range(self.config.num_hidden_layers - self.config.num_nextn_predict_layers)),
1178
+ **kwargs,
1179
+ )
1180
+
1181
+ hidden_states = outputs[0]
1182
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1183
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
1184
+
1185
+ loss = None
1186
+ if labels is not None:
1187
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
1188
+ # loss = ForCausalLMLoss(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
1189
+ # if self.training and torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
1190
+ # with torch.no_grad():
1191
+ # logger.info(f"STP {loss=}")
1192
+
1193
+ # ===============================================================================================
1194
+ if labels is not None and self.config.num_nextn_predict_layers > 0:
1195
+
1196
+ if self.lm_head.weight.requires_grad and False:
1197
+ if inputs_embeds is None:
1198
+ inputs_embeds = self.model.embed_tokens(input_ids)
1199
+
1200
+ inputs_embeds = inputs_embeds
1201
+ hidden_states = hidden_states
1202
+ kl_labels = logits
1203
+
1204
+ else:
1205
+ with torch.no_grad():
1206
+ if inputs_embeds is None:
1207
+ inputs_embeds = self.model.embed_tokens(input_ids)
1208
+
1209
+ inputs_embeds = inputs_embeds.detach()
1210
+ hidden_states = hidden_states.detach()
1211
+ kl_labels = logits.detach()
1212
+
1213
+ if self.lm_head.weight.requires_grad:
1214
+ pass
1215
+ else:
1216
+ loss = 0.0
1217
+
1218
+ for mtp_idx in range(self.config.num_nextn_predict_layers):
1219
+
1220
+ # SFT with data packing
1221
+ if True:
1222
+ mtp_mask = position_ids > mtp_idx
1223
+ # input_ids = input_ids[mtp_mask].unsqueeze(0)
1224
+ inputs_embeds = inputs_embeds[mtp_mask].unsqueeze(0)
1225
+ if attention_mask is not None:
1226
+ attention_mask = attention_mask[mtp_mask].unsqueeze(0)
1227
+ if position_ids is not None:
1228
+ position_ids = position_ids[mtp_mask].unsqueeze(0)
1229
+ labels = labels[mtp_mask].unsqueeze(0)
1230
+ kl_labels = kl_labels[mtp_mask].unsqueeze(0)
1231
+
1232
+ mtp_mask = torch.cat((mtp_mask[:, 1:], mtp_mask[:, :1]), dim=1)
1233
+ hidden_states = hidden_states[mtp_mask].unsqueeze(0)
1234
+
1235
+ cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k = prepare_fa2_from_position_ids_for_mtp(position_ids, mtp_idx)
1236
+ # kwargs["cu_seq_lens_q"] = cu_seq_lens_q
1237
+ # kwargs["cu_seq_lens_k"] = cu_seq_lens_k
1238
+ # kwargs["max_length_q"] = max_length_q
1239
+ # kwargs["max_length_k"] = max_length_k
1240
+
1241
+ # print(f"{cu_seq_lens_q}")
1242
+ # print(f"{cu_seq_lens_k}")
1243
+ # print(f"{max_length_q}")
1244
+ # print(f"{max_length_k}")
1245
+
1246
+ mtp_outputs, _, mtp_loss = self.mtp_forward(
1247
+ mtp_idx,
1248
+ input_ids=None,
1249
+ hidden_states=hidden_states,
1250
+ attention_mask=attention_mask,
1251
+ position_ids=position_ids,
1252
+ past_key_values=past_key_values,
1253
+ inputs_embeds=inputs_embeds,
1254
+ labels=labels,
1255
+ kl_labels=kl_labels,
1256
+ use_cache=use_cache,
1257
+ output_attentions=output_attentions,
1258
+ output_hidden_states=output_hidden_states,
1259
+ return_dict=return_dict,
1260
+ cache_position=cache_position,
1261
+ num_logits_to_keep=num_logits_to_keep,
1262
+ cu_seq_lens_q=cu_seq_lens_q,
1263
+ cu_seq_lens_k=cu_seq_lens_k,
1264
+ max_length_q=max_length_q,
1265
+ max_length_k=max_length_k,
1266
+ **kwargs,
1267
+ )
1268
+
1269
+ loss += sum(mtp_loss) / self.config.num_nextn_predict_layers * self.config.mtp_loss_weight
1270
+
1271
+ hidden_states = mtp_outputs.last_hidden_state
1272
+
1273
+ if not self.training:
1274
+ self.mtp_idx = 0
1275
+
1276
+ if use_cache:
1277
+ if self.hidden_states[self.mtp_idx] is None:
1278
+ self.hidden_states[self.mtp_idx] = hidden_states
1279
+
1280
+ else:
1281
+ self.hidden_states[self.mtp_idx] = torch.cat([self.hidden_states[self.mtp_idx], hidden_states], dim=1)
1282
+
1283
+ else:
1284
+ self.hidden_states[self.mtp_idx] = hidden_states
1285
+
1286
+ # ===============================================================================================
1287
+
1288
+ if not return_dict:
1289
+ output = (logits,) + outputs[1:]
1290
+ return (loss,) + output if loss is not None else output
1291
+
1292
+ return CausalLMOutputWithPast(
1293
+ loss=loss,
1294
+ logits=logits,
1295
+ past_key_values=outputs.past_key_values,
1296
+ hidden_states=outputs.hidden_states,
1297
+ attentions=outputs.attentions,
1298
+ )
1299
+
1300
+ def _prepare_mtp_for_generation(
1301
+ self,
1302
+ mtp_inference_mode,
1303
+ max_new_tokens,
1304
+ ):
1305
+
1306
+ self.input_ids = None
1307
+ self.inputs_embeds = None
1308
+ self.hidden_states = [None] * (self.config.num_nextn_predict_layers + 1)
1309
+ self.position_ids = None
1310
+ self.attention_mask = None
1311
+
1312
+ self.mtp_idx = -1
1313
+ self.num_prefill_tokens = -1
1314
+
1315
+ assert isinstance(mtp_inference_mode, list)
1316
+ assert len(mtp_inference_mode) >= 2
1317
+ assert len(mtp_inference_mode) % 2 == 0
1318
+
1319
+ main_nums = mtp_inference_mode[::2]
1320
+ mtp_nums = mtp_inference_mode[1::2]
1321
+
1322
+ mtp_inference_mode = []
1323
+ while len(mtp_inference_mode) < max_new_tokens:
1324
+
1325
+ if len(mtp_nums) > 1:
1326
+ mtp_num = mtp_nums.pop(0)
1327
+ else:
1328
+ mtp_num = mtp_nums[0]
1329
+
1330
+ if len(main_nums) > 1:
1331
+ main_num = main_nums.pop(0)
1332
+ else:
1333
+ main_num = main_nums[0]
1334
+
1335
+ mtp_inference_mode += "M" * main_num + "m" * mtp_num
1336
+
1337
+ self.mtp_inference_mode = mtp_inference_mode
1338
+
1339
+ def _prepare_cache_for_generation(self, *args, **kwargs):
1340
+
1341
+ generation_config = args[0]
1342
+ mtp_inference_mode = getattr(generation_config, "mtp_inference_mode", [1, self.config.num_nextn_predict_layers])
1343
+ max_new_tokens = generation_config.max_new_tokens
1344
+
1345
+ self._prepare_mtp_for_generation(mtp_inference_mode, max_new_tokens)
1346
+
1347
+ return super()._prepare_cache_for_generation(*args, **kwargs)
1348
+
1349
+
1350
+ @add_start_docstrings(
1351
+ """
1352
+ The Qwen2 Model transformer with a sequence classification head on top (linear layer).
1353
+
1354
+ [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1355
+ (e.g. GPT-2) do.
1356
+
1357
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1358
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1359
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1360
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1361
+ each row of the batch).
1362
+ """,
1363
+ QWEN2_START_DOCSTRING,
1364
+ )
1365
+ class Qwen2ForSequenceClassification(Qwen2PreTrainedModel):
1366
+ def __init__(self, config):
1367
+ super().__init__(config)
1368
+ self.num_labels = config.num_labels
1369
+ self.model = Qwen2Model(config)
1370
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1371
+
1372
+ # Initialize weights and apply final processing
1373
+ self.post_init()
1374
+
1375
+ def get_input_embeddings(self):
1376
+ return self.model.embed_tokens
1377
+
1378
+ def set_input_embeddings(self, value):
1379
+ self.model.embed_tokens = value
1380
+
1381
+ @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1382
+ def forward(
1383
+ self,
1384
+ input_ids: Optional[torch.LongTensor] = None,
1385
+ attention_mask: Optional[torch.Tensor] = None,
1386
+ position_ids: Optional[torch.LongTensor] = None,
1387
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1388
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1389
+ labels: Optional[torch.LongTensor] = None,
1390
+ use_cache: Optional[bool] = None,
1391
+ output_attentions: Optional[bool] = None,
1392
+ output_hidden_states: Optional[bool] = None,
1393
+ return_dict: Optional[bool] = None,
1394
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1395
+ r"""
1396
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1397
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1398
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1399
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1400
+ """
1401
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1402
+
1403
+ transformer_outputs = self.model(
1404
+ input_ids,
1405
+ attention_mask=attention_mask,
1406
+ position_ids=position_ids,
1407
+ past_key_values=past_key_values,
1408
+ inputs_embeds=inputs_embeds,
1409
+ use_cache=use_cache,
1410
+ output_attentions=output_attentions,
1411
+ output_hidden_states=output_hidden_states,
1412
+ return_dict=return_dict,
1413
+ )
1414
+ hidden_states = transformer_outputs[0]
1415
+ logits = self.score(hidden_states)
1416
+
1417
+ if input_ids is not None:
1418
+ batch_size = input_ids.shape[0]
1419
+ else:
1420
+ batch_size = inputs_embeds.shape[0]
1421
+
1422
+ if self.config.pad_token_id is None and batch_size != 1:
1423
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1424
+ if self.config.pad_token_id is None:
1425
+ sequence_lengths = -1
1426
+ else:
1427
+ if input_ids is not None:
1428
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1429
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1430
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1431
+ sequence_lengths = sequence_lengths.to(logits.device)
1432
+ else:
1433
+ sequence_lengths = -1
1434
+
1435
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1436
+
1437
+ loss = None
1438
+ if labels is not None:
1439
+ loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
1440
+
1441
+ if not return_dict:
1442
+ output = (pooled_logits,) + transformer_outputs[1:]
1443
+ return ((loss,) + output) if loss is not None else output
1444
+
1445
+ return SequenceClassifierOutputWithPast(
1446
+ loss=loss,
1447
+ logits=pooled_logits,
1448
+ past_key_values=transformer_outputs.past_key_values,
1449
+ hidden_states=transformer_outputs.hidden_states,
1450
+ attentions=transformer_outputs.attentions,
1451
+ )
1452
+
1453
+
1454
+ @add_start_docstrings(
1455
+ """
1456
+ The Qwen2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states
1457
+ output) e.g. for Named-Entity-Recognition (NER) tasks.
1458
+ """,
1459
+ QWEN2_START_DOCSTRING,
1460
+ )
1461
+ class Qwen2ForTokenClassification(Qwen2PreTrainedModel):
1462
+ def __init__(self, config):
1463
+ super().__init__(config)
1464
+ self.num_labels = config.num_labels
1465
+ self.model = Qwen2Model(config)
1466
+ if getattr(config, "classifier_dropout", None) is not None:
1467
+ classifier_dropout = config.classifier_dropout
1468
+ elif getattr(config, "hidden_dropout", None) is not None:
1469
+ classifier_dropout = config.hidden_dropout
1470
+ else:
1471
+ classifier_dropout = 0.1
1472
+ self.dropout = nn.Dropout(classifier_dropout)
1473
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
1474
+
1475
+ # Initialize weights and apply final processing
1476
+ self.post_init()
1477
+
1478
+ def get_input_embeddings(self):
1479
+ return self.model.embed_tokens
1480
+
1481
+ def set_input_embeddings(self, value):
1482
+ self.model.embed_tokens = value
1483
+
1484
+ @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1485
+ @add_code_sample_docstrings(
1486
+ checkpoint=_CHECKPOINT_FOR_DOC,
1487
+ output_type=TokenClassifierOutput,
1488
+ config_class=_CONFIG_FOR_DOC,
1489
+ )
1490
+ def forward(
1491
+ self,
1492
+ input_ids: Optional[torch.LongTensor] = None,
1493
+ attention_mask: Optional[torch.Tensor] = None,
1494
+ position_ids: Optional[torch.LongTensor] = None,
1495
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1496
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1497
+ labels: Optional[torch.LongTensor] = None,
1498
+ use_cache: Optional[bool] = None,
1499
+ output_attentions: Optional[bool] = None,
1500
+ output_hidden_states: Optional[bool] = None,
1501
+ return_dict: Optional[bool] = None,
1502
+ ) -> Union[Tuple, TokenClassifierOutput]:
1503
+ r"""
1504
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1505
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1506
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1507
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1508
+ """
1509
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1510
+
1511
+ outputs = self.model(
1512
+ input_ids,
1513
+ attention_mask=attention_mask,
1514
+ position_ids=position_ids,
1515
+ past_key_values=past_key_values,
1516
+ inputs_embeds=inputs_embeds,
1517
+ use_cache=use_cache,
1518
+ output_attentions=output_attentions,
1519
+ output_hidden_states=output_hidden_states,
1520
+ return_dict=return_dict,
1521
+ )
1522
+ sequence_output = outputs[0]
1523
+ sequence_output = self.dropout(sequence_output)
1524
+ logits = self.score(sequence_output)
1525
+
1526
+ loss = None
1527
+ if labels is not None:
1528
+ loss = self.loss_function(logits, labels, self.config)
1529
+
1530
+ if not return_dict:
1531
+ output = (logits,) + outputs[2:]
1532
+ return ((loss,) + output) if loss is not None else output
1533
+
1534
+ return TokenClassifierOutput(
1535
+ loss=loss,
1536
+ logits=logits,
1537
+ hidden_states=outputs.hidden_states,
1538
+ attentions=outputs.attentions,
1539
+ )
1540
+
1541
+
1542
+ @add_start_docstrings(
1543
+ """
1544
+ The Qwen2 Model transformer with a span classification head on top for extractive question-answering tasks like
1545
+ SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1546
+ """,
1547
+ QWEN2_START_DOCSTRING,
1548
+ )
1549
+ class Qwen2ForQuestionAnswering(Qwen2PreTrainedModel):
1550
+ base_model_prefix = "transformer"
1551
+
1552
+ def __init__(self, config):
1553
+ super().__init__(config)
1554
+ self.transformer = Qwen2Model(config)
1555
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
1556
+
1557
+ # Initialize weights and apply final processing
1558
+ self.post_init()
1559
+
1560
+ def get_input_embeddings(self):
1561
+ return self.transformer.embed_tokens
1562
+
1563
+ def set_input_embeddings(self, value):
1564
+ self.transformer.embed_tokens = value
1565
+
1566
+ @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
1567
+ def forward(
1568
+ self,
1569
+ input_ids: Optional[torch.LongTensor] = None,
1570
+ attention_mask: Optional[torch.FloatTensor] = None,
1571
+ position_ids: Optional[torch.LongTensor] = None,
1572
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1573
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1574
+ start_positions: Optional[torch.LongTensor] = None,
1575
+ end_positions: Optional[torch.LongTensor] = None,
1576
+ output_attentions: Optional[bool] = None,
1577
+ output_hidden_states: Optional[bool] = None,
1578
+ return_dict: Optional[bool] = None,
1579
+ **kwargs,
1580
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1581
+ r"""
1582
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1583
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1584
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1585
+ are not taken into account for computing the loss.
1586
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1587
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1588
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1589
+ are not taken into account for computing the loss.
1590
+ """
1591
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1592
+
1593
+ outputs = self.transformer(
1594
+ input_ids,
1595
+ attention_mask=attention_mask,
1596
+ position_ids=position_ids,
1597
+ past_key_values=past_key_values,
1598
+ inputs_embeds=inputs_embeds,
1599
+ output_attentions=output_attentions,
1600
+ output_hidden_states=output_hidden_states,
1601
+ return_dict=return_dict,
1602
+ )
1603
+
1604
+ sequence_output = outputs[0]
1605
+
1606
+ logits = self.qa_outputs(sequence_output)
1607
+ start_logits, end_logits = logits.split(1, dim=-1)
1608
+ start_logits = start_logits.squeeze(-1).contiguous()
1609
+ end_logits = end_logits.squeeze(-1).contiguous()
1610
+
1611
+ loss = None
1612
+ if start_positions is not None and end_positions is not None:
1613
+ loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
1614
+
1615
+ if not return_dict:
1616
+ output = (start_logits, end_logits) + outputs[2:]
1617
+ return ((loss,) + output) if loss is not None else output
1618
+
1619
+ return QuestionAnsweringModelOutput(
1620
+ loss=loss,
1621
+ start_logits=start_logits,
1622
+ end_logits=end_logits,
1623
+ hidden_states=outputs.hidden_states,
1624
+ attentions=outputs.attentions,
1625
+ )
1626
+
1627
+
1628
+ def prepare_fa2_from_position_ids_for_mtp(position_ids, mtp_idx):
1629
+ position_ids = position_ids.flatten()
1630
+ indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
1631
+
1632
+ cu_seq_lens = torch.cat(
1633
+ (
1634
+ indices_q[position_ids == mtp_idx + 1],
1635
+ torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
1636
+ )
1637
+ )
1638
+
1639
+ max_length = position_ids.max() + 1 - 1 - mtp_idx
1640
+
1641
+ return cu_seq_lens, cu_seq_lens, max_length, max_length
modeling_sensevoice.py ADDED
@@ -0,0 +1,1249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import time
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+ from typing import Iterable, Optional
7
+
8
+ from funasr.register import tables
9
+ from funasr.models.ctc.ctc import CTC
10
+ from funasr.utils.datadir_writer import DatadirWriter
11
+ from funasr.models.paraformer.search import Hypothesis
12
+ from funasr.train_utils.device_funcs import force_gatherable
13
+ from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
14
+ from funasr.metrics.compute_acc import compute_accuracy, th_accuracy
15
+ from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
16
+ # from utils.ctc_alignment import ctc_forced_align
17
+
18
+ def ctc_forced_align(
19
+ log_probs: torch.Tensor,
20
+ targets: torch.Tensor,
21
+ input_lengths: torch.Tensor,
22
+ target_lengths: torch.Tensor,
23
+ blank: int = 0,
24
+ ignore_id: int = -1,
25
+ ) -> torch.Tensor:
26
+ """Align a CTC label sequence to an emission.
27
+
28
+ Args:
29
+ log_probs (Tensor): log probability of CTC emission output.
30
+ Tensor of shape `(B, T, C)`. where `B` is the batch size, `T` is the input length,
31
+ `C` is the number of characters in alphabet including blank.
32
+ targets (Tensor): Target sequence. Tensor of shape `(B, L)`,
33
+ where `L` is the target length.
34
+ input_lengths (Tensor):
35
+ Lengths of the inputs (max value must each be <= `T`). 1-D Tensor of shape `(B,)`.
36
+ target_lengths (Tensor):
37
+ Lengths of the targets. 1-D Tensor of shape `(B,)`.
38
+ blank_id (int, optional): The index of blank symbol in CTC emission. (Default: 0)
39
+ ignore_id (int, optional): The index of ignore symbol in CTC emission. (Default: -1)
40
+ """
41
+ targets[targets == ignore_id] = blank
42
+
43
+ batch_size, input_time_size, _ = log_probs.size()
44
+ bsz_indices = torch.arange(batch_size, device=input_lengths.device)
45
+
46
+ _t_a_r_g_e_t_s_ = torch.cat(
47
+ (
48
+ torch.stack((torch.full_like(targets, blank), targets), dim=-1).flatten(start_dim=1),
49
+ torch.full_like(targets[:, :1], blank),
50
+ ),
51
+ dim=-1,
52
+ )
53
+ diff_labels = torch.cat(
54
+ (
55
+ torch.as_tensor([[False, False]], device=targets.device).expand(batch_size, -1),
56
+ _t_a_r_g_e_t_s_[:, 2:] != _t_a_r_g_e_t_s_[:, :-2],
57
+ ),
58
+ dim=1,
59
+ )
60
+
61
+ neg_inf = torch.tensor(float("-inf"), device=log_probs.device, dtype=log_probs.dtype)
62
+ padding_num = 2
63
+ padded_t = padding_num + _t_a_r_g_e_t_s_.size(-1)
64
+ best_score = torch.full((batch_size, padded_t), neg_inf, device=log_probs.device, dtype=log_probs.dtype)
65
+ best_score[:, padding_num + 0] = log_probs[:, 0, blank]
66
+ best_score[:, padding_num + 1] = log_probs[bsz_indices, 0, _t_a_r_g_e_t_s_[:, 1]]
67
+
68
+ backpointers = torch.zeros((batch_size, input_time_size, padded_t), device=log_probs.device, dtype=targets.dtype)
69
+
70
+ for t in range(1, input_time_size):
71
+ prev = torch.stack(
72
+ (best_score[:, 2:], best_score[:, 1:-1], torch.where(diff_labels, best_score[:, :-2], neg_inf))
73
+ )
74
+ prev_max_value, prev_max_idx = prev.max(dim=0)
75
+ best_score[:, padding_num:] = log_probs[:, t].gather(-1, _t_a_r_g_e_t_s_) + prev_max_value
76
+ backpointers[:, t, padding_num:] = prev_max_idx
77
+
78
+ l1l2 = best_score.gather(
79
+ -1, torch.stack((padding_num + target_lengths * 2 - 1, padding_num + target_lengths * 2), dim=-1)
80
+ )
81
+
82
+ path = torch.zeros((batch_size, input_time_size), device=best_score.device, dtype=torch.long)
83
+ path[bsz_indices, input_lengths - 1] = padding_num + target_lengths * 2 - 1 + l1l2.argmax(dim=-1)
84
+
85
+ for t in range(input_time_size - 1, 0, -1):
86
+ target_indices = path[:, t]
87
+ prev_max_idx = backpointers[bsz_indices, t, target_indices]
88
+ path[:, t - 1] += target_indices - prev_max_idx
89
+
90
+ alignments = _t_a_r_g_e_t_s_.gather(dim=-1, index=(path - padding_num).clamp(min=0))
91
+ return alignments
92
+
93
+ class SinusoidalPositionEncoder(torch.nn.Module):
94
+ """ """
95
+
96
+ def __int__(self, d_model=80, dropout_rate=0.1):
97
+ pass
98
+
99
+ def encode(
100
+ self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32
101
+ ):
102
+ batch_size = positions.size(0)
103
+ positions = positions.type(dtype)
104
+ device = positions.device
105
+ log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype, device=device)) / (
106
+ depth / 2 - 1
107
+ )
108
+ inv_timescales = torch.exp(
109
+ torch.arange(depth / 2, device=device).type(dtype) * (-log_timescale_increment)
110
+ )
111
+ inv_timescales = torch.reshape(inv_timescales, [batch_size, -1])
112
+ scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(
113
+ inv_timescales, [1, 1, -1]
114
+ )
115
+ encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
116
+ return encoding.type(dtype)
117
+
118
+ def forward(self, x):
119
+ batch_size, timesteps, input_dim = x.size()
120
+ positions = torch.arange(1, timesteps + 1, device=x.device)[None, :]
121
+ position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
122
+
123
+ return x + position_encoding
124
+
125
+
126
+ class PositionwiseFeedForward(torch.nn.Module):
127
+ """Positionwise feed forward layer.
128
+
129
+ Args:
130
+ idim (int): Input dimenstion.
131
+ hidden_units (int): The number of hidden units.
132
+ dropout_rate (float): Dropout rate.
133
+
134
+ """
135
+
136
+ def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()):
137
+ """Construct an PositionwiseFeedForward object."""
138
+ super(PositionwiseFeedForward, self).__init__()
139
+ self.w_1 = torch.nn.Linear(idim, hidden_units)
140
+ self.w_2 = torch.nn.Linear(hidden_units, idim)
141
+ self.dropout = torch.nn.Dropout(dropout_rate)
142
+ self.activation = activation
143
+
144
+ def forward(self, x):
145
+ """Forward function."""
146
+ return self.w_2(self.dropout(self.activation(self.w_1(x))))
147
+
148
+
149
+ class MultiHeadedAttentionSANM(nn.Module):
150
+ """Multi-Head Attention layer.
151
+
152
+ Args:
153
+ n_head (int): The number of heads.
154
+ n_feat (int): The number of features.
155
+ dropout_rate (float): Dropout rate.
156
+
157
+ """
158
+
159
+ def __init__(
160
+ self,
161
+ n_head,
162
+ in_feat,
163
+ n_feat,
164
+ dropout_rate,
165
+ kernel_size,
166
+ sanm_shfit=0,
167
+ lora_list=None,
168
+ lora_rank=8,
169
+ lora_alpha=16,
170
+ lora_dropout=0.1,
171
+ ):
172
+ """Construct an MultiHeadedAttention object."""
173
+ super().__init__()
174
+ assert n_feat % n_head == 0
175
+ # We assume d_v always equals d_k
176
+ self.d_k = n_feat // n_head
177
+ self.h = n_head
178
+ # self.linear_q = nn.Linear(n_feat, n_feat)
179
+ # self.linear_k = nn.Linear(n_feat, n_feat)
180
+ # self.linear_v = nn.Linear(n_feat, n_feat)
181
+
182
+ self.linear_out = nn.Linear(n_feat, n_feat)
183
+ self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
184
+ self.attn = None
185
+ self.dropout = nn.Dropout(p=dropout_rate)
186
+
187
+ self.fsmn_block = nn.Conv1d(
188
+ n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False
189
+ )
190
+ # padding
191
+ left_padding = (kernel_size - 1) // 2
192
+ if sanm_shfit > 0:
193
+ left_padding = left_padding + sanm_shfit
194
+ right_padding = kernel_size - 1 - left_padding
195
+ self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
196
+
197
+ def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
198
+ b, t, d = inputs.size()
199
+ if mask is not None:
200
+ mask = torch.reshape(mask, (b, -1, 1))
201
+ if mask_shfit_chunk is not None:
202
+ mask = mask * mask_shfit_chunk
203
+ inputs = inputs * mask
204
+
205
+ x = inputs.transpose(1, 2)
206
+ x = self.pad_fn(x)
207
+ x = self.fsmn_block(x)
208
+ x = x.transpose(1, 2)
209
+ x += inputs
210
+ x = self.dropout(x)
211
+ if mask is not None:
212
+ x = x * mask
213
+ return x
214
+
215
+ def forward_qkv(self, x):
216
+ """Transform query, key and value.
217
+
218
+ Args:
219
+ query (torch.Tensor): Query tensor (#batch, time1, size).
220
+ key (torch.Tensor): Key tensor (#batch, time2, size).
221
+ value (torch.Tensor): Value tensor (#batch, time2, size).
222
+
223
+ Returns:
224
+ torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
225
+ torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
226
+ torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
227
+
228
+ """
229
+ b, t, d = x.size()
230
+ q_k_v = self.linear_q_k_v(x)
231
+ q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
232
+ q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(
233
+ 1, 2
234
+ ) # (batch, head, time1, d_k)
235
+ k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(
236
+ 1, 2
237
+ ) # (batch, head, time2, d_k)
238
+ v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(
239
+ 1, 2
240
+ ) # (batch, head, time2, d_k)
241
+
242
+ return q_h, k_h, v_h, v
243
+
244
+ def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
245
+ """Compute attention context vector.
246
+
247
+ Args:
248
+ value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
249
+ scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
250
+ mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
251
+
252
+ Returns:
253
+ torch.Tensor: Transformed value (#batch, time1, d_model)
254
+ weighted by the attention score (#batch, time1, time2).
255
+
256
+ """
257
+ n_batch = value.size(0)
258
+ if mask is not None:
259
+ if mask_att_chunk_encoder is not None:
260
+ mask = mask * mask_att_chunk_encoder
261
+
262
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
263
+
264
+ min_value = -float(
265
+ "inf"
266
+ ) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
267
+ scores = scores.masked_fill(mask, min_value)
268
+ attn = torch.softmax(scores, dim=-1).masked_fill(
269
+ mask, 0.0
270
+ ) # (batch, head, time1, time2)
271
+ else:
272
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
273
+
274
+ p_attn = self.dropout(attn)
275
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
276
+ x = (
277
+ x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
278
+ ) # (batch, time1, d_model)
279
+
280
+ return self.linear_out(x) # (batch, time1, d_model)
281
+
282
+ def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
283
+ """Compute scaled dot product attention.
284
+
285
+ Args:
286
+ query (torch.Tensor): Query tensor (#batch, time1, size).
287
+ key (torch.Tensor): Key tensor (#batch, time2, size).
288
+ value (torch.Tensor): Value tensor (#batch, time2, size).
289
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
290
+ (#batch, time1, time2).
291
+
292
+ Returns:
293
+ torch.Tensor: Output tensor (#batch, time1, d_model).
294
+
295
+ """
296
+ q_h, k_h, v_h, v = self.forward_qkv(x)
297
+ fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
298
+ q_h = q_h * self.d_k ** (-0.5)
299
+ scores = torch.matmul(q_h, k_h.transpose(-2, -1))
300
+ att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
301
+ return att_outs + fsmn_memory
302
+
303
+ def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
304
+ """Compute scaled dot product attention.
305
+
306
+ Args:
307
+ query (torch.Tensor): Query tensor (#batch, time1, size).
308
+ key (torch.Tensor): Key tensor (#batch, time2, size).
309
+ value (torch.Tensor): Value tensor (#batch, time2, size).
310
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
311
+ (#batch, time1, time2).
312
+
313
+ Returns:
314
+ torch.Tensor: Output tensor (#batch, time1, d_model).
315
+
316
+ """
317
+ q_h, k_h, v_h, v = self.forward_qkv(x)
318
+ if chunk_size is not None and look_back > 0 or look_back == -1:
319
+ if cache is not None:
320
+ k_h_stride = k_h[:, :, : -(chunk_size[2]), :]
321
+ v_h_stride = v_h[:, :, : -(chunk_size[2]), :]
322
+ k_h = torch.cat((cache["k"], k_h), dim=2)
323
+ v_h = torch.cat((cache["v"], v_h), dim=2)
324
+
325
+ cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2)
326
+ cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2)
327
+ if look_back != -1:
328
+ cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]) :, :]
329
+ cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]) :, :]
330
+ else:
331
+ cache_tmp = {
332
+ "k": k_h[:, :, : -(chunk_size[2]), :],
333
+ "v": v_h[:, :, : -(chunk_size[2]), :],
334
+ }
335
+ cache = cache_tmp
336
+ fsmn_memory = self.forward_fsmn(v, None)
337
+ q_h = q_h * self.d_k ** (-0.5)
338
+ scores = torch.matmul(q_h, k_h.transpose(-2, -1))
339
+ att_outs = self.forward_attention(v_h, scores, None)
340
+ return att_outs + fsmn_memory, cache
341
+
342
+
343
+ class LayerNorm(nn.LayerNorm):
344
+ def __init__(self, *args, **kwargs):
345
+ super().__init__(*args, **kwargs)
346
+
347
+ def forward(self, input):
348
+ output = F.layer_norm(
349
+ input.float(),
350
+ self.normalized_shape,
351
+ self.weight.float() if self.weight is not None else None,
352
+ self.bias.float() if self.bias is not None else None,
353
+ self.eps,
354
+ )
355
+ return output.type_as(input)
356
+
357
+
358
+ def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
359
+ if maxlen is None:
360
+ maxlen = lengths.max()
361
+ row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
362
+ matrix = torch.unsqueeze(lengths, dim=-1)
363
+ mask = row_vector < matrix
364
+ mask = mask.detach()
365
+
366
+ return mask.to(dtype).to(device) if device is not None else mask.to(dtype)
367
+ # return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
368
+
369
+
370
+ class EncoderLayerSANM(nn.Module):
371
+ def __init__(
372
+ self,
373
+ in_size,
374
+ size,
375
+ self_attn,
376
+ feed_forward,
377
+ dropout_rate,
378
+ normalize_before=True,
379
+ concat_after=False,
380
+ stochastic_depth_rate=0.0,
381
+ ):
382
+ """Construct an EncoderLayer object."""
383
+ super(EncoderLayerSANM, self).__init__()
384
+ self.self_attn = self_attn
385
+ self.feed_forward = feed_forward
386
+ self.norm1 = LayerNorm(in_size)
387
+ self.norm2 = LayerNorm(size)
388
+ self.dropout = nn.Dropout(dropout_rate)
389
+ self.in_size = in_size
390
+ self.size = size
391
+ self.normalize_before = normalize_before
392
+ self.concat_after = concat_after
393
+ if self.concat_after:
394
+ self.concat_linear = nn.Linear(size + size, size)
395
+ self.stochastic_depth_rate = stochastic_depth_rate
396
+ self.dropout_rate = dropout_rate
397
+
398
+ def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
399
+ """Compute encoded features.
400
+
401
+ Args:
402
+ x_input (torch.Tensor): Input tensor (#batch, time, size).
403
+ mask (torch.Tensor): Mask tensor for the input (#batch, time).
404
+ cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
405
+
406
+ Returns:
407
+ torch.Tensor: Output tensor (#batch, time, size).
408
+ torch.Tensor: Mask tensor (#batch, time).
409
+
410
+ """
411
+ skip_layer = False
412
+ # with stochastic depth, residual connection `x + f(x)` becomes
413
+ # `x <- x + 1 / (1 - p) * f(x)` at training time.
414
+ stoch_layer_coeff = 1.0
415
+ if self.training and self.stochastic_depth_rate > 0:
416
+ skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
417
+ stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
418
+
419
+ if skip_layer:
420
+ if cache is not None:
421
+ x = torch.cat([cache, x], dim=1)
422
+ return x, mask
423
+
424
+ residual = x
425
+ if self.normalize_before:
426
+ x = self.norm1(x)
427
+
428
+ if self.concat_after:
429
+ x_concat = torch.cat(
430
+ (
431
+ x,
432
+ self.self_attn(
433
+ x,
434
+ mask,
435
+ mask_shfit_chunk=mask_shfit_chunk,
436
+ mask_att_chunk_encoder=mask_att_chunk_encoder,
437
+ ),
438
+ ),
439
+ dim=-1,
440
+ )
441
+ if self.in_size == self.size:
442
+ x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
443
+ else:
444
+ x = stoch_layer_coeff * self.concat_linear(x_concat)
445
+ else:
446
+ if self.in_size == self.size:
447
+ x = residual + stoch_layer_coeff * self.dropout(
448
+ self.self_attn(
449
+ x,
450
+ mask,
451
+ mask_shfit_chunk=mask_shfit_chunk,
452
+ mask_att_chunk_encoder=mask_att_chunk_encoder,
453
+ )
454
+ )
455
+ else:
456
+ x = stoch_layer_coeff * self.dropout(
457
+ self.self_attn(
458
+ x,
459
+ mask,
460
+ mask_shfit_chunk=mask_shfit_chunk,
461
+ mask_att_chunk_encoder=mask_att_chunk_encoder,
462
+ )
463
+ )
464
+ if not self.normalize_before:
465
+ x = self.norm1(x)
466
+
467
+ residual = x
468
+ if self.normalize_before:
469
+ x = self.norm2(x)
470
+ x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
471
+ if not self.normalize_before:
472
+ x = self.norm2(x)
473
+
474
+ return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
475
+
476
+ def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
477
+ """Compute encoded features.
478
+
479
+ Args:
480
+ x_input (torch.Tensor): Input tensor (#batch, time, size).
481
+ mask (torch.Tensor): Mask tensor for the input (#batch, time).
482
+ cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
483
+
484
+ Returns:
485
+ torch.Tensor: Output tensor (#batch, time, size).
486
+ torch.Tensor: Mask tensor (#batch, time).
487
+
488
+ """
489
+
490
+ residual = x
491
+ if self.normalize_before:
492
+ x = self.norm1(x)
493
+
494
+ if self.in_size == self.size:
495
+ attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
496
+ x = residual + attn
497
+ else:
498
+ x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
499
+
500
+ if not self.normalize_before:
501
+ x = self.norm1(x)
502
+
503
+ residual = x
504
+ if self.normalize_before:
505
+ x = self.norm2(x)
506
+ x = residual + self.feed_forward(x)
507
+ if not self.normalize_before:
508
+ x = self.norm2(x)
509
+
510
+ return x, cache
511
+
512
+
513
+ @tables.register("encoder_classes", "SenseVoiceEncoderSmall")
514
+ class SenseVoiceEncoderSmall(nn.Module):
515
+ """
516
+ Author: Speech Lab of DAMO Academy, Alibaba Group
517
+ SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
518
+ https://arxiv.org/abs/2006.01713
519
+ """
520
+
521
+ def __init__(
522
+ self,
523
+ input_size: int,
524
+ output_size: int = 256,
525
+ attention_heads: int = 4,
526
+ linear_units: int = 2048,
527
+ num_blocks: int = 6,
528
+ tp_blocks: int = 0,
529
+ dropout_rate: float = 0.1,
530
+ positional_dropout_rate: float = 0.1,
531
+ attention_dropout_rate: float = 0.0,
532
+ stochastic_depth_rate: float = 0.0,
533
+ input_layer: Optional[str] = "conv2d",
534
+ pos_enc_class=SinusoidalPositionEncoder,
535
+ normalize_before: bool = True,
536
+ concat_after: bool = False,
537
+ positionwise_layer_type: str = "linear",
538
+ positionwise_conv_kernel_size: int = 1,
539
+ padding_idx: int = -1,
540
+ kernel_size: int = 11,
541
+ sanm_shfit: int = 0,
542
+ selfattention_layer_type: str = "sanm",
543
+ **kwargs,
544
+ ):
545
+ super().__init__()
546
+ self._output_size = output_size
547
+
548
+ self.embed = SinusoidalPositionEncoder()
549
+
550
+ self.normalize_before = normalize_before
551
+
552
+ positionwise_layer = PositionwiseFeedForward
553
+ positionwise_layer_args = (
554
+ output_size,
555
+ linear_units,
556
+ dropout_rate,
557
+ )
558
+
559
+ encoder_selfattn_layer = MultiHeadedAttentionSANM
560
+ encoder_selfattn_layer_args0 = (
561
+ attention_heads,
562
+ input_size,
563
+ output_size,
564
+ attention_dropout_rate,
565
+ kernel_size,
566
+ sanm_shfit,
567
+ )
568
+ encoder_selfattn_layer_args = (
569
+ attention_heads,
570
+ output_size,
571
+ output_size,
572
+ attention_dropout_rate,
573
+ kernel_size,
574
+ sanm_shfit,
575
+ )
576
+
577
+ self.encoders0 = nn.ModuleList(
578
+ [
579
+ EncoderLayerSANM(
580
+ input_size,
581
+ output_size,
582
+ encoder_selfattn_layer(*encoder_selfattn_layer_args0),
583
+ positionwise_layer(*positionwise_layer_args),
584
+ dropout_rate,
585
+ )
586
+ for i in range(1)
587
+ ]
588
+ )
589
+ self.encoders = nn.ModuleList(
590
+ [
591
+ EncoderLayerSANM(
592
+ output_size,
593
+ output_size,
594
+ encoder_selfattn_layer(*encoder_selfattn_layer_args),
595
+ positionwise_layer(*positionwise_layer_args),
596
+ dropout_rate,
597
+ )
598
+ for i in range(num_blocks - 1)
599
+ ]
600
+ )
601
+
602
+ self.tp_encoders = nn.ModuleList(
603
+ [
604
+ EncoderLayerSANM(
605
+ output_size,
606
+ output_size,
607
+ encoder_selfattn_layer(*encoder_selfattn_layer_args),
608
+ positionwise_layer(*positionwise_layer_args),
609
+ dropout_rate,
610
+ )
611
+ for i in range(tp_blocks)
612
+ ]
613
+ )
614
+
615
+ self.after_norm = LayerNorm(output_size)
616
+
617
+ self.tp_norm = LayerNorm(output_size)
618
+
619
+ def output_size(self) -> int:
620
+ return self._output_size
621
+
622
+ def forward(
623
+ self,
624
+ xs_pad: torch.Tensor,
625
+ ilens: torch.Tensor,
626
+ ):
627
+ """Embed positions in tensor."""
628
+ masks = sequence_mask(ilens, dtype=torch.bfloat16, device=ilens.device)[:, None, :]
629
+ # print(f"{masks=}")
630
+ # print(f"{ilens=}")
631
+ # print(f"{(masks>0.5).squeeze(1).sum(1).int()=}")
632
+
633
+ xs_pad *= self.output_size() ** 0.5
634
+
635
+ xs_pad = self.embed(xs_pad)
636
+
637
+ # forward encoder1
638
+ for layer_idx, encoder_layer in enumerate(self.encoders0):
639
+ encoder_outs = encoder_layer(xs_pad, masks)
640
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
641
+
642
+ for layer_idx, encoder_layer in enumerate(self.encoders):
643
+ encoder_outs = encoder_layer(xs_pad, masks)
644
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
645
+
646
+ xs_pad = self.after_norm(xs_pad)
647
+
648
+ # forward encoder2
649
+ # olens = masks.squeeze(1).sum(1).int()
650
+ olens = (masks > 0.5).squeeze(1).sum(1).int()
651
+
652
+ for layer_idx, encoder_layer in enumerate(self.tp_encoders):
653
+ encoder_outs = encoder_layer(xs_pad, masks)
654
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
655
+
656
+ xs_pad = self.tp_norm(xs_pad)
657
+ return xs_pad, olens
658
+
659
+
660
+ @tables.register("model_classes", "SenseVoiceSmall")
661
+ class SenseVoiceSmall(nn.Module):
662
+ """CTC-attention hybrid Encoder-Decoder model"""
663
+
664
+ def __init__(
665
+ self,
666
+ specaug: str = None,
667
+ specaug_conf: dict = None,
668
+ normalize: str = None,
669
+ normalize_conf: dict = None,
670
+ encoder: str = None,
671
+ encoder_conf: dict = None,
672
+ ctc_conf: dict = None,
673
+ input_size: int = 80,
674
+ vocab_size: int = -1,
675
+ ignore_id: int = -1,
676
+ blank_id: int = 0,
677
+ sos: int = 1,
678
+ eos: int = 2,
679
+ length_normalized_loss: bool = False,
680
+ **kwargs,
681
+ ):
682
+
683
+ super().__init__()
684
+
685
+ if specaug is not None:
686
+ specaug_class = tables.specaug_classes.get(specaug)
687
+ specaug = specaug_class(**specaug_conf)
688
+ if normalize is not None:
689
+ normalize_class = tables.normalize_classes.get(normalize)
690
+ normalize = normalize_class(**normalize_conf)
691
+ encoder_class = tables.encoder_classes.get(encoder)
692
+ encoder = encoder_class(input_size=input_size, **encoder_conf)
693
+ encoder_output_size = encoder.output_size()
694
+
695
+ if ctc_conf is None:
696
+ ctc_conf = {}
697
+ ctc = CTC(odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf)
698
+
699
+ self.blank_id = blank_id
700
+ self.sos = sos if sos is not None else vocab_size - 1
701
+ self.eos = eos if eos is not None else vocab_size - 1
702
+ self.vocab_size = vocab_size
703
+ self.ignore_id = ignore_id
704
+ self.specaug = specaug
705
+ self.normalize = normalize
706
+ self.encoder = encoder
707
+ self.error_calculator = None
708
+
709
+ self.ctc = ctc
710
+
711
+ self.length_normalized_loss = length_normalized_loss
712
+ self.encoder_output_size = encoder_output_size
713
+
714
+ self.lid_dict = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13}
715
+ self.lid_int_dict = {24884: 3, 24885: 4, 24888: 7, 24892: 11, 24896: 12, 24992: 13}
716
+ self.textnorm_dict = {"withitn": 14, "woitn": 15}
717
+ self.textnorm_int_dict = {25016: 14, 25017: 15}
718
+ self.embed = torch.nn.Embedding(7 + len(self.lid_dict) + len(self.textnorm_dict), input_size)
719
+ self.emo_dict = {"unk": 25009, "happy": 25001, "sad": 25002, "angry": 25003, "neutral": 25004}
720
+
721
+ self.criterion_att = LabelSmoothingLoss(
722
+ size=self.vocab_size,
723
+ padding_idx=self.ignore_id,
724
+ smoothing=kwargs.get("lsm_weight", 0.0),
725
+ normalize_length=self.length_normalized_loss,
726
+ )
727
+
728
+ @staticmethod
729
+ def from_pretrained(model:str=None, **kwargs):
730
+ from funasr import AutoModel
731
+ model, kwargs = AutoModel.build_model(model=model, trust_remote_code=True, **kwargs)
732
+
733
+ return model, kwargs
734
+
735
+ def forward(
736
+ self,
737
+ speech: torch.Tensor,
738
+ speech_lengths: torch.Tensor,
739
+ text: torch.Tensor,
740
+ text_lengths: torch.Tensor,
741
+ **kwargs,
742
+ ):
743
+ """Encoder + Decoder + Calc loss
744
+ Args:
745
+ speech: (Batch, Length, ...)
746
+ speech_lengths: (Batch, )
747
+ text: (Batch, Length)
748
+ text_lengths: (Batch,)
749
+ """
750
+ # import pdb;
751
+ # pdb.set_trace()
752
+ if len(text_lengths.size()) > 1:
753
+ text_lengths = text_lengths[:, 0]
754
+ if len(speech_lengths.size()) > 1:
755
+ speech_lengths = speech_lengths[:, 0]
756
+
757
+ batch_size = speech.shape[0]
758
+
759
+ # 1. Encoder
760
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, text)
761
+
762
+ loss_ctc, cer_ctc = None, None
763
+ loss_rich, acc_rich = None, None
764
+ stats = dict()
765
+
766
+ loss_ctc, cer_ctc = self._calc_ctc_loss(
767
+ encoder_out[:, 4:, :], encoder_out_lens - 4, text[:, 4:], text_lengths - 4
768
+ )
769
+
770
+ loss_rich, acc_rich = self._calc_rich_ce_loss(
771
+ encoder_out[:, :4, :], text[:, :4]
772
+ )
773
+
774
+ loss = loss_ctc + loss_rich
775
+ # Collect total loss stats
776
+ stats["loss_ctc"] = torch.clone(loss_ctc.detach()) if loss_ctc is not None else None
777
+ stats["loss_rich"] = torch.clone(loss_rich.detach()) if loss_rich is not None else None
778
+ stats["loss"] = torch.clone(loss.detach()) if loss is not None else None
779
+ stats["acc_rich"] = acc_rich
780
+
781
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
782
+ if self.length_normalized_loss:
783
+ batch_size = int((text_lengths + 1).sum())
784
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
785
+ return loss, stats, weight
786
+
787
+ def encode(
788
+ self,
789
+ speech: torch.Tensor,
790
+ speech_lengths: torch.Tensor,
791
+ text: torch.Tensor,
792
+ **kwargs,
793
+ ):
794
+ """Frontend + Encoder. Note that this method is used by asr_inference.py
795
+ Args:
796
+ speech: (Batch, Length, ...)
797
+ speech_lengths: (Batch, )
798
+ ind: int
799
+ """
800
+
801
+ # Data augmentation
802
+ if self.specaug is not None and self.training:
803
+ speech, speech_lengths = self.specaug(speech, speech_lengths)
804
+
805
+ # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
806
+ if self.normalize is not None:
807
+ speech, speech_lengths = self.normalize(speech, speech_lengths)
808
+
809
+
810
+ lids = torch.LongTensor([[self.lid_int_dict[int(lid)] if torch.rand(1) > 0.2 and int(lid) in self.lid_int_dict else 0 ] for lid in text[:, 0]]).to(speech.device)
811
+ language_query = self.embed(lids)
812
+
813
+ styles = torch.LongTensor([[self.textnorm_int_dict[int(style)]] for style in text[:, 3]]).to(speech.device)
814
+ style_query = self.embed(styles)
815
+ speech = torch.cat((style_query, speech), dim=1)
816
+ speech_lengths += 1
817
+
818
+ event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(speech.size(0), 1, 1)
819
+ input_query = torch.cat((language_query, event_emo_query), dim=1)
820
+ speech = torch.cat((input_query, speech), dim=1)
821
+ speech_lengths += 3
822
+
823
+ encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
824
+
825
+ return encoder_out, encoder_out_lens
826
+
827
+ def _calc_ctc_loss(
828
+ self,
829
+ encoder_out: torch.Tensor,
830
+ encoder_out_lens: torch.Tensor,
831
+ ys_pad: torch.Tensor,
832
+ ys_pad_lens: torch.Tensor,
833
+ ):
834
+ # Calc CTC loss
835
+ loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
836
+
837
+ # Calc CER using CTC
838
+ cer_ctc = None
839
+ if not self.training and self.error_calculator is not None:
840
+ ys_hat = self.ctc.argmax(encoder_out).data
841
+ cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
842
+ return loss_ctc, cer_ctc
843
+
844
+ def _calc_rich_ce_loss(
845
+ self,
846
+ encoder_out: torch.Tensor,
847
+ ys_pad: torch.Tensor,
848
+ ):
849
+ decoder_out = self.ctc.ctc_lo(encoder_out)
850
+ # 2. Compute attention loss
851
+ loss_rich = self.criterion_att(decoder_out, ys_pad.contiguous())
852
+ acc_rich = th_accuracy(
853
+ decoder_out.view(-1, self.vocab_size),
854
+ ys_pad.contiguous(),
855
+ ignore_label=self.ignore_id,
856
+ )
857
+
858
+ return loss_rich, acc_rich
859
+
860
+
861
+ def inference(
862
+ self,
863
+ data_in,
864
+ data_lengths=None,
865
+ key: list = ["wav_file_tmp_name"],
866
+ tokenizer=None,
867
+ frontend=None,
868
+ **kwargs,
869
+ ):
870
+
871
+
872
+ meta_data = {}
873
+ if (
874
+ isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
875
+ ): # fbank
876
+ speech, speech_lengths = data_in, data_lengths
877
+ if len(speech.shape) < 3:
878
+ speech = speech[None, :, :]
879
+ if speech_lengths is None:
880
+ speech_lengths = speech.shape[1]
881
+ else:
882
+ # extract fbank feats
883
+ time1 = time.perf_counter()
884
+ audio_sample_list = load_audio_text_image_video(
885
+ data_in,
886
+ fs=frontend.fs,
887
+ audio_fs=kwargs.get("fs", 16000),
888
+ data_type=kwargs.get("data_type", "sound"),
889
+ tokenizer=tokenizer,
890
+ )
891
+ time2 = time.perf_counter()
892
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
893
+ speech, speech_lengths = extract_fbank(
894
+ audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
895
+ )
896
+ time3 = time.perf_counter()
897
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
898
+ meta_data["batch_data_time"] = (
899
+ speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
900
+ )
901
+
902
+ speech = speech.to(device=kwargs["device"])
903
+ speech_lengths = speech_lengths.to(device=kwargs["device"])
904
+
905
+ language = kwargs.get("language", "auto")
906
+ language_query = self.embed(
907
+ torch.LongTensor(
908
+ [[self.lid_dict[language] if language in self.lid_dict else 0]]
909
+ ).to(speech.device)
910
+ ).repeat(speech.size(0), 1, 1)
911
+
912
+ use_itn = kwargs.get("use_itn", False)
913
+ output_timestamp = kwargs.get("output_timestamp", False)
914
+
915
+ textnorm = kwargs.get("text_norm", None)
916
+ if textnorm is None:
917
+ textnorm = "withitn" if use_itn else "woitn"
918
+ textnorm_query = self.embed(
919
+ torch.LongTensor([[self.textnorm_dict[textnorm]]]).to(speech.device)
920
+ ).repeat(speech.size(0), 1, 1)
921
+ speech = torch.cat((textnorm_query, speech), dim=1)
922
+ speech_lengths += 1
923
+
924
+ event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(
925
+ speech.size(0), 1, 1
926
+ )
927
+ input_query = torch.cat((language_query, event_emo_query), dim=1)
928
+ speech = torch.cat((input_query, speech), dim=1)
929
+ speech_lengths += 3
930
+
931
+ # Encoder
932
+ encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
933
+ if isinstance(encoder_out, tuple):
934
+ encoder_out = encoder_out[0]
935
+
936
+ # c. Passed the encoder result and the beam search
937
+ ctc_logits = self.ctc.log_softmax(encoder_out)
938
+ if kwargs.get("ban_emo_unk", False):
939
+ ctc_logits[:, :, self.emo_dict["unk"]] = -float("inf")
940
+
941
+ results = []
942
+ b, n, d = encoder_out.size()
943
+ if isinstance(key[0], (list, tuple)):
944
+ key = key[0]
945
+ if len(key) < b:
946
+ key = key * b
947
+ for i in range(b):
948
+ x = ctc_logits[i, : encoder_out_lens[i].item(), :]
949
+ yseq = x.argmax(dim=-1)
950
+ yseq = torch.unique_consecutive(yseq, dim=-1)
951
+
952
+ ibest_writer = None
953
+ if kwargs.get("output_dir") is not None:
954
+ if not hasattr(self, "writer"):
955
+ self.writer = DatadirWriter(kwargs.get("output_dir"))
956
+ ibest_writer = self.writer[f"1best_recog"]
957
+
958
+ mask = yseq != self.blank_id
959
+ token_int = yseq[mask].tolist()
960
+
961
+ # Change integer-ids to tokens
962
+ text = tokenizer.decode(token_int)
963
+ if ibest_writer is not None:
964
+ ibest_writer["text"][key[i]] = text
965
+
966
+ if output_timestamp:
967
+ from itertools import groupby
968
+ timestamp = []
969
+ tokens = tokenizer.text2tokens(text)[4:]
970
+
971
+ logits_speech = self.ctc.softmax(encoder_out)[i, 4:encoder_out_lens[i].item(), :]
972
+
973
+ pred = logits_speech.argmax(-1).cpu()
974
+ logits_speech[pred==self.blank_id, self.blank_id] = 0
975
+
976
+ align = ctc_forced_align(
977
+ logits_speech.unsqueeze(0).float(),
978
+ torch.Tensor(token_int[4:]).unsqueeze(0).long().to(logits_speech.device),
979
+ (encoder_out_lens-4).long(),
980
+ torch.tensor(len(token_int)-4).unsqueeze(0).long().to(logits_speech.device),
981
+ ignore_id=self.ignore_id,
982
+ )
983
+
984
+ pred = groupby(align[0, :encoder_out_lens[0]])
985
+ _start = 0
986
+ token_id = 0
987
+ ts_max = encoder_out_lens[i] - 4
988
+ for pred_token, pred_frame in pred:
989
+ _end = _start + len(list(pred_frame))
990
+ if pred_token != 0:
991
+ ts_left = max((_start*60-30)/1000, 0)
992
+ ts_right = min((_end*60-30)/1000, (ts_max*60-30)/1000)
993
+ timestamp.append([tokens[token_id], ts_left, ts_right])
994
+ token_id += 1
995
+ _start = _end
996
+
997
+ result_i = {"key": key[i], "text": text, "timestamp": timestamp}
998
+ results.append(result_i)
999
+ else:
1000
+ result_i = {"key": key[i], "text": text}
1001
+ results.append(result_i)
1002
+ return results, meta_data
1003
+
1004
+
1005
+ def inference_encode(
1006
+ self,
1007
+ data_in,
1008
+ data_lengths=None,
1009
+ key: list = ["wav_file_tmp_name"],
1010
+ **kwargs,
1011
+ ):
1012
+
1013
+ # fbank
1014
+ speech, speech_lengths = data_in, data_lengths
1015
+ if len(speech.shape) < 3:
1016
+ speech = speech[None, :, :]
1017
+ if speech_lengths is None:
1018
+ speech_lengths = speech.shape[1]
1019
+
1020
+ speech = speech.to(device=kwargs["device"])
1021
+ speech_lengths = speech_lengths.to(device=kwargs["device"])
1022
+
1023
+ language = kwargs.get("language", "auto")
1024
+ language_query = self.embed(
1025
+ torch.LongTensor(
1026
+ [[self.lid_dict[language] if language in self.lid_dict else 0]]
1027
+ ).to(speech.device)
1028
+ ).repeat(speech.size(0), 1, 1)
1029
+
1030
+ use_itn = kwargs.get("use_itn", False)
1031
+ output_timestamp = kwargs.get("output_timestamp", False)
1032
+
1033
+ textnorm = kwargs.get("text_norm", None)
1034
+ if textnorm is None:
1035
+ textnorm = "withitn" if use_itn else "woitn"
1036
+ textnorm_query = self.embed(
1037
+ torch.LongTensor([[self.textnorm_dict[textnorm]]]).to(speech.device)
1038
+ ).repeat(speech.size(0), 1, 1)
1039
+ speech = torch.cat((textnorm_query, speech), dim=1)
1040
+ speech_lengths += 1
1041
+
1042
+ event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(
1043
+ speech.size(0), 1, 1
1044
+ )
1045
+ input_query = torch.cat((language_query, event_emo_query), dim=1)
1046
+ speech = torch.cat((input_query, speech), dim=1)
1047
+ speech_lengths += 3
1048
+
1049
+ # Encoder
1050
+ encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
1051
+ if isinstance(encoder_out, tuple):
1052
+ encoder_out = encoder_out[0]
1053
+
1054
+ return encoder_out, encoder_out_lens
1055
+
1056
+ def export(self, **kwargs):
1057
+ from export_meta import export_rebuild_model
1058
+
1059
+ if "max_seq_len" not in kwargs:
1060
+ kwargs["max_seq_len"] = 512
1061
+ models = export_rebuild_model(model=self, **kwargs)
1062
+ return models
1063
+
1064
+
1065
+ class AudioEncoder(nn.Module):
1066
+
1067
+ def __init__(
1068
+ self,
1069
+ config,
1070
+ ):
1071
+ super().__init__()
1072
+
1073
+ # TODO
1074
+ # model_dir = "/data/models/FunAudioLLM/SenseVoiceSmall/"
1075
+
1076
+ if "_name_or_path" in config:
1077
+ model_dir = config._name_or_path
1078
+ else:
1079
+ import os
1080
+ model_file= os.path.abspath(__file__)
1081
+ model_dir = os.path.dirname(model_file)
1082
+
1083
+ # self.model, self.kwargs = SenseVoiceSmall.from_pretrained(model_dir, device="cpu")
1084
+ self.model, self.kwargs = self.build_model(model=model_dir, trust_remote_code=False,)
1085
+
1086
+
1087
+ def forward(
1088
+ self,
1089
+ audios,
1090
+ ):
1091
+
1092
+ from torch.nn.utils.rnn import pad_sequence
1093
+ feats_pad = pad_sequence(audios, batch_first=True, padding_value=0.0)
1094
+ # feats_lens = torch.as_tensor([len(x) + 4 for x in audios])
1095
+ feats_lens = torch.as_tensor([len(x) for x in audios])
1096
+
1097
+ feats_pad = feats_pad.to(torch.bfloat16)
1098
+
1099
+ encoder_out, encoder_out_lens = self.model.inference_encode(
1100
+ feats_pad,
1101
+ data_lengths=feats_lens,
1102
+ language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
1103
+ use_itn=False,
1104
+ ban_emo_unk=False,
1105
+ **self.kwargs,
1106
+ )
1107
+
1108
+ return encoder_out, encoder_out_lens
1109
+
1110
+ audio_embeds = []
1111
+ for x, y in zip(encoder_out, encoder_out_lens):
1112
+ audio_embeds.append(x[:y, ...])
1113
+
1114
+ audio_embeds = torch.stack(audio_embeds, dim=0)
1115
+
1116
+ return audio_embeds
1117
+
1118
+ # https://github.com/modelscope/FunASR/blob/main/funasr/auto/auto_model.py
1119
+ @staticmethod
1120
+ def build_model(**kwargs):
1121
+ from omegaconf import DictConfig, ListConfig
1122
+ import os
1123
+
1124
+ from funasr.download.download_model_from_hub import download_model
1125
+ from funasr.train_utils.set_all_random_seed import set_all_random_seed
1126
+ from funasr.register import tables
1127
+ from funasr.train_utils.load_pretrained_model import load_pretrained_model
1128
+ from funasr.utils.misc import deep_update
1129
+
1130
+ import logging
1131
+
1132
+ assert "model" in kwargs
1133
+ if "model_conf" not in kwargs:
1134
+ logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms")))
1135
+ kwargs = download_model(**kwargs)
1136
+
1137
+ set_all_random_seed(kwargs.get("seed", 0))
1138
+
1139
+ device = kwargs.get("device", "cuda")
1140
+ if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
1141
+ device = "cpu"
1142
+ kwargs["batch_size"] = 1
1143
+ kwargs["device"] = device
1144
+
1145
+ torch.set_num_threads(kwargs.get("ncpu", 4))
1146
+
1147
+ # build tokenizer
1148
+ tokenizer = kwargs.get("tokenizer", None)
1149
+ kwargs["tokenizer"] = tokenizer
1150
+ kwargs["vocab_size"] = -1
1151
+
1152
+ if tokenizer is not None:
1153
+ tokenizers = (
1154
+ tokenizer.split(",") if isinstance(tokenizer, str) else tokenizer
1155
+ ) # type of tokenizers is list!!!
1156
+ tokenizers_conf = kwargs.get("tokenizer_conf", {})
1157
+ tokenizers_build = []
1158
+ vocab_sizes = []
1159
+ token_lists = []
1160
+
1161
+ ### === only for kws ===
1162
+ token_list_files = kwargs.get("token_lists", [])
1163
+ seg_dicts = kwargs.get("seg_dicts", [])
1164
+ ### === only for kws ===
1165
+
1166
+ if not isinstance(tokenizers_conf, (list, tuple, ListConfig)):
1167
+ tokenizers_conf = [tokenizers_conf] * len(tokenizers)
1168
+
1169
+ for i, tokenizer in enumerate(tokenizers):
1170
+ tokenizer_class = tables.tokenizer_classes.get(tokenizer)
1171
+ tokenizer_conf = tokenizers_conf[i]
1172
+
1173
+ ### === only for kws ===
1174
+ if len(token_list_files) > 1:
1175
+ tokenizer_conf["token_list"] = token_list_files[i]
1176
+ if len(seg_dicts) > 1:
1177
+ tokenizer_conf["seg_dict"] = seg_dicts[i]
1178
+ ### === only for kws ===
1179
+
1180
+ tokenizer = tokenizer_class(**tokenizer_conf)
1181
+ tokenizers_build.append(tokenizer)
1182
+ token_list = tokenizer.token_list if hasattr(tokenizer, "token_list") else None
1183
+ token_list = (
1184
+ tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else token_list
1185
+ )
1186
+ vocab_size = -1
1187
+ if token_list is not None:
1188
+ vocab_size = len(token_list)
1189
+
1190
+ if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"):
1191
+ vocab_size = tokenizer.get_vocab_size()
1192
+ token_lists.append(token_list)
1193
+ vocab_sizes.append(vocab_size)
1194
+
1195
+ if len(tokenizers_build) <= 1:
1196
+ tokenizers_build = tokenizers_build[0]
1197
+ token_lists = token_lists[0]
1198
+ vocab_sizes = vocab_sizes[0]
1199
+
1200
+ kwargs["tokenizer"] = tokenizers_build
1201
+ kwargs["vocab_size"] = vocab_sizes
1202
+ kwargs["token_list"] = token_lists
1203
+
1204
+ # build frontend
1205
+ frontend = kwargs.get("frontend", None)
1206
+ kwargs["input_size"] = None
1207
+ if frontend is not None:
1208
+ frontend_class = tables.frontend_classes.get(frontend)
1209
+ frontend = frontend_class(**kwargs.get("frontend_conf", {}))
1210
+ kwargs["input_size"] = (
1211
+ frontend.output_size() if hasattr(frontend, "output_size") else None
1212
+ )
1213
+ kwargs["frontend"] = frontend
1214
+ # build model
1215
+ model_class = tables.model_classes.get(kwargs["model"])
1216
+ assert model_class is not None, f'{kwargs["model"]} is not registered'
1217
+ model_conf = {}
1218
+ deep_update(model_conf, kwargs.get("model_conf", {}))
1219
+ deep_update(model_conf, kwargs)
1220
+ model = model_class(**model_conf)
1221
+
1222
+ # init_param
1223
+ init_param = kwargs.get("init_param", None)
1224
+ if init_param is not None:
1225
+ if os.path.exists(init_param):
1226
+ logging.info(f"Loading pretrained params from {init_param}")
1227
+ load_pretrained_model(
1228
+ model=model,
1229
+ path=init_param,
1230
+ ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
1231
+ oss_bucket=kwargs.get("oss_bucket", None),
1232
+ scope_map=kwargs.get("scope_map", []),
1233
+ excludes=kwargs.get("excludes", None),
1234
+ )
1235
+ else:
1236
+ print(f"error, init_param does not exist!: {init_param}")
1237
+
1238
+ # fp16
1239
+ if kwargs.get("fp16", False):
1240
+ model.to(torch.float16)
1241
+ elif kwargs.get("bf16", False):
1242
+ model.to(torch.bfloat16)
1243
+ # model.to(device)
1244
+
1245
+ if not kwargs.get("disable_log", True):
1246
+ tables.print()
1247
+
1248
+ return model, kwargs
1249
+
modular_qwen2.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Tuple
2
+
3
+ import torch
4
+ import torch.utils.checkpoint
5
+ from torch import nn
6
+
7
+ from transformers.cache_utils import Cache
8
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
9
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
10
+ from transformers.processing_utils import Unpack
11
+ from transformers.utils import logging
12
+ from transformers.models.llama.modeling_llama import (
13
+ LlamaAttention,
14
+ LlamaDecoderLayer,
15
+ LlamaForCausalLM,
16
+ LlamaForQuestionAnswering,
17
+ LlamaForSequenceClassification,
18
+ LlamaForTokenClassification,
19
+ LlamaMLP,
20
+ LlamaModel,
21
+ apply_rotary_pos_emb,
22
+ eager_attention_forward,
23
+ )
24
+ from .configuration_qwen2 import Qwen2Config
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ class Qwen2MLP(LlamaMLP):
31
+ def __init__(self, config):
32
+ super().__init__(config)
33
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
34
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
35
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
36
+
37
+
38
+ class Qwen2Attention(LlamaAttention):
39
+ def __init__(self, config: Qwen2Config, layer_idx: int):
40
+ super().__init__(config, layer_idx)
41
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True)
42
+ self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
43
+ self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True)
44
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
45
+
46
+ def forward(
47
+ self,
48
+ hidden_states: torch.Tensor,
49
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
50
+ attention_mask: Optional[torch.Tensor],
51
+ past_key_value: Optional[Cache] = None,
52
+ cache_position: Optional[torch.LongTensor] = None,
53
+ **kwargs: Unpack[FlashAttentionKwargs],
54
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
55
+ input_shape = hidden_states.shape[:-1]
56
+ hidden_shape = (*input_shape, -1, self.head_dim)
57
+
58
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
59
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
60
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
61
+
62
+ cos, sin = position_embeddings
63
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
64
+
65
+ if past_key_value is not None:
66
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
67
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
68
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
69
+
70
+ sliding_window = None
71
+ if (
72
+ self.config.use_sliding_window
73
+ and getattr(self.config, "sliding_window", None) is not None
74
+ and self.layer_idx >= self.config.max_window_layers
75
+ ):
76
+ sliding_window = self.config.sliding_window
77
+
78
+ attention_interface: Callable = eager_attention_forward
79
+ if self.config._attn_implementation != "eager":
80
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
81
+ logger.warning_once(
82
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
83
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
84
+ )
85
+ else:
86
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
87
+
88
+ attn_output, attn_weights = attention_interface(
89
+ self,
90
+ query_states,
91
+ key_states,
92
+ value_states,
93
+ attention_mask,
94
+ dropout=0.0 if not self.training else self.attention_dropout,
95
+ scaling=self.scaling,
96
+ sliding_window=sliding_window, # main diff with Llama
97
+ **kwargs,
98
+ )
99
+
100
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
101
+ attn_output = self.o_proj(attn_output)
102
+ return attn_output, attn_weights
103
+
104
+
105
+ class Qwen2DecoderLayer(LlamaDecoderLayer):
106
+ def __init__(self, config: Qwen2Config, layer_idx: int):
107
+ super().__init__()
108
+ self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx)
109
+ self.mlp = Qwen2MLP(config)
110
+ if config.sliding_window and config._attn_implementation != "flash_attention_2":
111
+ logger.warning_once(
112
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
113
+ "unexpected results may be encountered."
114
+ )
115
+
116
+
117
+ class Qwen2Model(LlamaModel):
118
+ pass
119
+
120
+
121
+ class Qwen2ForCausalLM(LlamaForCausalLM):
122
+ pass
123
+
124
+
125
+ class Qwen2ForSequenceClassification(LlamaForSequenceClassification):
126
+ pass
127
+
128
+
129
+ class Qwen2ForTokenClassification(LlamaForTokenClassification):
130
+ pass
131
+
132
+
133
+ class Qwen2ForQuestionAnswering(LlamaForQuestionAnswering):
134
+ pass
resampler_projector.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ import math
5
+
6
+
7
+
8
+ class ResamplerProjector(nn.Module):
9
+ def __init__(self, proj_input_size, hidden_size):
10
+ super().__init__()
11
+
12
+ self.pre_proj_layernorm = torch.nn.LayerNorm(proj_input_size)
13
+
14
+ self.mlp = nn.Sequential(
15
+ nn.Linear(proj_input_size, hidden_size, bias=False),
16
+ nn.GELU(),
17
+ nn.Linear(hidden_size, hidden_size, bias=False),
18
+ )
19
+ self.mlp.apply(init_weights)
20
+ self.pre_proj_layernorm.apply(init_weights)
21
+
22
+ def forward(self, x, *args, **kwargs):
23
+ x = x.reshape(x.shape[0], -1, x.shape[-1])
24
+ x = self.pre_proj_layernorm(x)
25
+ x = self.mlp(x)
26
+ # print(torch.distributed.get_rank(), {name: [param, param.grad] for name, param in self.pre_proj_layernorm.named_parameters()})
27
+ # print(torch.distributed.get_rank(), {name: [param, param.grad] for name, param in self.mlp.named_parameters()})
28
+ return x
29
+
30
+ def init_weights(m):
31
+ if isinstance(m, nn.Linear):
32
+ torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
33
+ if m.bias is not None:
34
+ torch.nn.init.zeros_(m.bias)
35
+
36
+ if isinstance(m, nn.LayerNorm):
37
+ torch.nn.init.ones_(m.weight)
38
+ torch.nn.init.zeros_(m.bias)
39
+
special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
tokenization_qwen2.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for Qwen2."""
16
+
17
+ import json
18
+ import os
19
+ import unicodedata
20
+ from functools import lru_cache
21
+ from typing import Optional, Tuple
22
+
23
+ import regex as re
24
+
25
+ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
26
+ from transformers.utils import logging
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+ VOCAB_FILES_NAMES = {
32
+ "vocab_file": "vocab.json",
33
+ "merges_file": "merges.txt",
34
+ }
35
+
36
+
37
+ MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}
38
+
39
+ PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
40
+
41
+
42
+ @lru_cache()
43
+ # Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
44
+ def bytes_to_unicode():
45
+ """
46
+ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
47
+ characters the bpe code barfs on.
48
+
49
+ The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
50
+ if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
51
+ decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
52
+ tables between utf-8 bytes and unicode strings.
53
+ """
54
+ bs = (
55
+ list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
56
+ )
57
+ cs = bs[:]
58
+ n = 0
59
+ for b in range(2**8):
60
+ if b not in bs:
61
+ bs.append(b)
62
+ cs.append(2**8 + n)
63
+ n += 1
64
+ cs = [chr(n) for n in cs]
65
+ return dict(zip(bs, cs))
66
+
67
+
68
+ # Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs
69
+ def get_pairs(word):
70
+ """
71
+ Return set of symbol pairs in a word.
72
+
73
+ Word is represented as tuple of symbols (symbols being variable-length strings).
74
+ """
75
+ pairs = set()
76
+ prev_char = word[0]
77
+ for char in word[1:]:
78
+ pairs.add((prev_char, char))
79
+ prev_char = char
80
+ return pairs
81
+
82
+
83
+ class Qwen2Tokenizer(PreTrainedTokenizer):
84
+ """
85
+ Construct a Qwen2 tokenizer. Based on byte-level Byte-Pair-Encoding.
86
+
87
+ Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
88
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
89
+
90
+ ```python
91
+ >>> from transformers import Qwen2Tokenizer
92
+
93
+ >>> tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen-tokenizer")
94
+ >>> tokenizer("Hello world")["input_ids"]
95
+ [9707, 1879]
96
+
97
+ >>> tokenizer(" Hello world")["input_ids"]
98
+ [21927, 1879]
99
+ ```
100
+ This is expected.
101
+
102
+ You should not use GPT2Tokenizer instead, because of the different pretokenization rules.
103
+
104
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
105
+ this superclass for more information regarding those methods.
106
+
107
+ Args:
108
+ vocab_file (`str`):
109
+ Path to the vocabulary file.
110
+ merges_file (`str`):
111
+ Path to the merges file.
112
+ errors (`str`, *optional*, defaults to `"replace"`):
113
+ Paradigm to follow when decoding bytes to UTF-8. See
114
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
115
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
116
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
117
+ token instead.
118
+ bos_token (`str`, *optional*):
119
+ The beginning of sequence token. Not applicable for this tokenizer.
120
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
121
+ The end of sequence token.
122
+ pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
123
+ The token used for padding, for example when batching sequences of different lengths.
124
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
125
+ Whether or not the model should cleanup the spaces that were added when splitting the input text during the
126
+ tokenization process. Not applicable to this tokenizer, since tokenization does not add spaces.
127
+ split_special_tokens (`bool`, *optional*, defaults to `False`):
128
+ Whether or not the special tokens should be split during the tokenization process. The default behavior is
129
+ to not split special tokens. This means that if `<|endoftext|>` is the `eos_token`, then `tokenizer.tokenize("<|endoftext|>") =
130
+ ['<|endoftext|>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<|endoftext|>")` will be give `['<',
131
+ '|', 'endo', 'ft', 'ext', '|', '>']`. This argument is only supported for `slow` tokenizers for the moment.
132
+ """
133
+
134
+ vocab_files_names = VOCAB_FILES_NAMES
135
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
136
+ max_model_input_sizes = MAX_MODEL_INPUT_SIZES
137
+ model_input_names = ["input_ids", "attention_mask"]
138
+
139
+ def __init__(
140
+ self,
141
+ vocab_file,
142
+ merges_file,
143
+ errors="replace",
144
+ unk_token="<|endoftext|>",
145
+ bos_token=None,
146
+ eos_token="<|endoftext|>",
147
+ pad_token="<|endoftext|>",
148
+ clean_up_tokenization_spaces=False,
149
+ split_special_tokens=False,
150
+ **kwargs,
151
+ ):
152
+ # Qwen vocab does not contain control tokens; added tokens need to be special
153
+ bos_token = (
154
+ AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
155
+ if isinstance(bos_token, str)
156
+ else bos_token
157
+ )
158
+ eos_token = (
159
+ AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
160
+ if isinstance(eos_token, str)
161
+ else eos_token
162
+ )
163
+ unk_token = (
164
+ AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
165
+ if isinstance(unk_token, str)
166
+ else unk_token
167
+ )
168
+ pad_token = (
169
+ AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
170
+ if isinstance(pad_token, str)
171
+ else pad_token
172
+ )
173
+
174
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
175
+ self.encoder = json.load(vocab_handle)
176
+ self.decoder = {v: k for k, v in self.encoder.items()}
177
+ self.errors = errors # how to handle errors in decoding
178
+ self.byte_encoder = bytes_to_unicode()
179
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
180
+ bpe_merges = []
181
+ with open(merges_file, encoding="utf-8") as merges_handle:
182
+ for line in merges_handle:
183
+ line = line.strip()
184
+ if not line or line.startswith("#"):
185
+ continue
186
+ bpe_merges.append(tuple(line.split()))
187
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
188
+ # NOTE: the cache can grow without bound and will get really large for long running processes
189
+ # (esp. for texts of language that do not use space between word, e.g. Chinese); technically
190
+ # not a memory leak but appears as one.
191
+ # GPT2Tokenizer has the same problem, so let's be consistent.
192
+ self.cache = {}
193
+
194
+ self.pat = re.compile(PRETOKENIZE_REGEX)
195
+
196
+ if kwargs.get("add_prefix_space", False):
197
+ logger.warning_once(
198
+ f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect."
199
+ )
200
+
201
+ super().__init__(
202
+ errors=errors,
203
+ bos_token=bos_token,
204
+ eos_token=eos_token,
205
+ pad_token=pad_token,
206
+ unk_token=unk_token,
207
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
208
+ split_special_tokens=split_special_tokens,
209
+ **kwargs,
210
+ )
211
+
212
+ @property
213
+ def vocab_size(self) -> int:
214
+ return len(self.encoder)
215
+
216
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab
217
+ def get_vocab(self):
218
+ return dict(self.encoder, **self.added_tokens_encoder)
219
+
220
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe
221
+ def bpe(self, token):
222
+ if token in self.cache:
223
+ return self.cache[token]
224
+ word = tuple(token)
225
+ pairs = get_pairs(word)
226
+
227
+ if not pairs:
228
+ return token
229
+
230
+ while True:
231
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
232
+ if bigram not in self.bpe_ranks:
233
+ break
234
+ first, second = bigram
235
+ new_word = []
236
+ i = 0
237
+ while i < len(word):
238
+ try:
239
+ j = word.index(first, i)
240
+ except ValueError:
241
+ new_word.extend(word[i:])
242
+ break
243
+ else:
244
+ new_word.extend(word[i:j])
245
+ i = j
246
+
247
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
248
+ new_word.append(first + second)
249
+ i += 2
250
+ else:
251
+ new_word.append(word[i])
252
+ i += 1
253
+ new_word = tuple(new_word)
254
+ word = new_word
255
+ if len(word) == 1:
256
+ break
257
+ else:
258
+ pairs = get_pairs(word)
259
+ word = " ".join(word)
260
+ self.cache[token] = word
261
+ return word
262
+
263
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize
264
+ def _tokenize(self, text):
265
+ """Tokenize a string."""
266
+ bpe_tokens = []
267
+ for token in re.findall(self.pat, text):
268
+ token = "".join(
269
+ self.byte_encoder[b] for b in token.encode("utf-8")
270
+ ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
271
+ bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
272
+ return bpe_tokens
273
+
274
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id
275
+ def _convert_token_to_id(self, token):
276
+ """Converts a token (str) in an id using the vocab."""
277
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
278
+
279
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token
280
+ def _convert_id_to_token(self, index):
281
+ """Converts an index (integer) in a token (str) using the vocab."""
282
+ return self.decoder.get(index)
283
+
284
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string
285
+ def convert_tokens_to_string(self, tokens):
286
+ """Converts a sequence of tokens (string) in a single string."""
287
+ text = "".join(tokens)
288
+ text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
289
+ return text
290
+
291
+ def decode(
292
+ self,
293
+ token_ids,
294
+ skip_special_tokens: bool = False,
295
+ clean_up_tokenization_spaces: Optional[bool] = False,
296
+ spaces_between_special_tokens: bool = False,
297
+ **kwargs,
298
+ ) -> str:
299
+ # `spaces_between_special_tokens` defaults to True for _decode in slow tokenizers
300
+ # and cannot be configured elsewhere, but it should default to False for Qwen2Tokenizer
301
+ return super().decode(
302
+ token_ids,
303
+ skip_special_tokens=skip_special_tokens,
304
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
305
+ spaces_between_special_tokens=spaces_between_special_tokens,
306
+ **kwargs,
307
+ )
308
+
309
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.save_vocabulary
310
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
311
+ if not os.path.isdir(save_directory):
312
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
313
+ return
314
+ vocab_file = os.path.join(
315
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
316
+ )
317
+ merge_file = os.path.join(
318
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
319
+ )
320
+
321
+ with open(vocab_file, "w", encoding="utf-8") as f:
322
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
323
+
324
+ index = 0
325
+ with open(merge_file, "w", encoding="utf-8") as writer:
326
+ writer.write("#version: 0.2\n")
327
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
328
+ if index != token_index:
329
+ logger.warning(
330
+ f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
331
+ " Please check that the tokenizer is not corrupted!"
332
+ )
333
+ index = token_index
334
+ writer.write(" ".join(bpe_tokens) + "\n")
335
+ index += 1
336
+
337
+ return vocab_file, merge_file
338
+
339
+ def prepare_for_tokenization(self, text, **kwargs):
340
+ text = unicodedata.normalize("NFC", text)
341
+ return (text, kwargs)
tokenization_qwen2_fast.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for Qwen2."""
16
+
17
+ from typing import Optional, Tuple
18
+
19
+ from transformers.tokenization_utils import AddedToken
20
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
21
+ from transformers.utils import logging
22
+ from .tokenization_qwen2 import Qwen2Tokenizer
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ VOCAB_FILES_NAMES = {
28
+ "vocab_file": "vocab.json",
29
+ "merges_file": "merges.txt",
30
+ "tokenizer_file": "tokenizer.json",
31
+ }
32
+
33
+
34
+ MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}
35
+
36
+
37
+ class Qwen2TokenizerFast(PreTrainedTokenizerFast):
38
+ """
39
+ Construct a "fast" Qwen2 tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
40
+ Byte-Pair-Encoding.
41
+
42
+ Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
43
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
44
+
45
+ ```python
46
+ >>> from transformers import Qwen2TokenizerFast
47
+
48
+ >>> tokenizer = Qwen2TokenizerFast.from_pretrained("Qwen/Qwen-tokenizer")
49
+ >>> tokenizer("Hello world")["input_ids"]
50
+ [9707, 1879]
51
+
52
+ >>> tokenizer(" Hello world")["input_ids"]
53
+ [21927, 1879]
54
+ ```
55
+ This is expected.
56
+
57
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
58
+ refer to this superclass for more information regarding those methods.
59
+
60
+ Args:
61
+ vocab_file (`str`, *optional*):
62
+ Path to the vocabulary file.
63
+ merges_file (`str`, *optional*):
64
+ Path to the merges file.
65
+ tokenizer_file (`str`, *optional*):
66
+ Path to [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
67
+ contains everything needed to load the tokenizer.
68
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
69
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
70
+ token instead. Not applicable to this tokenizer.
71
+ bos_token (`str`, *optional*):
72
+ The beginning of sequence token. Not applicable for this tokenizer.
73
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
74
+ The end of sequence token.
75
+ pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
76
+ The token used for padding, for example when batching sequences of different lengths.
77
+ """
78
+
79
+ vocab_files_names = VOCAB_FILES_NAMES
80
+ model_input_names = ["input_ids", "attention_mask"]
81
+ slow_tokenizer_class = Qwen2Tokenizer
82
+
83
+ def __init__(
84
+ self,
85
+ vocab_file=None,
86
+ merges_file=None,
87
+ tokenizer_file=None,
88
+ unk_token="<|endoftext|>",
89
+ bos_token=None,
90
+ eos_token="<|endoftext|>",
91
+ pad_token="<|endoftext|>",
92
+ **kwargs,
93
+ ):
94
+ # We need to at least pass vocab_file and merges_file to base class
95
+ # in case a slow tokenizer needs to be initialized; other can be
96
+ # configured through files.
97
+ # following GPT2TokenizerFast, also adding unk_token, bos_token, and eos_token
98
+
99
+ bos_token = (
100
+ AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
101
+ if isinstance(bos_token, str)
102
+ else bos_token
103
+ )
104
+ eos_token = (
105
+ AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
106
+ if isinstance(eos_token, str)
107
+ else eos_token
108
+ )
109
+ unk_token = (
110
+ AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
111
+ if isinstance(unk_token, str)
112
+ else unk_token
113
+ )
114
+ pad_token = (
115
+ AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
116
+ if isinstance(pad_token, str)
117
+ else pad_token
118
+ )
119
+
120
+ super().__init__(
121
+ vocab_file=vocab_file,
122
+ merges_file=merges_file,
123
+ tokenizer_file=tokenizer_file,
124
+ unk_token=unk_token,
125
+ bos_token=bos_token,
126
+ eos_token=eos_token,
127
+ pad_token=pad_token,
128
+ **kwargs,
129
+ )
130
+
131
+ # Copied from transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast.save_vocabulary
132
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
133
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
134
+ return tuple(files)
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
vocab.json ADDED
The diff for this file is too large to render. See raw diff