|
def extend_instance(obj, mixin): |
|
"""Apply mixins to a class instance after creation""" |
|
base_cls = obj.__class__ |
|
base_cls_name = obj.__class__.__name__ |
|
obj.__class__ = type( |
|
base_cls_name, (mixin, base_cls), {} |
|
) |
|
|
|
|
|
def getattr_recursive(obj, att): |
|
""" |
|
Return nested attribute of obj |
|
Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c |
|
""" |
|
if att == "": |
|
return obj |
|
i = att.find(".") |
|
if i < 0: |
|
return getattr(obj, att) |
|
else: |
|
return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :]) |
|
|
|
|
|
def setattr_recursive(obj, att, val): |
|
""" |
|
Set nested attribute of obj |
|
Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val |
|
""" |
|
if "." in att: |
|
obj = getattr_recursive(obj, ".".join(att.split(".")[:-1])) |
|
setattr(obj, att.split(".")[-1], val) |
|
|
|
|
|
def apply_with_stopping_condition( |
|
module, apply_fn, apply_condition=None, stopping_condition=None, **other_args |
|
): |
|
if stopping_condition(module): |
|
return |
|
if apply_condition(module): |
|
apply_fn(module, **other_args) |
|
for child in module.children(): |
|
apply_with_stopping_condition( |
|
child, |
|
apply_fn, |
|
apply_condition=apply_condition, |
|
stopping_condition=stopping_condition, |
|
**other_args |
|
) |
|
|
|
__KNOWN_DECODER_LAYERS_ATTR_NAMES = { |
|
"opt": "model.decoder.layers", |
|
"gptj": "transformer.h", |
|
"gpt-j": "transformer.h", |
|
"pythia": "gpt_neox.layers", |
|
"llama": "model.layers", |
|
"gptneoxforcausallm": "gpt_neox.layers", |
|
"mpt": "transformer.blocks", |
|
"mosaicgpt": "transformer.blocks", |
|
"internlm2forcausallm": "model.layers", |
|
} |
|
|
|
def _infer_decoder_layers_attr_name(model): |
|
for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES: |
|
if k.lower() in model.__class__.__name__.lower(): |
|
return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k] |
|
|
|
raise ValueError( |
|
f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually." |
|
) |
|
|