zjuJish commited on
Commit
2d42aa2
·
verified ·
1 Parent(s): 5655eae

Upload base_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. base_model.py +439 -0
base_model.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model.py# -*- encoding: utf-8 -*-
2
+ '''
3
+ @File : base_model.py
4
+ @Time : 2021/10/01 22:40:33
5
+ @Author : Ming Ding
6
+ @Contact : [email protected]
7
+ '''
8
+
9
+ # here put the import lib
10
+ from functools import partial
11
+ import os
12
+ import sys
13
+ import math
14
+ import random
15
+ import torch
16
+ import inspect
17
+ import warnings
18
+ import argparse
19
+ from sat.model.registry import model_registry, MetaModel
20
+
21
+ from sat.model.transformer import BaseTransformer, standard_attention
22
+ from sat.arguments import update_args_with_file, overwrite_args_by_dict, set_random_seed
23
+ from sat.training.model_io import load_checkpoint
24
+ from sat.helpers import print_rank0
25
+
26
+ from sat.transformer_defaults import HOOKS_DEFAULT, ARGS_DEFAULT
27
+ from sat.resources import auto_create
28
+ from sat.mpu.initialize import get_node_rank, get_model_parallel_rank, destroy_model_parallel, initialize_model_parallel
29
+ from sat.mpu.operation import mp_split_model_rank0, mp_split_model_receive, mp_merge_model_rank0, mp_merge_model_send
30
+ from sat.arguments import reset_random_seed
31
+
32
+ def non_conflict(func):
33
+ '''mark a hook function as non-conflict,
34
+ so that it can be compatible with any already defined hooks.
35
+ e.g. PrefixTuningMixin.attention_fn
36
+ '''
37
+ func.non_conflict = True
38
+ return func
39
+
40
+ def replacable(func):
41
+ '''mark a hook function as replacable,
42
+ so that it can be replaced by mixins added after it.
43
+ e.g. FP32AttentionMixin.attention_fn
44
+ '''
45
+ func.replacable = True
46
+ return func
47
+
48
+ class BaseMixin(torch.nn.Module):
49
+ non_conflict = non_conflict
50
+ replacable = replacable
51
+ def __init__(self):
52
+ super(BaseMixin, self).__init__()
53
+ # define new params
54
+
55
+ def reinit(self, parent_model=None):
56
+ # reload the initial params from previous trained modules
57
+ # you can also get access to other mixins through parent_model.get_mixin().
58
+ pass
59
+
60
+ # can define hook-functions here
61
+ # a hook, if default or replacable, can be overrided by mixins added after it.
62
+ # a hook can be augmented by non_conflict hooks added after it.
63
+ # default -> 0~n replacable -> 0~n non_conflict
64
+ # ...
65
+
66
+ # If the hook is just a pre- or post- transformation,
67
+ # You can use @non_conflict to mark it,
68
+ # and run `old_impl` to make it compatible with other mixins.
69
+ # Eg.,
70
+ #
71
+ # @non_conflict
72
+ # def attention_fn(q, k, v, mask, dropout_fn, old_impl=standard_attention, **kw_args):
73
+ # new_q, new_k, new_v = pre_hack(q, k, v)
74
+ # attn_result = old_impl(q, k, v, mask, dropout_fn, **kw_args)
75
+ # attn_result = post_hack(attn_result)
76
+ # return attn_result
77
+
78
+
79
+ class BaseModel(torch.nn.Module, metaclass=MetaModel):
80
+ def __init__(self, args, transformer=None, params_dtype=torch.float, **kwargs):
81
+ super(BaseModel, self).__init__()
82
+ self.mixins = torch.nn.ModuleDict()
83
+ self.collect_hooks_()
84
+ if transformer is not None:
85
+ self.transformer = transformer
86
+ else:
87
+ # check if model-only mode
88
+ from sat.arguments import _simple_init
89
+ success = _simple_init(model_parallel_size=args.model_parallel_size, seed=args.seed if hasattr(args, 'seed') else 1234)
90
+
91
+ args_dict = {k: (getattr(args, v[0]) if hasattr(args, v[0]) else v[1]) for k, v in ARGS_DEFAULT.items()}
92
+
93
+ self.transformer = BaseTransformer(
94
+ num_layers=args.num_layers,
95
+ vocab_size=args.vocab_size,
96
+ hidden_size=args.hidden_size,
97
+ num_attention_heads=args.num_attention_heads,
98
+ max_sequence_length=args.max_sequence_length,
99
+ layernorm_order=args.layernorm_order,
100
+ **args_dict,
101
+ hooks=self.hooks,
102
+ params_dtype=params_dtype,
103
+ skip_init=args.skip_init,
104
+ device=torch.cuda.current_device() if hasattr(args, 'use_gpu_initialization') and args.use_gpu_initialization else torch.device('cpu'),
105
+ **kwargs
106
+ )
107
+
108
+ def reinit(self, mixin_names=None): # will be called when loading model, None means all
109
+ # if some mixins are loaded, overrides this function
110
+ for k, m in self.mixins.items():
111
+ if mixin_names is None or k in mixin_names:
112
+ m.reinit(self)
113
+
114
+ def add_mixin(self, name, new_mixin, reinit=False):
115
+ assert name not in self.mixins
116
+ assert isinstance(new_mixin, BaseMixin)
117
+
118
+ self.mixins[name] = new_mixin # will auto-register parameters
119
+ object.__setattr__(new_mixin, 'transformer', self.transformer) # cannot use pytorch set_attr
120
+
121
+ self.collect_hooks_()
122
+ if reinit:
123
+ new_mixin.reinit(self) # also pass current mixins
124
+
125
+ def del_mixin(self, name):
126
+ assert name in self.mixins
127
+ del self.mixins[name]
128
+ self.collect_hooks_()
129
+
130
+ def get_mixin(self, name):
131
+ return self.mixins[name]
132
+
133
+ def forward(self, *args, **kwargs):
134
+ # update hooks as the current model (overrided forwards)
135
+ # Attention! the transformer might be shared by multiple models
136
+ self.transformer.hooks.clear()
137
+ self.transformer.hooks.update(self.hooks)
138
+ return self.transformer(*args, **kwargs)
139
+
140
+ def collect_hooks_(self):
141
+ names = list(HOOKS_DEFAULT.keys())
142
+ hooks = {}
143
+ hook_origins = {}
144
+ for name in names:
145
+ if hasattr(self, name):
146
+ hooks[name] = getattr(self, name)
147
+ hook_origins[name] = 'model'
148
+
149
+ for mixin_name, m in self.mixins.items():
150
+ if hasattr(m, name):
151
+ if hasattr(getattr(m, name), 'non_conflict'):
152
+ # check getattr(m, name), who must accept old_impl as an argument
153
+ signature = inspect.signature(getattr(m, name))
154
+ if 'old_impl' not in signature.parameters:
155
+ raise ValueError(f'Hook {name} at {mixin_name} must accept old_impl as an argument.')
156
+ # -------------
157
+ if name in hooks:
158
+ old_impl = hooks[name]
159
+ elif name == 'attention_fn': # the only hook without self
160
+ old_impl = HOOKS_DEFAULT[name]
161
+ else:
162
+ old_impl = partial(HOOKS_DEFAULT[name], self) # relax! `partial` does not affect the signature
163
+ old_origin = hook_origins.get(name, 'default')
164
+ hooks[name] = partial(getattr(m, name), old_impl=old_impl)
165
+ hook_origins[name] = mixin_name + ' -> ' + old_origin
166
+ elif name in hooks and not hasattr(hooks[name], 'replacable'): # if this hook name is already registered
167
+ raise ValueError(f'Hook {name} conflicts at {mixin_name} and {hook_origins[name]}.')
168
+ else: # new hook
169
+ if name in hooks and hasattr(hooks[name], 'replacable'):
170
+ warnings.warn(f'Hook {name} at {mixin_name} replaces {hook_origins[name]}.')
171
+ hooks[name] = getattr(m, name)
172
+ hook_origins[name] = mixin_name
173
+
174
+ self.hooks = hooks
175
+ self.hook_origins = hook_origins
176
+ return hooks
177
+
178
+ def disable_untrainable_params(self):
179
+ pass
180
+
181
+ @classmethod
182
+ def add_model_specific_args(cls, parser):
183
+ # recorded in arguments.py: add_model_config_args
184
+ return parser
185
+
186
+ @classmethod
187
+ def from_pretrained_base(cls, name, args=None, *, home_path=None, url=None, prefix='', build_only=False, overwrite_args={}, **kwargs):
188
+ '''Load a pretrained checkpoint of the current model.
189
+ Args:
190
+ name: The identifier of the pretrained model.
191
+ args: NameSpace. will add the loaded args into it. None will create a new model-only one with defaults.
192
+ path: the parent folder of existing `name` model. Default: SAT_HOME.
193
+ url: the url of the model. Default: SAT_URL.
194
+ prefix: the prefix of the checkpoint. Default: ''.
195
+ Returns:
196
+ model: the loaded model.
197
+ args: the loaded args.
198
+ '''
199
+ if os.path.exists(name) and os.path.isdir(name):
200
+ model_path = name
201
+ else:
202
+ model_path = auto_create(name, path=home_path, url=url)
203
+ # create a new args if not provided
204
+ if args is None:
205
+ args = cls.get_args()
206
+ args = update_args_with_file(args, path=os.path.join(model_path, 'model_config.json'))
207
+ args = overwrite_args_by_dict(args, overwrite_args=overwrite_args)
208
+ specific_iteration = kwargs.pop('specific_iteration', None)
209
+ model = get_model(args, cls, **kwargs)
210
+ if not build_only:
211
+ load_checkpoint(model, args, load_path=model_path, prefix=prefix, specific_iteration=specific_iteration)
212
+ return model, args
213
+
214
+ @classmethod
215
+ def from_pretrained(cls, name, args=None, *, home_path=None, url=None, prefix='', build_only=False, use_node_group=True, overwrite_args={}, **kwargs):
216
+ if build_only or 'model_parallel_size' not in overwrite_args:
217
+ return cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=build_only, overwrite_args=overwrite_args, **kwargs)
218
+ else:
219
+ new_model_parallel_size = overwrite_args['model_parallel_size']
220
+ if new_model_parallel_size != 1 or new_model_parallel_size == 1 and args.model_parallel_size == 1:
221
+ model, model_args = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=True, overwrite_args=overwrite_args, **kwargs)
222
+ local_rank = get_node_rank() if use_node_group else get_model_parallel_rank()
223
+ world_size = torch.distributed.get_world_size()
224
+ assert world_size % new_model_parallel_size == 0, "world size should be a multiplier of new model_parallel_size."
225
+ destroy_model_parallel()
226
+ initialize_model_parallel(1)
227
+ if local_rank == 0:
228
+ args.skip_init = True
229
+ args.use_gpu_initialization = False
230
+ args.device = 'cpu'
231
+ overwrite_args.pop('model_parallel_size')
232
+ model_full, args_ = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=False, overwrite_args=overwrite_args, **kwargs)
233
+ if args_.model_parallel_size != 1:
234
+ raise Exception("We do not support overwriting model_parallel_size when original model_parallel_size != 1. Try merging the model using `from_pretrained(xxx,overwrite_args={'model_parallel_size':1})` first if you still want to change model_parallel_size!")
235
+ if hasattr(args, 'mode') and args.mode == 'inference': # For multi-node inference, we should prevent rank 0 eagerly printing some info.
236
+ torch.distributed.barrier()
237
+ destroy_model_parallel()
238
+ initialize_model_parallel(new_model_parallel_size)
239
+ if local_rank == 0:
240
+ mp_split_model_rank0(model, model_full, use_node_group=use_node_group)
241
+ del model_full
242
+ else:
243
+ mp_split_model_receive(model, use_node_group=use_node_group)
244
+ reset_random_seed(6)
245
+ else:
246
+ overwrite_args.pop('model_parallel_size')
247
+ model, model_args = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=False, overwrite_args=overwrite_args, **kwargs)
248
+ rank = torch.distributed.get_rank()
249
+ world_size = torch.distributed.get_world_size()
250
+ assert world_size == model_args.model_parallel_size, "world size should be equal to model_parallel_size."
251
+ destroy_model_parallel()
252
+ initialize_model_parallel(1)
253
+ if rank == 0:
254
+ args.use_gpu_initialization = False
255
+ args.device = 'cpu'
256
+ overwrite_args['model_parallel_size'] = 1
257
+ model_full, args_ = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=True, overwrite_args=overwrite_args, **kwargs)
258
+ torch.distributed.barrier()
259
+ destroy_model_parallel()
260
+ initialize_model_parallel(model_args.model_parallel_size)
261
+ if rank == 0:
262
+ mp_merge_model_rank0(model, model_full)
263
+ model, model_args = model_full, args_
264
+ else:
265
+ mp_merge_model_send(model)
266
+ model_args.model_parallel_size = 1
267
+ destroy_model_parallel()
268
+ initialize_model_parallel(1)
269
+ return model, model_args
270
+
271
+ @classmethod
272
+ def list_avail_args(cls, print=True):
273
+ '''List all available args of the current model.'''
274
+ parser = argparse.ArgumentParser()
275
+ from sat.arguments import add_model_config_args
276
+ add_model_config_args(parser)
277
+ # add args of the current model
278
+ if hasattr(cls, 'add_model_specific_args'):
279
+ cls.add_model_specific_args(parser)
280
+ if print:
281
+ from sat.helpers import print_parser
282
+ print_parser(parser)
283
+ return parser
284
+
285
+ @classmethod
286
+ def get_args(cls, **kwargs):
287
+ '''Get the parsed args of the current model.
288
+ Args:
289
+ **kwargs: will override the default args.
290
+ Returns:
291
+ args: the parsed args.
292
+ '''
293
+ parser = cls.list_avail_args(print=False)
294
+ # use parser to parse kwargs
295
+ args = parser.parse_args([])
296
+ for k, v in kwargs.items():
297
+ if hasattr(args, k) or k in ['fp16']: # non-arch args but affect building models
298
+ setattr(args, k, v)
299
+ else:
300
+ print_rank0(f'warning: Unknown arg {k} for class {cls.__name__}.', level='DEBUG')
301
+ setattr(args, k, v)
302
+ return args
303
+
304
+ class AutoModel():
305
+ @classmethod
306
+ def from_pretrained_base(cls, name, args=None, *, home_path=None, url=None, prefix='', build_only=False, overwrite_args={}, **kwargs):
307
+ '''Automatically find the class and instantiate it. Auto-download.
308
+ Args:
309
+ name: The identifier of the pretrained model.
310
+ args: NameSpace. will add the loaded args into it.
311
+ path: the parent folder of existing `name` model. Default: SAT_HOME.
312
+ url: manually specified url for the `name` model.
313
+ '''
314
+ if os.path.exists(name) and os.path.isdir(name):
315
+ model_path = name
316
+ else:
317
+ model_path = auto_create(name, path=home_path, url=url)
318
+ if args is None:
319
+ args = argparse.Namespace() # null, fill later
320
+ null_args = True
321
+ else:
322
+ null_args = False
323
+ args = update_args_with_file(args, path=os.path.join(model_path, 'model_config.json'))
324
+ args = overwrite_args_by_dict(args, overwrite_args=overwrite_args)
325
+ if not hasattr(args, 'model_class'):
326
+ raise ValueError('model_config.json must have key "model_class" for AutoModel.from_pretrained.')
327
+ model_cls = model_registry.get(args.model_class)
328
+ if null_args:
329
+ # fill args with default values, if not provided
330
+ model_default_args = model_cls.get_args()
331
+ for k, v in model_default_args.__dict__.items():
332
+ if not hasattr(args, k):
333
+ setattr(args, k, v)
334
+ model = get_model(args, model_cls, **kwargs)
335
+ if not build_only:
336
+ load_checkpoint(model, args, load_path=model_path, prefix=prefix)
337
+ return model, args
338
+
339
+ @classmethod
340
+ def from_pretrained(cls, name, args=None, *, home_path=None, url=None, prefix='', build_only=False, use_node_group=True, overwrite_args={}, **kwargs):
341
+ if build_only or 'model_parallel_size' not in overwrite_args:
342
+ return cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=build_only, overwrite_args=overwrite_args, **kwargs)
343
+ else:
344
+ new_model_parallel_size = overwrite_args['model_parallel_size']
345
+ if new_model_parallel_size != 1 or new_model_parallel_size == 1 and args.model_parallel_size == 1:
346
+ model, model_args = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=True, overwrite_args=overwrite_args, **kwargs)
347
+ local_rank = get_node_rank() if use_node_group else get_model_parallel_rank()
348
+ world_size = torch.distributed.get_world_size()
349
+ assert world_size % new_model_parallel_size == 0, "world size should be a multiplier of new model_parallel_size."
350
+ destroy_model_parallel()
351
+ initialize_model_parallel(1)
352
+ if local_rank == 0:
353
+ args.skip_init = True
354
+ args.use_gpu_initialization = False
355
+ args.device = 'cpu'
356
+ overwrite_args.pop('model_parallel_size')
357
+ model_full, args_ = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=False, overwrite_args=overwrite_args, **kwargs)
358
+ if args_.model_parallel_size != 1:
359
+ raise Exception("We do not support overwriting model_parallel_size when original model_parallel_size != 1. Try merging the model using `from_pretrained(xxx,overwrite_args={'model_parallel_size':1})` first if you still want to change model_parallel_size!")
360
+ if hasattr(args, 'mode') and args.mode == 'inference': # For multi-node inference, we should prevent rank 0 eagerly printing some info.
361
+ torch.distributed.barrier()
362
+ destroy_model_parallel()
363
+ initialize_model_parallel(new_model_parallel_size)
364
+ if local_rank == 0:
365
+ mp_split_model_rank0(model, model_full, use_node_group=use_node_group)
366
+ del model_full
367
+ else:
368
+ mp_split_model_receive(model, use_node_group=use_node_group)
369
+ reset_random_seed(6)
370
+ else:
371
+ overwrite_args.pop('model_parallel_size')
372
+ model, model_args = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=False, overwrite_args=overwrite_args, **kwargs)
373
+ rank = torch.distributed.get_rank()
374
+ world_size = torch.distributed.get_world_size()
375
+ assert world_size == model_args.model_parallel_size, "world size should be equal to model_parallel_size."
376
+ destroy_model_parallel()
377
+ initialize_model_parallel(1)
378
+ if rank == 0:
379
+ args.use_gpu_initialization = False
380
+ args.device = 'cpu'
381
+ overwrite_args['model_parallel_size'] = 1
382
+ model_full, args_ = cls.from_pretrained_base(name, args=args, home_path=home_path, url=url, prefix=prefix, build_only=True, overwrite_args=overwrite_args, **kwargs)
383
+ torch.distributed.barrier()
384
+ destroy_model_parallel()
385
+ initialize_model_parallel(model_args.model_parallel_size)
386
+ if rank == 0:
387
+ mp_merge_model_rank0(model, model_full)
388
+ model, model_args = model_full, args_
389
+ else:
390
+ mp_merge_model_send(model)
391
+ model_args.model_parallel_size = 1
392
+ destroy_model_parallel()
393
+ initialize_model_parallel(1)
394
+ return model, model_args
395
+
396
+ def get_model(args, model_cls, **kwargs):
397
+ """Build the model."""
398
+ import torch
399
+ from sat.helpers import print_rank0,print_all
400
+ from sat import mpu
401
+
402
+ print_rank0(f'building {model_cls.__name__} model ...')
403
+ if 'params_dtype' not in kwargs:
404
+ if hasattr(args, 'fp16') and args.fp16:
405
+ params_dtype = torch.half
406
+ elif hasattr(args, 'bf16') and args.bf16:
407
+ params_dtype = torch.bfloat16
408
+ else:
409
+ params_dtype = torch.float32
410
+ else:
411
+ # pop params_dtype from kwargs
412
+ params_dtype = kwargs.pop('params_dtype')
413
+
414
+ from sat.helpers import check_if_zero3
415
+ if check_if_zero3(args):
416
+ import deepspeed
417
+ with deepspeed.zero.Init():
418
+ model = model_cls(args, params_dtype=params_dtype, **kwargs)
419
+ else:
420
+ model = model_cls(args, params_dtype=params_dtype, **kwargs)
421
+
422
+ if mpu.get_data_parallel_rank() == 0:
423
+ print_all(' > number of parameters on model parallel rank {}: {}'.format(
424
+ mpu.get_model_parallel_rank(),
425
+ sum([p.nelement() for p in model.parameters()])), flush=True)
426
+
427
+ if hasattr(args, 'fp16') and args.fp16:
428
+ model.half()
429
+ elif hasattr(args, 'bf16') and args.bf16:
430
+ model.bfloat16()
431
+
432
+ try: # TODO: is this useful?
433
+ if not hasattr(args, 'device'):
434
+ args.device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
435
+ model = model.to(args.device)
436
+ except Exception as e:
437
+ print_all(e)
438
+
439
+ return model