1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import re 8from dataclasses import dataclass 9from typing import List, Optional, Union 10 11from executorch.backends.vulkan.test.op_tests.utils.aten_types import ( 12 AT_INT_ARRAY_REF, 13 AT_SCALAR, 14 AT_TENSOR, 15 AT_TENSOR_LIST, 16 BOOL, 17 DOUBLE, 18 INT, 19 OPT_AT_DOUBLE_ARRAY_REF, 20 OPT_AT_INT_ARRAY_REF, 21 OPT_AT_TENSOR, 22 OPT_BOOL, 23 OPT_DEVICE, 24 OPT_INT64, 25 OPT_LAYOUT, 26 OPT_MEMORY_FORMAT, 27 OPT_SCALAR_TYPE, 28 STRING, 29 TENSOR_VECTOR, 30 THREE_TENSOR_TUPLE, 31 TWO_TENSOR_TUPLE, 32) 33from executorch.backends.vulkan.test.op_tests.utils.test_suite import TestSuite 34 35from torchgen.api import cpp 36from torchgen.api.types import CppSignatureGroup 37from torchgen.gen import generate_static_dispatch_backend_call, translate_args 38from torchgen.gen_aoti_c_shim import gen_static_dispatch_backend_call_signature 39from torchgen.model import NativeFunction, Variant 40 41################################### 42## Compute Graph Code Generation ## 43################################### 44 45 46@dataclass 47class ATenArg: 48 name: str 49 cpp_type: str 50 default: Optional[str] 51 52 53@dataclass 54class ValueRef: 55 name: str 56 src_cpp_name: str 57 src_cpp_type: str 58 is_in: bool = False 59 is_out: bool = False 60 requires_prepack: bool = False 61 supports_prepack: bool = False 62 # When is_dynamic_size is true, the underlying object size is not known 63 # during code-gen. Example is the out value for aten.split where the out 64 # value is a vector<Tensor>. In these cases, we need to use an additional 65 # vector or at::TensorList to track these values. 66 is_dynamic_size: bool = False 67 68 @property 69 def io_value_list_name(self): 70 assert self.is_dynamic_size 71 return f"{self.name}_io_value_list" 72 73 @property 74 def value_list_name(self): 75 assert self.is_dynamic_size 76 return f"{self.name}_value_list" 77 78 @property 79 def vk_out(self): 80 assert self.is_out 81 return f"vk_{self.name}" 82 83 84ValueRefList = Union[ValueRef, List[ValueRef]] 85 86InableCppType = frozenset([AT_TENSOR, AT_TENSOR_LIST]) 87 88 89class ComputeGraphGen: 90 backend_key = None 91 92 def __init__(self, op_reg_name: str, f: NativeFunction, suite_def: TestSuite): 93 self.op_reg_name = op_reg_name 94 self.f = f 95 self.suite_def = suite_def 96 97 self.f_sig = CppSignatureGroup.from_native_function( 98 self.f, method=False, fallback_binding=self.f.manual_cpp_binding 99 ).most_faithful_signature() 100 101 self.graph = "graph" 102 self.dot = "->" 103 104 self.args = [] 105 self.refs = {} 106 107 self.should_prepack = False 108 109 for binding in self.f_sig.arguments(): 110 arg = binding.argument 111 ctype = cpp.argumenttype_type( 112 arg.type, mutable=arg.is_write, binds=arg.name 113 ) 114 cpp_type = ctype.cpp_type(strip_ref=True) 115 116 self.args.append( 117 ATenArg(name=arg.name, cpp_type=cpp_type, default=arg.default) 118 ) 119 120 # These are the argument will be passed as a "weight" tensor, the 121 # corresponding object will be TensorRef in the compute graph. 122 requires_prepack = ( 123 "weight" in arg.name 124 or "bias" in arg.name 125 or "running_mean" in arg.name 126 or "running_var" in arg.name 127 ) 128 supports_prepack = False 129 if arg.name in self.suite_def.prepacked_args: 130 supports_prepack = True 131 132 self.refs[arg.name] = ValueRef( 133 name=f"{arg.name}_ref", 134 src_cpp_name=arg.name, 135 src_cpp_type=cpp_type, 136 is_in=(cpp_type in InableCppType), 137 requires_prepack=requires_prepack, 138 supports_prepack=supports_prepack, 139 ) 140 141 ret_type = cpp.returns_type(self.f.func.returns, symint=False).cpp_type() 142 self.out = ATenArg(name="out", cpp_type=ret_type, default=None) 143 if ret_type == AT_TENSOR: 144 self.refs["out"] = ValueRef( 145 name="out_ref", src_cpp_name="out", src_cpp_type=ret_type, is_out=True 146 ) 147 elif ret_type == TWO_TENSOR_TUPLE: 148 self.refs["out"] = [ 149 ValueRef( 150 name="out_ref_first", 151 src_cpp_name="std::get<0>(out)", 152 src_cpp_type="at::Tensor", 153 is_out=True, 154 ), 155 ValueRef( 156 name="out_ref_second", 157 src_cpp_name="std::get<1>(out)", 158 src_cpp_type="at::Tensor", 159 is_out=True, 160 ), 161 ValueRef( 162 name="out_ref", 163 src_cpp_name="out", 164 src_cpp_type=ret_type, 165 is_out=False, 166 ), 167 ] 168 elif ret_type == THREE_TENSOR_TUPLE: 169 self.refs["out"] = [ 170 ValueRef( 171 name="out_ref_first", 172 src_cpp_name="std::get<0>(out)", 173 src_cpp_type="at::Tensor", 174 is_out=True, 175 ), 176 ValueRef( 177 name="out_ref_second", 178 src_cpp_name="std::get<1>(out)", 179 src_cpp_type="at::Tensor", 180 is_out=True, 181 ), 182 ValueRef( 183 name="out_ref_third", 184 src_cpp_name="std::get<2>(out)", 185 src_cpp_type="at::Tensor", 186 is_out=True, 187 ), 188 ValueRef( 189 name="out_ref", 190 src_cpp_name="out", 191 src_cpp_type=ret_type, 192 is_out=False, 193 ), 194 ] 195 elif ret_type == TENSOR_VECTOR: 196 self.refs["out"] = ValueRef( 197 name="out_ref", 198 src_cpp_name="out", 199 src_cpp_type=ret_type, 200 is_out=True, 201 is_dynamic_size=True, 202 ) 203 else: 204 raise NotImplementedError( 205 f"ret_type: {ret_type} not supported for out value" 206 ) 207 208 ## ATen code generation 209 210 def gen_decl(self, fn_name: str, ret_type: str = "void") -> str: 211 cpp_args = [a.decl() for a in self.f_sig.arguments()] 212 cpp_args_str = ", ".join(cpp_args) 213 return f"{ret_type} {fn_name}({cpp_args_str})" 214 215 def create_aten_fn_call(self) -> str: 216 func_call = generate_static_dispatch_backend_call( 217 self.f_sig, self.f, ComputeGraphGen.backend_key 218 )[7:].replace("::cpu", "") 219 220 return func_call 221 222 def create_aten_method_call(self) -> str: 223 # For functions with only Method variant, we fallback to the function 224 # declared in MethodOperators.h. The method is declared as 225 # at::_ops::{name}::call(*), and ATEN_FN is a handly macro. 226 cpp_sig = gen_static_dispatch_backend_call_signature(self.f_sig, self.f) 227 exprs = translate_args(self.f_sig, cpp_sig) 228 func_call = f"ATEN_FN({self.f_sig.name()})({exprs});" 229 return func_call 230 231 def create_out_src(self, include_declarations: bool = True) -> str: 232 cpp_type = self.out.cpp_type if include_declarations else "" 233 if Variant.function in self.f.variants: 234 return f"{cpp_type} out = " + self.create_aten_fn_call() + "\n" 235 else: 236 return f"{cpp_type} out = " + self.create_aten_method_call() + "\n" 237 238 ## Graph code generation utils 239 240 def prepack_ref(self, ref: ValueRef) -> bool: 241 if ref.requires_prepack: 242 return True 243 else: 244 return ref.supports_prepack and self.should_prepack 245 246 def create_value_decl_for(self, ref: ValueRefList) -> str: # noqa: C901 247 if isinstance(ref, list): 248 ret_str = "" 249 for r in ref: 250 ret_str += self.create_value_decl_for(r) 251 return ret_str 252 253 cpp_type = "IOValueRef" if (ref.is_in or ref.requires_prepack) else "ValueRef" 254 if ref.src_cpp_type == AT_TENSOR_LIST: 255 ret_str = f"std::vector<IOValueRef> {ref.name}_io_value_refs;\n" 256 ret_str += f"std::vector<ValueRef> {ref.name}_value_refs;\n" 257 return ret_str 258 elif ref.src_cpp_type == TENSOR_VECTOR: 259 ret_str = f"std::vector<IOValueRef> {ref.io_value_list_name};\n" 260 ret_str += f"std::vector<ValueRef> {ref.value_list_name};\n" 261 return ret_str 262 else: 263 return f"{cpp_type} {ref.name};\n" 264 265 def create_value_for( # noqa: C901 266 self, ref: ValueRefList, include_declarations: bool = True 267 ) -> str: 268 if isinstance(ref, list): 269 ret_str = "" 270 for r in ref: 271 ret_str += self.create_value_for(r) 272 return ret_str 273 274 prepack = self.prepack_ref(ref) 275 ref_is_view = self.suite_def.is_view_op and ref.is_out 276 277 cpp_type = "IOValueRef" if (ref.is_in and not prepack) else "ValueRef" 278 if not include_declarations: 279 cpp_type = "" 280 281 if ref.src_cpp_type == OPT_AT_TENSOR: 282 ret_str = f"{cpp_type} {ref.name} = " 283 if prepack: 284 ret_str = "" 285 if include_declarations: 286 ret_str += f"IOValueRef {ref.name};\n" 287 ret_str += f"{ref.name}.value = " 288 ret_str += f"!{ref.src_cpp_name}.has_value() ? " 289 ret_str += f"{self.graph}{self.dot}add_none() : " 290 if not prepack: 291 ret_str += f"{self.graph}{self.dot}" 292 ret_str += "add_input_tensor(" if ref.is_in else "add_tensor(" 293 ret_str += f"{ref.src_cpp_name}->sizes().vec(), " 294 ret_str += f"from_at_scalartype({ref.src_cpp_name}->scalar_type())); \n" 295 elif prepack: 296 ret_str += f"{self.graph}{self.dot}" 297 ret_str += f"add_tensorref({ref.src_cpp_name}->sizes().vec(), " 298 ret_str += f"from_at_scalartype({ref.src_cpp_name}->scalar_type()), " 299 ret_str += f"{ref.src_cpp_name}->const_data_ptr()); \n" 300 return ret_str 301 elif ref.src_cpp_type == OPT_INT64: 302 ret_str = f"{cpp_type} {ref.name} = " 303 ret_str += f"!{ref.src_cpp_name}.has_value() ? " 304 ret_str += f"{self.graph}{self.dot}add_none() : " 305 ret_str += f"{self.graph}{self.dot}add_scalar<int64_t>" 306 ret_str += f"({ref.src_cpp_name}.value());\n" 307 return ret_str 308 elif ( 309 ref.src_cpp_type == OPT_AT_DOUBLE_ARRAY_REF 310 or ref.src_cpp_type == OPT_AT_INT_ARRAY_REF 311 ): 312 ret_str = f"{cpp_type} {ref.name} = " 313 ret_str += f"!{ref.src_cpp_name}.has_value() ? " 314 ret_str += f"{self.graph}{self.dot}add_none() : " 315 ret_str += f"{self.graph}{self.dot}add_scalar_list" 316 ret_str += f"({ref.src_cpp_name}->vec());\n" 317 return ret_str 318 elif ref.src_cpp_type == AT_TENSOR_LIST: 319 assert ref.is_in, "AT_TENSOR_LIST must be an input" 320 # This logic is a bit convoluted. We need to create a IOValueRef for 321 # each tensor, to facilate staging. On the other hand, we will 322 # use the .value tensor to create a ValueList, which will be passed 323 # to the corresponding ops. 324 ret_str = "" 325 if include_declarations: 326 ret_str += f"std::vector<IOValueRef> {ref.name}_io_value_refs;\n" 327 ret_str += f"std::vector<ValueRef> {ref.name}_value_refs;\n" 328 ret_str += f"for (int i=0; i < {ref.src_cpp_name}.size(); i++) {{\n" 329 ret_str += ( 330 f" IOValueRef io_value_ref = {self.graph}{self.dot}add_input_tensor(\n" 331 ) 332 ret_str += f" {ref.src_cpp_name}[i].sizes().vec(),\n" 333 ret_str += ( 334 f" from_at_scalartype({ref.src_cpp_name}[i].scalar_type())); \n" 335 ) 336 ret_str += f" {ref.name}_value_refs.emplace_back(io_value_ref.value);\n" 337 ret_str += f" {ref.name}_io_value_refs.emplace_back(io_value_ref);\n" 338 ret_str += "}\n" 339 ret_str += f"ValueRef {ref.name} = {self.graph}{self.dot}add_value_list(std::move({ref.name}_value_refs));\n" 340 return ret_str 341 elif ref.src_cpp_type == TENSOR_VECTOR: 342 ret_str = "" 343 if include_declarations: 344 ret_str += f"std::vector<IOValueRef> {ref.io_value_list_name};\n" 345 ret_str += f"std::vector<ValueRef> {ref.value_list_name};\n" 346 ret_str += f""" 347for (int i=0; i<out.size(); i++) {{ 348 const at::Tensor& cur = out[i]; 349 IOValueRef io_value_ref; 350 io_value_ref.value = {self.graph}{self.dot}add_tensor( 351 cur.sizes().vec(), from_at_scalartype(cur.scalar_type())); 352 {ref.io_value_list_name}.emplace_back(io_value_ref); 353 {ref.value_list_name}.emplace_back(io_value_ref.value); 354}} 355ValueRef out_ref = {self.graph}{self.dot}add_value_list(std::move({ref.value_list_name})); 356""" 357 return ret_str 358 359 ret_str = f"{cpp_type} {ref.name} = {self.graph}{self.dot}" 360 if prepack: 361 ret_str = "" 362 if include_declarations: 363 ret_str = f"IOValueRef {ref.name};\n" 364 ret_str += f"{ref.name}.value = {self.graph}{self.dot}" 365 366 if ref.src_cpp_type == AT_TENSOR and ref_is_view: 367 input_name = None 368 for _name, ref in self.refs.items(): 369 if ref.is_in and ref.src_cpp_type == AT_TENSOR: 370 input_name = ref.name 371 372 assert input_name is not None 373 ret_str += f"add_tensor_view({input_name}.value);" 374 elif ref.src_cpp_type == AT_TENSOR and not prepack: 375 ret_str += "add_input_tensor(" if ref.is_in else "add_tensor(" 376 ret_str += f"{ref.src_cpp_name}.sizes().vec(), " 377 ret_str += f"from_at_scalartype({ref.src_cpp_name}.scalar_type())); \n" 378 elif ref.src_cpp_type == AT_TENSOR and prepack: 379 ret_str += f"add_tensorref({ref.src_cpp_name}.sizes().vec(), " 380 ret_str += f"from_at_scalartype({ref.src_cpp_name}.scalar_type()), " 381 ret_str += f"{ref.src_cpp_name}.const_data_ptr()); \n" 382 elif ref.src_cpp_type == AT_SCALAR: 383 # TODO(ssjia): generalize this to work with all scalar types 384 ret_str += f"add_scalar<double>({ref.src_cpp_name}.toDouble()); \n" 385 elif ref.src_cpp_type == AT_INT_ARRAY_REF: 386 ret_str += f"add_scalar_list({ref.src_cpp_name}.vec()); \n" 387 elif ref.src_cpp_type == BOOL: 388 ret_str += f"add_scalar<bool>({ref.src_cpp_name}); \n" 389 elif ref.src_cpp_type == INT: 390 ret_str += f"add_scalar<int64_t>({ref.src_cpp_name}); \n" 391 elif ref.src_cpp_type == DOUBLE: 392 ret_str += f"add_scalar<double>({ref.src_cpp_name}); \n" 393 elif ( 394 ref.src_cpp_type == OPT_SCALAR_TYPE 395 or ref.src_cpp_type == OPT_LAYOUT 396 or ref.src_cpp_type == OPT_DEVICE 397 or ref.src_cpp_type == OPT_BOOL 398 or ref.src_cpp_type == OPT_MEMORY_FORMAT 399 ): 400 ret_str += "add_none(); \n" 401 elif ref.src_cpp_type == STRING: 402 ret_str += f"add_string(std::string({ref.src_cpp_name})); \n" 403 elif ref.src_cpp_type == TWO_TENSOR_TUPLE: 404 ret_str += f"add_value_list({{{ref.name}_first, {ref.name}_second}}); \n" 405 elif ref.src_cpp_type == THREE_TENSOR_TUPLE: 406 ret_str += f"add_value_list({{{ref.name}_first, {ref.name}_second, {ref.name}_third}}); \n" 407 else: 408 raise RuntimeError(f"Unsupported cpp type {ref.src_cpp_type}") 409 410 return ret_str 411 412 def create_op_call(self) -> str: 413 deref = "*" if self.dot == "->" else "" 414 op_create_code = f'VK_GET_OP_FN("{self.op_reg_name}")({deref}{self.graph}, {{' 415 416 for aten_arg in self.args: 417 ref = self.refs[aten_arg.name] 418 if ref.src_cpp_type == AT_TENSOR_LIST: 419 # Special case. Underlying tensors are input tensors, but the 420 # container itself is just a normal value. 421 op_create_code += f"{ref.name}, " 422 else: 423 op_create_code += ( 424 f"{ref.name}.value, " 425 if ref.is_in or ref.requires_prepack or ref.is_out 426 else f"{ref.name}, " 427 ) 428 # op_create_code += f"{ref.name}, " 429 430 op_create_code += "out_ref});\n" 431 return op_create_code 432 433 def gen_output_staging_valueref_decl(self, ref: ValueRefList) -> str: 434 if isinstance(ref, list): 435 ret_str = "" 436 for r in ref[:-1]: 437 ret_str += self.gen_output_staging_valueref_decl(r) 438 return ret_str 439 elif ref.src_cpp_type == TENSOR_VECTOR: 440 assert ref.is_out 441 ret_str = "" 442 return ret_str 443 444 assert ref.src_cpp_type == AT_TENSOR and ref.is_out 445 return f"ValueRef {ref.name}_staging;\n" 446 447 def set_output(self, ref: ValueRefList, include_declarations: bool = True) -> str: 448 if isinstance(ref, list): 449 ret_str = "" 450 for r in ref[:-1]: 451 ret_str += self.set_output(r, include_declarations) 452 return ret_str 453 elif ref.src_cpp_type == TENSOR_VECTOR: 454 assert ref.is_out 455 ret_str = f""" 456for (int i=0; i<out.size(); i++) {{ 457 {ref.io_value_list_name}[i].staging = {self.graph}{self.dot}set_output_tensor( 458 {ref.io_value_list_name}[i].value); 459}} 460""" 461 return ret_str 462 463 assert ref.src_cpp_type == AT_TENSOR and ref.is_out 464 cpptype = "ValueRef" if include_declarations else "" 465 ret_str = f"{cpptype} {ref.name}_staging = {self.graph}{self.dot}" 466 ret_str += f"set_output_tensor({ref.name});\n" 467 return ret_str 468 469 def virtual_resize(self, ref: ValueRefList) -> str: 470 assert isinstance(ref, ValueRef) 471 assert ref.src_cpp_type in InableCppType and ref.is_in 472 if self.prepack_ref(ref): 473 return "" 474 475 if ref.src_cpp_type == AT_TENSOR: 476 ret_str = f"{self.graph}{self.dot}get_tensor({ref.name}.value)" 477 ret_str += f"->virtual_resize({ref.src_cpp_name}.sizes().vec());\n" 478 elif ref.src_cpp_type == AT_TENSOR_LIST: 479 ret_str = "" 480 ret_str += f"for (int i=0; i < {ref.name}_io_value_refs.size(); i++) {{\n" 481 ret_str += ( 482 f" {self.graph}{self.dot}get_tensor({ref.name}_io_value_refs[i].value)" 483 ) 484 ret_str += f"->virtual_resize({ref.src_cpp_name}[i].sizes().vec());\n" 485 ret_str += "}\n" 486 else: 487 raise AssertionError(f"{ref.src_cpp_type} not expected") 488 489 return ret_str 490 491 def copy_into_staging(self, ref: ValueRefList) -> str: 492 assert isinstance(ref, ValueRef) 493 assert ref.src_cpp_type in InableCppType and ref.is_in 494 495 if self.prepack_ref(ref): 496 return "" 497 498 if ref.src_cpp_type == AT_TENSOR: 499 ret_str = f"{self.graph}{self.dot}copy_into_staging(" 500 ret_str += f"{ref.name}.staging, " 501 ret_str += f"{ref.src_cpp_name}.const_data_ptr(), " 502 ret_str += f"{ref.src_cpp_name}.numel());\n" 503 elif ref.src_cpp_type == AT_TENSOR_LIST: 504 ret_str = "" 505 ret_str += f"for (int i=0; i < {ref.name}_io_value_refs.size(); i++) {{\n" 506 ret_str += f" {self.graph}{self.dot}copy_into_staging(" 507 ret_str += f"{ref.name}_io_value_refs[i].staging, " 508 ret_str += f"{ref.src_cpp_name}[i].const_data_ptr(), " 509 ret_str += f"{ref.src_cpp_name}[i].numel());\n" 510 ret_str += "}\n" 511 else: 512 raise AssertionError(f"{ref.src_cpp_type} not expected") 513 return ret_str 514 515 def declare_vk_out_for(self, ref: Union[ValueRef, List[ValueRef]]) -> str: 516 if isinstance(ref, list): 517 ret_str = "" 518 for r in ref[:-1]: 519 ret_str += self.declare_vk_out_for(r) 520 return ret_str 521 elif ref.src_cpp_type == TENSOR_VECTOR: 522 assert ref.is_out 523 ret_str = f""" 524std::vector<at::Tensor> {ref.vk_out}; 525for (int i=0; i<out.size(); i++) {{ 526 {ref.vk_out}.emplace_back(at::empty_like(out[i]).contiguous()); 527}} 528""" 529 return ret_str 530 531 assert ref.src_cpp_type == AT_TENSOR and ref.is_out 532 ret_str = f"at::Tensor vk_{ref.name} = at::empty_like({ref.src_cpp_name})" 533 ret_str += ".contiguous();\n" 534 return ret_str 535 536 def copy_from_staging(self, ref: ValueRefList) -> str: 537 if isinstance(ref, list): 538 ret_str = "" 539 for r in ref[:-1]: 540 ret_str += self.copy_from_staging(r) 541 return ret_str 542 elif ref.src_cpp_type == TENSOR_VECTOR: 543 assert ref.is_out 544 ret_str = f""" 545for (int i=0; i<out.size(); i++) {{ 546 {self.graph}{self.dot}copy_from_staging( 547 {ref.io_value_list_name}[i].staging, 548 {ref.vk_out}[i].mutable_data_ptr(), 549 {ref.vk_out}[i].numel()); 550}} 551""" 552 return ret_str 553 554 assert ref.src_cpp_type == AT_TENSOR and ref.is_out 555 ret_str = f"{self.graph}{self.dot}copy_from_staging({ref.name}_staging, " 556 ret_str += f"vk_{ref.name}.mutable_data_ptr(), vk_{ref.name}.numel());\n" 557 558 return ret_str 559 560 ## Misc. code generation utilities 561 562 def check_graph_out(self, ref: ValueRefList) -> str: 563 if isinstance(ref, list): 564 ret_str = "" 565 for r in ref[:-1]: 566 ret_str += self.check_graph_out(r) 567 return ret_str 568 elif ref.src_cpp_type == TENSOR_VECTOR: 569 assert ref.is_out 570 ret_str = f""" 571for (int i=0; i<out.size(); i++) {{ 572 EXPECT_TRUE(check_close(out[i], {ref.vk_out}[i], rtol, atol)); 573}} 574""" 575 return ret_str 576 577 return ( 578 f"EXPECT_TRUE(check_close({ref.src_cpp_name}, vk_{ref.name}, rtol, atol));" 579 ) 580 581 ## Top level code generation 582 583 def gen_arg_valueref_decls(self) -> str: 584 ret_str = "" 585 for aten_arg in self.args: 586 ref = self.refs[aten_arg.name] 587 ret_str += self.create_value_decl_for(ref) 588 589 ret_str += self.create_value_decl_for(self.refs["out"]) 590 ret_str += f"{self.out.cpp_type} out;\n" 591 ret_str += self.gen_output_staging_valueref_decl(self.refs["out"]) 592 return ret_str 593 594 def gen_graph_build_code(self, include_declarations: bool = True) -> str: 595 graph_build = self.create_out_src(include_declarations) 596 for aten_arg in self.args: 597 graph_build += self.create_value_for( 598 self.refs[aten_arg.name], include_declarations 599 ) 600 601 graph_build += self.create_value_for(self.refs["out"], include_declarations) 602 graph_build += self.create_op_call() 603 604 graph_build += self.set_output(self.refs["out"], include_declarations) 605 606 graph_build += f"{self.graph}{self.dot}prepare();\n" 607 graph_build += f"{self.graph}{self.dot}encode_prepack();\n" 608 graph_build += f"{self.graph}{self.dot}prepack();\n" 609 graph_build += f"{self.graph}{self.dot}encode_execute();\n" 610 611 graph_build += "\n" 612 return graph_build 613 614 def gen_graph_exec_code(self, check_output=True) -> str: 615 graph_exec = "" 616 for aten_arg in self.args: 617 ref = self.refs[aten_arg.name] 618 if ref.is_in: 619 graph_exec += self.virtual_resize(ref) 620 graph_exec += self.copy_into_staging(ref) 621 622 graph_exec += f"{self.graph}{self.dot}propagate_resize();\n" 623 graph_exec += f"{self.graph}{self.dot}execute();\n" 624 625 graph_exec += self.declare_vk_out_for(self.refs["out"]) 626 graph_exec += self.copy_from_staging(self.refs["out"]) 627 if check_output: 628 graph_exec += self.check_graph_out(self.refs["out"]) 629 630 graph_exec = re.sub(r"^", " ", graph_exec, flags=re.M) 631 graph_exec = "{\n" + graph_exec + "\n}" 632 633 return graph_exec 634 635 def gen_conditional_skips(self, skip_str: str = "GTEST_SKIP();") -> str: 636 fp16_skip = f"if (!{self.graph}{self.dot}context()->adapter_ptr()->has_full_float16_buffers_support()) {{\n" 637 fp16_skip += f" {skip_str}\n" 638 fp16_skip += "}" 639 fp16_skip = re.sub(r"^", " ", fp16_skip, flags=re.M) + "\n" 640 641 int8_skip = f"if (!{self.graph}{self.dot}context()->adapter_ptr()->has_full_int8_buffers_support()) {{\n" 642 int8_skip += f" {skip_str};\n" 643 int8_skip += "}\n" 644 645 skips = "" 646 647 skips += "if (test_dtype == at::kHalf) {\n" 648 skips += fp16_skip 649 skips += "}\n" 650 651 for _, dtype in self.suite_def.arg_dtype.items(): 652 if dtype == "at::kChar" or dtype == "at::kQInt8": 653 skips += int8_skip 654 continue 655 656 skips += "\n" 657 return skips 658 659 def gen_op_check_fn(self) -> str: 660 op_name = self.f.func.name.unambiguous_name() 661 if self.suite_def.test_name_suffix is not None: 662 op_name += "_" + self.suite_def.test_name_suffix 663 664 op_check_fn = self.gen_decl(f"check_{op_name}") + " {\n" 665 if self.should_prepack: 666 op_check_fn = self.gen_decl(f"prepacked_check_{op_name}") + " {\n" 667 668 op_check_fn_body = "" 669 op_check_fn_body += self.gen_conditional_skips() 670 op_check_fn_body += self.gen_graph_build_code() 671 op_check_fn_body += self.gen_graph_exec_code() 672 673 op_check_fn_body = re.sub(r"^", " ", op_check_fn_body, flags=re.M) 674 675 op_check_fn += op_check_fn_body 676 op_check_fn += "\n }" 677 678 return op_check_fn 679 680 def gen_build_graph_fn(self, include_declarations: bool = False) -> str: 681 op_name = self.f.func.name.unambiguous_name() 682 if self.suite_def.test_name_suffix is not None: 683 op_name += "_" + self.suite_def.test_name_suffix 684 op_build_graph_fn = self.gen_decl(f"build_graph_{op_name}") + " {\n" 685 if self.should_prepack: 686 op_build_graph_fn = ( 687 self.gen_decl(f"prepacked_build_graph_{op_name}") + " {\n" 688 ) 689 690 op_build_graph_fn_body = "" 691 op_build_graph_fn_body += self.gen_graph_build_code(include_declarations) 692 693 op_build_graph_fn += op_build_graph_fn_body 694 op_build_graph_fn += "\n }" 695 return op_build_graph_fn 696 697 def gen_op_exec_graph_fn(self) -> str: 698 op_name = self.f.func.name.unambiguous_name() 699 if self.suite_def.test_name_suffix is not None: 700 op_name += "_" + self.suite_def.test_name_suffix 701 op_benchmark_fn = self.gen_decl(f"benchmark_{op_name}") + " {\n" 702 if self.should_prepack: 703 op_benchmark_fn = self.gen_decl(f"prepacked_benchmark_{op_name}") + " {\n" 704 705 op_benchmark_fn_body = "" 706 op_benchmark_fn_body += self.gen_graph_exec_code(False) 707 708 op_benchmark_fn_body = re.sub(r"^", " ", op_benchmark_fn_body, flags=re.M) 709 710 op_benchmark_fn += op_benchmark_fn_body 711 op_benchmark_fn += "\n }" 712 return op_benchmark_fn 713