xref: /aosp_15_r20/external/executorch/backends/vulkan/test/op_tests/utils/gen_computegraph.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import re
8from dataclasses import dataclass
9from typing import List, Optional, Union
10
11from executorch.backends.vulkan.test.op_tests.utils.aten_types import (
12    AT_INT_ARRAY_REF,
13    AT_SCALAR,
14    AT_TENSOR,
15    AT_TENSOR_LIST,
16    BOOL,
17    DOUBLE,
18    INT,
19    OPT_AT_DOUBLE_ARRAY_REF,
20    OPT_AT_INT_ARRAY_REF,
21    OPT_AT_TENSOR,
22    OPT_BOOL,
23    OPT_DEVICE,
24    OPT_INT64,
25    OPT_LAYOUT,
26    OPT_MEMORY_FORMAT,
27    OPT_SCALAR_TYPE,
28    STRING,
29    TENSOR_VECTOR,
30    THREE_TENSOR_TUPLE,
31    TWO_TENSOR_TUPLE,
32)
33from executorch.backends.vulkan.test.op_tests.utils.test_suite import TestSuite
34
35from torchgen.api import cpp
36from torchgen.api.types import CppSignatureGroup
37from torchgen.gen import generate_static_dispatch_backend_call, translate_args
38from torchgen.gen_aoti_c_shim import gen_static_dispatch_backend_call_signature
39from torchgen.model import NativeFunction, Variant
40
41###################################
42## Compute Graph Code Generation ##
43###################################
44
45
46@dataclass
47class ATenArg:
48    name: str
49    cpp_type: str
50    default: Optional[str]
51
52
53@dataclass
54class ValueRef:
55    name: str
56    src_cpp_name: str
57    src_cpp_type: str
58    is_in: bool = False
59    is_out: bool = False
60    requires_prepack: bool = False
61    supports_prepack: bool = False
62    # When is_dynamic_size is true, the underlying object size is not known
63    # during code-gen. Example is the out value for aten.split where the out
64    # value is a vector<Tensor>. In these cases, we need to use an additional
65    # vector or at::TensorList to track these values.
66    is_dynamic_size: bool = False
67
68    @property
69    def io_value_list_name(self):
70        assert self.is_dynamic_size
71        return f"{self.name}_io_value_list"
72
73    @property
74    def value_list_name(self):
75        assert self.is_dynamic_size
76        return f"{self.name}_value_list"
77
78    @property
79    def vk_out(self):
80        assert self.is_out
81        return f"vk_{self.name}"
82
83
84ValueRefList = Union[ValueRef, List[ValueRef]]
85
86InableCppType = frozenset([AT_TENSOR, AT_TENSOR_LIST])
87
88
89class ComputeGraphGen:
90    backend_key = None
91
92    def __init__(self, op_reg_name: str, f: NativeFunction, suite_def: TestSuite):
93        self.op_reg_name = op_reg_name
94        self.f = f
95        self.suite_def = suite_def
96
97        self.f_sig = CppSignatureGroup.from_native_function(
98            self.f, method=False, fallback_binding=self.f.manual_cpp_binding
99        ).most_faithful_signature()
100
101        self.graph = "graph"
102        self.dot = "->"
103
104        self.args = []
105        self.refs = {}
106
107        self.should_prepack = False
108
109        for binding in self.f_sig.arguments():
110            arg = binding.argument
111            ctype = cpp.argumenttype_type(
112                arg.type, mutable=arg.is_write, binds=arg.name
113            )
114            cpp_type = ctype.cpp_type(strip_ref=True)
115
116            self.args.append(
117                ATenArg(name=arg.name, cpp_type=cpp_type, default=arg.default)
118            )
119
120            # These are the argument will be passed as a "weight" tensor, the
121            # corresponding object will be TensorRef in the compute graph.
122            requires_prepack = (
123                "weight" in arg.name
124                or "bias" in arg.name
125                or "running_mean" in arg.name
126                or "running_var" in arg.name
127            )
128            supports_prepack = False
129            if arg.name in self.suite_def.prepacked_args:
130                supports_prepack = True
131
132            self.refs[arg.name] = ValueRef(
133                name=f"{arg.name}_ref",
134                src_cpp_name=arg.name,
135                src_cpp_type=cpp_type,
136                is_in=(cpp_type in InableCppType),
137                requires_prepack=requires_prepack,
138                supports_prepack=supports_prepack,
139            )
140
141        ret_type = cpp.returns_type(self.f.func.returns, symint=False).cpp_type()
142        self.out = ATenArg(name="out", cpp_type=ret_type, default=None)
143        if ret_type == AT_TENSOR:
144            self.refs["out"] = ValueRef(
145                name="out_ref", src_cpp_name="out", src_cpp_type=ret_type, is_out=True
146            )
147        elif ret_type == TWO_TENSOR_TUPLE:
148            self.refs["out"] = [
149                ValueRef(
150                    name="out_ref_first",
151                    src_cpp_name="std::get<0>(out)",
152                    src_cpp_type="at::Tensor",
153                    is_out=True,
154                ),
155                ValueRef(
156                    name="out_ref_second",
157                    src_cpp_name="std::get<1>(out)",
158                    src_cpp_type="at::Tensor",
159                    is_out=True,
160                ),
161                ValueRef(
162                    name="out_ref",
163                    src_cpp_name="out",
164                    src_cpp_type=ret_type,
165                    is_out=False,
166                ),
167            ]
168        elif ret_type == THREE_TENSOR_TUPLE:
169            self.refs["out"] = [
170                ValueRef(
171                    name="out_ref_first",
172                    src_cpp_name="std::get<0>(out)",
173                    src_cpp_type="at::Tensor",
174                    is_out=True,
175                ),
176                ValueRef(
177                    name="out_ref_second",
178                    src_cpp_name="std::get<1>(out)",
179                    src_cpp_type="at::Tensor",
180                    is_out=True,
181                ),
182                ValueRef(
183                    name="out_ref_third",
184                    src_cpp_name="std::get<2>(out)",
185                    src_cpp_type="at::Tensor",
186                    is_out=True,
187                ),
188                ValueRef(
189                    name="out_ref",
190                    src_cpp_name="out",
191                    src_cpp_type=ret_type,
192                    is_out=False,
193                ),
194            ]
195        elif ret_type == TENSOR_VECTOR:
196            self.refs["out"] = ValueRef(
197                name="out_ref",
198                src_cpp_name="out",
199                src_cpp_type=ret_type,
200                is_out=True,
201                is_dynamic_size=True,
202            )
203        else:
204            raise NotImplementedError(
205                f"ret_type: {ret_type} not supported for out value"
206            )
207
208    ## ATen code generation
209
210    def gen_decl(self, fn_name: str, ret_type: str = "void") -> str:
211        cpp_args = [a.decl() for a in self.f_sig.arguments()]
212        cpp_args_str = ", ".join(cpp_args)
213        return f"{ret_type} {fn_name}({cpp_args_str})"
214
215    def create_aten_fn_call(self) -> str:
216        func_call = generate_static_dispatch_backend_call(
217            self.f_sig, self.f, ComputeGraphGen.backend_key
218        )[7:].replace("::cpu", "")
219
220        return func_call
221
222    def create_aten_method_call(self) -> str:
223        # For functions with only Method variant, we fallback to the function
224        # declared in MethodOperators.h. The method is declared as
225        # at::_ops::{name}::call(*), and ATEN_FN is a handly macro.
226        cpp_sig = gen_static_dispatch_backend_call_signature(self.f_sig, self.f)
227        exprs = translate_args(self.f_sig, cpp_sig)
228        func_call = f"ATEN_FN({self.f_sig.name()})({exprs});"
229        return func_call
230
231    def create_out_src(self, include_declarations: bool = True) -> str:
232        cpp_type = self.out.cpp_type if include_declarations else ""
233        if Variant.function in self.f.variants:
234            return f"{cpp_type} out = " + self.create_aten_fn_call() + "\n"
235        else:
236            return f"{cpp_type} out = " + self.create_aten_method_call() + "\n"
237
238    ## Graph code generation utils
239
240    def prepack_ref(self, ref: ValueRef) -> bool:
241        if ref.requires_prepack:
242            return True
243        else:
244            return ref.supports_prepack and self.should_prepack
245
246    def create_value_decl_for(self, ref: ValueRefList) -> str:  # noqa: C901
247        if isinstance(ref, list):
248            ret_str = ""
249            for r in ref:
250                ret_str += self.create_value_decl_for(r)
251            return ret_str
252
253        cpp_type = "IOValueRef" if (ref.is_in or ref.requires_prepack) else "ValueRef"
254        if ref.src_cpp_type == AT_TENSOR_LIST:
255            ret_str = f"std::vector<IOValueRef> {ref.name}_io_value_refs;\n"
256            ret_str += f"std::vector<ValueRef> {ref.name}_value_refs;\n"
257            return ret_str
258        elif ref.src_cpp_type == TENSOR_VECTOR:
259            ret_str = f"std::vector<IOValueRef> {ref.io_value_list_name};\n"
260            ret_str += f"std::vector<ValueRef> {ref.value_list_name};\n"
261            return ret_str
262        else:
263            return f"{cpp_type} {ref.name};\n"
264
265    def create_value_for(  # noqa: C901
266        self, ref: ValueRefList, include_declarations: bool = True
267    ) -> str:
268        if isinstance(ref, list):
269            ret_str = ""
270            for r in ref:
271                ret_str += self.create_value_for(r)
272            return ret_str
273
274        prepack = self.prepack_ref(ref)
275        ref_is_view = self.suite_def.is_view_op and ref.is_out
276
277        cpp_type = "IOValueRef" if (ref.is_in and not prepack) else "ValueRef"
278        if not include_declarations:
279            cpp_type = ""
280
281        if ref.src_cpp_type == OPT_AT_TENSOR:
282            ret_str = f"{cpp_type} {ref.name} = "
283            if prepack:
284                ret_str = ""
285                if include_declarations:
286                    ret_str += f"IOValueRef {ref.name};\n"
287                ret_str += f"{ref.name}.value = "
288            ret_str += f"!{ref.src_cpp_name}.has_value() ? "
289            ret_str += f"{self.graph}{self.dot}add_none() : "
290            if not prepack:
291                ret_str += f"{self.graph}{self.dot}"
292                ret_str += "add_input_tensor(" if ref.is_in else "add_tensor("
293                ret_str += f"{ref.src_cpp_name}->sizes().vec(), "
294                ret_str += f"from_at_scalartype({ref.src_cpp_name}->scalar_type())); \n"
295            elif prepack:
296                ret_str += f"{self.graph}{self.dot}"
297                ret_str += f"add_tensorref({ref.src_cpp_name}->sizes().vec(), "
298                ret_str += f"from_at_scalartype({ref.src_cpp_name}->scalar_type()), "
299                ret_str += f"{ref.src_cpp_name}->const_data_ptr()); \n"
300            return ret_str
301        elif ref.src_cpp_type == OPT_INT64:
302            ret_str = f"{cpp_type} {ref.name} = "
303            ret_str += f"!{ref.src_cpp_name}.has_value() ? "
304            ret_str += f"{self.graph}{self.dot}add_none() : "
305            ret_str += f"{self.graph}{self.dot}add_scalar<int64_t>"
306            ret_str += f"({ref.src_cpp_name}.value());\n"
307            return ret_str
308        elif (
309            ref.src_cpp_type == OPT_AT_DOUBLE_ARRAY_REF
310            or ref.src_cpp_type == OPT_AT_INT_ARRAY_REF
311        ):
312            ret_str = f"{cpp_type} {ref.name} = "
313            ret_str += f"!{ref.src_cpp_name}.has_value() ? "
314            ret_str += f"{self.graph}{self.dot}add_none() : "
315            ret_str += f"{self.graph}{self.dot}add_scalar_list"
316            ret_str += f"({ref.src_cpp_name}->vec());\n"
317            return ret_str
318        elif ref.src_cpp_type == AT_TENSOR_LIST:
319            assert ref.is_in, "AT_TENSOR_LIST must be an input"
320            # This logic is a bit convoluted. We need to create a IOValueRef for
321            # each tensor, to facilate staging. On the other hand, we will
322            # use the .value tensor to create a ValueList, which will be passed
323            # to the corresponding ops.
324            ret_str = ""
325            if include_declarations:
326                ret_str += f"std::vector<IOValueRef> {ref.name}_io_value_refs;\n"
327                ret_str += f"std::vector<ValueRef> {ref.name}_value_refs;\n"
328            ret_str += f"for (int i=0; i < {ref.src_cpp_name}.size(); i++) {{\n"
329            ret_str += (
330                f"  IOValueRef io_value_ref = {self.graph}{self.dot}add_input_tensor(\n"
331            )
332            ret_str += f"      {ref.src_cpp_name}[i].sizes().vec(),\n"
333            ret_str += (
334                f"      from_at_scalartype({ref.src_cpp_name}[i].scalar_type())); \n"
335            )
336            ret_str += f"  {ref.name}_value_refs.emplace_back(io_value_ref.value);\n"
337            ret_str += f"  {ref.name}_io_value_refs.emplace_back(io_value_ref);\n"
338            ret_str += "}\n"
339            ret_str += f"ValueRef {ref.name} = {self.graph}{self.dot}add_value_list(std::move({ref.name}_value_refs));\n"
340            return ret_str
341        elif ref.src_cpp_type == TENSOR_VECTOR:
342            ret_str = ""
343            if include_declarations:
344                ret_str += f"std::vector<IOValueRef> {ref.io_value_list_name};\n"
345                ret_str += f"std::vector<ValueRef> {ref.value_list_name};\n"
346            ret_str += f"""
347for (int i=0; i<out.size(); i++) {{
348    const at::Tensor& cur = out[i];
349    IOValueRef io_value_ref;
350    io_value_ref.value = {self.graph}{self.dot}add_tensor(
351        cur.sizes().vec(), from_at_scalartype(cur.scalar_type()));
352    {ref.io_value_list_name}.emplace_back(io_value_ref);
353    {ref.value_list_name}.emplace_back(io_value_ref.value);
354}}
355ValueRef out_ref = {self.graph}{self.dot}add_value_list(std::move({ref.value_list_name}));
356"""
357            return ret_str
358
359        ret_str = f"{cpp_type} {ref.name} = {self.graph}{self.dot}"
360        if prepack:
361            ret_str = ""
362            if include_declarations:
363                ret_str = f"IOValueRef {ref.name};\n"
364            ret_str += f"{ref.name}.value = {self.graph}{self.dot}"
365
366        if ref.src_cpp_type == AT_TENSOR and ref_is_view:
367            input_name = None
368            for _name, ref in self.refs.items():
369                if ref.is_in and ref.src_cpp_type == AT_TENSOR:
370                    input_name = ref.name
371
372            assert input_name is not None
373            ret_str += f"add_tensor_view({input_name}.value);"
374        elif ref.src_cpp_type == AT_TENSOR and not prepack:
375            ret_str += "add_input_tensor(" if ref.is_in else "add_tensor("
376            ret_str += f"{ref.src_cpp_name}.sizes().vec(), "
377            ret_str += f"from_at_scalartype({ref.src_cpp_name}.scalar_type())); \n"
378        elif ref.src_cpp_type == AT_TENSOR and prepack:
379            ret_str += f"add_tensorref({ref.src_cpp_name}.sizes().vec(), "
380            ret_str += f"from_at_scalartype({ref.src_cpp_name}.scalar_type()), "
381            ret_str += f"{ref.src_cpp_name}.const_data_ptr()); \n"
382        elif ref.src_cpp_type == AT_SCALAR:
383            # TODO(ssjia): generalize this to work with all scalar types
384            ret_str += f"add_scalar<double>({ref.src_cpp_name}.toDouble()); \n"
385        elif ref.src_cpp_type == AT_INT_ARRAY_REF:
386            ret_str += f"add_scalar_list({ref.src_cpp_name}.vec()); \n"
387        elif ref.src_cpp_type == BOOL:
388            ret_str += f"add_scalar<bool>({ref.src_cpp_name}); \n"
389        elif ref.src_cpp_type == INT:
390            ret_str += f"add_scalar<int64_t>({ref.src_cpp_name}); \n"
391        elif ref.src_cpp_type == DOUBLE:
392            ret_str += f"add_scalar<double>({ref.src_cpp_name}); \n"
393        elif (
394            ref.src_cpp_type == OPT_SCALAR_TYPE
395            or ref.src_cpp_type == OPT_LAYOUT
396            or ref.src_cpp_type == OPT_DEVICE
397            or ref.src_cpp_type == OPT_BOOL
398            or ref.src_cpp_type == OPT_MEMORY_FORMAT
399        ):
400            ret_str += "add_none(); \n"
401        elif ref.src_cpp_type == STRING:
402            ret_str += f"add_string(std::string({ref.src_cpp_name})); \n"
403        elif ref.src_cpp_type == TWO_TENSOR_TUPLE:
404            ret_str += f"add_value_list({{{ref.name}_first, {ref.name}_second}}); \n"
405        elif ref.src_cpp_type == THREE_TENSOR_TUPLE:
406            ret_str += f"add_value_list({{{ref.name}_first, {ref.name}_second, {ref.name}_third}}); \n"
407        else:
408            raise RuntimeError(f"Unsupported cpp type {ref.src_cpp_type}")
409
410        return ret_str
411
412    def create_op_call(self) -> str:
413        deref = "*" if self.dot == "->" else ""
414        op_create_code = f'VK_GET_OP_FN("{self.op_reg_name}")({deref}{self.graph}, {{'
415
416        for aten_arg in self.args:
417            ref = self.refs[aten_arg.name]
418            if ref.src_cpp_type == AT_TENSOR_LIST:
419                # Special case. Underlying tensors are input tensors, but the
420                # container itself is just a normal value.
421                op_create_code += f"{ref.name}, "
422            else:
423                op_create_code += (
424                    f"{ref.name}.value, "
425                    if ref.is_in or ref.requires_prepack or ref.is_out
426                    else f"{ref.name}, "
427                )
428                # op_create_code += f"{ref.name}, "
429
430        op_create_code += "out_ref});\n"
431        return op_create_code
432
433    def gen_output_staging_valueref_decl(self, ref: ValueRefList) -> str:
434        if isinstance(ref, list):
435            ret_str = ""
436            for r in ref[:-1]:
437                ret_str += self.gen_output_staging_valueref_decl(r)
438            return ret_str
439        elif ref.src_cpp_type == TENSOR_VECTOR:
440            assert ref.is_out
441            ret_str = ""
442            return ret_str
443
444        assert ref.src_cpp_type == AT_TENSOR and ref.is_out
445        return f"ValueRef {ref.name}_staging;\n"
446
447    def set_output(self, ref: ValueRefList, include_declarations: bool = True) -> str:
448        if isinstance(ref, list):
449            ret_str = ""
450            for r in ref[:-1]:
451                ret_str += self.set_output(r, include_declarations)
452            return ret_str
453        elif ref.src_cpp_type == TENSOR_VECTOR:
454            assert ref.is_out
455            ret_str = f"""
456for (int i=0; i<out.size(); i++) {{
457    {ref.io_value_list_name}[i].staging = {self.graph}{self.dot}set_output_tensor(
458        {ref.io_value_list_name}[i].value);
459}}
460"""
461            return ret_str
462
463        assert ref.src_cpp_type == AT_TENSOR and ref.is_out
464        cpptype = "ValueRef" if include_declarations else ""
465        ret_str = f"{cpptype} {ref.name}_staging = {self.graph}{self.dot}"
466        ret_str += f"set_output_tensor({ref.name});\n"
467        return ret_str
468
469    def virtual_resize(self, ref: ValueRefList) -> str:
470        assert isinstance(ref, ValueRef)
471        assert ref.src_cpp_type in InableCppType and ref.is_in
472        if self.prepack_ref(ref):
473            return ""
474
475        if ref.src_cpp_type == AT_TENSOR:
476            ret_str = f"{self.graph}{self.dot}get_tensor({ref.name}.value)"
477            ret_str += f"->virtual_resize({ref.src_cpp_name}.sizes().vec());\n"
478        elif ref.src_cpp_type == AT_TENSOR_LIST:
479            ret_str = ""
480            ret_str += f"for (int i=0; i < {ref.name}_io_value_refs.size(); i++) {{\n"
481            ret_str += (
482                f"  {self.graph}{self.dot}get_tensor({ref.name}_io_value_refs[i].value)"
483            )
484            ret_str += f"->virtual_resize({ref.src_cpp_name}[i].sizes().vec());\n"
485            ret_str += "}\n"
486        else:
487            raise AssertionError(f"{ref.src_cpp_type} not expected")
488
489        return ret_str
490
491    def copy_into_staging(self, ref: ValueRefList) -> str:
492        assert isinstance(ref, ValueRef)
493        assert ref.src_cpp_type in InableCppType and ref.is_in
494
495        if self.prepack_ref(ref):
496            return ""
497
498        if ref.src_cpp_type == AT_TENSOR:
499            ret_str = f"{self.graph}{self.dot}copy_into_staging("
500            ret_str += f"{ref.name}.staging, "
501            ret_str += f"{ref.src_cpp_name}.const_data_ptr(), "
502            ret_str += f"{ref.src_cpp_name}.numel());\n"
503        elif ref.src_cpp_type == AT_TENSOR_LIST:
504            ret_str = ""
505            ret_str += f"for (int i=0; i < {ref.name}_io_value_refs.size(); i++) {{\n"
506            ret_str += f"  {self.graph}{self.dot}copy_into_staging("
507            ret_str += f"{ref.name}_io_value_refs[i].staging, "
508            ret_str += f"{ref.src_cpp_name}[i].const_data_ptr(), "
509            ret_str += f"{ref.src_cpp_name}[i].numel());\n"
510            ret_str += "}\n"
511        else:
512            raise AssertionError(f"{ref.src_cpp_type} not expected")
513        return ret_str
514
515    def declare_vk_out_for(self, ref: Union[ValueRef, List[ValueRef]]) -> str:
516        if isinstance(ref, list):
517            ret_str = ""
518            for r in ref[:-1]:
519                ret_str += self.declare_vk_out_for(r)
520            return ret_str
521        elif ref.src_cpp_type == TENSOR_VECTOR:
522            assert ref.is_out
523            ret_str = f"""
524std::vector<at::Tensor> {ref.vk_out};
525for (int i=0; i<out.size(); i++) {{
526    {ref.vk_out}.emplace_back(at::empty_like(out[i]).contiguous());
527}}
528"""
529            return ret_str
530
531        assert ref.src_cpp_type == AT_TENSOR and ref.is_out
532        ret_str = f"at::Tensor vk_{ref.name} = at::empty_like({ref.src_cpp_name})"
533        ret_str += ".contiguous();\n"
534        return ret_str
535
536    def copy_from_staging(self, ref: ValueRefList) -> str:
537        if isinstance(ref, list):
538            ret_str = ""
539            for r in ref[:-1]:
540                ret_str += self.copy_from_staging(r)
541            return ret_str
542        elif ref.src_cpp_type == TENSOR_VECTOR:
543            assert ref.is_out
544            ret_str = f"""
545for (int i=0; i<out.size(); i++) {{
546    {self.graph}{self.dot}copy_from_staging(
547        {ref.io_value_list_name}[i].staging,
548        {ref.vk_out}[i].mutable_data_ptr(),
549        {ref.vk_out}[i].numel());
550}}
551"""
552            return ret_str
553
554        assert ref.src_cpp_type == AT_TENSOR and ref.is_out
555        ret_str = f"{self.graph}{self.dot}copy_from_staging({ref.name}_staging, "
556        ret_str += f"vk_{ref.name}.mutable_data_ptr(), vk_{ref.name}.numel());\n"
557
558        return ret_str
559
560    ## Misc. code generation utilities
561
562    def check_graph_out(self, ref: ValueRefList) -> str:
563        if isinstance(ref, list):
564            ret_str = ""
565            for r in ref[:-1]:
566                ret_str += self.check_graph_out(r)
567            return ret_str
568        elif ref.src_cpp_type == TENSOR_VECTOR:
569            assert ref.is_out
570            ret_str = f"""
571for (int i=0; i<out.size(); i++) {{
572    EXPECT_TRUE(check_close(out[i], {ref.vk_out}[i], rtol, atol));
573}}
574"""
575            return ret_str
576
577        return (
578            f"EXPECT_TRUE(check_close({ref.src_cpp_name}, vk_{ref.name}, rtol, atol));"
579        )
580
581    ## Top level code generation
582
583    def gen_arg_valueref_decls(self) -> str:
584        ret_str = ""
585        for aten_arg in self.args:
586            ref = self.refs[aten_arg.name]
587            ret_str += self.create_value_decl_for(ref)
588
589        ret_str += self.create_value_decl_for(self.refs["out"])
590        ret_str += f"{self.out.cpp_type} out;\n"
591        ret_str += self.gen_output_staging_valueref_decl(self.refs["out"])
592        return ret_str
593
594    def gen_graph_build_code(self, include_declarations: bool = True) -> str:
595        graph_build = self.create_out_src(include_declarations)
596        for aten_arg in self.args:
597            graph_build += self.create_value_for(
598                self.refs[aten_arg.name], include_declarations
599            )
600
601        graph_build += self.create_value_for(self.refs["out"], include_declarations)
602        graph_build += self.create_op_call()
603
604        graph_build += self.set_output(self.refs["out"], include_declarations)
605
606        graph_build += f"{self.graph}{self.dot}prepare();\n"
607        graph_build += f"{self.graph}{self.dot}encode_prepack();\n"
608        graph_build += f"{self.graph}{self.dot}prepack();\n"
609        graph_build += f"{self.graph}{self.dot}encode_execute();\n"
610
611        graph_build += "\n"
612        return graph_build
613
614    def gen_graph_exec_code(self, check_output=True) -> str:
615        graph_exec = ""
616        for aten_arg in self.args:
617            ref = self.refs[aten_arg.name]
618            if ref.is_in:
619                graph_exec += self.virtual_resize(ref)
620                graph_exec += self.copy_into_staging(ref)
621
622        graph_exec += f"{self.graph}{self.dot}propagate_resize();\n"
623        graph_exec += f"{self.graph}{self.dot}execute();\n"
624
625        graph_exec += self.declare_vk_out_for(self.refs["out"])
626        graph_exec += self.copy_from_staging(self.refs["out"])
627        if check_output:
628            graph_exec += self.check_graph_out(self.refs["out"])
629
630        graph_exec = re.sub(r"^", "  ", graph_exec, flags=re.M)
631        graph_exec = "{\n" + graph_exec + "\n}"
632
633        return graph_exec
634
635    def gen_conditional_skips(self, skip_str: str = "GTEST_SKIP();") -> str:
636        fp16_skip = f"if (!{self.graph}{self.dot}context()->adapter_ptr()->has_full_float16_buffers_support()) {{\n"
637        fp16_skip += f"  {skip_str}\n"
638        fp16_skip += "}"
639        fp16_skip = re.sub(r"^", "  ", fp16_skip, flags=re.M) + "\n"
640
641        int8_skip = f"if (!{self.graph}{self.dot}context()->adapter_ptr()->has_full_int8_buffers_support()) {{\n"
642        int8_skip += f"  {skip_str};\n"
643        int8_skip += "}\n"
644
645        skips = ""
646
647        skips += "if (test_dtype == at::kHalf) {\n"
648        skips += fp16_skip
649        skips += "}\n"
650
651        for _, dtype in self.suite_def.arg_dtype.items():
652            if dtype == "at::kChar" or dtype == "at::kQInt8":
653                skips += int8_skip
654                continue
655
656        skips += "\n"
657        return skips
658
659    def gen_op_check_fn(self) -> str:
660        op_name = self.f.func.name.unambiguous_name()
661        if self.suite_def.test_name_suffix is not None:
662            op_name += "_" + self.suite_def.test_name_suffix
663
664        op_check_fn = self.gen_decl(f"check_{op_name}") + " {\n"
665        if self.should_prepack:
666            op_check_fn = self.gen_decl(f"prepacked_check_{op_name}") + " {\n"
667
668        op_check_fn_body = ""
669        op_check_fn_body += self.gen_conditional_skips()
670        op_check_fn_body += self.gen_graph_build_code()
671        op_check_fn_body += self.gen_graph_exec_code()
672
673        op_check_fn_body = re.sub(r"^", "    ", op_check_fn_body, flags=re.M)
674
675        op_check_fn += op_check_fn_body
676        op_check_fn += "\n  }"
677
678        return op_check_fn
679
680    def gen_build_graph_fn(self, include_declarations: bool = False) -> str:
681        op_name = self.f.func.name.unambiguous_name()
682        if self.suite_def.test_name_suffix is not None:
683            op_name += "_" + self.suite_def.test_name_suffix
684        op_build_graph_fn = self.gen_decl(f"build_graph_{op_name}") + " {\n"
685        if self.should_prepack:
686            op_build_graph_fn = (
687                self.gen_decl(f"prepacked_build_graph_{op_name}") + " {\n"
688            )
689
690        op_build_graph_fn_body = ""
691        op_build_graph_fn_body += self.gen_graph_build_code(include_declarations)
692
693        op_build_graph_fn += op_build_graph_fn_body
694        op_build_graph_fn += "\n  }"
695        return op_build_graph_fn
696
697    def gen_op_exec_graph_fn(self) -> str:
698        op_name = self.f.func.name.unambiguous_name()
699        if self.suite_def.test_name_suffix is not None:
700            op_name += "_" + self.suite_def.test_name_suffix
701        op_benchmark_fn = self.gen_decl(f"benchmark_{op_name}") + " {\n"
702        if self.should_prepack:
703            op_benchmark_fn = self.gen_decl(f"prepacked_benchmark_{op_name}") + " {\n"
704
705        op_benchmark_fn_body = ""
706        op_benchmark_fn_body += self.gen_graph_exec_code(False)
707
708        op_benchmark_fn_body = re.sub(r"^", "    ", op_benchmark_fn_body, flags=re.M)
709
710        op_benchmark_fn += op_benchmark_fn_body
711        op_benchmark_fn += "\n  }"
712        return op_benchmark_fn
713