# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-strict from math import prod from typing import Optional, Tuple import torch from executorch.exir.scalar_type import ScalarType from torch.library import Library, register_fake from .utils import get_conv1d_output_size, get_conv2d_output_size lib = Library("cadence", "DEF") lib.define( "quantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" ) lib.define( "quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( "dequantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" ) lib.define( "dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( "quantized_layer_norm(Tensor X, Tensor X_scale, Tensor X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point) -> (Tensor Y)" ) lib.define( "quantized_layer_norm.out(Tensor X, Tensor X_scale, Tensor X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor (a!)" ) lib.define( "quantized_layer_norm.per_tensor(Tensor X, float X_scale, int X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point) -> (Tensor Y)" ) lib.define( "quantized_layer_norm.per_tensor_out(Tensor X, float X_scale, int X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor (a!)" ) lib.define( "quantized_linear(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)" ) lib.define( "quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( "quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( "quantized_linear.per_tensor(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, " "SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset) -> Tensor" ) lib.define( "quantized_relu(Tensor X, Tensor X_zero_point, int out_zero_point, Tensor out_multiplier, Tensor out_shift) -> (Tensor Y)" ) lib.define( "quantized_relu.out(Tensor X, Tensor X_zero_point, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor (a!)" ) lib.define( "quantized_conv(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False) -> (Tensor Z)" ) lib.define( "quantized_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( "quantized_conv.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False) -> (Tensor Z)" ) lib.define( "quantized_conv.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( "quantized_matmul(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False) -> (Tensor Z)" ) lib.define( "quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( "convolution(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, " "int[] dilation, int groups, bool channel_last=False) -> (Tensor Y)" ) lib.define( "transposed_convolution(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, " "int[] dilation, SymInt[] output_padding, int groups, bool channel_last=False) -> (Tensor Y)" ) lib.define("dequantize(Tensor X, Tensor X_scale, Tensor X_zero_point) -> (Tensor Y)") # cadence::quantized_relu is defined in OSS lib.define( "quantized_add(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, " "Tensor Y_zero_point, float out_scale, int out_zero_point) -> (Tensor Z)" ) lib.define( "quantized_mul(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, " "Tensor Y_zero_point, float out_scale, int out_zero_point) -> (Tensor Z)" ) lib.define( "quantized_add_Scalar(Tensor X, Tensor X_scale, Tensor X_zero_point, Scalar Y, " "float out_scale, int out_zero_point) -> (Tensor Z)" ) lib.define( "quantized_mul_Scalar(Tensor X, Tensor X_scale, Tensor X_zero_point, Scalar Y, " "float out_scale, int out_zero_point) -> (Tensor Z)" ) lib.define( "quantized_embedding_byte(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, " "Tensor indices, bool pruned_weights=False) -> (Tensor X)" ) # cadence::quantized_layer_norm is defined in OSS # cadence::quantized_conv is defined is OSS lib.define( "quantized_transposed_conv(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, " "int[] dilation, SymInt[] output_padding, int groups, int input_zero_point, Tensor weight_zero_point, " "Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False) -> (Tensor out)" ) lib.define( "avg_pool2d(Tensor input, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, " "bool count_include_pad=True, int? divisor_override=None, Tensor? in_zero_point=None, bool channel_last=False) -> (Tensor out)" ) lib.define( "im2row(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, " "Tensor in_zero_point, bool channel_last=False) -> (Tensor out)" ) lib.define("linalg_vector_norm(Tensor X) -> (Tensor Y)") lib.define( "transposed_im2row(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, " "int[2] output_padding, Tensor in_zero_point, bool channel_last=False) -> (Tensor out)" ) lib.define( "requantize(Tensor input, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, " "Tensor out_zero_point, ScalarType out_dtype) -> (Tensor Y)" ) lib.define( "fully_connected(Tensor input, Tensor weight, Tensor? bias=None) -> (Tensor out)" ) lib.define( "quantized_fully_connected(Tensor src, Tensor weight, Tensor bias, int src_zero_point, " "Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)" ) # ------------------------------------ # # Migrated from custom_ops.ymal # # ------------------------------------ # # Migrated from the custom_ops.yaml files containing different operator variants (e.g., .out, .tensor_out) lib.define( "convolution.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, " "int groups, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( "transposed_convolution.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, " "int[] dilation, SymInt[] output_padding, int groups, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)" ) # cadence::quantized_relu.out is defined in OSS lib.define( "quantized_relu.per_tensor(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, int out_shift) -> Tensor" ) lib.define( "quantized_relu.per_tensor_out(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, " "int out_shift, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( "quantized_add.out(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, " "Tensor Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( "quantized_mul.out(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, " "Tensor Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( "quantized_add_Scalar.out(Tensor X, Tensor X_scale, Tensor X_zero_point, Scalar Y, " "float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( "quantized_mul_Scalar.out(Tensor X, Tensor X_scale, Tensor X_zero_point, Scalar Y, " "float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( "fully_connected.out(Tensor input, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define("linalg_vector_norm.out(Tensor X, *, Tensor(a!) out) -> Tensor(a!)") lib.define( "quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, " "Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( "quantized_embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, " "Tensor indices, bool pruned_weights=False, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( "quantized_transposed_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, " "SymInt[] padding, int[] dilation, SymInt[] output_padding, int groups, int input_zero_point, " "Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, " "Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( "avg_pool2d.out(Tensor input, int[2] kernel_size, int[2] stride=[], int[2] padding=0, " "bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, " "Tensor? in_zero_point=None, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( "im2row.out(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, " "Tensor in_zero_point, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( "transposed_im2row.out(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, " "int[2] stride, int[2] output_padding, Tensor in_zero_point, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)" ) lib.define( "requantize.out(Tensor input, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, " "Tensor out_zero_point, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!)" ) # Custom ops with aten namespace. Need to specify the lib var as FRAGMENT type as aten library is already defined aten_lib = Library("aten", "FRAGMENT") aten_lib.define( "chunk.out(Tensor self, int chunks, int dim=0, *, Tensor(a!)[] out) -> ()" ) aten_lib.define( "contiguous.out(Tensor self, *, MemoryFormat memory_format=contiguous_format, " "Tensor(a!) out) -> Tensor(a!)" ) aten_lib.define( "tensor_split.sections_out(Tensor self, int sections, int dim=0, *, Tensor(a!)[] out) -> ()" ) aten_lib.define( "_slice_copy_nop(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, " "SymInt step=1) -> Tensor(a!)" ) aten_lib.define( "_select_copy_nop.int_out(Tensor self, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!)" ) aten_lib.define( "_slice_copy_nop.Tensor_out(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, " "SymInt step=1, *, Tensor(a!) out) -> Tensor(a!)" ) aten_lib.define("_cat_nop(Tensor[] tensors, int dim=0) -> Tensor(a!)") aten_lib.define( "_cat_nop.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)" ) # Custom ops with jarvis_nn_ops namespace jarvis_nn_lib = Library("jarvis_nn_ops", "DEF") jarvis_nn_lib.define( "attention_mask.out(Tensor input, Tensor start, Tensor stop, *, Tensor(a!) out) -> Tensor(a!)" ) m = Library("cadence", "IMPL", "Meta") @register_fake("cadence::quantize_per_tensor") def quantize_per_tensor_meta( input: torch.Tensor, scale: float, zero_point: int, quant_min: int, quant_max: int, dtype: torch.dtype, ) -> torch.Tensor: return input.new_empty(input.size(), dtype=dtype) @register_fake("cadence::dequantize_per_tensor") def dequantize_per_tensor_meta( input: torch.Tensor, scale: float, zero_point: int, quant_min: int, quant_max: int, dtype: torch.dtype, ) -> torch.Tensor: return input.new_empty(input.size(), dtype=torch.float) @register_fake("cadence::quantized_linear") def quantized_linear_meta( src: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, in_zero_point: int, weight_zero_point: torch.Tensor, out_multiplier: torch.Tensor, out_shift: torch.Tensor, out_zero_point: int, offset: Optional[torch.Tensor], ) -> torch.Tensor: # src comes in shape [leading_dims, in_dim] # weight comes in shape [out_dim, in_dim] # output comes in empty with shape [leading_dims, out_dim] out_size = list(src.size()) weight_size = list(weight.size()) assert len(weight_size) == 2 out_size[-1] = weight_size[0] return src.new_empty(out_size, dtype=src.dtype) @register_fake("cadence::quantized_linear.per_tensor") def quantized_linear_per_tensor_meta( src: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, in_zero_point: torch.SymInt, weight_zero_point: torch.SymInt, out_multiplier: torch.SymInt, out_shift: torch.SymInt, out_zero_point: torch.SymInt, offset: Optional[torch.Tensor], ) -> torch.Tensor: # src comes in shape [leading_dims, in_dim] # weight comes in shape [out_dim, in_dim] # output comes in empty with shape [leading_dims, out_dim] out_size = list(src.size()) weight_size = list(weight.size()) assert len(weight_size) == 2 out_size[-1] = weight_size[0] return src.new_empty(out_size, dtype=src.dtype) @register_fake("cadence::quantized_conv") def quantized_conv_meta( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: Tuple[int], padding: Tuple[int], dilation: Tuple[int], groups: int, in_zero_point: int, weight_zero_point: torch.Tensor, bias_scale: torch.Tensor, output_scale: float, output_zero_point: int, out_multiplier: torch.Tensor, out_shift: torch.Tensor, channel_last: bool = False, ) -> torch.Tensor: if channel_last: out_channels, *kernel_size, _ = weight.shape else: out_channels, _, *kernel_size = weight.shape in_size = input.shape # Assert that the input tensor has at least 3 dimensions, and at most 6 assert len(in_size) > 2 assert len(in_size) < 6 # Compute the output tensor size output_size = ( get_conv1d_output_size( in_size, out_channels, stride[1], padding[1], dilation[1], kernel_size[0], channel_last, ) if len(in_size) == 3 else get_conv2d_output_size( in_size, out_channels, stride, padding, dilation, kernel_size, channel_last ) ) return input.new_empty(output_size, dtype=input.dtype) @register_fake("cadence::quantized_conv.per_tensor") def quantized_conv_per_tensor_meta( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: Tuple[int], padding: Tuple[int], dilation: Tuple[int], groups: int, in_zero_point: int, weight_zero_point: int, bias_scale: float, output_scale: float, output_zero_point: int, out_multiplier: int, out_shift: int, channel_last: bool = False, ) -> torch.Tensor: if channel_last: out_channels, *kernel_size, _ = weight.shape else: out_channels, _, *kernel_size = weight.shape in_size = input.shape # Assert that the input tensor has at least 3 dimensions, and at most 6 assert len(in_size) > 2 assert len(in_size) < 6 # Compute the output tensor size output_size = ( get_conv1d_output_size( in_size, out_channels, stride[1], padding[1], dilation[1], kernel_size[0], channel_last, ) if len(in_size) == 3 else get_conv2d_output_size( in_size, out_channels, stride, padding, dilation, kernel_size, channel_last ) ) return input.new_empty(output_size, dtype=input.dtype) @register_fake("cadence::quantized_layer_norm") def quantized_layer_norm_meta( input: torch.Tensor, X_scale: torch.Tensor, X_zero_point: torch.Tensor, normalized_shape: int, weight: torch.Tensor, bias: torch.Tensor, eps: float, output_scale: float, output_zero_point: int, ) -> torch.Tensor: return input.new_empty(input.size(), dtype=input.dtype) @register_fake("cadence::quantized_layer_norm.per_tensor") def quantized_layer_norm_per_tensor_meta( input: torch.Tensor, X_scale: float, X_zero_point: int, normalized_shape: int, weight: torch.Tensor, bias: torch.Tensor, eps: float, output_scale: float, output_zero_point: int, ) -> torch.Tensor: return input.new_empty(input.size(), dtype=input.dtype) @register_fake("cadence::quantized_relu") def quantized_relu_meta( X: torch.Tensor, X_zero_point: torch.Tensor, out_zero_point: int, out_multiplier: torch.Tensor, out_shift: torch.Tensor, ) -> torch.Tensor: return X.new_empty(X.size(), dtype=X.dtype) @register_fake("cadence::quantized_matmul") def quantized_matmul_meta( X: torch.Tensor, X_zero_point: int, Y: torch.Tensor, Y_zero_point: int, bias: Optional[torch.Tensor], out_multiplier: int, out_shift: int, out_zero_point: int, transposed: bool = False, ) -> torch.Tensor: X_size = list(X.size()) Y_size = list(Y.size()) # Get the batch dimensions for both tensors X_batch_dims = X_size[:-2] Y_batch_dims = Y_size[:-2] # If they don't match, check that they're compatible if X_batch_dims != Y_batch_dims: assert prod(X_batch_dims) == prod( Y_batch_dims ), f"Batch dimensions of X and Y do not match: {X_batch_dims} vs {Y_batch_dims}" # Get the matmul output size if transposed: assert X_size[-1] == Y_size[-1], "matrices cannot be multiplied" mat_size = [X_size[-2], Y_size[-2]] else: assert X_size[-1] == Y_size[-2], "matrices cannot be multiplied" mat_size = [X_size[-2], Y_size[-1]] # Combine the larger batch dimensions with the matmul output size out_size = ( X_batch_dims + mat_size if len(X_batch_dims) > len(Y_batch_dims) else Y_batch_dims + mat_size ) return X.new_empty(out_size, dtype=X.dtype) @register_fake("cadence::im2row") def im2row_meta( input: torch.Tensor, kernel_size: Tuple[int], dilation: Tuple[int], padding: Tuple[int], stride: Tuple[int], in_zero_point: torch.Tensor, channel_last: bool = False, ) -> torch.Tensor: if len(input.shape) == 3: height_dim = 1 if channel_last else 2 input = input.unsqueeze(height_dim) batch_size = input.shape[0] n_input_plane = input.shape[3] if channel_last else input.shape[1] input_height = input.shape[1] if channel_last else input.shape[2] input_width = input.shape[2] if channel_last else input.shape[3] output_height = ( input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1) ) // stride[0] + 1 output_width = ( input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1) ) // stride[1] + 1 n_output_plane = n_input_plane * kernel_size[0] * kernel_size[1] output_size = torch.Size((batch_size, output_height * output_width, n_output_plane)) return input.new_empty(output_size, dtype=input.dtype) # Define the abstract implementations of the operators as required @register_fake("cadence::linalg_vector_norm") def linalg_vector_norm_meta( X: torch.Tensor, ) -> torch.Tensor: # Output of norm is a scalar, so we return a [] tensor return X.new_empty([], dtype=X.dtype) @register_fake("cadence::requantize") def requantize_meta( input: torch.Tensor, in_scale: torch.Tensor, in_zero_point: torch.Tensor, out_scale: torch.Tensor, out_zero_point: torch.Tensor, dtype: ScalarType, ) -> torch.Tensor: return input.new_empty( input.size(), # pyre-ignore[6]: Incompatible type dtype=dtype, ) @register_fake("cadence::quantized_relu.per_tensor") def quantized_relu_per_tensor_meta( input: torch.Tensor, in_zero_point: int, out_zero_point: int, out_multiplier: int, out_shift: int, ) -> torch.Tensor: return input.new_empty(input.size(), dtype=torch.uint8) @register_fake("cadence::fully_connected") def fully_connected_meta( src: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, ) -> torch.Tensor: # src comes in shape [leading_dims, in_dim] # weight comes in shape [out_dim, in_dim] # output comes in empty with shape [leading_dims, out_dim] out_size = list(src.size()) weight_size = list(weight.size()) assert len(weight_size) == 2 out_size[-1] = weight_size[0] return src.new_empty(out_size, dtype=src.dtype) @register_fake("cadence::quantized_fully_connected") def quantized_fully_connected_meta( src: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, in_zero_point: int, weight_zero_point: torch.Tensor, out_multiplier: int, out_shift: int, out_zero_point: int, offset: Optional[torch.Tensor], ) -> torch.Tensor: # src comes in shape [leading_dims, in_dim] # weight comes in shape [out_dim, in_dim] # output comes in empty with shape [leading_dims, out_dim] out_size = list(src.size()) weight_size = list(weight.size()) assert len(weight_size) == 2 out_size[-1] = weight_size[0] return src.new_empty(out_size, dtype=torch.uint8) @register_fake("cadence::convolution") def convolution_meta( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: Tuple[int], padding: Tuple[int], dilation: Tuple[int], groups: int, channel_last: bool = False, ) -> torch.Tensor: if channel_last: out_channels, *kernel_size, _ = weight.shape else: out_channels, _, *kernel_size = weight.shape in_size = input.shape # Assert that the input tensor has at least 3 dimensions, and at most 6 assert len(in_size) > 2 assert len(in_size) < 6 # Compute the output tensor size output_size = ( get_conv1d_output_size( in_size, out_channels, stride[0], padding[0], dilation[0], kernel_size[0], channel_last, ) if len(in_size) == 3 else get_conv2d_output_size( in_size, out_channels, stride, padding, dilation, kernel_size, channel_last ) ) return input.new_empty(output_size, dtype=input.dtype) @register_fake("cadence::transposed_convolution") def transposed_convolution_meta( input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: Tuple[int], padding: Tuple[int], dilation: Tuple[int], output_padding: Tuple[int], groups: int, channel_last: bool = False, ) -> torch.Tensor: # The native definition of torch transposed conv will have weight shape as # (in_channels, out_channels/groups, *kernel_size). # However, the two channel position is flipped in the Jarvis pass of replacing it # with cadence::transposed_convolution here: https://fburl.com/code/d2s7pkyy out_channels, _input_channels, *kernel_size = weight.shape out_channels *= groups in_size = input.shape # Get the output size of a transposed 1D convolution given the input size and parameters def get_conv_transpose1d_output_size( in_size: torch.Size, kernel_size: list[int], out_channels: int, stride: Tuple[int], padding: Tuple[int], dilation: Tuple[int], output_padding: Tuple[int], channel_last: bool = False, ) -> torch.Size: assert len(in_size) == 3 if channel_last: N, L, C = in_size else: N, C, L = in_size # Reference: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html lout = ( (L - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1 ) if channel_last: return torch.Size((in_size[0], lout, out_channels)) else: return torch.Size((in_size[0], out_channels, lout)) def get_conv_transpose2d_output_size( in_size: torch.Size, kernel_size: list[int], out_channels: int, stride: Tuple[int], padding: Tuple[int], dilation: Tuple[int], output_padding: Tuple[int], channel_last: bool = False, ) -> torch.Size: assert len(in_size) == 4 if channel_last: N, H, W, C = in_size else: N, C, H, W = in_size # Reference: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html hout = ( (H - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1 ) wout = ( (W - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + output_padding[1] + 1 ) if channel_last: return torch.Size((in_size[0], hout, wout, out_channels)) else: return torch.Size((in_size[0], out_channels, hout, wout)) # Compute the output tensor size if len(in_size) == 3: output_size = get_conv_transpose1d_output_size( in_size, kernel_size, out_channels, stride, padding, dilation, output_padding, channel_last, ) elif len(in_size) == 4: output_size = get_conv_transpose2d_output_size( in_size, kernel_size, out_channels, stride, padding, dilation, output_padding, channel_last, ) else: raise NotImplementedError( f"transposed_convolution meta is not implemented for input tensor with {len(in_size)} dimensions" ) return input.new_empty(output_size, dtype=input.dtype) @register_fake("cadence::avg_pool2d") def avg_pool2d_meta( input: torch.Tensor, kernel_size: Tuple[int], stride: Tuple[int], padding: Tuple[int], ceil_mode: bool, count_include_pad: Optional[bool] = True, divisor_override: Optional[int] = None, in_zero_point: Optional[int] = None, channel_last: bool = False, ) -> torch.Tensor: # Use torch native meta kernels when operator semantics are similar return torch._meta_registrations.meta_avg_pool2d( input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, ) @register_fake("cadence::transposed_im2row") def transposed_im2row_meta( input: torch.Tensor, kernel_size: Tuple[int], dilation: Tuple[int], padding: Tuple[int], stride: Tuple[int], output_padding: Tuple[int], in_zero_point: torch.Tensor, channel_last: bool = False, ) -> torch.Tensor: if len(input.shape) == 3: height_dim = 1 if channel_last else 2 input = input.unsqueeze(height_dim) batch_size = input.shape[0] n_input_plane = input.shape[3] if channel_last else input.shape[1] input_height = input.shape[1] if channel_last else input.shape[2] input_width = input.shape[2] if channel_last else input.shape[3] output_height = ( (input_height - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1 ) output_width = ( (input_width - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + output_padding[1] + 1 ) n_output_plane = n_input_plane * kernel_size[0] * kernel_size[1] output_length = output_height * output_width output_size = torch.Size((batch_size, output_length, n_output_plane)) return input.new_empty(output_size, dtype=input.dtype)