OzzyGT/mellon_examples
Updated
•
123
the first integration was really good because we didn't need to try to remember or find what the numbers are, but to avoid complications with the UIs it was reverted to index numbers, You can find them here but also I'll to post them here too:
0 -- openpose
1 -- depth
2 -- hed/pidi/scribble/ted
3 -- canny/lineart/anime_lineart/mlsd
4 -- normal
5 -- segment
6 -- tile
7 -- repaint
torch.compile
? What's stopping you from writing your own compiler?import torch
from torch._functorch.partitioners import draw_graph
def compiler(fx_module: torch.fx.GraphModule, _):
draw_graph(fx_module, f"compile.dot")
return fx_module.forward
def capture(model, *inputs):
compiled_model = torch.compile(model, backend=compiler)
y = compiled_model(*inputs)
y.sum().backward()
class MLP(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear_1 = torch.nn.Linear(16, 32)
self.linear_2 = torch.nn.Linear(32, 16)
def forward(self, x):
x = self.linear_1(x)
x = torch.nn.functional.silu(x)
x = self.linear_2(x)
return x
if __name__ == '__main__':
model = MLP()
model.to("mps")
x = torch.randn(4, 16, device="mps", dtype=torch.float32)
capture(model, x)