|
base_model.py |
|
''' |
|
@File : base_model.py |
|
@Time : 2021/10/01 22:40:33 |
|
@Author : Ming Ding |
|
@Contact : [email protected] |
|
''' |
|
|
|
|
|
from functools import partial |
|
import os |
|
import sys |
|
import math |
|
import random |
|
import torch |
|
import inspect |
|
import warnings |
|
import argparse |
|
from sat.model.registry import model_registry, MetaModel |
|
|
|
from sat.model.transformer import BaseTransformer, standard_attention |
|
from sat.arguments import update_args_with_file, overwrite_args_by_dict, set_random_seed |
|
from sat.training.model_io import load_checkpoint |
|
from sat.helpers import print_rank0 |
|
|
|
from sat.transformer_defaults import HOOKS_DEFAULT, ARGS_DEFAULT |
|
from sat.resources import auto_create |
|
from sat.mpu.initialize import get_node_rank, get_model_parallel_rank, destroy_model_parallel, initialize_model_parallel |
|
from sat.mpu.operation import mp_split_model_rank0, mp_split_model_receive, mp_merge_model_rank0, mp_merge_model_send |
|
from sat.arguments import reset_random_seed |
|
|
|
def non_conflict(func): |
|
'''mark a hook function as non-conflict, |
|
so that it can be compatible with any already defined hooks. |
|
e.g. PrefixTuningMixin.attention_fn |
|
''' |
|
func.non_conflict = True |
|
return func |
|
|
|
def replacable(func): |
|
'''mark a hook function as replacable, |
|
so that it can be replaced by mixins added after it. |
|
e.g. FP32AttentionMixin.attention_fn |
|
''' |
|
func.replacable = True |
|
return func |
|
|
|
class BaseMixin(torch.nn.Module): |
|
non_conflict = non_conflict |
|
replacable = replacable |
|
def __init__(self): |
|
super(BaseMixin, self).__init__() |
|
|
|
|
|
def reinit(self, parent_model=None): |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseModel(torch.nn.Module, metaclass=MetaModel): |
|
def __init__(self, args, transformer=None, params_dtype=torch.float, **kwargs): |
|
super(BaseModel, self).__init__() |
|
self.mixins = torch.nn.ModuleDict() |
|
self.collect_hooks_() |
|
if transformer is not None: |
|
self.transformer = transformer |
|
else: |
|
|
|
from sat.arguments import _simple_init |
|
success = _simple_init(model_parallel_size=args.model_parallel_size, seed=args.seed if hasattr(args, 'seed') else 1234) |
|
|
|
args_dict = {k: (getattr(args, v[0]) if hasattr(args, v[0]) else v[1]) for k, v in ARGS_DEFAULT.items()} |
|
|
|
self.transformer = BaseTransformer( |
|
num_layers=args.num_layers, |
|
vocab_size=args.vocab_size, |
|
hidden_size=args.hidden_size, |
|
num_attention_heads=args.num_attention_heads, |
|
max_sequence_length=args.max_sequence_length, |
|
layernorm_order=args.layernorm_order, |
|
**args_dict, |
|
hooks=self.hooks, |
|
params_dtype=params_dtype, |
|
skip_init=args.skip_init, |
|
device=torch.cuda.current_device() if hasattr(args, 'use_gpu_initialization') and args.use_gpu_initialization else torch.device('cpu'), |
|
**kwargs |
|
) |
|
|
|
def reinit(self, mixin_names=None): |
|
|
|
for k, m in self.mixins.items(): |
|
if mixin_names is None or k in mixin_names: |
|
m.reinit(self) |
|
|
|
def add_mixin(self, name, new_mixin, reinit=False): |
|
assert name not in self.mixins |
|
assert isinstance(new_mixin, BaseMixin) |
|
|
|
self.mixins[name] = new_mixin |
|
object.__setattr__(new_mixin, 'transformer', self.transformer) |
|
|
|
self.collect_hooks_() |
|
if reinit: |
|
new_mixin.reinit(self) |
|
|
|
def del_mixin(self, name): |
|
assert name in self.mixins |
|
del self.mixins[name] |
|
self.collect_hooks_() |
|
|
|
def get_mixin(self, name): |
|
return self.mixins[name] |
|
|
|
def forward(self, *args, **kwargs): |
|
|
|
|
|
self.transformer.hooks.clear() |
|
self.transformer.hooks.update(self.hooks) |
|
return self.transformer(*args, **kwargs) |
|
|
|
def collect_hooks_(self): |
|
names = list(HOOKS_DEFAULT.keys()) |
|
hooks = {} |
|
hook_origins = {} |
|
for name in names: |
|
if hasattr(self, name): |
|
hooks[name] = getattr(self, name) |
|
hook_origins[name] = 'model' |
|
|
|
for mixin_name, m in self.mixins.items(): |
|
if hasattr(m, name): |
|
if hasattr(getattr(m, name), 'non_conflict'): |
|
|
|
signature = inspect.signature(getattr(m, name)) |
|
if 'old_impl' not in signature.parameters: |
|
raise ValueError(f'Hook {name} at {mixin_name} must accept old_impl as an argument.') |
|
|
|
if name in hooks: |
|
old_impl = hooks[name] |
|
elif name == 'attention_fn': |
|
old_impl = HOOKS_DEFAULT[name] |
|
else: |
|
old_impl = partial(HOOKS_DEFAULT[name], self) |
|
old_origin = hook_origins.get(name, 'default') |
|
hooks[name] = partial(getattr(m, name), old_impl=old_impl) |
|
hook_origins[name] = mixin_name + ' -> ' + old_origin |
|
elif name in hooks and not hasattr(hooks[name], 'replacable'): |
|
raise ValueError(f'Hook {name} conflicts at {mixin_name} and {hook_origins[name]}.') |
|
else: |
|
if name in hooks and hasattr(hooks[name], 'replacable'): |
|
warnings.warn(f'Hook {name} at {mixin_name} replaces {hook_origins[name]}.') |
|
hooks[name] = getattr(m, name) |
|
hook_origins[name] = mixin_name |
|
|
|
self.hooks = hooks |
|
self.hook_origins = hook_origins |
|
return hooks |
|
|
|
def disable_untrainable_params(self): |
|
pass |
|
|
|
@classmethod |
|
def add_model_specific_args(cls, parser): |
|
|
|
return parser |
|
|
|
@classmethod |
|
def from_pretrained_base(cls, name, args=None, *, home_path=None, url=None, prefix='', build_only=False, overwrite_args={}, **kwargs): |
|
'''Load a pretrained checkpoint of the current model. |
|
Args: |
|
name: The identifier of the pretrained model. |
|
args: NameSpace. will add the loaded args into it. None will create a new model-only one with defaults. |
|
path: the parent folder of existing `name` model. Default: SAT_HOME. |
|
url: the url of the model. Default: SAT_URL. |
|
prefix: the prefix of the checkpoint. Default: ''. |
|
Returns: |
|
model: the loaded model. |
|
args: the loaded args. |
|
''' |
|
if os.path.exists(name) and os.path.isdir(name): |
|
model_path = name |
|
else: |
|
model_path = auto_create(name, path=home_path, url=url) |
|
|
|
if args is None: |
|
args = cls.get_args() |
|
args = update_args_with_file(args, path=os.path.join(model_path, 'model_config.json')) |
|
args = overwrite_args_by_dict(args, overwrite_args=overwrite_args) |
|
specific_iteration = kwargs.pop('specific_iteration', None) |
|
model = get_model(args, cls, **kwargs) |
|
if not build_only: |
|
load_checkpoint(model, args, load_path=model_path, prefix=prefix, specific_iteration=specific_iteration) |
|
return model, args |
|
|
|
@classmethod |
|
def from_pretrained(cls, name, args=None, *, home_path=None, url=None, prefix='', build_only=False, use_node_group=True, overwrite_args={}, **kwargs): |
|
if build_only or 'model_parallel_size' not in overwrite_args: |
|
return cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=build_only, overwrite_args=overwrite_args, **kwargs) |
|
else: |
|
new_model_parallel_size = overwrite_args['model_parallel_size'] |
|
if new_model_parallel_size != 1 or new_model_parallel_size == 1 and args.model_parallel_size == 1: |
|
model, model_args = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=True, overwrite_args=overwrite_args, **kwargs) |
|
local_rank = get_node_rank() if use_node_group else get_model_parallel_rank() |
|
world_size = torch.distributed.get_world_size() |
|
assert world_size % new_model_parallel_size == 0, "world size should be a multiplier of new model_parallel_size." |
|
destroy_model_parallel() |
|
initialize_model_parallel(1) |
|
if local_rank == 0: |
|
args.skip_init = True |
|
args.use_gpu_initialization = False |
|
args.device = 'cpu' |
|
overwrite_args.pop('model_parallel_size') |
|
model_full, args_ = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=False, overwrite_args=overwrite_args, **kwargs) |
|
if args_.model_parallel_size != 1: |
|
raise Exception("We do not support overwriting model_parallel_size when original model_parallel_size != 1. Try merging the model using `from_pretrained(xxx,overwrite_args={'model_parallel_size':1})` first if you still want to change model_parallel_size!") |
|
if hasattr(args, 'mode') and args.mode == 'inference': |
|
torch.distributed.barrier() |
|
destroy_model_parallel() |
|
initialize_model_parallel(new_model_parallel_size) |
|
if local_rank == 0: |
|
mp_split_model_rank0(model, model_full, use_node_group=use_node_group) |
|
del model_full |
|
else: |
|
mp_split_model_receive(model, use_node_group=use_node_group) |
|
reset_random_seed(6) |
|
else: |
|
overwrite_args.pop('model_parallel_size') |
|
model, model_args = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=False, overwrite_args=overwrite_args, **kwargs) |
|
rank = torch.distributed.get_rank() |
|
world_size = torch.distributed.get_world_size() |
|
assert world_size == model_args.model_parallel_size, "world size should be equal to model_parallel_size." |
|
destroy_model_parallel() |
|
initialize_model_parallel(1) |
|
if rank == 0: |
|
args.use_gpu_initialization = False |
|
args.device = 'cpu' |
|
overwrite_args['model_parallel_size'] = 1 |
|
model_full, args_ = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=True, overwrite_args=overwrite_args, **kwargs) |
|
torch.distributed.barrier() |
|
destroy_model_parallel() |
|
initialize_model_parallel(model_args.model_parallel_size) |
|
if rank == 0: |
|
mp_merge_model_rank0(model, model_full) |
|
model, model_args = model_full, args_ |
|
else: |
|
mp_merge_model_send(model) |
|
model_args.model_parallel_size = 1 |
|
destroy_model_parallel() |
|
initialize_model_parallel(1) |
|
return model, model_args |
|
|
|
@classmethod |
|
def list_avail_args(cls, print=True): |
|
'''List all available args of the current model.''' |
|
parser = argparse.ArgumentParser() |
|
from sat.arguments import add_model_config_args |
|
add_model_config_args(parser) |
|
|
|
if hasattr(cls, 'add_model_specific_args'): |
|
cls.add_model_specific_args(parser) |
|
if print: |
|
from sat.helpers import print_parser |
|
print_parser(parser) |
|
return parser |
|
|
|
@classmethod |
|
def get_args(cls, **kwargs): |
|
'''Get the parsed args of the current model. |
|
Args: |
|
**kwargs: will override the default args. |
|
Returns: |
|
args: the parsed args. |
|
''' |
|
parser = cls.list_avail_args(print=False) |
|
|
|
args = parser.parse_args([]) |
|
for k, v in kwargs.items(): |
|
if hasattr(args, k) or k in ['fp16']: |
|
setattr(args, k, v) |
|
else: |
|
print_rank0(f'warning: Unknown arg {k} for class {cls.__name__}.', level='DEBUG') |
|
setattr(args, k, v) |
|
return args |
|
|
|
class AutoModel(): |
|
@classmethod |
|
def from_pretrained_base(cls, name, args=None, *, home_path=None, url=None, prefix='', build_only=False, overwrite_args={}, **kwargs): |
|
'''Automatically find the class and instantiate it. Auto-download. |
|
Args: |
|
name: The identifier of the pretrained model. |
|
args: NameSpace. will add the loaded args into it. |
|
path: the parent folder of existing `name` model. Default: SAT_HOME. |
|
url: manually specified url for the `name` model. |
|
''' |
|
if os.path.exists(name) and os.path.isdir(name): |
|
model_path = name |
|
else: |
|
model_path = auto_create(name, path=home_path, url=url) |
|
if args is None: |
|
args = argparse.Namespace() |
|
null_args = True |
|
else: |
|
null_args = False |
|
args = update_args_with_file(args, path=os.path.join(model_path, 'model_config.json')) |
|
args = overwrite_args_by_dict(args, overwrite_args=overwrite_args) |
|
if not hasattr(args, 'model_class'): |
|
raise ValueError('model_config.json must have key "model_class" for AutoModel.from_pretrained.') |
|
model_cls = model_registry.get(args.model_class) |
|
if null_args: |
|
|
|
model_default_args = model_cls.get_args() |
|
for k, v in model_default_args.__dict__.items(): |
|
if not hasattr(args, k): |
|
setattr(args, k, v) |
|
model = get_model(args, model_cls, **kwargs) |
|
if not build_only: |
|
load_checkpoint(model, args, load_path=model_path, prefix=prefix) |
|
return model, args |
|
|
|
@classmethod |
|
def from_pretrained(cls, name, args=None, *, home_path=None, url=None, prefix='', build_only=False, use_node_group=True, overwrite_args={}, **kwargs): |
|
if build_only or 'model_parallel_size' not in overwrite_args: |
|
return cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=build_only, overwrite_args=overwrite_args, **kwargs) |
|
else: |
|
new_model_parallel_size = overwrite_args['model_parallel_size'] |
|
if new_model_parallel_size != 1 or new_model_parallel_size == 1 and args.model_parallel_size == 1: |
|
model, model_args = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=True, overwrite_args=overwrite_args, **kwargs) |
|
local_rank = get_node_rank() if use_node_group else get_model_parallel_rank() |
|
world_size = torch.distributed.get_world_size() |
|
assert world_size % new_model_parallel_size == 0, "world size should be a multiplier of new model_parallel_size." |
|
destroy_model_parallel() |
|
initialize_model_parallel(1) |
|
if local_rank == 0: |
|
args.skip_init = True |
|
args.use_gpu_initialization = False |
|
args.device = 'cpu' |
|
overwrite_args.pop('model_parallel_size') |
|
model_full, args_ = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=False, overwrite_args=overwrite_args, **kwargs) |
|
if args_.model_parallel_size != 1: |
|
raise Exception("We do not support overwriting model_parallel_size when original model_parallel_size != 1. Try merging the model using `from_pretrained(xxx,overwrite_args={'model_parallel_size':1})` first if you still want to change model_parallel_size!") |
|
if hasattr(args, 'mode') and args.mode == 'inference': |
|
torch.distributed.barrier() |
|
destroy_model_parallel() |
|
initialize_model_parallel(new_model_parallel_size) |
|
if local_rank == 0: |
|
mp_split_model_rank0(model, model_full, use_node_group=use_node_group) |
|
del model_full |
|
else: |
|
mp_split_model_receive(model, use_node_group=use_node_group) |
|
reset_random_seed(6) |
|
else: |
|
overwrite_args.pop('model_parallel_size') |
|
model, model_args = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=False, overwrite_args=overwrite_args, **kwargs) |
|
rank = torch.distributed.get_rank() |
|
world_size = torch.distributed.get_world_size() |
|
assert world_size == model_args.model_parallel_size, "world size should be equal to model_parallel_size." |
|
destroy_model_parallel() |
|
initialize_model_parallel(1) |
|
if rank == 0: |
|
args.use_gpu_initialization = False |
|
args.device = 'cpu' |
|
overwrite_args['model_parallel_size'] = 1 |
|
model_full, args_ = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=True, overwrite_args=overwrite_args, **kwargs) |
|
torch.distributed.barrier() |
|
destroy_model_parallel() |
|
initialize_model_parallel(model_args.model_parallel_size) |
|
if rank == 0: |
|
mp_merge_model_rank0(model, model_full) |
|
model, model_args = model_full, args_ |
|
else: |
|
mp_merge_model_send(model) |
|
model_args.model_parallel_size = 1 |
|
destroy_model_parallel() |
|
initialize_model_parallel(1) |
|
return model, model_args |
|
|
|
def get_model(args, model_cls, **kwargs): |
|
"""Build the model.""" |
|
import torch |
|
from sat.helpers import print_rank0,print_all |
|
from sat import mpu |
|
|
|
print_rank0(f'building {model_cls.__name__} model ...') |
|
if 'params_dtype' not in kwargs: |
|
if hasattr(args, 'fp16') and args.fp16: |
|
params_dtype = torch.half |
|
elif hasattr(args, 'bf16') and args.bf16: |
|
params_dtype = torch.bfloat16 |
|
else: |
|
params_dtype = torch.float32 |
|
else: |
|
|
|
params_dtype = kwargs.pop('params_dtype') |
|
|
|
from sat.helpers import check_if_zero3 |
|
if check_if_zero3(args): |
|
import deepspeed |
|
with deepspeed.zero.Init(): |
|
model = model_cls(args, params_dtype=params_dtype, **kwargs) |
|
else: |
|
model = model_cls(args, params_dtype=params_dtype, **kwargs) |
|
|
|
if mpu.get_data_parallel_rank() == 0: |
|
print_all(' > number of parameters on model parallel rank {}: {}'.format( |
|
mpu.get_model_parallel_rank(), |
|
sum([p.nelement() for p in model.parameters()])), flush=True) |
|
|
|
if hasattr(args, 'fp16') and args.fp16: |
|
model.half() |
|
elif hasattr(args, 'bf16') and args.bf16: |
|
model.bfloat16() |
|
|
|
try: |
|
if not hasattr(args, 'device'): |
|
args.device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu' |
|
model = model.to(args.device) |
|
except Exception as e: |
|
print_all(e) |
|
|
|
return model |
|
|