# Copyright (c) Qualcomm Innovation Center, Inc. # All rights reserved # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import torch # module with related operator only class Add(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, y): return torch.add(x, y) class AddConstantFloat(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return 10.0 + x class AddConstantLong(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return 10 + x class Arange(torch.nn.Module): def __init__(self, x): super().__init__() self.x = x def forward(self, y): return torch.arange(self.x, dtype=torch.float32) + y class AvgPoolModule(torch.nn.Module): def __init__(self): super().__init__() self.avgPool = torch.nn.AvgPool2d( kernel_size=(2, 2), padding=(1, 1), stride=(1, 1), count_include_pad=False, ) def forward(self, x): return self.avgPool(x) class BatchNorm(torch.nn.Module): def __init__(self, n_features): super().__init__() self.native_batchnorm = torch.nn.BatchNorm2d(n_features) self.eval() def forward(self, x): return self.native_batchnorm(x) class Bmm(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, y): return torch.matmul(x, y) class Cast(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return x.type(torch.IntTensor) class Cat2(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, y): return torch.cat((x, y), axis=2) class Cat3(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, y): return torch.concat((y, y, x), axis=2) class Cat4(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, y): return torch.cat((y, y, x, x), axis=2) class Ceil(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return torch.ceil(x) class Chunk(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return torch.chunk(x, chunks=2, dim=-1) class ChunkAdd(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): c1, c2 = torch.chunk(x, chunks=2, dim=-1) return torch.add(c1, c2) class Clamp(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return torch.clamp(x, max=0) class CompositeDelegateModule(torch.nn.Module): def __init__( self, compiler_specs, partitioner_type, capture_method, lowered_method, quantize_method=None, ) -> None: super().__init__() self.modules = [ Conv2dSequential(), Conv2dSequential(), Add(), Relu(), ] self.sample_inputs = [ (torch.randn([1, 1, 3, 3]),), (torch.randn([1, 1, 3, 3]),), (torch.randn([1, 2, 3, 3]), torch.randn([1, 2, 3, 3])), (torch.randn([1, 2, 3, 3]),), ] self.lowered_modules = [] for module, sample_input in zip(self.modules, self.sample_inputs): partitioner = partitioner_type(compiler_specs) if quantize_method: module = quantize_method(module, sample_input) edge_prog = capture_method(module, sample_input) edge_prog.exported_program = lowered_method( edge_prog.exported_program, partitioner ) self.lowered_modules.append( edge_prog.exported_program.graph_module._modules.get("lowered_module_0") ) def forward(self, x, y): x1 = self.lowered_modules[0](x) x2 = self.lowered_modules[1](y) x3 = self.lowered_modules[2](x1[0], x2[0]) x4 = self.lowered_modules[3](x3[0]) return x4[0] def get_random_input(self): return (torch.randn([1, 1, 3, 3]), torch.randn([1, 1, 3, 3])) def get_reference_module(self): class CompositeReferenceModule(torch.nn.Module): def __init__(self, modules): super().__init__() self.modules = modules def forward(self, x, y): x1 = self.modules[0](x) x2 = self.modules[1](y) x3 = self.modules[2](x1, x2) x4 = self.modules[3](x3) return x4 return CompositeReferenceModule(self.modules) class ContextBinaryExample(torch.nn.Module): def forward(self, x, y): x = torch.nn.functional.relu(x) y = torch.nn.functional.relu(y) return x, y def example_inputs(self): return { "x": torch.randn((1, 3, 3, 3)), "y": torch.randn((2, 1, 5, 5)), } class Conv1dSequential(torch.nn.Module): def __init__(self, bias=True): super().__init__() self.first = torch.nn.Conv1d( in_channels=1, out_channels=3, kernel_size=(3), padding=1, bias=bias, ) self.second = torch.nn.Conv1d( in_channels=3, out_channels=2, kernel_size=(3), padding=1, bias=bias, ) def forward(self, x): return self.second(self.first(x)) # small models class Conv1dReluLogSoftmax(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv1d( in_channels=2, out_channels=2, kernel_size=1, stride=1, padding=1 ) self.logsoftmax = torch.nn.LogSoftmax(dim=1) def forward(self, x): x = torch.nn.functional.relu(self.conv(x)) x = self.logsoftmax(x) return x class Conv2dAvgPool2d(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d( 3, 16, 7, bias=True, stride=2, padding=3, dilation=1 ) self.pool = torch.nn.AvgPool2d(3, stride=2, padding=1) def forward(self, x): return self.pool(self.conv(x)) class Conv2dBnHardtanhMean(torch.nn.Module): def __init__(self): super(Conv2dBnHardtanhMean, self).__init__() groups = 1 stride = [2, 2] padding = [1, 1] dilation = [1, 1] in_channels = 1 out_channels = 1 self.conv = torch.nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), stride=stride, padding=padding, groups=groups, dilation=dilation, bias=True, ) self.conv.weight = torch.nn.Parameter(torch.randn(self.conv.weight.size())) self.native_batchnorm = torch.nn.BatchNorm2d(out_channels) self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6) self.eval() def forward(self, x): x1 = self.conv(x) x2 = self.native_batchnorm(x1) x3 = self.hardtanh(x2) x4 = torch.mean(x3, (1), keepdim=True) return x4 class Conv2dCat(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = torch.nn.Conv2d(3, 3, 3) self.conv2 = torch.nn.Conv2d(3, 3, 3) def forward(self, x, y): x = self.conv1(x) y = self.conv2(y) z = torch.cat([x, y], dim=1) return z class Conv2dMaxPool2d(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d( in_channels=2, out_channels=2, kernel_size=(1, 1), padding=1, bias=True, ) self.pool = torch.nn.MaxPool2d(1, 1) def forward(self, x): return self.pool(self.conv(x)) class Conv2dSequential(torch.nn.Module): def __init__(self, bias=True): super().__init__() self.first = torch.nn.Conv2d( in_channels=1, out_channels=3, kernel_size=(3, 3), padding=1, bias=bias, ) self.second = torch.nn.Conv2d( in_channels=3, out_channels=2, kernel_size=(3, 3), padding=1, bias=bias, ) def forward(self, x): return self.second(self.first(x)) class Conv2dSingle(torch.nn.Module): def __init__(self, bias=True): super().__init__() self.conv = torch.nn.Conv2d( in_channels=1, out_channels=3, kernel_size=(3, 3), padding=1, bias=bias, ) def forward(self, x): return self.conv(x) class ConvTranspose2dSingle(torch.nn.Module): def __init__(self, bias=True): super().__init__() self.conv_transpose = torch.nn.ConvTranspose2d( in_channels=1, out_channels=3, kernel_size=3, stride=2, padding=1, bias=bias, ) def forward(self, x): return self.conv_transpose(x) class Conv2dDownUpSample(torch.nn.Module): def __init__(self, bias=True): super().__init__() self.conv = torch.nn.Conv2d( in_channels=16, out_channels=16, kernel_size=3, stride=2, padding=1, bias=bias, ) self.conv_transpose = torch.nn.ConvTranspose2d( in_channels=16, out_channels=16, kernel_size=3, stride=2, padding=1, bias=bias, ) def forward(self, x): return self.conv_transpose(self.conv(x)) class Conv2dSumReduceDim(torch.nn.Module): def __init__(self): super().__init__() self.first = torch.nn.Conv2d( in_channels=1, out_channels=3, kernel_size=(3, 3), padding=1, bias=True, ) def forward(self, x): return torch.sum(self.first(x), dim=(2, 3), keepdim=False) class Conv2dTopK(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d(3, 16, 3) def forward(self, x): x = self.conv(x) topk_values, topk_indices = torch.topk(x, 5, dim=1) return topk_values class Div(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, y): return torch.divide(x, y) class DivConstantFloat(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return x / 10.0 class DivConstantLong(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return x / 10 class EinsumBilinear(torch.nn.Module): def __init__(self): super().__init__() def forward(self, bn, anm, bm): return torch.einsum("bn,anm,bm->ba", bn, anm, bm) class EinsumOuterProduct(torch.nn.Module): def __init__(self): super().__init__() def forward(self, i, j): return torch.einsum("i,j->ij", i, j) class EinsumOuterProductRelu(torch.nn.Module): def __init__(self): super().__init__() def forward(self, i, j): return torch.relu(torch.einsum("i,j->ij", i, j)) class Embedding(torch.nn.Module): def __init__(self): super().__init__() self.embedding = torch.nn.Embedding(10, 3) def forward(self, x): return self.embedding(x) class ExpandCopy(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return x.expand(3, 4) class Gelu(torch.nn.Module): def __init__(self): super().__init__() self.gelu = torch.nn.GELU() def forward(self, x): return self.gelu(x) class GroupNorm(torch.nn.Module): def __init__(self, bias=True): super().__init__() self.conv = torch.nn.Conv2d( 32, 256, kernel_size=3, stride=1, padding=1, bias=bias, ) self.norm = torch.nn.GroupNorm(32, 256) def forward(self, x): y = self.conv(x) return y, self.norm(y) class HardSigmoid(torch.nn.Module): def __init__(self): super().__init__() self.hardsigmoid = torch.nn.Hardsigmoid() def forward(self, x): return self.hardsigmoid(x) class HardSwish(torch.nn.Module): def __init__(self): super().__init__() self.hardswish = torch.nn.Hardswish() def forward(self, x): return self.hardswish(x) class HardTanh(torch.nn.Module): def __init__(self): super().__init__() self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6) def forward(self, x): return self.hardtanh(x) class Index(torch.nn.Module): def __init__(self): super().__init__() self.idx0 = torch.tensor([[0, 1], [2, 3], [4, 5]], dtype=torch.int32) self.idx1 = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.int32) def forward(self, x): return x[self.idx0] + x[self.idx1] class IndexPut(torch.nn.Module): def __init__(self): super().__init__() self.register_buffer( "k_cache", torch.zeros((1, 1024, 12, 64), dtype=torch.float32), ) def forward(self, input_pos, k_val): k_out = torch.ops.aten.index_put_(self.k_cache, [None, input_pos], k_val) return k_out class LayerNorm(torch.nn.Module): def __init__(self): super().__init__() self.layer_norm = torch.nn.LayerNorm([768], eps=1e-6) self.linear = torch.nn.Linear(768, 196) def forward(self, x): return self.linear(self.layer_norm(x)) class LeakyReLUDefault(torch.nn.Module): def __init__(self): super().__init__() self.leaky_relu = torch.nn.LeakyReLU() def forward(self, x): return self.leaky_relu(x) class LeakyReLUCustom(torch.nn.Module): def __init__(self, coeff): super().__init__() self.leaky_relu = torch.nn.LeakyReLU(coeff) def forward(self, x): return self.leaky_relu(x) class Linear(torch.nn.Module): def __init__(self, use_bias: bool = True): super().__init__() self.linear = torch.nn.Linear(4, 5, use_bias).eval() def forward(self, x): return self.linear(x) class LogSoftmax(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return torch.nn.functional.log_softmax(x, dim=-1) class MaxPool2d(torch.nn.Module): def __init__(self): super().__init__() self.max_pool2d = torch.nn.MaxPool2d( kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=True, ) def forward(self, x): return self.max_pool2d(x) class MeanWKeppDim(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return torch.mean(x, (-1, -2), keepdim=True) class MeanWOKeppDim(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return torch.mean(x, (-1, -2)) class Mul(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, y): return torch.mul(x, y) class MulConstantFloat(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return 10.0 * x class MulConstantLong(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return 10 * x class MulScalar(torch.nn.Module): def __init__(self): super().__init__() self._scalar = 3.14 def forward(self, x): out1 = torch.ops.aten.mul.Scalar(x, self._scalar) return out1 class MultiheadAttention(torch.nn.Module): def __init__(self): super().__init__() self.multi_head_attention = torch.nn.MultiheadAttention( 96, 12, dropout=0.0, batch_first=True ) def forward(self, x): attn_output, _ = self.multi_head_attention(x, x, x, need_weights=False) return attn_output class Pad(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return torch.nn.functional.pad( x[:, 1:], [0, 0, 0, 1, 0, 0], value=0.0, mode="constant" ) class PixelShuffle(torch.nn.Module): def __init__(self, scale): super().__init__() self.pixel_shuffle = torch.nn.PixelShuffle(scale) def forward(self, x): return self.pixel_shuffle(x) class PixelUnshuffle(torch.nn.Module): def __init__(self, scale): super().__init__() self.pixel_unshuffle = torch.nn.PixelUnshuffle(scale) def forward(self, x): return self.pixel_unshuffle(x) class PixelUnshuffleMathEquivalent(torch.nn.Module): def __init__(self, scale): super().__init__() self.scale = scale def forward(self, x): b, c, hh, hw = x.size() out_channel = c * (self.scale**2) h = hh // self.scale w = hw // self.scale x_view = x.view(b, c, h, self.scale, w, self.scale) return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) class PowTensorScalar(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return torch.pow(x, 2) class PReLUDefault(torch.nn.Module): def __init__(self): super().__init__() self.prelu = torch.nn.PReLU() def forward(self, x): return self.prelu(x) class PReLUPerChannel(torch.nn.Module): def __init__(self, channels): super().__init__() self.prelu = torch.nn.PReLU(channels) def forward(self, x): return self.prelu(x) class Relu(torch.nn.Module): def __init__(self): super().__init__() self.relu = torch.nn.ReLU() def forward(self, x): return self.relu(x) class Reshape(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return x.reshape(1, 12) class ResidualBlockModule(torch.nn.Module): def __init__(self): super(ResidualBlockModule, self).__init__() groups = 1 stride = [1, 1] padding = [1, 1] dilation = [1, 1] in_channels = 32 out_channels = 32 self.conv = torch.nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), stride=stride, padding=padding, groups=groups, dilation=dilation, bias=True, ) self.native_batchnorm = torch.nn.BatchNorm2d(out_channels) self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6.0) self.eval() def forward(self, x): x1 = self.conv(x) x2 = self.native_batchnorm(x1) x3 = self.conv(x2) x4 = self.native_batchnorm(x3) x5 = self.hardtanh(x4) x6 = torch.add(x5, x2) return x6 class ResizeBilinear2D(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): output_shape = [dim * 2 for dim in x.shape[-2:]] return torch.nn.functional.interpolate( x, size=list(torch.randn(output_shape).shape), mode="bilinear", align_corners=False, ) class ResizeNearest2D(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): output_shape = [dim * 2 for dim in x.shape[-2:]] return torch.nn.functional.interpolate( x, size=list(torch.randn(output_shape).shape), mode="nearest", ) class RmsNorm(torch.nn.Module): def __init__(self): super().__init__() self.eps = 1e-5 self.rms = torch.nn.RMSNorm([4], 1e-5) def forward(self, x): return self.rms(x) class Rsqrt(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return torch.rsqrt(x) class ScaledDotProductAttention(torch.nn.Module): def __init__(self): super().__init__() def forward(self, query_layer, key_layer, value_layer, attn_mask): attn_output = torch.nn.functional.scaled_dot_product_attention( query_layer, key_layer, value_layer, attn_mask ) return attn_output class SelectCopy(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d( in_channels=3, out_channels=2, kernel_size=(3, 3), padding=1, bias=True, ) def forward(self, x): return self.conv(x)[0, 1, 1:2] class Sigmoid(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return torch.sigmoid(x) class SimpleModel(torch.nn.Module): def __init__(self): super().__init__() kernel_sz = 32 self.conv1 = torch.nn.Conv2d(kernel_sz, kernel_sz, 3, padding=1, bias=True) self.conv2 = torch.nn.Conv2d(kernel_sz, kernel_sz, 3, padding=1, bias=True) self.conv3 = torch.nn.Conv2d(kernel_sz, kernel_sz, 3, padding=1, bias=False) self.conv4 = torch.nn.Conv2d(kernel_sz, kernel_sz, 3, padding=1, bias=False) self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6) self.relu = torch.nn.ReLU() self.batch_norm = torch.nn.BatchNorm2d(kernel_sz) self.add = torch.add self.mean = torch.mean self.reshape = torch.reshape self.linear = torch.nn.Linear(4, 10) self.permute = torch.permute self.eval() def forward(self, x, y): x1 = self.conv1(x) x2 = self.batch_norm(x1) x3 = self.relu(x2) x4 = self.conv2(x3) x5 = self.relu(x4) y1 = self.conv3(y) y2 = self.batch_norm(y1) y3 = self.relu(y2) y4 = self.conv4(y3) y5 = self.relu(y4) z = self.add(x5, y5) z1 = self.permute(z, (0, 3, 2, 1)) z2 = torch.mean(z1, [1, 2], True) z3 = self.reshape(z2, (8, -1)) z4 = self.linear(z3) z5 = self.hardtanh(z4) return z5 class SliceCopy(torch.nn.Module): def __init__(self): super().__init__() self.position_ids = torch.randn([1, 512]) def forward(self, x, y): seq_length = y.size()[1] return x[:, :seq_length] + self.position_ids[:, :seq_length] class SliceCopyWithStep(torch.nn.Module): def __init__(self): super().__init__() self.position_ids = torch.randn([1, 512]) self.step = 2 def forward(self, x, y): seq_length = y.size()[1] return ( x[:, : seq_length : self.step] + self.position_ids[:, : seq_length : self.step] ) class Softmax(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return torch.nn.functional.softmax(x, dim=-1) class Sqrt(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return torch.sqrt(x) class SqrtConstant(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return x / torch.sqrt(torch.tensor([64.0])) class Squeeze(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return x.squeeze() class Stack(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, y): return torch.stack((x, y)) class Sub(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, y): return torch.sub(x, y) class SubConstantFloat(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return 10.0 - x class SubConstantLong(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return 10 - x class SumIntList(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return torch.sum(x, dim=(2, 3), keepdim=True) class Tanh(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return torch.tanh(x) class TopKandIndex(torch.nn.Module): def __init__(self): super().__init__() self.idx_source = torch.rand(10, 3) def forward(self, x): a, b = torch.topk(x, 3) return a + self.idx_source[b] class Unbind(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return torch.unbind(x) class Unsqueeze(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x): return x.unsqueeze(0) class View(torch.nn.Module): def __init__(self): super().__init__() self.first_size = 2 self.second_size = 256 def forward(self, x, y): new_shape = x.size()[:-1] + (self.first_size, self.second_size) return x.view(new_shape) class ViewPermuteMatMul(torch.nn.Module): def __init__(self): super().__init__() self.first_size = 2 self.second_size = 256 def forward(self, x, y): new_shape = x.size()[:-1] + (self.first_size, self.second_size) x = x.view(new_shape) x = x.permute(0, 2, 1, 3) return torch.matmul(x, y.transpose(-1, -2))