# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import torch class LinearModel(torch.nn.Module): def __init__(self): super().__init__() self.a = 3 * torch.ones(2, 2, dtype=torch.float) self.b = 2 * torch.ones(2, 2, dtype=torch.float) def forward(self, x: torch.Tensor): out_1 = torch.mul(self.a, x) out_2 = torch.add(out_1, self.b) return out_2