Spaces:
Runtime error
Runtime error
Commit
·
402c662
1
Parent(s):
7d77ab3
first upload code and model
Browse files- app.py +66 -0
- config/llama_7b.json +21 -0
- generate.py +144 -0
- model_file/chatllama_7b.bin +3 -0
- model_file/tokenizer.model +3 -0
- models/llama.py +197 -0
- models/norm.py +16 -0
- models/rope.py +30 -0
- models/tokenize.py +40 -0
- requirements.txt +4 -0
- utils.py +143 -0
app.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import argparse
|
| 5 |
+
from utils import load_hyperparam, load_model
|
| 6 |
+
from models.tokenize import Tokenizer
|
| 7 |
+
from models.llama import *
|
| 8 |
+
from generate import LmGeneration
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
args = None
|
| 12 |
+
lm_generation = None
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def init_args():
|
| 16 |
+
global args
|
| 17 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 18 |
+
args = parser.parse_args()
|
| 19 |
+
args.load_model_path = './model_file/chatllama_7b.bin'
|
| 20 |
+
args.config_path = './config/llama_7b.json'
|
| 21 |
+
args.spm_model_path = './model_file/tokenizer.model'
|
| 22 |
+
args.batch_size = 1
|
| 23 |
+
args.seq_length = 1024
|
| 24 |
+
args.world_size = 1
|
| 25 |
+
args.use_int8 = False
|
| 26 |
+
args.top_p = 0
|
| 27 |
+
args.repetition_penalty_range = 1024
|
| 28 |
+
args.repetition_penalty_slope = 0
|
| 29 |
+
args.repetition_penalty = 1.15
|
| 30 |
+
|
| 31 |
+
args = load_hyperparam(args)
|
| 32 |
+
|
| 33 |
+
args.tokenizer = Tokenizer(model_path=args.spm_model_path)
|
| 34 |
+
args.vocab_size = args.tokenizer.sp_model.vocab_size()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def init_model():
|
| 38 |
+
global lm_generation
|
| 39 |
+
torch.set_default_tensor_type(torch.HalfTensor)
|
| 40 |
+
model = LLaMa(args)
|
| 41 |
+
torch.set_default_tensor_type(torch.FloatTensor)
|
| 42 |
+
model = load_model(model, args.load_model_path)
|
| 43 |
+
model.eval()
|
| 44 |
+
|
| 45 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 46 |
+
model.to(device)
|
| 47 |
+
lm_generation = LmGeneration(model, args.tokenizer)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def chat(prompt, top_k, temperature):
|
| 51 |
+
args.top_k = int(top_k)
|
| 52 |
+
args.temperature = temperature
|
| 53 |
+
response = lm_generation.generate(args, [prompt])
|
| 54 |
+
return response[0]
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
if __name__ == '__main__':
|
| 58 |
+
init_args()
|
| 59 |
+
init_model()
|
| 60 |
+
demo = gr.Interface(
|
| 61 |
+
fn=chat,
|
| 62 |
+
inputs=["text", gr.Slider(1, 60, value=40, step=1), gr.Slider(0.1, 2.0, value=1.2, step=0.1)],
|
| 63 |
+
outputs="text",
|
| 64 |
+
)
|
| 65 |
+
demo.launch()
|
| 66 |
+
|
config/llama_7b.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"emb_size": 4096,
|
| 3 |
+
"feedforward_size": 11008,
|
| 4 |
+
"hidden_size": 4096,
|
| 5 |
+
"hidden_act": "silu",
|
| 6 |
+
"heads_num": 32,
|
| 7 |
+
"layers_num": 32,
|
| 8 |
+
"dropout": 0.1,
|
| 9 |
+
"data_processor": "lm",
|
| 10 |
+
"max_seq_length": 2048,
|
| 11 |
+
"embedding": ["word"],
|
| 12 |
+
"remove_transformer_bias": true,
|
| 13 |
+
"remove_embedding_layernorm": true,
|
| 14 |
+
"rotary_position_embedding": true,
|
| 15 |
+
"encoder": "transformer",
|
| 16 |
+
"feed_forward": "gated",
|
| 17 |
+
"mask": "causal",
|
| 18 |
+
"layernorm_positioning": "pre",
|
| 19 |
+
"layernorm": "rms",
|
| 20 |
+
"target": ["lm"]
|
| 21 |
+
}
|
generate.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def apply_temperature(scores, tempt):
|
| 6 |
+
if tempt > 0:
|
| 7 |
+
scores = scores / tempt
|
| 8 |
+
return scores
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def apply_top_p(scores, top_p, filter_value=-float("Inf"), min_tokens_to_keep=1):
|
| 12 |
+
if top_p > 0 and top_p < 1:
|
| 13 |
+
sorted_logits, sorted_indices = torch.sort(scores, descending=False)
|
| 14 |
+
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
| 15 |
+
|
| 16 |
+
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
| 17 |
+
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
|
| 18 |
+
if min_tokens_to_keep > 1:
|
| 19 |
+
# Keep at least min_tokens_to_keep
|
| 20 |
+
sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0
|
| 21 |
+
|
| 22 |
+
# scatter sorted tensors to original indexing
|
| 23 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
| 24 |
+
1, sorted_indices, sorted_indices_to_remove
|
| 25 |
+
)
|
| 26 |
+
scores = scores.masked_fill(indices_to_remove, filter_value)
|
| 27 |
+
return scores
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def apply_top_k(logits, top_k):
|
| 31 |
+
top_k = min(top_k, logits.size(-1)) # Safety check
|
| 32 |
+
if top_k > 0:
|
| 33 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
| 34 |
+
indices_to_remove = logits < torch.topk(logits.float(), top_k)[0][..., -1, None]
|
| 35 |
+
logits[indices_to_remove] = -float("Inf")
|
| 36 |
+
|
| 37 |
+
return logits
|
| 38 |
+
|
| 39 |
+
def apply_advanced_repetition_penalty(
|
| 40 |
+
input_ids, scores, penalty_range, penalty_slope, penalty
|
| 41 |
+
):
|
| 42 |
+
penalty_range = int(penalty_range)
|
| 43 |
+
clipped_penalty_range = min(input_ids.shape[-1], penalty_range)
|
| 44 |
+
|
| 45 |
+
if penalty != 1.0:
|
| 46 |
+
if penalty_range > 0:
|
| 47 |
+
if clipped_penalty_range < input_ids.shape[1]:
|
| 48 |
+
input_ids = input_ids[..., -clipped_penalty_range:]
|
| 49 |
+
|
| 50 |
+
if penalty_slope != 0:
|
| 51 |
+
_penalty = (
|
| 52 |
+
torch.arange(
|
| 53 |
+
penalty_range, dtype=scores.dtype, device=scores.device
|
| 54 |
+
)
|
| 55 |
+
/ (penalty_range - 1)
|
| 56 |
+
) * 2.0 - 1
|
| 57 |
+
_penalty = (penalty_slope * _penalty) / (
|
| 58 |
+
1 + torch.abs(_penalty) * (penalty_slope - 1)
|
| 59 |
+
)
|
| 60 |
+
_penalty = 1 + ((_penalty + 1) / 2).unsqueeze(0) * (penalty - 1)
|
| 61 |
+
penalty = _penalty[..., -clipped_penalty_range:]
|
| 62 |
+
|
| 63 |
+
score = torch.gather(scores, 1, input_ids)
|
| 64 |
+
score = torch.where(score <= 0, score * penalty, score / penalty)
|
| 65 |
+
scores.scatter_(1, input_ids, score)
|
| 66 |
+
|
| 67 |
+
return scores
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class LmGeneration:
|
| 71 |
+
def __init__(self, model, tokenizer):
|
| 72 |
+
self.model = model
|
| 73 |
+
self.tokenizer = tokenizer
|
| 74 |
+
|
| 75 |
+
def generate(self, args, prompts, cut_off=None, cut_off_times=1):
|
| 76 |
+
if cut_off is not None:
|
| 77 |
+
cut_off_times = [cut_off_times for i in range(len(prompts))]
|
| 78 |
+
batch = len(prompts)
|
| 79 |
+
assert batch <= args.batch_size
|
| 80 |
+
|
| 81 |
+
prompt_tokens = [args.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
|
| 82 |
+
|
| 83 |
+
min_prompt_len = min([len(x) for x in prompt_tokens])
|
| 84 |
+
# max_prompt_len = max([len(x) for x in prompt_tokens])
|
| 85 |
+
|
| 86 |
+
total_len = args.seq_length
|
| 87 |
+
|
| 88 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 89 |
+
tokens = torch.full((batch, total_len), self.tokenizer.pad_id).to(device).long()
|
| 90 |
+
for idx, t in enumerate(prompt_tokens):
|
| 91 |
+
tokens[idx, : len(t)] = torch.tensor(t).long()
|
| 92 |
+
mask = tokens != self.tokenizer.pad_id
|
| 93 |
+
start_pos = min_prompt_len
|
| 94 |
+
prev_pos = 0
|
| 95 |
+
continue_exsample = [i for i in range(batch)]
|
| 96 |
+
with torch.no_grad():
|
| 97 |
+
for cur_pos in range(start_pos, total_len):
|
| 98 |
+
print(cur_pos)
|
| 99 |
+
logits = self.model.forward(tokens[continue_exsample, prev_pos:cur_pos], prev_pos, continue_exsample).float()
|
| 100 |
+
next_token_scores = apply_top_k(logits, top_k=args.top_k)
|
| 101 |
+
next_token_scores = apply_top_p(next_token_scores, args.top_p)
|
| 102 |
+
next_token_scores = apply_temperature(next_token_scores, args.temperature)
|
| 103 |
+
next_token_scores = apply_advanced_repetition_penalty(
|
| 104 |
+
tokens[continue_exsample, :cur_pos],
|
| 105 |
+
next_token_scores,
|
| 106 |
+
args.repetition_penalty_range,
|
| 107 |
+
args.repetition_penalty_slope,
|
| 108 |
+
args.repetition_penalty
|
| 109 |
+
)
|
| 110 |
+
scores = F.softmax(next_token_scores, dim=-1)
|
| 111 |
+
next_token = torch.multinomial(scores, num_samples=1).squeeze(1)
|
| 112 |
+
next_token = next_token.reshape(-1)
|
| 113 |
+
next_token = torch.where(
|
| 114 |
+
mask[continue_exsample, cur_pos], tokens[continue_exsample, cur_pos], next_token
|
| 115 |
+
)
|
| 116 |
+
tokens[continue_exsample, cur_pos] = next_token
|
| 117 |
+
prev_pos = cur_pos
|
| 118 |
+
# remove eos examples.
|
| 119 |
+
continue_exsample = []
|
| 120 |
+
for i, t in enumerate(tokens.tolist()):
|
| 121 |
+
try:
|
| 122 |
+
t.index(self.tokenizer.eos_id)
|
| 123 |
+
except ValueError:
|
| 124 |
+
if cut_off is not None:
|
| 125 |
+
if cut_off == self.tokenizer.decode(t[:cur_pos + 1])[-len(cut_off):]:
|
| 126 |
+
if cut_off_times[i] == 1:
|
| 127 |
+
continue
|
| 128 |
+
else:
|
| 129 |
+
cut_off_times[i] -= 1
|
| 130 |
+
continue_exsample.append(i)
|
| 131 |
+
if len(continue_exsample) == 0:
|
| 132 |
+
break
|
| 133 |
+
|
| 134 |
+
decoder = []
|
| 135 |
+
for i, t in enumerate(tokens.tolist()):
|
| 136 |
+
t = t[: args.seq_length]
|
| 137 |
+
try:
|
| 138 |
+
t = t[: t.index(self.tokenizer.pad_id)]
|
| 139 |
+
t = t[: t.index(self.tokenizer.eos_id)]
|
| 140 |
+
except ValueError:
|
| 141 |
+
pass
|
| 142 |
+
decoder.append(self.tokenizer.decode(t))
|
| 143 |
+
|
| 144 |
+
return decoder
|
model_file/chatllama_7b.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b5bb1fb1ddf737e7f1fcbe0284ecd384dbe8f243d843b82fcdf59fd00e9b3c61
|
| 3 |
+
size 13476956615
|
model_file/tokenizer.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
|
| 3 |
+
size 499723
|
models/llama.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from models.norm import RMSNorm
|
| 5 |
+
from models.rope import precompute_freqs_cis, apply_rotary_emb
|
| 6 |
+
import bitsandbytes as bnb
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class NormalLinear(nn.Linear):
|
| 11 |
+
def reset_parameters(self) -> None:
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class BnbInt8Linear(bnb.nn.Linear8bitLt):
|
| 16 |
+
def __init__(self, *args, **kwargs):
|
| 17 |
+
super().__init__(has_fp16_weights=False, threshold=6.0, *args, **kwargs)
|
| 18 |
+
|
| 19 |
+
def reset_parameters(self) -> None:
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_linear_layer(use_int8):
|
| 24 |
+
if use_int8:
|
| 25 |
+
return BnbInt8Linear
|
| 26 |
+
return NormalLinear
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class WordEmbedding(nn.Module):
|
| 30 |
+
def __init__(self, args):
|
| 31 |
+
super(WordEmbedding, self).__init__()
|
| 32 |
+
self.embedding = nn.Embedding(args.vocab_size, args.emb_size)
|
| 33 |
+
|
| 34 |
+
def forward(self, src):
|
| 35 |
+
emb = self.embedding(src)
|
| 36 |
+
return emb
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class MultiHeadedAttention(nn.Module):
|
| 40 |
+
def __init__(self, args, hidden_size, heads_num, attention_head_size, has_bias=True, use_int8=True):
|
| 41 |
+
super(MultiHeadedAttention, self).__init__()
|
| 42 |
+
self.heads_num = heads_num
|
| 43 |
+
|
| 44 |
+
self.per_head_size = attention_head_size
|
| 45 |
+
self.inner_hidden_size = heads_num * attention_head_size
|
| 46 |
+
|
| 47 |
+
Linear = get_linear_layer(use_int8)
|
| 48 |
+
self.linear_layers = nn.ModuleList(
|
| 49 |
+
[Linear(hidden_size, self.inner_hidden_size, bias=has_bias) for _ in range(3)]
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
self.final_linear = Linear(self.inner_hidden_size, hidden_size, bias=has_bias)
|
| 53 |
+
|
| 54 |
+
# add cache to reduce compute source.
|
| 55 |
+
self.cache_k = torch.zeros(
|
| 56 |
+
(args.batch_size, args.seq_length, self.heads_num, self.per_head_size)
|
| 57 |
+
)
|
| 58 |
+
self.cache_v = torch.zeros(
|
| 59 |
+
(args.batch_size, args.seq_length, self.heads_num, self.per_head_size)
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
def forward(self, key, value, query, start_pos, continue_exsample, mask, freqs_cis):
|
| 63 |
+
batch_size, seq_length, _ = query.size()
|
| 64 |
+
heads_num = self.heads_num
|
| 65 |
+
per_head_size = self.per_head_size
|
| 66 |
+
query, key, value = [l(x).view(batch_size, -1, heads_num, per_head_size) \
|
| 67 |
+
for l, x in zip(self.linear_layers, (query, key, value))]
|
| 68 |
+
query, key = apply_rotary_emb(query, key, freqs_cis=freqs_cis)
|
| 69 |
+
if self.cache_k.device != key.device:
|
| 70 |
+
self.cache_k = self.cache_k.to(key)
|
| 71 |
+
if self.cache_v.device != value.device:
|
| 72 |
+
self.cache_v = self.cache_v.to(value)
|
| 73 |
+
|
| 74 |
+
self.cache_k[continue_exsample, start_pos: start_pos + seq_length] = key
|
| 75 |
+
self.cache_v[continue_exsample, start_pos: start_pos + seq_length] = value
|
| 76 |
+
|
| 77 |
+
key = self.cache_k[continue_exsample, : start_pos + seq_length]
|
| 78 |
+
value = self.cache_v[continue_exsample, : start_pos + seq_length]
|
| 79 |
+
|
| 80 |
+
query, key, value = [x.transpose(1, 2) for x in (query, key, value)]
|
| 81 |
+
|
| 82 |
+
scores = torch.matmul(query, key.transpose(-2, -1))
|
| 83 |
+
scores = scores / math.sqrt(float(per_head_size))
|
| 84 |
+
if mask is not None:
|
| 85 |
+
scores += mask
|
| 86 |
+
# probs = nn.Softmax(dim=-1)(scores)
|
| 87 |
+
probs = F.softmax(scores.float(), dim=-1).type_as(query)
|
| 88 |
+
output = torch.matmul(probs, value).transpose(1, 2).\
|
| 89 |
+
contiguous().view(batch_size, seq_length, -1)
|
| 90 |
+
return self.final_linear(output)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class GatedFeedForward(nn.Module):
|
| 94 |
+
def __init__(self, hidden_size, feedforward_size, has_bias=True, use_int8=True):
|
| 95 |
+
super(GatedFeedForward, self).__init__()
|
| 96 |
+
Linear = get_linear_layer(use_int8)
|
| 97 |
+
self.linear_gate = Linear(hidden_size, feedforward_size, bias=has_bias)
|
| 98 |
+
self.linear_1 = Linear(hidden_size, feedforward_size, bias=has_bias)
|
| 99 |
+
self.linear_2 = Linear(feedforward_size, hidden_size, bias=has_bias)
|
| 100 |
+
self.act = F.silu
|
| 101 |
+
|
| 102 |
+
def forward(self, x):
|
| 103 |
+
# gate = self.act(self.linear_gate(x))
|
| 104 |
+
gate = self.act(self.linear_gate(x)).type_as(x)
|
| 105 |
+
inter_linear = self.linear_1(x)
|
| 106 |
+
inter = gate * inter_linear
|
| 107 |
+
output = self.linear_2(inter)
|
| 108 |
+
return output
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class TransformerLayer(nn.Module):
|
| 112 |
+
def __init__(self, args):
|
| 113 |
+
super(TransformerLayer, self).__init__()
|
| 114 |
+
|
| 115 |
+
if hasattr(args, "attention_head_size"):
|
| 116 |
+
attention_head_size = args.attention_head_size
|
| 117 |
+
else:
|
| 118 |
+
attention_head_size = args.hidden_size // args.heads_num
|
| 119 |
+
|
| 120 |
+
has_bias = bool(1 - args.remove_transformer_bias)
|
| 121 |
+
# Multi-head Attention
|
| 122 |
+
self.self_attn = MultiHeadedAttention(
|
| 123 |
+
args, args.hidden_size, args.heads_num, attention_head_size, has_bias=has_bias,
|
| 124 |
+
use_int8=args.use_int8
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# FFN
|
| 128 |
+
self.feed_forward = GatedFeedForward(
|
| 129 |
+
args.hidden_size, args.feedforward_size, has_bias, use_int8=args.use_int8
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
self.layer_norm_1 = RMSNorm(args.hidden_size)
|
| 133 |
+
self.layer_norm_2 = RMSNorm(args.hidden_size)
|
| 134 |
+
|
| 135 |
+
def forward(self, hidden, start_pos, continue_exsample, mask, freqs_cis=None):
|
| 136 |
+
inter = self.layer_norm_1(hidden)
|
| 137 |
+
inter = self.self_attn(inter, inter, inter, start_pos, continue_exsample, mask, freqs_cis)
|
| 138 |
+
hidden = hidden + inter
|
| 139 |
+
output = self.layer_norm_2(hidden)
|
| 140 |
+
output = self.feed_forward(output) + hidden
|
| 141 |
+
return output
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class TransformerEncoder(nn.Module):
|
| 145 |
+
def __init__(self, args):
|
| 146 |
+
super(TransformerEncoder, self).__init__()
|
| 147 |
+
self.mask = args.mask
|
| 148 |
+
self.layers_num = args.layers_num
|
| 149 |
+
|
| 150 |
+
self.transformer = nn.ModuleList(
|
| 151 |
+
[TransformerLayer(args) for _ in range(self.layers_num)]
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
self.layer_norm = RMSNorm(args.hidden_size)
|
| 155 |
+
self.freqs_cis = precompute_freqs_cis(args.hidden_size // args.heads_num, args.max_seq_length * 2)
|
| 156 |
+
|
| 157 |
+
def forward(self, emb, start_pos, continue_exsample):
|
| 158 |
+
batch_size, seq_length, _ = emb.size()
|
| 159 |
+
mask = None
|
| 160 |
+
if seq_length > 1:
|
| 161 |
+
mask = torch.ones(seq_length, seq_length, device=emb.device)
|
| 162 |
+
mask = torch.tril(mask)
|
| 163 |
+
mask = (1.0 - mask) * -10000
|
| 164 |
+
mask = mask.repeat(batch_size, 1, 1, 1)
|
| 165 |
+
|
| 166 |
+
hidden = emb
|
| 167 |
+
freqs_cis = self.freqs_cis[start_pos: start_pos + seq_length].to(hidden.device)
|
| 168 |
+
|
| 169 |
+
for i in range(self.layers_num):
|
| 170 |
+
hidden = self.transformer[i](hidden, start_pos, continue_exsample, mask, freqs_cis=freqs_cis)
|
| 171 |
+
return self.layer_norm(hidden)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class LmOutput(nn.Module):
|
| 175 |
+
def __init__(self, args):
|
| 176 |
+
super(LmOutput, self).__init__()
|
| 177 |
+
# update: lm output not use int8
|
| 178 |
+
Linear = get_linear_layer(False)
|
| 179 |
+
self.lm = Linear(args.hidden_size, args.vocab_size, bias=False)
|
| 180 |
+
|
| 181 |
+
def forward(self, x):
|
| 182 |
+
return self.lm(x[:, -1, :])
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class LLaMa(nn.Module):
|
| 186 |
+
def __init__(self, args):
|
| 187 |
+
super(LLaMa, self).__init__()
|
| 188 |
+
self.embedding = WordEmbedding(args)
|
| 189 |
+
self.encoder = TransformerEncoder(args)
|
| 190 |
+
self.target = LmOutput(args)
|
| 191 |
+
|
| 192 |
+
#@torch.inference_mode()
|
| 193 |
+
def forward(self, src, start_pos, continue_exsample):
|
| 194 |
+
emb = self.embedding(src)
|
| 195 |
+
output = self.encoder(emb, start_pos, continue_exsample)
|
| 196 |
+
output = self.target(output)
|
| 197 |
+
return output
|
models/norm.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class RMSNorm(torch.nn.Module):
|
| 6 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 7 |
+
super().__init__()
|
| 8 |
+
self.eps = eps
|
| 9 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 10 |
+
|
| 11 |
+
def _norm(self, x):
|
| 12 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 13 |
+
|
| 14 |
+
def forward(self, x):
|
| 15 |
+
output = self._norm(x.float()).type_as(x)
|
| 16 |
+
return output * self.weight
|
models/rope.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
|
| 4 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
| 5 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 6 |
+
t = torch.arange(end, device=freqs.device) # type: ignore
|
| 7 |
+
freqs = torch.outer(t, freqs).float() # type: ignore
|
| 8 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
| 9 |
+
return freqs_cis
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
| 13 |
+
ndim = x.ndim
|
| 14 |
+
assert 0 <= 1 < ndim
|
| 15 |
+
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
| 16 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
| 17 |
+
return freqs_cis.view(*shape)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def apply_rotary_emb(
|
| 21 |
+
xq: torch.Tensor,
|
| 22 |
+
xk: torch.Tensor,
|
| 23 |
+
freqs_cis: torch.Tensor,
|
| 24 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 25 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
| 26 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
| 27 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
| 28 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
| 29 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
| 30 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
models/tokenize.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# copy from
|
| 2 |
+
# https://github.com/tloen/llama-int8/blob/ce74669c767e42b5082391dd0cfcb621ba40c7f9/llama/tokenizer.py
|
| 3 |
+
|
| 4 |
+
from sentencepiece import SentencePieceProcessor
|
| 5 |
+
from logging import getLogger
|
| 6 |
+
from typing import List
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
logger = getLogger()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Tokenizer:
|
| 14 |
+
def __init__(self, model_path: str):
|
| 15 |
+
# reload tokenizer
|
| 16 |
+
assert os.path.isfile(model_path), model_path
|
| 17 |
+
self.sp_model = SentencePieceProcessor(model_file=model_path)
|
| 18 |
+
logger.info(f"Reloaded SentencePiece model from {model_path}")
|
| 19 |
+
|
| 20 |
+
# BOS / EOS token IDs
|
| 21 |
+
self.n_words: int = self.sp_model.vocab_size()
|
| 22 |
+
self.bos_id: int = self.sp_model.bos_id()
|
| 23 |
+
self.eos_id: int = self.sp_model.eos_id()
|
| 24 |
+
self.pad_id: int = self.sp_model.pad_id()
|
| 25 |
+
logger.info(
|
| 26 |
+
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
|
| 27 |
+
)
|
| 28 |
+
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
|
| 29 |
+
|
| 30 |
+
def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
|
| 31 |
+
assert type(s) is str
|
| 32 |
+
t = self.sp_model.encode(s)
|
| 33 |
+
if bos:
|
| 34 |
+
t = [self.bos_id] + t
|
| 35 |
+
if eos:
|
| 36 |
+
t = t + [self.eos_id]
|
| 37 |
+
return t
|
| 38 |
+
|
| 39 |
+
def decode(self, t: List[int]) -> str:
|
| 40 |
+
return self.sp_model.decode(t)
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==1.9.0
|
| 2 |
+
bitsandbytes==0.37.2
|
| 3 |
+
sentencepiece
|
| 4 |
+
argparse
|
utils.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import sys
|
| 3 |
+
from argparse import Namespace
|
| 4 |
+
import torch
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def load_hyperparam(default_args):
|
| 9 |
+
"""
|
| 10 |
+
Load arguments form argparse and config file
|
| 11 |
+
Priority: default options < config file < command line args
|
| 12 |
+
"""
|
| 13 |
+
with open(default_args.config_path, mode="r", encoding="utf-8") as f:
|
| 14 |
+
config_args_dict = json.load(f)
|
| 15 |
+
|
| 16 |
+
default_args_dict = vars(default_args)
|
| 17 |
+
|
| 18 |
+
command_line_args_dict = {k: default_args_dict[k] for k in [
|
| 19 |
+
a[2:] for a in sys.argv if (a[:2] == "--" and "local_rank" not in a)
|
| 20 |
+
]}
|
| 21 |
+
default_args_dict.update(config_args_dict)
|
| 22 |
+
default_args_dict.update(command_line_args_dict)
|
| 23 |
+
args = Namespace(**default_args_dict)
|
| 24 |
+
|
| 25 |
+
return args
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _load_state_dict_into_model(model_to_load, model_path, start_prefix=""):
|
| 29 |
+
# Convert old format to new format if needed from a PyTorch state_dict
|
| 30 |
+
|
| 31 |
+
# copy state_dict so _load_from_state_dict can modify it
|
| 32 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
| 33 |
+
metadata = getattr(state_dict, "_metadata", None)
|
| 34 |
+
state_dict = state_dict.copy()
|
| 35 |
+
state_dict['target.lm.weight'] = state_dict['target.lm.output_layer.weight']
|
| 36 |
+
del state_dict['target.lm.output_layer.weight']
|
| 37 |
+
state_dict['embedding.embedding.weight'] = state_dict['embedding.word.embedding.weight']
|
| 38 |
+
del state_dict['embedding.word.embedding.weight']
|
| 39 |
+
|
| 40 |
+
if metadata is not None:
|
| 41 |
+
metadata['embedding.embedding'] = metadata['embedding.word.embedding']
|
| 42 |
+
metadata['target.lm'] = metadata['target.lm.output_layer']
|
| 43 |
+
if metadata.get('embedding.dropout', None) is not None:
|
| 44 |
+
del metadata['embedding.dropout']
|
| 45 |
+
del metadata['embedding.word']
|
| 46 |
+
del metadata['embedding.word.embedding']
|
| 47 |
+
del metadata['target.lm.output_layer']
|
| 48 |
+
del metadata['target.lm.softmax']
|
| 49 |
+
del metadata['target.lm.criterion']
|
| 50 |
+
state_dict._metadata = metadata
|
| 51 |
+
|
| 52 |
+
error_msgs = []
|
| 53 |
+
|
| 54 |
+
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
| 55 |
+
# so we need to apply the function recursively.
|
| 56 |
+
def load(module, state_dict, prefix=""):
|
| 57 |
+
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
| 58 |
+
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
|
| 59 |
+
# Parameters of module and children will start with prefix. We can exit early if there are none in this
|
| 60 |
+
# state_dict
|
| 61 |
+
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
|
| 62 |
+
import deepspeed
|
| 63 |
+
# In sharded models, each shard has only part of the full state_dict, so only gather
|
| 64 |
+
# parameters that are in the current state_dict.
|
| 65 |
+
named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
|
| 66 |
+
params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters]
|
| 67 |
+
if len(params_to_gather) > 0:
|
| 68 |
+
# because zero3 puts placeholders in model params, this context
|
| 69 |
+
# manager gathers (unpartitions) the params of the current layer, then loads from
|
| 70 |
+
# the state dict and then re-partitions them again
|
| 71 |
+
with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
|
| 72 |
+
if torch.distributed.get_rank() == 0:
|
| 73 |
+
module._load_from_state_dict(*args)
|
| 74 |
+
|
| 75 |
+
for name, child in module._modules.items():
|
| 76 |
+
if child is not None:
|
| 77 |
+
load(child, state_dict, prefix + name + ".")
|
| 78 |
+
|
| 79 |
+
load(model_to_load, state_dict, prefix=start_prefix)
|
| 80 |
+
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
|
| 81 |
+
# it's safe to delete it.
|
| 82 |
+
del state_dict
|
| 83 |
+
|
| 84 |
+
return model_to_load
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def convert_normal_parameter_to_int8(model, threshold=6.0, modules_to_not_convert=None, current_key_name=None):
|
| 88 |
+
import bitsandbytes as bnb
|
| 89 |
+
modules_to_not_convert = ["lm"] if modules_to_not_convert is None else modules_to_not_convert
|
| 90 |
+
for name, module in model.named_children():
|
| 91 |
+
if current_key_name is None:
|
| 92 |
+
current_key_name = []
|
| 93 |
+
current_key_name.append(name)
|
| 94 |
+
|
| 95 |
+
if len(list(module.children())) > 0:
|
| 96 |
+
convert_normal_parameter_to_int8(module, threshold, modules_to_not_convert, current_key_name)
|
| 97 |
+
|
| 98 |
+
if isinstance(module, bnb.nn.Linear8bitLt) and name not in modules_to_not_convert:
|
| 99 |
+
# Check if the current key is not in the `modules_to_not_convert`
|
| 100 |
+
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
|
| 101 |
+
model._modules[name].weight = bnb.nn.Int8Params(
|
| 102 |
+
module.weight.data,
|
| 103 |
+
requires_grad=False,
|
| 104 |
+
has_fp16_weights=False
|
| 105 |
+
)
|
| 106 |
+
# Force requires grad to False to avoid unexpected errors
|
| 107 |
+
model._modules[name].requires_grad_(False)
|
| 108 |
+
# Remove the last key for recursion
|
| 109 |
+
current_key_name.pop(-1)
|
| 110 |
+
return model
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def load_model(model, model_path):
|
| 114 |
+
if os.path.isdir(model_path):
|
| 115 |
+
index_filename = os.path.join(model_path, 'pytorch_model.bin.index.json')
|
| 116 |
+
with open(index_filename, "r") as f:
|
| 117 |
+
index = json.loads(f.read())
|
| 118 |
+
shard_filenames = sorted(set(index["weight_map"].values()))
|
| 119 |
+
shard_filenames = [os.path.join(model_path, f) for f in shard_filenames]
|
| 120 |
+
for shard_file in shard_filenames:
|
| 121 |
+
shard_checkpoint = torch.load(shard_file, map_location='cpu')
|
| 122 |
+
for name, parameter in model.named_parameters():
|
| 123 |
+
if shard_checkpoint.get(name, None) is not None:
|
| 124 |
+
if 'target' in name:
|
| 125 |
+
parameter.data = shard_checkpoint['target.lm.output_layer.weight']
|
| 126 |
+
elif 'embedding' in name:
|
| 127 |
+
parameter.data = shard_checkpoint['embedding.word.embedding.weight']
|
| 128 |
+
else:
|
| 129 |
+
parameter.data = shard_checkpoint[name]
|
| 130 |
+
parameter.requires_grad = False
|
| 131 |
+
del shard_checkpoint
|
| 132 |
+
else:
|
| 133 |
+
checkpoint = torch.load(model_path, map_location='cpu')
|
| 134 |
+
for parameter_name, parameter in model.named_parameters():
|
| 135 |
+
if 'target' in parameter_name:
|
| 136 |
+
parameter.data = checkpoint['target.lm.output_layer.weight']
|
| 137 |
+
elif 'embedding' in parameter_name:
|
| 138 |
+
parameter.data = checkpoint['embedding.word.embedding.weight']
|
| 139 |
+
else:
|
| 140 |
+
parameter.data = checkpoint[parameter_name]
|
| 141 |
+
parameter.requires_grad = False
|
| 142 |
+
del checkpoint
|
| 143 |
+
return model
|