1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerimport torch 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Worker 10*523fa7a6SAndroid Build Coastguard Workerclass LinearModel(torch.nn.Module): 11*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 12*523fa7a6SAndroid Build Coastguard Worker super().__init__() 13*523fa7a6SAndroid Build Coastguard Worker self.a = 3 * torch.ones(2, 2, dtype=torch.float) 14*523fa7a6SAndroid Build Coastguard Worker self.b = 2 * torch.ones(2, 2, dtype=torch.float) 15*523fa7a6SAndroid Build Coastguard Worker 16*523fa7a6SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor): 17*523fa7a6SAndroid Build Coastguard Worker out_1 = torch.mul(self.a, x) 18*523fa7a6SAndroid Build Coastguard Worker out_2 = torch.add(out_1, self.b) 19*523fa7a6SAndroid Build Coastguard Worker return out_2 20