Upload base_model.py with huggingface_hub
Browse files- base_model.py +439 -0
base_model.py
ADDED
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_model.py# -*- encoding: utf-8 -*-
|
2 |
+
'''
|
3 |
+
@File : base_model.py
|
4 |
+
@Time : 2021/10/01 22:40:33
|
5 |
+
@Author : Ming Ding
|
6 |
+
@Contact : [email protected]
|
7 |
+
'''
|
8 |
+
|
9 |
+
# here put the import lib
|
10 |
+
from functools import partial
|
11 |
+
import os
|
12 |
+
import sys
|
13 |
+
import math
|
14 |
+
import random
|
15 |
+
import torch
|
16 |
+
import inspect
|
17 |
+
import warnings
|
18 |
+
import argparse
|
19 |
+
from sat.model.registry import model_registry, MetaModel
|
20 |
+
|
21 |
+
from sat.model.transformer import BaseTransformer, standard_attention
|
22 |
+
from sat.arguments import update_args_with_file, overwrite_args_by_dict, set_random_seed
|
23 |
+
from sat.training.model_io import load_checkpoint
|
24 |
+
from sat.helpers import print_rank0
|
25 |
+
|
26 |
+
from sat.transformer_defaults import HOOKS_DEFAULT, ARGS_DEFAULT
|
27 |
+
from sat.resources import auto_create
|
28 |
+
from sat.mpu.initialize import get_node_rank, get_model_parallel_rank, destroy_model_parallel, initialize_model_parallel
|
29 |
+
from sat.mpu.operation import mp_split_model_rank0, mp_split_model_receive, mp_merge_model_rank0, mp_merge_model_send
|
30 |
+
from sat.arguments import reset_random_seed
|
31 |
+
|
32 |
+
def non_conflict(func):
|
33 |
+
'''mark a hook function as non-conflict,
|
34 |
+
so that it can be compatible with any already defined hooks.
|
35 |
+
e.g. PrefixTuningMixin.attention_fn
|
36 |
+
'''
|
37 |
+
func.non_conflict = True
|
38 |
+
return func
|
39 |
+
|
40 |
+
def replacable(func):
|
41 |
+
'''mark a hook function as replacable,
|
42 |
+
so that it can be replaced by mixins added after it.
|
43 |
+
e.g. FP32AttentionMixin.attention_fn
|
44 |
+
'''
|
45 |
+
func.replacable = True
|
46 |
+
return func
|
47 |
+
|
48 |
+
class BaseMixin(torch.nn.Module):
|
49 |
+
non_conflict = non_conflict
|
50 |
+
replacable = replacable
|
51 |
+
def __init__(self):
|
52 |
+
super(BaseMixin, self).__init__()
|
53 |
+
# define new params
|
54 |
+
|
55 |
+
def reinit(self, parent_model=None):
|
56 |
+
# reload the initial params from previous trained modules
|
57 |
+
# you can also get access to other mixins through parent_model.get_mixin().
|
58 |
+
pass
|
59 |
+
|
60 |
+
# can define hook-functions here
|
61 |
+
# a hook, if default or replacable, can be overrided by mixins added after it.
|
62 |
+
# a hook can be augmented by non_conflict hooks added after it.
|
63 |
+
# default -> 0~n replacable -> 0~n non_conflict
|
64 |
+
# ...
|
65 |
+
|
66 |
+
# If the hook is just a pre- or post- transformation,
|
67 |
+
# You can use @non_conflict to mark it,
|
68 |
+
# and run `old_impl` to make it compatible with other mixins.
|
69 |
+
# Eg.,
|
70 |
+
#
|
71 |
+
# @non_conflict
|
72 |
+
# def attention_fn(q, k, v, mask, dropout_fn, old_impl=standard_attention, **kw_args):
|
73 |
+
# new_q, new_k, new_v = pre_hack(q, k, v)
|
74 |
+
# attn_result = old_impl(q, k, v, mask, dropout_fn, **kw_args)
|
75 |
+
# attn_result = post_hack(attn_result)
|
76 |
+
# return attn_result
|
77 |
+
|
78 |
+
|
79 |
+
class BaseModel(torch.nn.Module, metaclass=MetaModel):
|
80 |
+
def __init__(self, args, transformer=None, params_dtype=torch.float, **kwargs):
|
81 |
+
super(BaseModel, self).__init__()
|
82 |
+
self.mixins = torch.nn.ModuleDict()
|
83 |
+
self.collect_hooks_()
|
84 |
+
if transformer is not None:
|
85 |
+
self.transformer = transformer
|
86 |
+
else:
|
87 |
+
# check if model-only mode
|
88 |
+
from sat.arguments import _simple_init
|
89 |
+
success = _simple_init(model_parallel_size=args.model_parallel_size, seed=args.seed if hasattr(args, 'seed') else 1234)
|
90 |
+
|
91 |
+
args_dict = {k: (getattr(args, v[0]) if hasattr(args, v[0]) else v[1]) for k, v in ARGS_DEFAULT.items()}
|
92 |
+
|
93 |
+
self.transformer = BaseTransformer(
|
94 |
+
num_layers=args.num_layers,
|
95 |
+
vocab_size=args.vocab_size,
|
96 |
+
hidden_size=args.hidden_size,
|
97 |
+
num_attention_heads=args.num_attention_heads,
|
98 |
+
max_sequence_length=args.max_sequence_length,
|
99 |
+
layernorm_order=args.layernorm_order,
|
100 |
+
**args_dict,
|
101 |
+
hooks=self.hooks,
|
102 |
+
params_dtype=params_dtype,
|
103 |
+
skip_init=args.skip_init,
|
104 |
+
device=torch.cuda.current_device() if hasattr(args, 'use_gpu_initialization') and args.use_gpu_initialization else torch.device('cpu'),
|
105 |
+
**kwargs
|
106 |
+
)
|
107 |
+
|
108 |
+
def reinit(self, mixin_names=None): # will be called when loading model, None means all
|
109 |
+
# if some mixins are loaded, overrides this function
|
110 |
+
for k, m in self.mixins.items():
|
111 |
+
if mixin_names is None or k in mixin_names:
|
112 |
+
m.reinit(self)
|
113 |
+
|
114 |
+
def add_mixin(self, name, new_mixin, reinit=False):
|
115 |
+
assert name not in self.mixins
|
116 |
+
assert isinstance(new_mixin, BaseMixin)
|
117 |
+
|
118 |
+
self.mixins[name] = new_mixin # will auto-register parameters
|
119 |
+
object.__setattr__(new_mixin, 'transformer', self.transformer) # cannot use pytorch set_attr
|
120 |
+
|
121 |
+
self.collect_hooks_()
|
122 |
+
if reinit:
|
123 |
+
new_mixin.reinit(self) # also pass current mixins
|
124 |
+
|
125 |
+
def del_mixin(self, name):
|
126 |
+
assert name in self.mixins
|
127 |
+
del self.mixins[name]
|
128 |
+
self.collect_hooks_()
|
129 |
+
|
130 |
+
def get_mixin(self, name):
|
131 |
+
return self.mixins[name]
|
132 |
+
|
133 |
+
def forward(self, *args, **kwargs):
|
134 |
+
# update hooks as the current model (overrided forwards)
|
135 |
+
# Attention! the transformer might be shared by multiple models
|
136 |
+
self.transformer.hooks.clear()
|
137 |
+
self.transformer.hooks.update(self.hooks)
|
138 |
+
return self.transformer(*args, **kwargs)
|
139 |
+
|
140 |
+
def collect_hooks_(self):
|
141 |
+
names = list(HOOKS_DEFAULT.keys())
|
142 |
+
hooks = {}
|
143 |
+
hook_origins = {}
|
144 |
+
for name in names:
|
145 |
+
if hasattr(self, name):
|
146 |
+
hooks[name] = getattr(self, name)
|
147 |
+
hook_origins[name] = 'model'
|
148 |
+
|
149 |
+
for mixin_name, m in self.mixins.items():
|
150 |
+
if hasattr(m, name):
|
151 |
+
if hasattr(getattr(m, name), 'non_conflict'):
|
152 |
+
# check getattr(m, name), who must accept old_impl as an argument
|
153 |
+
signature = inspect.signature(getattr(m, name))
|
154 |
+
if 'old_impl' not in signature.parameters:
|
155 |
+
raise ValueError(f'Hook {name} at {mixin_name} must accept old_impl as an argument.')
|
156 |
+
# -------------
|
157 |
+
if name in hooks:
|
158 |
+
old_impl = hooks[name]
|
159 |
+
elif name == 'attention_fn': # the only hook without self
|
160 |
+
old_impl = HOOKS_DEFAULT[name]
|
161 |
+
else:
|
162 |
+
old_impl = partial(HOOKS_DEFAULT[name], self) # relax! `partial` does not affect the signature
|
163 |
+
old_origin = hook_origins.get(name, 'default')
|
164 |
+
hooks[name] = partial(getattr(m, name), old_impl=old_impl)
|
165 |
+
hook_origins[name] = mixin_name + ' -> ' + old_origin
|
166 |
+
elif name in hooks and not hasattr(hooks[name], 'replacable'): # if this hook name is already registered
|
167 |
+
raise ValueError(f'Hook {name} conflicts at {mixin_name} and {hook_origins[name]}.')
|
168 |
+
else: # new hook
|
169 |
+
if name in hooks and hasattr(hooks[name], 'replacable'):
|
170 |
+
warnings.warn(f'Hook {name} at {mixin_name} replaces {hook_origins[name]}.')
|
171 |
+
hooks[name] = getattr(m, name)
|
172 |
+
hook_origins[name] = mixin_name
|
173 |
+
|
174 |
+
self.hooks = hooks
|
175 |
+
self.hook_origins = hook_origins
|
176 |
+
return hooks
|
177 |
+
|
178 |
+
def disable_untrainable_params(self):
|
179 |
+
pass
|
180 |
+
|
181 |
+
@classmethod
|
182 |
+
def add_model_specific_args(cls, parser):
|
183 |
+
# recorded in arguments.py: add_model_config_args
|
184 |
+
return parser
|
185 |
+
|
186 |
+
@classmethod
|
187 |
+
def from_pretrained_base(cls, name, args=None, *, home_path=None, url=None, prefix='', build_only=False, overwrite_args={}, **kwargs):
|
188 |
+
'''Load a pretrained checkpoint of the current model.
|
189 |
+
Args:
|
190 |
+
name: The identifier of the pretrained model.
|
191 |
+
args: NameSpace. will add the loaded args into it. None will create a new model-only one with defaults.
|
192 |
+
path: the parent folder of existing `name` model. Default: SAT_HOME.
|
193 |
+
url: the url of the model. Default: SAT_URL.
|
194 |
+
prefix: the prefix of the checkpoint. Default: ''.
|
195 |
+
Returns:
|
196 |
+
model: the loaded model.
|
197 |
+
args: the loaded args.
|
198 |
+
'''
|
199 |
+
if os.path.exists(name) and os.path.isdir(name):
|
200 |
+
model_path = name
|
201 |
+
else:
|
202 |
+
model_path = auto_create(name, path=home_path, url=url)
|
203 |
+
# create a new args if not provided
|
204 |
+
if args is None:
|
205 |
+
args = cls.get_args()
|
206 |
+
args = update_args_with_file(args, path=os.path.join(model_path, 'model_config.json'))
|
207 |
+
args = overwrite_args_by_dict(args, overwrite_args=overwrite_args)
|
208 |
+
specific_iteration = kwargs.pop('specific_iteration', None)
|
209 |
+
model = get_model(args, cls, **kwargs)
|
210 |
+
if not build_only:
|
211 |
+
load_checkpoint(model, args, load_path=model_path, prefix=prefix, specific_iteration=specific_iteration)
|
212 |
+
return model, args
|
213 |
+
|
214 |
+
@classmethod
|
215 |
+
def from_pretrained(cls, name, args=None, *, home_path=None, url=None, prefix='', build_only=False, use_node_group=True, overwrite_args={}, **kwargs):
|
216 |
+
if build_only or 'model_parallel_size' not in overwrite_args:
|
217 |
+
return cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=build_only, overwrite_args=overwrite_args, **kwargs)
|
218 |
+
else:
|
219 |
+
new_model_parallel_size = overwrite_args['model_parallel_size']
|
220 |
+
if new_model_parallel_size != 1 or new_model_parallel_size == 1 and args.model_parallel_size == 1:
|
221 |
+
model, model_args = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=True, overwrite_args=overwrite_args, **kwargs)
|
222 |
+
local_rank = get_node_rank() if use_node_group else get_model_parallel_rank()
|
223 |
+
world_size = torch.distributed.get_world_size()
|
224 |
+
assert world_size % new_model_parallel_size == 0, "world size should be a multiplier of new model_parallel_size."
|
225 |
+
destroy_model_parallel()
|
226 |
+
initialize_model_parallel(1)
|
227 |
+
if local_rank == 0:
|
228 |
+
args.skip_init = True
|
229 |
+
args.use_gpu_initialization = False
|
230 |
+
args.device = 'cpu'
|
231 |
+
overwrite_args.pop('model_parallel_size')
|
232 |
+
model_full, args_ = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=False, overwrite_args=overwrite_args, **kwargs)
|
233 |
+
if args_.model_parallel_size != 1:
|
234 |
+
raise Exception("We do not support overwriting model_parallel_size when original model_parallel_size != 1. Try merging the model using `from_pretrained(xxx,overwrite_args={'model_parallel_size':1})` first if you still want to change model_parallel_size!")
|
235 |
+
if hasattr(args, 'mode') and args.mode == 'inference': # For multi-node inference, we should prevent rank 0 eagerly printing some info.
|
236 |
+
torch.distributed.barrier()
|
237 |
+
destroy_model_parallel()
|
238 |
+
initialize_model_parallel(new_model_parallel_size)
|
239 |
+
if local_rank == 0:
|
240 |
+
mp_split_model_rank0(model, model_full, use_node_group=use_node_group)
|
241 |
+
del model_full
|
242 |
+
else:
|
243 |
+
mp_split_model_receive(model, use_node_group=use_node_group)
|
244 |
+
reset_random_seed(6)
|
245 |
+
else:
|
246 |
+
overwrite_args.pop('model_parallel_size')
|
247 |
+
model, model_args = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=False, overwrite_args=overwrite_args, **kwargs)
|
248 |
+
rank = torch.distributed.get_rank()
|
249 |
+
world_size = torch.distributed.get_world_size()
|
250 |
+
assert world_size == model_args.model_parallel_size, "world size should be equal to model_parallel_size."
|
251 |
+
destroy_model_parallel()
|
252 |
+
initialize_model_parallel(1)
|
253 |
+
if rank == 0:
|
254 |
+
args.use_gpu_initialization = False
|
255 |
+
args.device = 'cpu'
|
256 |
+
overwrite_args['model_parallel_size'] = 1
|
257 |
+
model_full, args_ = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=True, overwrite_args=overwrite_args, **kwargs)
|
258 |
+
torch.distributed.barrier()
|
259 |
+
destroy_model_parallel()
|
260 |
+
initialize_model_parallel(model_args.model_parallel_size)
|
261 |
+
if rank == 0:
|
262 |
+
mp_merge_model_rank0(model, model_full)
|
263 |
+
model, model_args = model_full, args_
|
264 |
+
else:
|
265 |
+
mp_merge_model_send(model)
|
266 |
+
model_args.model_parallel_size = 1
|
267 |
+
destroy_model_parallel()
|
268 |
+
initialize_model_parallel(1)
|
269 |
+
return model, model_args
|
270 |
+
|
271 |
+
@classmethod
|
272 |
+
def list_avail_args(cls, print=True):
|
273 |
+
'''List all available args of the current model.'''
|
274 |
+
parser = argparse.ArgumentParser()
|
275 |
+
from sat.arguments import add_model_config_args
|
276 |
+
add_model_config_args(parser)
|
277 |
+
# add args of the current model
|
278 |
+
if hasattr(cls, 'add_model_specific_args'):
|
279 |
+
cls.add_model_specific_args(parser)
|
280 |
+
if print:
|
281 |
+
from sat.helpers import print_parser
|
282 |
+
print_parser(parser)
|
283 |
+
return parser
|
284 |
+
|
285 |
+
@classmethod
|
286 |
+
def get_args(cls, **kwargs):
|
287 |
+
'''Get the parsed args of the current model.
|
288 |
+
Args:
|
289 |
+
**kwargs: will override the default args.
|
290 |
+
Returns:
|
291 |
+
args: the parsed args.
|
292 |
+
'''
|
293 |
+
parser = cls.list_avail_args(print=False)
|
294 |
+
# use parser to parse kwargs
|
295 |
+
args = parser.parse_args([])
|
296 |
+
for k, v in kwargs.items():
|
297 |
+
if hasattr(args, k) or k in ['fp16']: # non-arch args but affect building models
|
298 |
+
setattr(args, k, v)
|
299 |
+
else:
|
300 |
+
print_rank0(f'warning: Unknown arg {k} for class {cls.__name__}.', level='DEBUG')
|
301 |
+
setattr(args, k, v)
|
302 |
+
return args
|
303 |
+
|
304 |
+
class AutoModel():
|
305 |
+
@classmethod
|
306 |
+
def from_pretrained_base(cls, name, args=None, *, home_path=None, url=None, prefix='', build_only=False, overwrite_args={}, **kwargs):
|
307 |
+
'''Automatically find the class and instantiate it. Auto-download.
|
308 |
+
Args:
|
309 |
+
name: The identifier of the pretrained model.
|
310 |
+
args: NameSpace. will add the loaded args into it.
|
311 |
+
path: the parent folder of existing `name` model. Default: SAT_HOME.
|
312 |
+
url: manually specified url for the `name` model.
|
313 |
+
'''
|
314 |
+
if os.path.exists(name) and os.path.isdir(name):
|
315 |
+
model_path = name
|
316 |
+
else:
|
317 |
+
model_path = auto_create(name, path=home_path, url=url)
|
318 |
+
if args is None:
|
319 |
+
args = argparse.Namespace() # null, fill later
|
320 |
+
null_args = True
|
321 |
+
else:
|
322 |
+
null_args = False
|
323 |
+
args = update_args_with_file(args, path=os.path.join(model_path, 'model_config.json'))
|
324 |
+
args = overwrite_args_by_dict(args, overwrite_args=overwrite_args)
|
325 |
+
if not hasattr(args, 'model_class'):
|
326 |
+
raise ValueError('model_config.json must have key "model_class" for AutoModel.from_pretrained.')
|
327 |
+
model_cls = model_registry.get(args.model_class)
|
328 |
+
if null_args:
|
329 |
+
# fill args with default values, if not provided
|
330 |
+
model_default_args = model_cls.get_args()
|
331 |
+
for k, v in model_default_args.__dict__.items():
|
332 |
+
if not hasattr(args, k):
|
333 |
+
setattr(args, k, v)
|
334 |
+
model = get_model(args, model_cls, **kwargs)
|
335 |
+
if not build_only:
|
336 |
+
load_checkpoint(model, args, load_path=model_path, prefix=prefix)
|
337 |
+
return model, args
|
338 |
+
|
339 |
+
@classmethod
|
340 |
+
def from_pretrained(cls, name, args=None, *, home_path=None, url=None, prefix='', build_only=False, use_node_group=True, overwrite_args={}, **kwargs):
|
341 |
+
if build_only or 'model_parallel_size' not in overwrite_args:
|
342 |
+
return cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=build_only, overwrite_args=overwrite_args, **kwargs)
|
343 |
+
else:
|
344 |
+
new_model_parallel_size = overwrite_args['model_parallel_size']
|
345 |
+
if new_model_parallel_size != 1 or new_model_parallel_size == 1 and args.model_parallel_size == 1:
|
346 |
+
model, model_args = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=True, overwrite_args=overwrite_args, **kwargs)
|
347 |
+
local_rank = get_node_rank() if use_node_group else get_model_parallel_rank()
|
348 |
+
world_size = torch.distributed.get_world_size()
|
349 |
+
assert world_size % new_model_parallel_size == 0, "world size should be a multiplier of new model_parallel_size."
|
350 |
+
destroy_model_parallel()
|
351 |
+
initialize_model_parallel(1)
|
352 |
+
if local_rank == 0:
|
353 |
+
args.skip_init = True
|
354 |
+
args.use_gpu_initialization = False
|
355 |
+
args.device = 'cpu'
|
356 |
+
overwrite_args.pop('model_parallel_size')
|
357 |
+
model_full, args_ = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=False, overwrite_args=overwrite_args, **kwargs)
|
358 |
+
if args_.model_parallel_size != 1:
|
359 |
+
raise Exception("We do not support overwriting model_parallel_size when original model_parallel_size != 1. Try merging the model using `from_pretrained(xxx,overwrite_args={'model_parallel_size':1})` first if you still want to change model_parallel_size!")
|
360 |
+
if hasattr(args, 'mode') and args.mode == 'inference': # For multi-node inference, we should prevent rank 0 eagerly printing some info.
|
361 |
+
torch.distributed.barrier()
|
362 |
+
destroy_model_parallel()
|
363 |
+
initialize_model_parallel(new_model_parallel_size)
|
364 |
+
if local_rank == 0:
|
365 |
+
mp_split_model_rank0(model, model_full, use_node_group=use_node_group)
|
366 |
+
del model_full
|
367 |
+
else:
|
368 |
+
mp_split_model_receive(model, use_node_group=use_node_group)
|
369 |
+
reset_random_seed(6)
|
370 |
+
else:
|
371 |
+
overwrite_args.pop('model_parallel_size')
|
372 |
+
model, model_args = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=False, overwrite_args=overwrite_args, **kwargs)
|
373 |
+
rank = torch.distributed.get_rank()
|
374 |
+
world_size = torch.distributed.get_world_size()
|
375 |
+
assert world_size == model_args.model_parallel_size, "world size should be equal to model_parallel_size."
|
376 |
+
destroy_model_parallel()
|
377 |
+
initialize_model_parallel(1)
|
378 |
+
if rank == 0:
|
379 |
+
args.use_gpu_initialization = False
|
380 |
+
args.device = 'cpu'
|
381 |
+
overwrite_args['model_parallel_size'] = 1
|
382 |
+
model_full, args_ = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=True, overwrite_args=overwrite_args, **kwargs)
|
383 |
+
torch.distributed.barrier()
|
384 |
+
destroy_model_parallel()
|
385 |
+
initialize_model_parallel(model_args.model_parallel_size)
|
386 |
+
if rank == 0:
|
387 |
+
mp_merge_model_rank0(model, model_full)
|
388 |
+
model, model_args = model_full, args_
|
389 |
+
else:
|
390 |
+
mp_merge_model_send(model)
|
391 |
+
model_args.model_parallel_size = 1
|
392 |
+
destroy_model_parallel()
|
393 |
+
initialize_model_parallel(1)
|
394 |
+
return model, model_args
|
395 |
+
|
396 |
+
def get_model(args, model_cls, **kwargs):
|
397 |
+
"""Build the model."""
|
398 |
+
import torch
|
399 |
+
from sat.helpers import print_rank0,print_all
|
400 |
+
from sat import mpu
|
401 |
+
|
402 |
+
print_rank0(f'building {model_cls.__name__} model ...')
|
403 |
+
if 'params_dtype' not in kwargs:
|
404 |
+
if hasattr(args, 'fp16') and args.fp16:
|
405 |
+
params_dtype = torch.half
|
406 |
+
elif hasattr(args, 'bf16') and args.bf16:
|
407 |
+
params_dtype = torch.bfloat16
|
408 |
+
else:
|
409 |
+
params_dtype = torch.float32
|
410 |
+
else:
|
411 |
+
# pop params_dtype from kwargs
|
412 |
+
params_dtype = kwargs.pop('params_dtype')
|
413 |
+
|
414 |
+
from sat.helpers import check_if_zero3
|
415 |
+
if check_if_zero3(args):
|
416 |
+
import deepspeed
|
417 |
+
with deepspeed.zero.Init():
|
418 |
+
model = model_cls(args, params_dtype=params_dtype, **kwargs)
|
419 |
+
else:
|
420 |
+
model = model_cls(args, params_dtype=params_dtype, **kwargs)
|
421 |
+
|
422 |
+
if mpu.get_data_parallel_rank() == 0:
|
423 |
+
print_all(' > number of parameters on model parallel rank {}: {}'.format(
|
424 |
+
mpu.get_model_parallel_rank(),
|
425 |
+
sum([p.nelement() for p in model.parameters()])), flush=True)
|
426 |
+
|
427 |
+
if hasattr(args, 'fp16') and args.fp16:
|
428 |
+
model.half()
|
429 |
+
elif hasattr(args, 'bf16') and args.bf16:
|
430 |
+
model.bfloat16()
|
431 |
+
|
432 |
+
try: # TODO: is this useful?
|
433 |
+
if not hasattr(args, 'device'):
|
434 |
+
args.device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
|
435 |
+
model = model.to(args.device)
|
436 |
+
except Exception as e:
|
437 |
+
print_all(e)
|
438 |
+
|
439 |
+
return model
|