import os compiled_mode = os.getenv("COMPILER_MODE") == "1" ci_env = os.getenv("CI_ENV") == "1" def get_abs_err(x, y): return (x.detach()-y.detach()).flatten().abs().max().item() def get_err_ratio(x, y): err = (x-y).flatten().square().mean().sqrt().item() base = (x).flatten().square().mean().sqrt().item() return err / (base + 1e-15) def assert_close(prefix, ref, tri, ratio, warning=False): msg = f"{prefix} diff: {get_abs_err(ref, tri):.6f} ratio: {get_err_ratio(ref, tri):.6f}" print(msg) error_rate = get_err_ratio(ref, tri) if warning or str(prefix).strip().lower() == "dh0" or (ci_env and error_rate < 0.01): if error_rate > ratio: import warnings warnings.warn(msg) else: assert error_rate < ratio, msg