# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from torchtitan.models.llama3 import Transformer, TransformerModelArgs from .simple_fsdp import disable_data_parallel class SimpleFSDPTransformer(Transformer): def __init__(self, model_args: TransformerModelArgs): super().__init__(model_args) self.init_weights() def init_weights(self, *args, **kwargs): with disable_data_parallel(): super().init_weights(*args, **kwargs)