|
import gc |
|
import os |
|
from copy import deepcopy |
|
from typing import Dict, Union, Optional |
|
|
|
import torch |
|
from torch.nn import Module |
|
from transformers import AutoModel, PreTrainedModel, PreTrainedTokenizer |
|
from transformers.generation.logits_process import LogitsProcessor |
|
|
|
|
|
def auto_configure_device_map(num_gpus: int) -> Dict[str, int]: |
|
|
|
|
|
|
|
|
|
num_trans_layers = 28 |
|
per_gpu_layers = 30 / num_gpus |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device_map = { |
|
'transformer.embedding.word_embeddings': 0, |
|
'transformer.encoder.final_layernorm': 0, |
|
'transformer.output_layer': 0, |
|
'transformer.rotary_pos_emb': 0, |
|
'lm_head': 0 |
|
} |
|
|
|
used = 2 |
|
gpu_target = 0 |
|
for i in range(num_trans_layers): |
|
if used >= per_gpu_layers: |
|
gpu_target += 1 |
|
used = 0 |
|
assert gpu_target < num_gpus |
|
device_map[f'transformer.encoder.layers.{i}'] = gpu_target |
|
used += 1 |
|
|
|
return device_map |
|
|
|
|
|
def load_model_on_gpus(checkpoint_path: Union[str, os.PathLike], num_gpus: int = 2, |
|
device_map: Optional[Dict[str, int]] = None, **kwargs) -> Module: |
|
if num_gpus < 2 and device_map is None: |
|
model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half().cuda() |
|
else: |
|
from accelerate import dispatch_model |
|
|
|
model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half() |
|
|
|
if device_map is None: |
|
device_map = auto_configure_device_map(num_gpus) |
|
|
|
model = dispatch_model(model, device_map=device_map) |
|
|
|
return model |
|
|
|
|
|
class InvalidScoreLogitsProcessor(LogitsProcessor): |
|
def __call__( |
|
self, input_ids: torch.LongTensor, scores: torch.FloatTensor |
|
) -> torch.FloatTensor: |
|
if torch.isnan(scores).any() or torch.isinf(scores).any(): |
|
scores.zero_() |
|
scores[..., 5] = 5e4 |
|
return scores |
|
|
|
|
|
def process_response(output, history): |
|
content = "" |
|
history = deepcopy(history) |
|
for response in output.split("<|assistant|>"): |
|
metadata, content = response.split("\n", maxsplit=1) |
|
if not metadata.strip(): |
|
content = content.strip() |
|
history.append( |
|
{ |
|
|
|
"role": "assistant", |
|
"metadata": metadata, |
|
"content": content |
|
} |
|
) |
|
content = content.replace("[[训练时间]]", "2023年") |
|
else: |
|
history.append( |
|
{ |
|
"role": "assistant", |
|
"metadata": metadata, |
|
"content": content |
|
} |
|
) |
|
if history[0]["role"] == "system" and "tools" in history[0]: |
|
content = "\n".join(content.split("\n")[1:-1]) |
|
|
|
def tool_call(**kwargs): |
|
return kwargs |
|
|
|
parameters = eval(content) |
|
content = { |
|
"name": metadata.strip(), |
|
"parameters": parameters |
|
} |
|
else: |
|
content = { |
|
"name": metadata.strip(), |
|
"content": content |
|
} |
|
return content, history |
|
|
|
|
|
@torch.inference_mode() |
|
def generate_stream_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict): |
|
messages = params["messages"] |
|
temperature = float(params.get("temperature", 1.0)) |
|
repetition_penalty = float(params.get("repetition_penalty", 1.0)) |
|
top_p = float(params.get("top_p", 1.0)) |
|
max_new_tokens = int(params.get("max_tokens", 256)) |
|
echo = params.get("echo", True) |
|
|
|
query, role = messages[-1].content, messages[-1].role |
|
history = [m.dict(exclude_none=True) for m in messages[:-1]] |
|
|
|
inputs = tokenizer.build_chat_input(query, history=history, role=role) |
|
inputs = inputs.to(model.device) |
|
input_echo_len = len(inputs["input_ids"][0]) |
|
|
|
if input_echo_len >= model.config.seq_length: |
|
raise |
|
|
|
eos_token_id = [ |
|
tokenizer.eos_token_id, |
|
tokenizer.get_command("<|user|>"), |
|
tokenizer.get_command("<|observation|>") |
|
] |
|
|
|
gen_kwargs = { |
|
"max_length": max_new_tokens + input_echo_len, |
|
"do_sample": True if temperature > 1e-5 else False, |
|
"top_p": top_p, |
|
"repetition_penalty": repetition_penalty, |
|
"logits_processor": [InvalidScoreLogitsProcessor()], |
|
} |
|
if temperature > 1e-5: |
|
gen_kwargs["temperature"] = temperature |
|
|
|
history.append( |
|
{ |
|
"role": role, |
|
"content": query |
|
} |
|
) |
|
|
|
total_len = 0 |
|
for total_ids in model.stream_generate(**inputs, eos_token_id=eos_token_id, **gen_kwargs): |
|
total_ids = total_ids.tolist()[0] |
|
total_len = len(total_ids) |
|
if echo: |
|
output_ids = total_ids[:-1] |
|
else: |
|
output_ids = total_ids[input_echo_len:-1] |
|
|
|
response = tokenizer.decode(output_ids) |
|
if response and response[-1] != "�": |
|
yield { |
|
"text": response, |
|
"usage": { |
|
"prompt_tokens": input_echo_len, |
|
"completion_tokens": total_len - input_echo_len, |
|
"total_tokens": total_len, |
|
}, |
|
"finish_reason": None, |
|
} |
|
|
|
|
|
ret = { |
|
"text": response, |
|
"usage": { |
|
"prompt_tokens": input_echo_len, |
|
"completion_tokens": total_len - input_echo_len, |
|
"total_tokens": total_len, |
|
}, |
|
"finish_reason": "stop", |
|
} |
|
yield ret |
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
def generate_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict): |
|
for response in generate_stream_chatglm3(model, tokenizer, params): |
|
pass |
|
return response |
|
|