Join the conversation

Join the community of Machine Learners and AI enthusiasts.

Sign Up
a-r-r-o-w 
posted an update 24 days ago
Post
1303
Did you know how simple it was to get started with your own custom compiler backend with torch.compile? What's stopping you from writing your own compiler?

import torch
from torch._functorch.partitioners import draw_graph

def compiler(fx_module: torch.fx.GraphModule, _):
    draw_graph(fx_module, f"compile.dot")
    return fx_module.forward

def capture(model, *inputs):
    compiled_model = torch.compile(model, backend=compiler)
    y = compiled_model(*inputs)
    y.sum().backward()

class MLP(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
        self.linear_1 = torch.nn.Linear(16, 32)
        self.linear_2 = torch.nn.Linear(32, 16)
    
    def forward(self, x):
        x = self.linear_1(x)
        x = torch.nn.functional.silu(x)
        x = self.linear_2(x)
        return x

if __name__ == '__main__':
    model = MLP()
    model.to("mps")
    x = torch.randn(4, 16, device="mps", dtype=torch.float32)

    capture(model, x)


--------------

Part of https://huggingface.co/posts/a-r-r-o-w/231008365980283

I don't know what the hell I'm doing in this community 😂

but I'm constantly learning about the platform.

Can you tell me what kind of code is this even brother?