xref: /aosp_15_r20/external/pytorch/test/linear.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport torch
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerclass LinearMod(torch.nn.Linear):
5*da0073e9SAndroid Build Coastguard Worker    def __init__(self, *args, **kwargs):
6*da0073e9SAndroid Build Coastguard Worker        super().__init__(*args, **kwargs)
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker    def forward(self, input):
9*da0073e9SAndroid Build Coastguard Worker        return torch._C._nn.linear(input, self.weight, self.bias)
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerprint(torch.jit.trace(LinearMod(20, 20), torch.rand([20, 20])).graph)
13