| import tempfile | |
| import pathlib | |
| import torch | |
| class ATensor(torch.Tensor): | |
| pass | |
| def test_lazy_load_basic(lit_llama): | |
| import lit_llama.utils | |
| with tempfile.TemporaryDirectory() as tmpdirname: | |
| m = torch.nn.Linear(5, 3) | |
| path = pathlib.Path(tmpdirname) | |
| fn = str(path / "test.pt") | |
| torch.save(m.state_dict(), fn) | |
| with lit_llama.utils.lazy_load(fn) as sd_lazy: | |
| assert "NotYetLoadedTensor" in str(next(iter(sd_lazy.values()))) | |
| m2 = torch.nn.Linear(5, 3) | |
| m2.load_state_dict(sd_lazy) | |
| x = torch.randn(2, 5) | |
| actual = m2(x) | |
| expected = m(x) | |
| torch.testing.assert_close(actual, expected) | |
| def test_lazy_load_subclass(lit_llama): | |
| import lit_llama.utils | |
| with tempfile.TemporaryDirectory() as tmpdirname: | |
| path = pathlib.Path(tmpdirname) | |
| fn = str(path / "test.pt") | |
| t = torch.randn(2, 3)[:, 1:] | |
| sd = { | |
| 1: t, | |
| 2: torch.nn.Parameter(t), | |
| 3: torch.Tensor._make_subclass(ATensor, t), | |
| } | |
| torch.save(sd, fn) | |
| with lit_llama.utils.lazy_load(fn) as sd_lazy: | |
| for k in sd.keys(): | |
| actual = sd_lazy[k] | |
| expected = sd[k] | |
| torch.testing.assert_close(actual._load_tensor(), expected) | |
| def test_incremental_write(tmp_path, lit_llama): | |
| import lit_llama.utils | |
| sd = {str(k): torch.randn(5, 10) for k in range(3)} | |
| sd_expected = {k: v.clone() for k, v in sd.items()} | |
| fn = str(tmp_path / "test.pt") | |
| with lit_llama.utils.incremental_save(fn) as f: | |
| sd["0"] = f.store_early(sd["0"]) | |
| sd["2"] = f.store_early(sd["2"]) | |
| f.save(sd) | |
| sd_actual = torch.load(fn) | |
| assert sd_actual.keys() == sd_expected.keys() | |
| for k, v_expected in sd_expected.items(): | |
| v_actual = sd_actual[k] | |
| torch.testing.assert_close(v_expected, v_actual) | |
| def test_find_multiple(lit_llama): | |
| from lit_llama.utils import find_multiple | |
| assert find_multiple(17, 5) == 20 | |
| assert find_multiple(30, 7) == 35 | |
| assert find_multiple(10, 2) == 10 | |
| assert find_multiple(5, 10) == 10 | |