1from __future__ import annotations 2 3import itertools 4from abc import ABC 5from dataclasses import dataclass 6from typing import Any 7 8import torchgen.api.dispatcher as dispatcher 9from torchgen.api.lazy import ( 10 getValueT, 11 isValueType, 12 LazyArgument, 13 LazyIrProperties, 14 LazyIrSchema, 15 tensorListValueT, 16) 17from torchgen.api.translate import translate 18from torchgen.api.types import ( 19 BaseCType, 20 Binding, 21 deviceT, 22 DispatcherSignature, 23 kernel_signature, 24 NativeSignature, 25 OptionalCType, 26 VectorCType, 27) 28from torchgen.context import method_with_native_function 29from torchgen.dest.lazy_ts_lowering import ts_lowering_body 30from torchgen.model import ( 31 Argument, 32 BackendIndex, 33 BackendMetadata, 34 BaseTy, 35 BaseType, 36 FunctionSchema, 37 ListType, 38 NativeFunction, 39 NativeFunctionsGroup, 40) 41 42 43def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str: 44 """ 45 Given a LazyArgument, 46 generate a c++ string for materializing an rvalue of that arg for passing into 47 a lazy Node constructor. 48 """ 49 50 # TODO: Matching on CType seems wrong; should be matching on Type 51 if isValueType(arg.lazy_type): 52 if isinstance(arg.lazy_type, BaseCType): 53 if arg.is_wrapped_scalar: 54 return f"node_{arg.name}" 55 elif arg.lazy_type.type is tensorListValueT: 56 return f"lazy_{arg.name}_tensorlist" 57 elif arg.is_symint_or_list: 58 return f"GetSymIntValue({arg.name})" 59 return f"lazy_{arg.name}->GetIrValue()" 60 elif isinstance(arg.lazy_type, OptionalCType): 61 if arg.is_symint_or_list: 62 # TODO: I don't understand when you should put lazy_ in the name 63 # or not 64 return f"{arg.name} ? std::make_optional(GetSymIntValue(*{arg.name})) : ::std::nullopt" 65 elif arg.is_wrapped_scalar: 66 return f"node_{arg.name}" 67 return ( 68 f"lazy_{arg.name} ? " 69 f"std::make_optional(lazy_{arg.name}->GetIrValue()) : " 70 "::std::nullopt" 71 ) 72 else: 73 raise AssertionError( 74 f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})" 75 ) 76 else: 77 # NB: this is here because right now we aren't treating SymInt[] as a 78 # value type; when we do this needs to move above 79 # NB: we cannot test arg.lazy_type as we've already specified it is an 80 # int64_t and so we cannot distinguish between SymInt and int64_t 81 if isinstance(arg.orig_type, ListType) and arg.orig_type.elem == BaseType( 82 BaseTy.SymInt 83 ): 84 if arg.symint: 85 return f"GetSymIntArrayRefValue({arg.name})" 86 else: 87 return f"std::vector<int64_t>({arg.name}.begin(), {arg.name}.end())" 88 elif isinstance(arg.lazy_type, VectorCType) and isinstance( 89 arg.lazy_type.elem, BaseCType 90 ): 91 return f"std::vector<{arg.lazy_type.elem.type}>({arg.name}.begin(), {arg.name}.end())" 92 elif ( 93 isinstance(arg.lazy_type, OptionalCType) 94 and isinstance(arg.lazy_type.elem, VectorCType) 95 and isinstance(arg.lazy_type.elem.elem, BaseCType) 96 ): 97 return f"torch::lazy::ToOptionalVector<{arg.lazy_type.elem.elem.type}>({arg.name})" 98 else: 99 return f"{arg.name}" 100 101 102def node_ctor_inputs(schema: LazyIrSchema) -> str: 103 """ 104 Produce a formatted string with the arguments as passed into the constructor of a node class. 105 """ 106 node_ctor_values = [ 107 node_ctor_arg_rvalue_string(arg) for arg in schema.filtered_args() 108 ] 109 return ", ".join(node_ctor_values) 110 111 112def gen_fallback_code( 113 schema: LazyIrSchema, 114 sig: DispatcherSignature | NativeSignature, 115 overload_name: str, 116) -> str: 117 """ 118 Generate code that falls back to eager conditioned on a predicate 119 """ 120 dispatcher_sig = DispatcherSignature.from_schema(schema.func) 121 exprs = translate(sig.arguments(), dispatcher_sig.arguments()) 122 fallback_args = ",\n ".join([a.expr for a in exprs]) 123 if len(overload_name): 124 aten_op_str = f"ATEN_OP2({schema.aten_name}, {overload_name})" 125 else: 126 aten_op_str = f"ATEN_OP({schema.aten_name})" 127 return f""" 128 if (force_eager_fallback({aten_symbol(schema)})) {{ 129 return at::native::call_fallback_fn_symint<<c_eager_fallback, {aten_op_str}>::call( 130 {fallback_args} 131 ); 132 }} 133""" 134 135 136def aten_symbol(schema: LazyIrSchema) -> str: 137 missing_interned_strings = { 138 "sigmoid_backward", 139 } 140 if schema.aten_name in missing_interned_strings: 141 return f'c10::Symbol::fromQualString("aten::{schema.aten_name}")' 142 143 if not schema.aten_name.startswith("at::"): 144 return f"at::aten::{schema.aten_name}" 145 else: 146 return schema.aten_name 147 148 149# converts all tensor-like arguments to meta tensors. Returns: 150# (1) a string containing all of the logic that does the conversions. 151# (2) a context, to be used by translate(), with all of the relevant bindings. 152def convert_to_meta_tensors(sig: DispatcherSignature) -> tuple[str, list[Binding]]: 153 context: list[Binding] = [] 154 unwrapped_tensor_args: list[str] = [] 155 for arg in sig.arguments(): 156 if isinstance(arg.argument, Argument) and arg.argument.type.is_tensor_like(): 157 unwrapped_name = f"{arg.name}_meta" 158 unwrapped_tensor_args.append( 159 f"auto {unwrapped_name} = to_meta({arg.name});" 160 ) 161 context.append(arg.with_name(unwrapped_name)) 162 else: 163 context.append(arg) 164 unwrap_tensor_args_str = "\n ".join(unwrapped_tensor_args) 165 return unwrap_tensor_args_str, context 166 167 168@dataclass(frozen=True) 169class GenLazyIR(ABC): 170 backend_index: BackendIndex 171 backend_name: str 172 node_base: str 173 use_lazy_shape: bool 174 175 @method_with_native_function 176 def __call__(self, f: NativeFunctionsGroup | NativeFunction) -> list[str]: 177 func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func 178 metadata = self.backend_index.get_kernel( 179 f.functional if isinstance(f, NativeFunctionsGroup) else f 180 ) 181 schema = LazyIrSchema( 182 func, symint=metadata is not None and metadata.supports_symint() 183 ) 184 return self.gen(schema) 185 186 # there is no lowering functionality generated unless this IR base class is subclassed and 187 # implemented as a backend-specific node 188 def lowering_function(self, schema: LazyIrSchema) -> str: 189 return "" 190 191 def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str: 192 return "" 193 194 def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str: 195 return f"""bool CanBeReused({node_ctor_args}) const {{ 196 return false; 197 }}""" 198 199 def node_base_ctor_call(self, schema: LazyIrSchema) -> str: 200 value_args = schema.filtered_args(values=True, scalars=False) 201 # backends can customize the way the node base class constructor is called, 202 # as long as all of its arguments can be generated from information available from the schema 203 base_ctor_value_args_list = [] 204 for arg in value_args: 205 if isinstance(arg.lazy_type, (BaseCType, VectorCType)): 206 base_ctor_value_args_list.append(f"{arg.name}") 207 elif isinstance(arg.lazy_type, OptionalCType): 208 base_ctor_value_args_list.append(f"{arg.name}.value_or(kNullValue)") 209 else: 210 raise AssertionError( 211 f"Unsupported type ({arg.lazy_type}) - add support if necessary" 212 ) 213 base_ctor_value_args = ", ".join(base_ctor_value_args_list) 214 215 scalar_args = schema.filtered_args(values=False, scalars=True) 216 217 # Shape construction. 218 # Conditionally build shape depending on specified shape property 219 if schema.properties.ShapePrecompute: 220 shape_ctor_arg = "std::move(shapes)," 221 elif schema.properties.ShapeCompute: 222 shape_args = [a.name for a in value_args] 223 shape_args.extend(a.name for a in scalar_args) 224 shape_ctor_arg = f"compute_shape_{schema.name}({', '.join(shape_args)})," 225 elif schema.properties.ShapeCache: 226 shape_args = [f"operand({i})" for i in range(len(value_args))] 227 shape_args.extend(a.name for a in scalar_args) 228 shape_ctor_arg = f"[&](){{ return compute_shape_{schema.name}({', '.join(shape_args)})[0]; }}," 229 else: 230 shape_ctor_arg = "" 231 232 scalar_hashes = ", ".join(f"{a.name}" for a in scalar_args) 233 234 return f"""{self.node_base}( 235 {schema.node_name}::ClassOpKind(), 236 OpList{{{base_ctor_value_args}}}, 237 {shape_ctor_arg} 238 /* num_outputs */ {len(schema.returns)}, 239 torch::lazy::MHash({scalar_hashes}))""" 240 241 def gen(self, schema: LazyIrSchema) -> list[str]: 242 opkind = schema.opkind or aten_symbol(schema) 243 244 # for now, we just want one IR class decl and soon after also the method defs 245 # and we use the functional version not out/inplace. 246 all_args = schema.filtered_args() 247 scalar_args = schema.filtered_args(values=False, scalars=True) 248 249 ctor_args = [f"const {i.lazy_type.cpp_type()}& {i.name}" for i in all_args] 250 reuse_ctor_args = ", ".join(ctor_args) 251 if self.use_lazy_shape and schema.properties.ShapePrecompute: 252 ctor_args.append("std::vector<torch::lazy::Shape>&& shapes") 253 node_ctor_args = ", ".join(ctor_args) 254 255 scalar_initializers = ",\n ".join( 256 [ 257 # This code is just special casing the mapping from string_view -> strings 258 f"{a.name}({a.name}.has_value() ? ::std::make_optional(std::string(*{a.name})) : ::std::nullopt)" 259 if a.lazy_type.cpp_type() == "::std::optional<c10::string_view>" 260 else f"{a.name}({a.name})" 261 for a in scalar_args 262 ] 263 ) 264 if len(scalar_initializers): 265 scalar_initializers = f",\n {scalar_initializers}" 266 scalar_decls = "\n ".join( 267 [ 268 f"std::string {a.name};" 269 if a.lazy_type.cpp_type() == "c10::string_view" 270 else f"::std::optional<std::string> {a.name};" 271 if a.lazy_type.cpp_type() == "::std::optional<c10::string_view>" 272 else f"{a.lazy_type.cpp_type()} {a.name};" 273 for a in scalar_args 274 ] 275 ) 276 optional_values = [ 277 arg.name 278 for arg in schema.filtered_args(values=True, scalars=False) 279 if isinstance(arg.lazy_type, OptionalCType) 280 ] 281 has_optional_decls = "\n ".join( 282 [f"bool has_{value}: 1;" for value in optional_values] 283 ) 284 has_optional_defs = "\n ".join( 285 [f"has_{value} = !!{value};" for value in optional_values] 286 ) 287 members_to_string = [] 288 for arg in scalar_args: 289 if isinstance(arg.lazy_type, OptionalCType): 290 value = f"{arg.name}.value()" 291 if arg.is_generator: 292 value = '"torch.Generator()"' 293 members_to_string.append( 294 f"""if ({arg.name}.has_value()) {{ 295 ss << ", {arg.name}=" << {value}; 296 }} else {{ 297 ss << ", {arg.name}=null"; 298 }}""" 299 ) 300 else: 301 members_to_string.append(f'ss << ", {arg.name}=" << {arg.name};') 302 members_to_string_str = "\n ".join(members_to_string) 303 304 return [ 305 f"""\ 306class {schema.node_name} : public {self.node_base} {{ 307 public: 308 static torch::lazy::OpKind ClassOpKind() {{ 309 return torch::lazy::OpKind({opkind}); 310 }} 311 312 {schema.node_name}({node_ctor_args}) 313 : {self.node_base_ctor_call(schema)}{scalar_initializers} 314 {{ 315 {has_optional_defs} 316 }} 317 318 std::string ToString() const override {{ 319 std::stringstream ss; 320 ss << {self.node_base}::ToString(); 321 {members_to_string_str} 322 return ss.str(); 323 }} 324 325 {self.create_function(schema, reuse_ctor_args)} 326 327 {self.can_be_reused_function(schema, reuse_ctor_args)} 328 329 {self.lowering_function(schema)} 330 331 {scalar_decls} 332 {has_optional_decls} 333 334}}; 335 336""", 337 ] 338 339 340@dataclass(frozen=True) 341class GenTSLazyIR(GenLazyIR): 342 def lowering_function(self, schema: LazyIrSchema) -> str: 343 signature = """ 344 torch::lazy::TSOpVector Lower( 345 std::shared_ptr<torch::jit::GraphFunction> function, 346 torch::lazy::TSLoweringContext* loctx) const override""" 347 348 if schema.properties.LowerDeclOnly: 349 return f"{signature};" 350 elif schema.properties.Lower: 351 return f"""{signature} {{ 352 {ts_lowering_body(schema)} 353 }} 354 """ 355 else: 356 return "" 357 358 def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str: 359 signature = f"static NodePtr Create({node_ctor_args})" 360 if schema.properties.CreateFnDeclOnly: 361 return f"{signature};" 362 elif not schema.properties.CreateFn: 363 return "" 364 return f"""{signature} {{ 365 return ReuseOrMakeNode<{schema.node_name}>(data); 366 }}""" 367 368 def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str: 369 signature = f"bool CanBeReused({node_ctor_args}) const" 370 if schema.properties.CanBeReusedDeclOnly: 371 return f"{signature};" 372 elif not schema.properties.CanBeReused: 373 return "" 374 value_comparison = [] 375 for arg in itertools.chain(schema.positional_values, schema.keyword_values): 376 if isinstance(arg.lazy_type, OptionalCType): 377 value_comparison.append( 378 f"nullable_operand(i++) == {arg.name}.value_or(kNullValue)" 379 ) 380 else: 381 value_comparison.append(f"operand(i++) == {arg.name}") 382 for arg in itertools.chain(schema.positional_scalars, schema.keyword_scalars): 383 if isinstance(arg.lazy_type, OptionalCType): 384 value_comparison.append( 385 f"((!this->{arg.name}&&!{arg.name}) || (this->{arg.name}&&{arg.name} && *(this->{arg.name}) == *{arg.name}))" 386 ) 387 else: 388 value_comparison.append(f"this->{arg.name} == {arg.name}") 389 value_comparison_str = " &&\n ".join(value_comparison) 390 391 return f"""{signature} {{ 392 size_t i = 0; 393 return ({value_comparison_str}); 394 }}""" 395 396 397@dataclass(frozen=True) 398class GenLazyNativeFuncDefinition: 399 class_method_name: str 400 backend_index: BackendIndex 401 tensor_class: str 402 gen_forced_fallback_code: bool 403 backend_namespace: str 404 get_tensorlist: str 405 get_tensor_or_wrap_number: str 406 try_get_tensor: str 407 metrics_counter: str 408 create_tensor: str 409 create_from_first_tensor: bool 410 create_aten_from_ltc_tensor: str 411 tuple_aten_from_ltc_tensors: str 412 lazy_tensor_ptr: str 413 get_device_fn: str 414 415 def lazy_tensor_decls(self, func: NativeFunction, schema: LazyIrSchema) -> str: 416 value_args = schema.filtered_args(values=True, scalars=False) 417 # Generates lazy_{name} variables for LazyTensors wrapping input tensors 418 lazy_tensor_decls: list[str] = [] 419 for arg in value_args: 420 if arg.is_wrapped_scalar: 421 if isinstance(arg.lazy_type, OptionalCType): 422 lazy_tensor_decls.append( 423 f"""auto node_{arg.name} = {arg.name} ? 424 std::make_optional(torch::lazy::LazyGraphExecutor::Get()-> 425 GetIrValueForScalarFromCodegen(*{arg.name}, *common_device)): 426 ::std::nullopt;""" 427 ) 428 else: 429 lazy_tensor_decls.append( 430 f"""auto node_{arg.name} = torch::lazy::LazyGraphExecutor::Get()-> 431 GetIrValueForScalarFromCodegen({arg.name}, *common_device);""" 432 ) 433 elif arg.is_symint_or_list: 434 continue # values are extracted in isValueType 435 elif isinstance(arg.lazy_type, BaseCType): 436 if arg.lazy_type.type is tensorListValueT: 437 lazy_tensor_decls.append( 438 f"auto lazy_{arg.name}_tensorlist = " 439 f"{self.backend_namespace}::{self.get_tensorlist}({arg.name});" 440 ) 441 else: 442 lazy_tensor_decls.append( 443 f"{self.lazy_tensor_ptr} lazy_{arg.name} = " 444 f"{self.backend_namespace}::{self.get_tensor_or_wrap_number}({arg.name}, *common_device);" 445 ) 446 elif isinstance(arg.lazy_type, OptionalCType): 447 assert arg.lazy_type.elem == BaseCType(getValueT()), arg.lazy_type.elem 448 # TODO(alanwaketan): Maybe we want to apply GetLtcTensorOrCreateForWrappedNumber here, but hold it 449 # until we encounter a real world example. 450 lazy_tensor_decls.append( 451 f"{self.lazy_tensor_ptr} lazy_{arg.name} = " 452 f"{self.backend_namespace}::{self.try_get_tensor}({arg.name}.value_or(at::Tensor()));" 453 ) 454 else: 455 raise AssertionError( 456 f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})" 457 ) 458 return ("\n ").join(lazy_tensor_decls) 459 460 def force_eager_fallback( 461 self, 462 func: NativeFunction, 463 schema: LazyIrSchema, 464 metadata: BackendMetadata, 465 sig: DispatcherSignature | NativeSignature, 466 ) -> str: 467 if self.gen_forced_fallback_code: 468 return gen_fallback_code( 469 schema, sig, overload_name=func.func.name.overload_name 470 ) 471 return "" 472 473 def metrics(self, func: NativeFunction, schema: LazyIrSchema) -> str: 474 return f"{self.metrics_counter};" 475 476 def get_device(self, func: NativeFunction, schema: LazyIrSchema) -> str: 477 value_args = schema.filtered_args(values=True, scalars=False) 478 scalar_args = schema.filtered_args(values=False, scalars=True) 479 value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar] 480 optional_device = OptionalCType(BaseCType(deviceT)) 481 optional_devices = [ 482 a.name for a in scalar_args if a.lazy_type == optional_device 483 ] 484 assert ( 485 len(value_types_names) > 0 or len(optional_devices) > 0 486 ), "Expected at least one Value or Device type" 487 get_device_str = ( 488 f"{self.get_device_fn}({', '.join(value_types_names + optional_devices)})" 489 ) 490 return f"""auto common_device = {get_device_str}; 491 TORCH_INTERNAL_ASSERT(common_device); 492 """ 493 494 def shape_inference(self, func: NativeFunction, schema: LazyIrSchema) -> str: 495 metadata = self.backend_index.get_kernel(func) 496 assert metadata is not None 497 all_args = schema.filtered_args() 498 returns_length = len(schema.returns) 499 # call the meta kernel if it exists, to compute output shape/dtype for our IR 500 # Note [Generated LTC Shape Functions] 501 # LTC uses meta tensors from core to do shape inference when possible, and otherwise 502 # we generate a shape function declaration that needs to be manually implemented. 503 # How do we detect which ops are eligible to use meta tensors? 504 # In general we should be able to use meta tensors not just on structured operators, 505 # but also on composite operators that are implemented in terms of structured kernels. 506 # We don't currently have a way of knowing at codegen time which ops are implemented that way. 507 # This is the case for all view and view_copy operators however, so we're going to 508 # use them specifically for all of the view_copy ops (instead of manually writing shape rules for all of them). 509 is_view_copy_op = "view_copy" in func.tags 510 is_structured = func.structured or func.structured_delegate is not None 511 if is_structured or is_view_copy_op: 512 meta_out = """ 513std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};""" 514 if returns_length > 1: 515 516 def this_shape(i: int) -> str: 517 return f"torch::lazy::Shape(std::get<{i}>(out_meta).scalar_type(), std::get<{i}>(out_meta).sizes().vec())" 518 519 shapes_str = ",".join([this_shape(i) for i in range(returns_length)]) 520 meta_out = "std::vector<torch::lazy::Shape> shapes{" + shapes_str + "};" 521 522 # Convert tensor args to the meta device and call it. 523 # (We can't pass in the input tensors directly, because they are "functional wrappers". 524 # If any of the meta kernels call a tensor op and redispatch, we don't want to hit the functionalize kernels.) 525 # Even at::meta:: functions might redispatch, e.g. if they call into view ops. 526 dispatcher_sig = DispatcherSignature.from_schema(func.func) 527 meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig) 528 meta_call_args = [ 529 e.expr 530 for e in translate( 531 meta_call_ctx, dispatcher_sig.arguments(), method=False 532 ) 533 ] 534 if is_view_copy_op: 535 # view_copy ops always have a CompositeExplicitAutogradNonFunctional kernel 536 assert func.has_composite_explicit_autograd_non_functional_kernel 537 dispatch_ns = "compositeexplicitautogradnonfunctional" 538 else: 539 dispatch_ns = "meta" 540 aten_name = schema.aten_name 541 # TODO: this is trolling 542 if func.func.has_symint() and metadata.supports_symint(): 543 aten_name += "_symint" 544 shape_str = f"""\ 545 {meta_conversion_str} 546 auto out_meta = at::{dispatch_ns}::{aten_name}({', '.join(meta_call_args)}); 547 {meta_out}""" 548 else: 549 shape_sig = ComputeShapeSignature( 550 metadata.kernel, func, symint=metadata.supports_symint() 551 ) 552 shape_str = f""" 553 auto shapes = {shape_sig.shape_call};""" 554 555 shape_str += f""" 556 TORCH_INTERNAL_ASSERT(shapes.size() == {returns_length});""" 557 558 # Calculating which dimensions are symbolic 559 func_schema_str = "aten::" + str(func.func) 560 shape_str += f""" 561 if(torch::lazy::symbolicShapeEnabled()){{ 562 std::vector<torch::jit::IValue> inputs = {{ {', '.join(str(a.name) for a in all_args)} }}; 563 const char* schema_str = "{func_schema_str}"; 564 applySymbolicShapesOnLT(schema_str, inputs, shapes); 565 }} 566 """ 567 return shape_str 568 569 def build_ir_node(self, func: NativeFunction, schema: LazyIrSchema) -> str: 570 node_ctor_input_str = node_ctor_inputs(schema) 571 return f"""torch::lazy::NodePtr node = torch::lazy::ReuseNode<{schema.node_name}>({node_ctor_input_str}); 572 if (!node) {{ 573 {self.shape_inference(func, schema)} 574 node = torch::lazy::MakeNode<{schema.node_name}>({node_ctor_input_str}, std::move(shapes)); 575 CacheNode(node); 576 }} 577 """ 578 579 def create_lazy_tensor(self, first_tensor_name: str | None = None) -> str: 580 # xla uses an instance method for tensor creation, for the time being 581 if self.create_from_first_tensor: 582 # TODO(whc) remove this if XLA switches to using static method for creation 583 assert ( 584 first_tensor_name is not None 585 ), "Requires first tensor to create lazy tensor" 586 return f"{first_tensor_name}.{self.create_tensor}" 587 return f"{self.backend_namespace}::{self.create_tensor}" 588 589 def return_aten_tensor(self, func: NativeFunction, schema: LazyIrSchema) -> str: 590 returns_length = len(schema.returns) 591 value_args = schema.filtered_args(values=True, scalars=False) 592 value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar] 593 first_tensor_name = value_types_names[0] if len(value_types_names) > 0 else None 594 bridge_str = f"""auto result = {self.create_aten_from_ltc_tensor}( 595 {self.create_lazy_tensor(first_tensor_name)}(std::move(node), *common_device));""" 596 597 if returns_length > 1: 598 assert ( 599 len(value_types_names) > 0 600 ), "Code below assumes there is at least one tensor arg" 601 bridge_str = f"""std::vector<{self.lazy_tensor_ptr}> lazy_tensors; 602 for (int i = 0; i < {returns_length}; i++) {{ 603 lazy_tensors.push_back({self.create_lazy_tensor(first_tensor_name)}({getValueT()}(node, i), *common_device)); 604 }} 605 auto result = {self.tuple_aten_from_ltc_tensors}<{returns_length}>(lazy_tensors);""" 606 607 if schema.name.name.inplace or func.func.is_out_fn(): 608 assert returns_length == 1, ( 609 "We assumed there was no such case where an op is an in-place variant " 610 f"and has tuple outputs, but got tuple of len {returns_length}." 611 ) 612 bridge_str = f"""lazy_{first_tensor_name}->SetInPlaceIrValue(node); 613 auto& result = {first_tensor_name};""" 614 615 bridge_str += """ 616 return result;""" 617 return bridge_str 618 619 @method_with_native_function 620 def __call__(self, func: NativeFunction) -> list[str]: 621 sig = kernel_signature(func, self.backend_index) 622 metadata = self.backend_index.get_kernel(func) 623 assert metadata is not None 624 schema = LazyIrSchema(func.func, symint=metadata.supports_symint()) 625 return [ 626 f"""\ 627 {sig.decl(name=f"{self.class_method_name}::{metadata.kernel}")} {{ 628 {self.force_eager_fallback(func, schema, metadata, sig)} 629 {self.metrics(func, schema)} 630 {self.get_device(func, schema)} 631 {self.lazy_tensor_decls(func, schema)} 632 {self.build_ir_node(func, schema)} 633 {self.return_aten_tensor(func, schema)} 634 }}\n 635 """ 636 ] 637 638 639class ComputeShapeSignature: 640 """ 641 Here we use the base name as the suffix of the signature to avoid generating for in-place variants. 642 """ 643 644 def __init__(self, kernel_name: str, f: NativeFunction, *, symint: bool) -> None: 645 self.__schema = LazyIrSchema(f.func, symint=symint) 646 self.__dispatch_args = ", ".join( 647 [a.decl() for a in dispatcher.arguments(f.func, symint=symint)] 648 ) 649 self.__call_args = ", ".join( 650 [f"{arg.name}" for arg in self.__schema.filtered_args(generator=True)] 651 ) 652 self.__kernel_name = kernel_name 653 654 def __decl_suffix(self) -> str: 655 return f"{self.__kernel_name}({self.__dispatch_args})" 656 657 def __call_suffix(self) -> str: 658 return f"{self.__kernel_name}({self.__call_args})" 659 660 @property 661 def shape_decl(self) -> str: 662 return f"TORCH_API std::vector<torch::lazy::Shape> compute_shape_{self.__decl_suffix()}" 663 664 @property 665 def shape_call(self) -> str: 666 return f"torch::lazy::compute_shape_{self.__call_suffix()}" 667 668 669@dataclass(frozen=True) 670class GenLazyShapeInferenceDefinition: 671 backend_index: BackendIndex 672 tensor_class: str 673 674 @method_with_native_function 675 def __call__(self, f: NativeFunction) -> list[str]: 676 metadata = self.backend_index.get_kernel(f) 677 assert metadata is not None 678 679 # See Note [Generated LTC Shape Functions] 680 is_view_copy_op = "view_copy" in f.tags 681 is_structured = f.structured or f.structured_delegate is not None 682 if is_structured or is_view_copy_op: 683 return [] 684 else: 685 shape_sig = ComputeShapeSignature( 686 metadata.kernel, f, symint=metadata.supports_symint() 687 ) 688 return ["\n".join([f"{shape_sig.shape_decl};"])] 689 690 691def generate_non_native_lazy_ir_nodes( 692 non_native: list[dict[str, Any]], gen_lazy_ir: GenLazyIR 693) -> list[str]: 694 """Generate the non-native lazy IR node classes""" 695 nodes = [] 696 for op in non_native: 697 # Set default properties for Non-Native IRs 698 properties = LazyIrProperties("ShapeCache", "CanBeReused", "LowerDeclOnly") 699 for p in op.get("properties", []): 700 setattr(properties, p, True) 701 702 # non-native is assumed to want symint bindings if you wrote symint 703 schema = LazyIrSchema(FunctionSchema.parse(op["func"]), properties, symint=True) 704 schema.opkind = op.get("opkind") 705 nodes.append(gen_lazy_ir.gen(schema)[0]) 706 707 return nodes 708