from abc import ABC, abstractmethod from typing import Any, Dict, Sequence import torch class DiffCase(ABC): @abstractmethod def build_inputs(self, hidden: int, bs: int, sl: int, dtype: torch.dtype, eps: float) -> Dict[str, Any]: ... @abstractmethod def make_naive(self, I: Dict[str, Any]) -> Any: ... @abstractmethod def make_cuda(self, I: Dict[str, Any]) -> Any: ... @abstractmethod def forward(self, obj: Any, I: Dict[str, Any]) -> torch.Tensor: ... @abstractmethod def grad_inputs(self, I: Dict[str, Any]) -> Sequence[torch.Tensor]: ... def _clone_payload(d, device): out = {} for k, v in d.items(): if isinstance(v, torch.Tensor): t = v.detach().clone().to(device) t.requires_grad_(v.requires_grad) out[k] = t else: out[k] = v return out def _unit_grad_like(y): g = torch.randn_like(y) n = g.norm() return g if n == 0 else g / n def calculate_diff( case: DiffCase, *, batch_size: int, seq_len: int, hidden_size: int, dtype=torch.bfloat16, eps: float = 1e-6, atol: float = 1e-2, rtol: float = 1e-2, device="cuda", ) -> None: base = case.build_inputs(hidden_size, batch_size, seq_len, dtype, eps) I_n = _clone_payload(base, device) I_c = _clone_payload(base, device) obj_n = case.make_naive(I_n) obj_c = case.make_cuda(I_c) y_n = case.forward(obj_n, I_n) y_c = case.forward(obj_c, I_c) torch.testing.assert_close(y_n, y_c, atol=atol, rtol=rtol) gin_n = list(case.grad_inputs(I_n)) + list(obj_n.parameters()) gin_c = list(case.grad_inputs(I_c)) + list(obj_c.parameters()) if isinstance(y_n, torch.Tensor): g = [_unit_grad_like(y_n).to(device)] else: g = [_unit_grad_like(r).to(device) for r in y_n] ng = torch.autograd.grad(y_n, gin_n, g, retain_graph=False, create_graph=False, allow_unused=False) cg = torch.autograd.grad(y_c, gin_c, g, retain_graph=False, create_graph=False, allow_unused=False) torch.testing.assert_close(ng, cg, atol=atol, rtol=rtol) print("✅ forward + backward match")