xref: /aosp_15_r20/external/executorch/backends/vulkan/custom_ops_lib.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import torch.library
8
9namespace = "et_vk"
10lib = torch.library.Library(namespace, "DEF")
11
12#############
13## prepack ##
14#############
15
16
17def prepack_impl(x: torch.Tensor):
18    return x
19
20
21name = "prepack"
22lib.define(f"{name}(Tensor x) -> Tensor")
23lib.impl(name, prepack_impl, "CompositeExplicitAutograd")
24prepack_op = getattr(getattr(torch.ops, namespace), name)
25
26#####################
27## conv_with_clamp ##
28#####################
29
30
31def conv_with_clamp_impl(
32    input,
33    weight,
34    bias=None,
35    stride=1,
36    padding=0,
37    dilation=1,
38    transposed=False,
39    output_padding=0,
40    groups=1,
41    output_min=-float("inf"),
42    output_max=float("inf"),
43):
44    return torch.clamp(
45        torch.convolution(
46            input,
47            weight,
48            bias,
49            stride,
50            padding,
51            dilation,
52            transposed,
53            output_padding,
54            groups,
55        ),
56        output_min,
57        output_max,
58    )
59
60
61name = "conv_with_clamp"
62lib.define(
63    f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Scalar? output_min, Scalar? output_max) -> Tensor"
64)
65lib.impl(name, conv_with_clamp_impl, "CompositeExplicitAutograd")
66conv_with_clamp_op = getattr(getattr(torch.ops, namespace), name)
67
68#########################
69## conv_with_clamp.out ##
70#########################
71
72
73def conv_with_clamp_out_impl(
74    input,
75    weight,
76    bias=None,
77    stride=1,
78    padding=0,
79    dilation=1,
80    transposed=False,
81    output_padding=0,
82    groups=1,
83    output_min=-float("inf"),
84    output_max=float("inf"),
85    out=None,
86):
87    out = conv_with_clamp_impl(
88        input,
89        weight,
90        bias,
91        stride,
92        padding,
93        dilation,
94        transposed,
95        output_padding,
96        groups,
97        output_min,
98        output_max,
99    )
100    return out
101
102
103name = "conv_with_clamp.out"
104lib.define(
105    f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Scalar? output_min, Scalar? output_max, *, Tensor(a!) out) -> Tensor(a!)"
106)
107lib.impl(name, conv_with_clamp_out_impl, "CompositeExplicitAutograd")
108
109#################
110## grid_priors ##
111#################
112
113
114# The dimension of x should be larger than 1
115def grid_priors_impl(
116    x,
117    stride,
118    offset,
119):
120    height, width = x.shape[-2:]
121    # Need to specify device of torch.arange to avoid executorch exporting error
122    shift_x = (torch.arange(0, width, device=x.device) + offset) * stride
123    shift_y = (torch.arange(0, height, device=x.device) + offset) * stride
124    # Need to specify indexing parameter ('ij' is the default value) to avoid executorch exporting error
125    shift_xx, shift_yy = torch.meshgrid([shift_y, shift_x], indexing="ij")
126    shift_xx = shift_xx.reshape(-1)
127    shift_yy = shift_yy.reshape(-1)
128    shifts = torch.stack((shift_yy, shift_xx), dim=-1)
129    return shifts
130
131
132name = "grid_priors"
133lib.define(f"{name}(Tensor self, int stride, float offset) -> Tensor")
134lib.impl(name, grid_priors_impl, "CompositeExplicitAutograd")
135grid_priors_op = getattr(getattr(torch.ops, namespace), name)
136
137
138# When lowering to executorch, ops are converted from default to out variant. Hence, custom ops define both variants.
139def grid_priors_out_impl(
140    x,
141    stride,
142    offset,
143    out,
144):
145    out = grid_priors_impl(x, stride, offset)
146    return out
147
148
149name = "grid_priors_out"
150lib.define(
151    f"{name}(Tensor self, int stride, float offset, *, Tensor(a!) out) -> Tensor(a!)"
152)
153lib.impl(name, grid_priors_out_impl, "CompositeExplicitAutograd")
154
155########################
156## linear_weight_int4 ##
157########################
158
159
160def linear_weight_int4_impl(
161    x: torch.Tensor,
162    weights_4x8: torch.Tensor,
163    groupsize: int,
164    scales_and_zeros: torch.Tensor,
165    inner_k_tiles: int,
166):
167    original_x_size = x.size()
168    out_features = weights_4x8.size(0)
169    x = x.reshape(-1, original_x_size[-1])
170    weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
171        weights_4x8, inner_k_tiles
172    )
173    out = torch.ops.aten._weight_int4pack_mm(
174        x, weight_int4pack, groupsize, scales_and_zeros
175    )
176    out_shape = original_x_size[:-1] + (out_features,)
177    return out.reshape(out_shape)
178
179
180name = "linear_weight_int4"
181lib.define(
182    f"{name}(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros, int inner_k_tiles) -> Tensor"
183)
184lib.impl(name, linear_weight_int4_impl, "CompositeExplicitAutograd")
185linear_weight_int4_op = getattr(getattr(torch.ops, namespace), name)
186
187######################
188## apply_rotary_emb ##
189######################
190
191
192# Note that this implementation is copied from executorch.examples.models.llama.rope
193# but it is copied here to avoid introducing a dependency on the llama code.
194def apply_rotary_emb_impl(
195    xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
196):
197    def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
198        ndim = x.ndim
199        freqs_cis_ndim = freqs_cis.ndim
200        if freqs_cis_ndim == 3:
201            # freqs_cis: (seq_len, n_heads, head_dim // 2)
202            assert freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1])
203            shape = [
204                d if (i == ndim - 3 or i == ndim - 2 or i == ndim - 1) else 1
205                for i, d in enumerate(x.shape)
206            ]
207        else:
208            # freqs_cis: (seq_len, head_dim // 2)
209            assert freqs_cis.shape == (x.shape[1], x.shape[-1])
210            shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
211        return freqs_cis.view(shape)
212
213    xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
214    xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
215
216    freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
217    freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
218
219    xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
220    xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
221    xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
222    xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
223
224    xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
225    xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
226
227    return xq_out.type_as(xq), xk_out.type_as(xk)
228
229
230name = "apply_rotary_emb"
231lib.define(
232    f"{name}(Tensor xq, Tensor xk, Tensor freqs_cos, Tensor freqs_sin) -> (Tensor, Tensor)"
233)
234lib.impl(name, apply_rotary_emb_impl, "CompositeExplicitAutograd")
235apply_rotary_emb_op = getattr(getattr(torch.ops, namespace), name)
236