1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import torch.library 8 9namespace = "et_vk" 10lib = torch.library.Library(namespace, "DEF") 11 12############# 13## prepack ## 14############# 15 16 17def prepack_impl(x: torch.Tensor): 18 return x 19 20 21name = "prepack" 22lib.define(f"{name}(Tensor x) -> Tensor") 23lib.impl(name, prepack_impl, "CompositeExplicitAutograd") 24prepack_op = getattr(getattr(torch.ops, namespace), name) 25 26##################### 27## conv_with_clamp ## 28##################### 29 30 31def conv_with_clamp_impl( 32 input, 33 weight, 34 bias=None, 35 stride=1, 36 padding=0, 37 dilation=1, 38 transposed=False, 39 output_padding=0, 40 groups=1, 41 output_min=-float("inf"), 42 output_max=float("inf"), 43): 44 return torch.clamp( 45 torch.convolution( 46 input, 47 weight, 48 bias, 49 stride, 50 padding, 51 dilation, 52 transposed, 53 output_padding, 54 groups, 55 ), 56 output_min, 57 output_max, 58 ) 59 60 61name = "conv_with_clamp" 62lib.define( 63 f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Scalar? output_min, Scalar? output_max) -> Tensor" 64) 65lib.impl(name, conv_with_clamp_impl, "CompositeExplicitAutograd") 66conv_with_clamp_op = getattr(getattr(torch.ops, namespace), name) 67 68######################### 69## conv_with_clamp.out ## 70######################### 71 72 73def conv_with_clamp_out_impl( 74 input, 75 weight, 76 bias=None, 77 stride=1, 78 padding=0, 79 dilation=1, 80 transposed=False, 81 output_padding=0, 82 groups=1, 83 output_min=-float("inf"), 84 output_max=float("inf"), 85 out=None, 86): 87 out = conv_with_clamp_impl( 88 input, 89 weight, 90 bias, 91 stride, 92 padding, 93 dilation, 94 transposed, 95 output_padding, 96 groups, 97 output_min, 98 output_max, 99 ) 100 return out 101 102 103name = "conv_with_clamp.out" 104lib.define( 105 f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Scalar? output_min, Scalar? output_max, *, Tensor(a!) out) -> Tensor(a!)" 106) 107lib.impl(name, conv_with_clamp_out_impl, "CompositeExplicitAutograd") 108 109################# 110## grid_priors ## 111################# 112 113 114# The dimension of x should be larger than 1 115def grid_priors_impl( 116 x, 117 stride, 118 offset, 119): 120 height, width = x.shape[-2:] 121 # Need to specify device of torch.arange to avoid executorch exporting error 122 shift_x = (torch.arange(0, width, device=x.device) + offset) * stride 123 shift_y = (torch.arange(0, height, device=x.device) + offset) * stride 124 # Need to specify indexing parameter ('ij' is the default value) to avoid executorch exporting error 125 shift_xx, shift_yy = torch.meshgrid([shift_y, shift_x], indexing="ij") 126 shift_xx = shift_xx.reshape(-1) 127 shift_yy = shift_yy.reshape(-1) 128 shifts = torch.stack((shift_yy, shift_xx), dim=-1) 129 return shifts 130 131 132name = "grid_priors" 133lib.define(f"{name}(Tensor self, int stride, float offset) -> Tensor") 134lib.impl(name, grid_priors_impl, "CompositeExplicitAutograd") 135grid_priors_op = getattr(getattr(torch.ops, namespace), name) 136 137 138# When lowering to executorch, ops are converted from default to out variant. Hence, custom ops define both variants. 139def grid_priors_out_impl( 140 x, 141 stride, 142 offset, 143 out, 144): 145 out = grid_priors_impl(x, stride, offset) 146 return out 147 148 149name = "grid_priors_out" 150lib.define( 151 f"{name}(Tensor self, int stride, float offset, *, Tensor(a!) out) -> Tensor(a!)" 152) 153lib.impl(name, grid_priors_out_impl, "CompositeExplicitAutograd") 154 155######################## 156## linear_weight_int4 ## 157######################## 158 159 160def linear_weight_int4_impl( 161 x: torch.Tensor, 162 weights_4x8: torch.Tensor, 163 groupsize: int, 164 scales_and_zeros: torch.Tensor, 165 inner_k_tiles: int, 166): 167 original_x_size = x.size() 168 out_features = weights_4x8.size(0) 169 x = x.reshape(-1, original_x_size[-1]) 170 weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( 171 weights_4x8, inner_k_tiles 172 ) 173 out = torch.ops.aten._weight_int4pack_mm( 174 x, weight_int4pack, groupsize, scales_and_zeros 175 ) 176 out_shape = original_x_size[:-1] + (out_features,) 177 return out.reshape(out_shape) 178 179 180name = "linear_weight_int4" 181lib.define( 182 f"{name}(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros, int inner_k_tiles) -> Tensor" 183) 184lib.impl(name, linear_weight_int4_impl, "CompositeExplicitAutograd") 185linear_weight_int4_op = getattr(getattr(torch.ops, namespace), name) 186 187###################### 188## apply_rotary_emb ## 189###################### 190 191 192# Note that this implementation is copied from executorch.examples.models.llama.rope 193# but it is copied here to avoid introducing a dependency on the llama code. 194def apply_rotary_emb_impl( 195 xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor 196): 197 def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 198 ndim = x.ndim 199 freqs_cis_ndim = freqs_cis.ndim 200 if freqs_cis_ndim == 3: 201 # freqs_cis: (seq_len, n_heads, head_dim // 2) 202 assert freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]) 203 shape = [ 204 d if (i == ndim - 3 or i == ndim - 2 or i == ndim - 1) else 1 205 for i, d in enumerate(x.shape) 206 ] 207 else: 208 # freqs_cis: (seq_len, head_dim // 2) 209 assert freqs_cis.shape == (x.shape[1], x.shape[-1]) 210 shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] 211 return freqs_cis.view(shape) 212 213 xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1) 214 xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1) 215 216 freqs_cos = reshape_for_broadcast(freqs_cos, xq_r) 217 freqs_sin = reshape_for_broadcast(freqs_sin, xq_r) 218 219 xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin 220 xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos 221 xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin 222 xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos 223 224 xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3) 225 xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3) 226 227 return xq_out.type_as(xq), xk_out.type_as(xk) 228 229 230name = "apply_rotary_emb" 231lib.define( 232 f"{name}(Tensor xq, Tensor xk, Tensor freqs_cos, Tensor freqs_sin) -> (Tensor, Tensor)" 233) 234lib.impl(name, apply_rotary_emb_impl, "CompositeExplicitAutograd") 235apply_rotary_emb_op = getattr(getattr(torch.ops, namespace), name) 236