|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from abc import abstractmethod |
|
from collections.abc import Callable, Mapping |
|
from dataclasses import dataclass |
|
from typing import Protocol, TypeAlias |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.distributed.pipelining.schedules import _PipelineSchedule |
|
|
|
from torchtitan.components.dataloader import BaseDataLoader |
|
from torchtitan.components.ft import FTManager |
|
from torchtitan.components.loss import LossFunction |
|
from torchtitan.components.lr_scheduler import LRSchedulersContainer |
|
from torchtitan.components.metrics import MetricsProcessor |
|
from torchtitan.components.optimizer import OptimizersContainer |
|
from torchtitan.components.tokenizer import Tokenizer |
|
from torchtitan.config_manager import JobConfig |
|
|
|
DeviceType = int | str | torch.device |
|
|
|
|
|
@dataclass |
|
class BaseModelArgs: |
|
"""All ModelArgs should inherit from this class. |
|
|
|
The only usage of this class is type checking but allows us to extend common |
|
arguments to all models in the future. |
|
""" |
|
|
|
_enforced: str = "This field is used to enforce all fields have defaults." |
|
|
|
@abstractmethod |
|
def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None: |
|
pass |
|
|
|
@abstractmethod |
|
def get_nparams_and_flops( |
|
self, model: nn.Module, seq_len: int |
|
) -> tuple[int, float]: |
|
pass |
|
|
|
|
|
class ModelProtocol(Protocol): |
|
"""Defines the interface for a model class. |
|
|
|
This is used to enforce that all model classes have some methods that are |
|
required by the TorchTitan trainer. |
|
""" |
|
|
|
@classmethod |
|
def from_model_args(cls, args: BaseModelArgs) -> nn.Module: |
|
... |
|
|
|
|
|
ParallelizeFunction: TypeAlias = Callable[..., nn.Module] |
|
PipeliningFunction: TypeAlias = Callable[ |
|
..., tuple[_PipelineSchedule, list[nn.Module], bool, bool] |
|
] |
|
DataLoaderBuilder: TypeAlias = Callable[..., BaseDataLoader] |
|
TokenizerBuilder: TypeAlias = Callable[..., Tokenizer] |
|
MetricsProcessorBuilder: TypeAlias = Callable[..., MetricsProcessor] |
|
OptimizersBuilder: TypeAlias = Callable[ |
|
[list[nn.Module], JobConfig, FTManager], OptimizersContainer |
|
] |
|
LRSchedulersBuilder: TypeAlias = Callable[ |
|
[OptimizersContainer, JobConfig], LRSchedulersContainer |
|
] |
|
LossFunctionBuilder: TypeAlias = Callable[..., LossFunction] |
|
|
|
|
|
@dataclass |
|
class TrainSpec: |
|
name: str |
|
cls: type[nn.Module] |
|
config: Mapping[str, BaseModelArgs] |
|
parallelize_fn: ParallelizeFunction |
|
pipelining_fn: PipeliningFunction | None |
|
build_optimizers_fn: OptimizersBuilder |
|
build_lr_schedulers_fn: LRSchedulersBuilder |
|
build_dataloader_fn: DataLoaderBuilder |
|
build_tokenizer_fn: TokenizerBuilder | None |
|
build_loss_fn: LossFunctionBuilder |
|
build_metrics_processor_fn: MetricsProcessorBuilder | None = None |
|
|
|
|
|
_train_specs = {} |
|
|
|
|
|
def register_train_spec(train_spec: TrainSpec) -> None: |
|
global _train_specs |
|
if train_spec.name in _train_specs: |
|
raise ValueError(f"Model {train_spec.name} is already registered.") |
|
|
|
_train_specs[train_spec.name] = train_spec |
|
|
|
|
|
def get_train_spec(name: str) -> TrainSpec: |
|
global _train_specs |
|
if name not in _train_specs: |
|
raise ValueError(f"Model {name} is not registered.") |
|
return _train_specs[name] |
|
|
|
|
|
def apply_to_train_specs(func: Callable[[TrainSpec], TrainSpec]) -> None: |
|
global _train_specs |
|
for name, train_spec in _train_specs.items(): |
|
_train_specs[name] = func(train_spec) |
|
|