English
File size: 14,305 Bytes
606f02a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
import torch
import torch.nn as nn
from loguru import logger as log
from torch.distributions import Categorical
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessorList, TopKLogitsWarper, TopPLogitsWarper, \
    TemperatureLogitsWarper


class DiscreteActor(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=1024):
        super(DiscreteActor, self).__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim

        self.ln = nn.LayerNorm(state_dim)
        self.linear1 = nn.Linear(self.state_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, int(hidden_dim / 4))
        self.output_linear = nn.Linear(int(hidden_dim / 4), self.action_dim)

    def forward(self, state):
        state = self.ln(state)
        x = torch.relu(self.linear1(state))
        x = torch.relu(self.linear2(x))
        output = torch.softmax(self.output_linear(x), dim=1)
        return output

    def sample(self, state):
        state = state.unsqueeze(0)
        prob = self.forward(state)
        distribution = Categorical(torch.Tensor(prob))
        sample_action = distribution.sample().unsqueeze(-1).detach()
        z = (prob == 0.0).float() * 1e-8
        logprob = torch.log(prob + z)
        greedy_action = torch.argmax(prob, dim=-1).unsqueeze(-1)  # 1d tensor
        return sample_action, prob, logprob, greedy_action

    def select_action(self, state):
        state = state.unsqueeze(0)
        prob = self.forward(state)
        action = torch.argmax(prob, dim=-1).unsqueeze(-1)  # 1d tensor
        action = action.squeeze().tolist()
        return action


class MARAGenerator():
    def __init__(
            self,
            agent_path, base_model_path, state_dim=4096, hidden_dim=1024, model_device="cuda:0",
            max_new_token=2048, topk=40, topp=0.95, temperature=0.8
    ):
        # 文本生成模型初始化
        self.base_model_path = base_model_path
        self.model_device = torch.device(model_device)
        self.base_model = AutoModelForCausalLM.from_pretrained(self.base_model_path).to(
            self.model_device).eval().requires_grad_(False)
        self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_path)
        self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

        self.topk = topk
        self.topp = topp
        self.max_new_token = max_new_token
        self.temperature = temperature

        # instantiate logits processors and wraper
        self.logits_wraper = LogitsProcessorList(
            [TopKLogitsWarper(self.topk), TopPLogitsWarper(top_p=self.topp),
             TemperatureLogitsWarper(self.temperature)])

        '''init agent'''
        self.mara_agent = self.get_mara_agent(agent_path, state_dim, hidden_dim).to(self.model_device)
        # log.info("mara_agent:{}".format(self.mara_agent))

        self.instruction = ""  # 输入指令
        self.generate_ids = []  # 已生成的token
        self.new_token_cnt = 0  # 已生成token数
        self.curr_input_ids = None  # 生成new_token的输入input_id,
        self.curr_new_token_ids_list = None  # 采样的new_token
        self.sum_logprobs = None
        self.curr_new_outputs_list = None  # 以self.curr_input_ids+curr_new_token_id为输入的模型输出
        self.next_new_outputs_list = None  # 对应self.curr_new_outputs_list下一个token
        self.is_new_token = True
        self.chosen_indices = None
        self.cand_chosen_idx = 0
        self.last_outputs = None
        self.attention_mask = None
        self.position_id = None

        # proxy_detail
        self.gen_info = {"gen_token_cnt": 0,  # 每个样例生成的结果长度,len(gen_token_cnt_list)=sample_cnt
                         "proxy_token_cnt": 0,  # 需要经过agent进行决策的token长度,len(proxy_token_cnt_listt)=sample_cnt
                         "cand_token_dict": {},  # 每个决策位的候选token数,1<=cand_token_cnt<=topK,0表示不需要决策
                         "accept_index_dict": {}  # 每个决策位的选择第几个token, accept_idx
                         }

    def get_mara_agent(self, agent_path, state_dim, hidden_dim):
        mara_agent = DiscreteActor(state_dim, 2, hidden_dim)
        log.info('Begin to load mara agent model from {}'.format(agent_path))
        try:
            model_state_dict = torch.load(agent_path, map_location=self.model_device)
            mara_agent.load_state_dict(model_state_dict)
        except Exception as e:
            log.error("load mara_agent occur error: {}".format(str(e)))
            raise
        return mara_agent

    def get_input_text(self, instruction):
        messages = [{"role": "user", "content": instruction}]
        input_text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        return input_text

    def get_raw_output(self, instruction, do_sample=True):
        if do_sample:
            generation_config = {"do_sample": True, "max_new_tokens": self.max_new_token, "top_k": self.topk,
                                 "top_p": self.topp, "temperature": self.temperature}
        else:
            generation_config = {"do_sample": False, "max_new_tokens": self.max_new_token}
        input_text = self.get_input_text(instruction)
        inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model_device)

        output_ids = self.base_model.generate(**inputs, **generation_config)[0]
        response = self.tokenizer.decode(output_ids[len(inputs.input_ids[0]):], skip_special_tokens=True)
        return {"answer": response}

    def get_proxy_output(self, instruction):
        input_text = self.get_input_text(instruction)
        model_inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model_device)
        self.new_token_cnt = 0
        self.generate_ids = []
        self.gen_info = {"gen_token_cnt": 0,  # 每个样例生成的结果长度,len(gen_token_cnt_list)=sample_cnt
                         "mara_token_cnt": 0,  # 需要经过agent进行决策的token长度,len(proxy_token_cnt_list)=sample_cnt
                         "cand_token_dict": {},  # 每个决策位的候选token数,1<=cand_token_cnt<=topK,0表示不需要决策
                         "accept_index_dict": {}  # 每个决策位的选择第几个token, accept_idx
                         }

        input_ids = model_inputs.input_ids
        self.attention_mask = model_inputs.attention_mask
        self.curr_input_ids = input_ids
        self.last_outputs = self.base_model(input_ids=input_ids,
                                            attention_mask=self.attention_mask,
                                            output_hidden_states=True)

        # 一个token一个token地进行状态转移
        self.is_new_token = True
        self.cand_chosen_idx = 0
        self.position_id = None
        end_of_generate = False
        while not end_of_generate:
            end_of_generate, self.chosen_indices = self.rank_topk_ouputs_serial(self.curr_input_ids, self.last_outputs)
            self.is_new_token = False
            if end_of_generate:
                break
            accept_idx = 0
            curr_new_outputs_list = []
            for i in range(len(self.chosen_indices)):
                _, curr_new_outputs = self.rank_topk_ouputs_serial(self.curr_input_ids, self.last_outputs)
                curr_new_outputs_list.append(curr_new_outputs)
                curr_state = curr_new_outputs['hidden_states'][-1][0, -1].to(self.model_device)

                action = self.mara_agent.select_action(curr_state)
                if action == 1:
                    accept_idx = i
                    break
            self.is_new_token = True
            # log.info(
            #     "new_token_cnt:{}/{}, accept_idx: {}".format(self.new_token_cnt, self.max_new_token, accept_idx))
            accept_token_id = self.chosen_indices[accept_idx]
            self.generate_ids.append(accept_token_id)
            self.new_token_cnt += 1
            self.gen_info["gen_token_cnt"] += 1
            self.gen_info["mara_token_cnt"] += 1
            if len(self.chosen_indices) not in self.gen_info["cand_token_dict"]:
                self.gen_info["cand_token_dict"][len(self.chosen_indices)] = 1
            else:
                self.gen_info["cand_token_dict"][len(self.chosen_indices)] += 1
            if accept_idx not in self.gen_info["accept_index_dict"]:
                self.gen_info["accept_index_dict"][accept_idx] = 1
            else:
                self.gen_info["accept_index_dict"][accept_idx] += 1

            self.curr_input_ids = torch.cat(
                (self.curr_input_ids, torch.LongTensor([[accept_token_id]]).to(self.model_device)), dim=-1)
            self.last_outputs = curr_new_outputs_list[accept_idx]

            if accept_token_id == self.tokenizer.eos_token_id or self.new_token_cnt >= self.max_new_token:
                end_of_generate = True
        completion = self.tokenizer.decode(self.generate_ids, skip_special_tokens=True)
        return {"answer": completion, "detail": self.gen_info}

    def one_step_transfer(self, pre_input_ids, past_key_values, new_token_id):
        attention_mask = torch.ones_like(pre_input_ids)
        attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
        position_ids = attention_mask.long().cumsum(-1) - 1
        position_ids.masked_fill_(attention_mask == 0, 1)
        position_id = position_ids[:, -1:].to(self.model_device)
        new_token_id = new_token_id.to(self.model_device)

        new_outputs = self.base_model(input_ids=new_token_id,
                                      attention_mask=attention_mask.to(self.model_device),
                                      position_ids=position_id,
                                      past_key_values=past_key_values,
                                      output_hidden_states=True)
        new_input_ids = torch.cat((pre_input_ids, new_token_id), dim=-1)
        return new_input_ids, new_outputs

    def rank_topk_ouputs_serial(self, pre_input_ids, last_outputs):
        if self.is_new_token:
            end_of_generate = False
            next_token_logits = last_outputs.logits[:, -1, :].clone()  # (batch_size, vocab_size)
            next_token_scores = nn.functional.log_softmax(next_token_logits, dim=-1)  # (batch_size, vocab_size)
            next_token_scores = self.logits_wraper(pre_input_ids, next_token_scores)
            sorted_scores, sorted_indices = torch.sort(next_token_scores, descending=True)
            chosen_indices = torch.masked_select(sorted_indices, sorted_scores != -float("Inf")).tolist()
            # 当候选token数只有1个且不为终止eos_token_id时直接一步转移
            while len(chosen_indices) == 1:
                self.new_token_cnt += 1
                self.gen_info["gen_token_cnt"] += 1
                self.generate_ids.append(chosen_indices[0])
                if chosen_indices[0] == self.tokenizer.eos_token_id or self.new_token_cnt >= self.max_new_token:
                    end_of_generate = True
                    return end_of_generate, None
                else:
                    pre_input_ids, last_outputs = self.one_step_transfer(pre_input_ids,
                                                                         past_key_values=last_outputs.past_key_values,
                                                                         new_token_id=torch.LongTensor(
                                                                             [[chosen_indices[0]]]))
                    self.curr_input_ids = pre_input_ids

                    next_token_logits = last_outputs.logits[:, -1, :].clone()  # (batch_size, vocab_size)
                    next_token_scores = nn.functional.log_softmax(next_token_logits, dim=-1)  # (batch_size, vocab_size)
                    next_token_scores = self.logits_wraper(pre_input_ids, next_token_scores)
                    sorted_scores, sorted_indices = torch.sort(next_token_scores, descending=True)
                    chosen_indices = torch.masked_select(sorted_indices, sorted_scores != -float("Inf")).tolist()

            attention_mask = torch.ones_like(pre_input_ids)
            attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            position_id = position_ids[:, -1:].to(self.model_device)
            # log.info("len(chosen_indices):{}".format(len(chosen_indices)))
            self.chosen_indices = chosen_indices
            self.cand_chosen_idx = 0
            self.attention_mask = attention_mask
            self.position_id = position_id
            self.last_outputs = last_outputs
            return end_of_generate, chosen_indices
        else:
            new_token_id = torch.LongTensor([[self.chosen_indices[self.cand_chosen_idx]]])
            curr_next_output = self.base_model(input_ids=new_token_id.to(self.model_device),
                                               attention_mask=self.attention_mask.to(self.model_device),
                                               position_ids=self.position_id.to(self.model_device),
                                               past_key_values=self.last_outputs.past_key_values,
                                               output_hidden_states=True)
            self.cand_chosen_idx += 1
            return False, curr_next_output


if __name__ == "__main__":
    agent_path = "../proxy_rlhf/train_result/multi_reward/mistral_v3_2_1/run2/trained_model/actor_11000.pth"
    base_model_path = "/mnt/public/model/huggingface/Mistral-7B-Instruct-v0.3"
    proxy_generator = MARAGenerator(agent_path, base_model_path)
    instruction = "Please introduce yourself."
    raw_result = proxy_generator.get_raw_output(instruction, do_sample=False)
    print("base model answer: ")
    print(raw_result["answer"])
    proxy_result = proxy_generator.get_proxy_output(instruction)
    print("mara agent align answer: ")
    print(proxy_result["answer"])