Post
1266
Did you know how simple it was to get started with your own custom compiler backend with
--------------
Part of https://huggingface.co/posts/a-r-r-o-w/231008365980283
torch.compile
? What's stopping you from writing your own compiler?import torch
from torch._functorch.partitioners import draw_graph
def compiler(fx_module: torch.fx.GraphModule, _):
draw_graph(fx_module, f"compile.dot")
return fx_module.forward
def capture(model, *inputs):
compiled_model = torch.compile(model, backend=compiler)
y = compiled_model(*inputs)
y.sum().backward()
class MLP(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear_1 = torch.nn.Linear(16, 32)
self.linear_2 = torch.nn.Linear(32, 16)
def forward(self, x):
x = self.linear_1(x)
x = torch.nn.functional.silu(x)
x = self.linear_2(x)
return x
if __name__ == '__main__':
model = MLP()
model.to("mps")
x = torch.randn(4, 16, device="mps", dtype=torch.float32)
capture(model, x)
--------------
Part of https://huggingface.co/posts/a-r-r-o-w/231008365980283