xref: /aosp_15_r20/external/executorch/test/models/linear_model.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerimport torch
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Worker
10*523fa7a6SAndroid Build Coastguard Workerclass LinearModel(torch.nn.Module):
11*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
12*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
13*523fa7a6SAndroid Build Coastguard Worker        self.a = 3 * torch.ones(2, 2, dtype=torch.float)
14*523fa7a6SAndroid Build Coastguard Worker        self.b = 2 * torch.ones(2, 2, dtype=torch.float)
15*523fa7a6SAndroid Build Coastguard Worker
16*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x: torch.Tensor):
17*523fa7a6SAndroid Build Coastguard Worker        out_1 = torch.mul(self.a, x)
18*523fa7a6SAndroid Build Coastguard Worker        out_2 = torch.add(out_1, self.b)
19*523fa7a6SAndroid Build Coastguard Worker        return out_2
20