|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass, field |
|
from datetime import timedelta |
|
from io import BytesIO |
|
from typing import Any, Dict, List |
|
|
|
import torch |
|
from torch.distributed.checkpoint.stateful import Stateful |
|
|
|
|
|
@dataclass |
|
class TrainState(Stateful): |
|
step: int = 0 |
|
skipped_step: int = 0 |
|
token: int = 0 |
|
elapsed: timedelta = timedelta(0) |
|
global_avg_losses: List[float] = field(default_factory=list) |
|
global_max_losses: List[float] = field(default_factory=list) |
|
log_steps: List[int] = field(default_factory=list) |
|
|
|
def state_dict(self) -> Dict[str, Any]: |
|
|
|
|
|
global_avg_losses_bytes = BytesIO() |
|
torch.save(self.global_avg_losses, global_avg_losses_bytes) |
|
global_max_losses_bytes = BytesIO() |
|
torch.save(self.global_max_losses, global_max_losses_bytes) |
|
log_steps_bytes = BytesIO() |
|
torch.save(self.log_steps, log_steps_bytes) |
|
return { |
|
"step": torch.tensor(self.step, dtype=torch.int32), |
|
"skipped_step": torch.tensor(self.skipped_step, dtype=torch.int32), |
|
"token": torch.tensor(self.token, dtype=torch.int64), |
|
"elapsed": self.elapsed, |
|
"global_avg_losses": global_avg_losses_bytes, |
|
"global_max_losses": global_max_losses_bytes, |
|
"log_steps": log_steps_bytes, |
|
} |
|
|
|
def load_state_dict(self, state_dict) -> None: |
|
self.step = state_dict["step"].item() |
|
self.skipped_step = state_dict.get("skipped_step", 0).item() |
|
self.token = state_dict["token"].item() |
|
self.elapsed = state_dict["elapsed"] |
|
state_dict["global_avg_losses"].seek(0) |
|
self.global_avg_losses = torch.load( |
|
state_dict["global_avg_losses"], weights_only=False |
|
) |
|
state_dict["global_max_losses"].seek(0) |
|
self.global_max_losses = torch.load( |
|
state_dict["global_max_losses"], weights_only=False |
|
) |
|
state_dict["log_steps"].seek(0) |
|
self.log_steps = torch.load(state_dict["log_steps"], weights_only=False) |
|
|