xref: /aosp_15_r20/external/pytorch/torchgen/gen_lazy_tensor.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import argparse
4import os
5from collections import namedtuple
6from pathlib import Path
7from typing import Any, Callable, Iterable, Iterator, Sequence
8
9import yaml
10
11import torchgen.dest as dest
12from torchgen.api.lazy import setValueT
13from torchgen.api.types import BaseCppType
14from torchgen.dest.lazy_ir import GenLazyIR, GenLazyNativeFuncDefinition, GenTSLazyIR
15from torchgen.gen import get_grouped_native_functions, parse_native_yaml
16from torchgen.gen_backend_stubs import (
17    error_on_missing_kernels,
18    gen_dispatcher_registrations,
19    gen_dispatchkey_nativefunc_headers,
20    parse_backend_yaml,
21)
22from torchgen.model import NativeFunction, NativeFunctionsGroup, OperatorName
23from torchgen.selective_build.selector import SelectiveBuilder
24from torchgen.utils import FileManager, NamespaceHelper
25from torchgen.yaml_utils import YamlLoader
26
27
28# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
29#
30#                        Lazy Tensor Codegen
31#
32# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
33# Overview
34# ~~~~~~~~
35#
36# This codegen script builds on existing data models and helpers used
37# by all ATen backends, and adds new functionality specific to lazy
38# tensor backends.
39#
40# Inputs:
41# - <backend>_native_functions.yaml: controls which operators are
42#   supported by the backend.
43#
44# Outputs:
45# (for all backends)
46# <DispatchKey>Ir.h defines Lazy IR classes to be constructed during tracing
47# - opt-in: also generate 'lowering' methods for the TorchScript backend only
48# <DispatchKey>NativeFunctions.cpp defines implementations of native functions which perform lazy tracing
49# - opt-in: 'full_codegen' section of backend yaml; 'supported' section omits these implementations
50# <DispatchKey>NativeFunctions.h declares implementations of native functions for both 'supported' and 'full_codegen'
51# ops
52#
53# Register<DispatchKey>.cpp registers all op implementations with the dispatcher
54# RegisterAutograd<DispatchKey>.cpp registers all autograd implementations with the dispatcher
55#
56# Validation Helpers:
57# - Shape Inference: errs if any ops in backend yaml require shape inference not provided by meta kernels or
58#   implementations in torch/csrc/lazy/core/shape_inference.*
59# - native function impls: errs if any 'supported' ops do not have an implementation defined in the backend
60#   (non-codegen) implementation file
61#
62#
63# About the Data Model
64# ~~~~~~~~~~~~~~~~~~~~
65#
66# Modeled after ATen codegen, the first step is to parse yaml and build a data model for the operators
67# we care about.  In this case, the <backend>_native_functions yaml defines a subset of the core operators
68# (defined in more detail in the main native_functions.yaml), which will be supported by your backend.
69# Backends can list ops in two categories:
70#  - `supported` ops require hand-implementations but still get codegenned declarations and registrations
71#  - `full_codegen` ops get implementations (and IR classes) generated too
72#
73# Each native function is modeled as an object with a schema, and each schema has objects representing their
74# arguments.  Much of the codegen is manipulation of the arguments and their types.  For example, lazy tensor
75# backends need to transform 'at::Tensor' arguments into 'lazy::Value' objects, as well as replacing reference
76# types (stringref) with actual string objects, and this is done by manipulating the data model objects.
77# - see api/lazy.py for the lazy data model
78#
79# Once the data model is set up, the rest of this script processes a number of templates for output CPP file
80# and fills in the template values using helpers in `dest/lazy_ir.py` and `dest/lazy_ts_lowering.py`.  These
81# helpers mostly iterate over functions and their arguments, outputting different c++ snippets.
82#
83# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
84
85
86# Parses the external backend's yaml, and adds a new BackendIndex for the backend's dispatch key.
87# Returns a Tuple of (backend_key, autograd_key, cpp_namespace, updated BackendIndex mapping, full_codegen)
88ParsedExternalYaml = namedtuple(
89    "ParsedExternalYaml",
90    ["backend_key", "autograd_key", "cpp_namespace", "backend_indices", "full_codegen"],
91)
92
93
94def parse_native_functions_keys(
95    backend_yaml_path: str,
96    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
97) -> tuple[list[OperatorName], list[Any], list[OperatorName]]:
98    with open(backend_yaml_path) as f:
99        yaml_values = yaml.load(f, Loader=YamlLoader)
100    assert isinstance(yaml_values, dict)
101
102    full_codegen = yaml_values.pop("full_codegen", [])
103    non_native = yaml_values.pop("non_native", [])
104    ir_gen = yaml_values.pop("ir_gen", [])
105    assert isinstance(full_codegen, list)
106    assert isinstance(non_native, list)
107    assert isinstance(ir_gen, list)
108    full_codegen_opnames = [OperatorName.parse(name) for name in full_codegen]
109    ir_gen_opnames = [OperatorName.parse(name) for name in ir_gen]
110    return full_codegen_opnames, non_native, ir_gen_opnames
111
112
113def validate_shape_inference_header(
114    shape_inference_hdr: str, expected_shape_infr_decls: list[str]
115) -> None:
116    try:
117        with open(shape_inference_hdr) as f:
118            shape_infr_decls = f.read()
119            shape_infr_decl_lines = set(shape_infr_decls.split("\n"))
120    except OSError as e:
121        raise AssertionError(
122            f"Unable to read from the specified shape_inference_hdr file: {shape_inference_hdr}"
123        ) from e
124
125    # TODO(whc) add a check for shape inference functions that have meta kernels implement and should be retired.
126
127    missing_decls = [
128        decl for decl in expected_shape_infr_decls if decl not in shape_infr_decl_lines
129    ]
130    if missing_decls:
131        raise Exception(  # noqa: TRY002
132            f"""Missing shape inference function.\n
133Please add declare this function in {shape_inference_hdr}:\n
134and implement it in the corresponding shape_inference.cpp file.\n
135{os.linesep.join(missing_decls)}"""
136        )
137
138
139# Some helper functions for the codegen.
140def get_ltc_helper_fns() -> str:
141    return """\
142at::Tensor to_meta(const at::Tensor& tensor) {
143  // undefined tensors can't be converted to the meta device, since they don't have sizes/strides
144  if (!tensor.defined()) return tensor;
145  auto out = at::native::empty_strided_meta_symint(tensor.sym_sizes(), tensor.sym_strides(), \
146/*dtype=*/std::make_optional(tensor.scalar_type()), /*layout=*/std::make_optional(tensor.layout()), \
147/*device=*/std::make_optional(c10::Device(c10::kMeta)), /*pin_memory=*/std::nullopt);
148  // needs to handle wrapped numbers, so dtype promotion works properly.
149  if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
150    out.unsafeGetTensorImpl()->set_wrapped_number(true);
151  }
152  return out;
153}
154std::optional<at::Tensor> to_meta(const std::optional<at::Tensor>& tensor) {
155  if (tensor.has_value()) {
156    return to_meta(*tensor);
157  }
158  return std::nullopt;
159}
160
161std::vector<at::Tensor> to_meta(at::ITensorListRef t_list) {
162  std::vector<at::Tensor> outs;
163  outs.reserve(t_list.size());
164  for (const auto& tensor : t_list) {
165    outs.push_back(to_meta(tensor));
166  }
167  return outs;
168}
169"""
170
171
172class default_args:
173    node_base: str = "Node"
174    node_base_hdr: str | None = None
175    shape_inference_hdr: str = "torch/csrc/lazy/core/shape_inference.h"
176    tensor_class: str = "torch::lazy::LazyTensor"
177    tensor_class_hdr: str = "torch/csrc/lazy/core/tensor.h"
178    lazy_ir_generator: type[GenLazyIR] = GenLazyIR
179    native_func_definition_generator: type[
180        GenLazyNativeFuncDefinition
181    ] = GenLazyNativeFuncDefinition
182    backend_name: str = "TorchScript"
183
184
185def main() -> None:
186    parser = argparse.ArgumentParser(description="Generate Lazy Tensor backend files")
187    parser.add_argument(
188        "-s",
189        "--source-yaml",
190        "--source_yaml",
191        help="path to source yaml file containing operator external definitions",
192    )
193    parser.add_argument("-o", "--output-dir", "--output_dir", help="output directory")
194    parser.add_argument(
195        "--dry-run", "--dry_run", type=bool, default=False, help="output directory"
196    )
197    parser.add_argument(
198        "--impl-path",
199        "--impl_path",
200        type=str,
201        default=None,
202        help="path to the source C++ file containing kernel definitions",
203    )
204    parser.add_argument(
205        "--gen-ts-lowerings",
206        "--gen_ts_lowerings",
207        action="store_true",
208        help="Generate TorchScript lowerings in addition to Lazy IR and NativeFunctions",
209    )
210    parser.add_argument(
211        "--node-base",
212        "--node_base",
213        type=str,
214        default=default_args.node_base,
215        help="Name of backend specific custom Lazy IR Node base class",
216    )
217    parser.add_argument(
218        "--node-base-hdr",
219        "--node_base_hdr",
220        type=str,
221        default=default_args.node_base_hdr,
222        help="Path to header file defining custom Lazy IR Node base class",
223    )
224    parser.add_argument(
225        "--shape-inference-hdr",
226        "--shape_inference_hdr",
227        type=str,
228        default=default_args.shape_inference_hdr,
229        help="Path to header file defining custom Lazy shape inference functions",
230    )
231    parser.add_argument(
232        "--tensor-class",
233        "--tensor_class",
234        type=str,
235        default=default_args.tensor_class,
236        help="Name of backend specific custom Lazy Tensor class",
237    )
238    parser.add_argument(
239        "--tensor-class-hdr",
240        "--tensor_class_hdr",
241        type=str,
242        default=default_args.tensor_class_hdr,
243        help="Path to header file defining custom Lazy Tensor class",
244    )
245    parser.add_argument(
246        "--backend-name",
247        "--backend_name",
248        type=str,
249        default=default_args.backend_name,
250        help="Name of the backend to generate",
251    )
252    options = parser.parse_args()
253
254    # Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py
255    torch_root = Path(__file__).parent.parent.parent.absolute()
256    aten_path = str(torch_root / "aten" / "src" / "ATen")
257    lazy_ir_generator: type[GenLazyIR] = default_args.lazy_ir_generator
258    if options.gen_ts_lowerings:
259        lazy_ir_generator = GenTSLazyIR
260    native_func_definition_generator: type[
261        GenLazyNativeFuncDefinition
262    ] = default_args.native_func_definition_generator
263
264    run_gen_lazy_tensor(
265        aten_path,
266        options.source_yaml,
267        options.output_dir,
268        options.dry_run,
269        options.impl_path,
270        options.node_base,
271        options.node_base_hdr,
272        options.tensor_class,
273        options.tensor_class_hdr,
274        options.shape_inference_hdr,
275        lazy_ir_generator,
276        native_func_definition_generator,
277        options.backend_name,
278    )
279
280
281def run_gen_lazy_tensor(
282    aten_path: str,
283    source_yaml: str,
284    output_dir: str,
285    dry_run: bool,
286    impl_path: str | None,
287    node_base: str = default_args.node_base,
288    node_base_hdr: str | None = default_args.node_base_hdr,
289    tensor_class: str = default_args.tensor_class,
290    tensor_class_hdr: str = default_args.tensor_class_hdr,
291    shape_inference_hdr: str = default_args.shape_inference_hdr,
292    lazy_ir_generator: type[GenLazyIR] = default_args.lazy_ir_generator,
293    native_func_definition_generator: type[
294        GenLazyNativeFuncDefinition
295    ] = default_args.native_func_definition_generator,
296    # build_in_tree is true for TS backend and affects include paths
297    build_in_tree: bool = False,
298    # per_operator_headers changes whether ATen/Functions.h or individual operator headers are used
299    # it must match how ATen was built
300    per_operator_headers: bool = False,
301    backend_name: str = default_args.backend_name,
302    gen_forced_fallback_code: bool = False,
303    use_lazy_shape: bool = True,
304    # the following arguments are temporary customization points for xla backend migration.
305    # do not rely on them otherwise, they should be removed once migration is complete
306    backend_namespace: str = "torch::lazy",
307    get_tensorlist: str = "GetTensorList",
308    get_tensor_or_wrap_number: str = "GetLtcTensorOrCreateForWrappedNumber",
309    try_get_tensor: str = "TryGetLtcTensor",
310    metrics_counter: str = 'TORCH_LAZY_FN_COUNTER("lazy::")',
311    create_tensor: str = "LazyTensor::Create",
312    create_from_first_tensor: bool = False,
313    create_aten_from_ltc_tensor: str = "torch::lazy::CreateAtenFromLtcTensor",
314    tuple_aten_from_ltc_tensors: str = "torch::lazy::TupleAtenFromLtcTensors",
315    lazy_value_class: str = "torch::lazy::Value",
316    lazy_tensor_ptr: str = "LazyTensorPtr",
317    get_device_fn: str = "torch::lazy::GetBackendDevice",
318) -> None:
319    lv_tokens = lazy_value_class.split("::")
320    lv_class = lv_tokens[-1]
321    lv_ns = "::".join(lv_tokens[:-1])
322    setValueT(BaseCppType(lv_ns, lv_class))
323    template_dir = os.path.join(aten_path, "templates")
324
325    def make_file_manager(install_dir: str) -> FileManager:
326        return FileManager(
327            install_dir=install_dir, template_dir=template_dir, dry_run=dry_run
328        )
329
330    fm = make_file_manager(output_dir)
331
332    native_yaml_path = os.path.join(aten_path, "native/native_functions.yaml")
333    tags_yaml_path = os.path.join(aten_path, "native/tags.yaml")
334    parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path)
335    native_functions, backend_indices = (
336        parsed_yaml.native_functions,
337        parsed_yaml.backend_indices,
338    )
339    grouped_native_functions = get_grouped_native_functions(native_functions)
340
341    def sort_native_function(f: NativeFunctionsGroup | NativeFunction) -> str:
342        """
343        We sort the native function because of the note in concat_map_codegen.
344        TODO(alanwaketan): Remove this sorting hack once all ops are grouped properly.
345        """
346        func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
347        return str(func.name.name)
348
349    grouped_native_functions = sorted(
350        grouped_native_functions, key=sort_native_function
351    )
352
353    parsed_backend_yaml = parse_backend_yaml(
354        source_yaml, grouped_native_functions, backend_indices
355    )
356    backend_key = parsed_backend_yaml.backend_key
357    autograd_key = parsed_backend_yaml.autograd_key
358    cpp_namespace = parsed_backend_yaml.cpp_namespace
359    backend_indices = parsed_backend_yaml.backend_indices
360    # the following 3 keys are all processed differently
361    # for full_codegen, we generate IR, kernels, etc
362    # for ir_gen, we generate only IR
363    # non_native is used to register kernels not declared in
364    # native_functions.yaml
365    full_codegen, non_native, ir_gen = parse_native_functions_keys(
366        source_yaml, grouped_native_functions
367    )
368
369    def concat_map_codegen(
370        func: Callable[[NativeFunction], Sequence[str]],
371        xs: Iterable[NativeFunctionsGroup | NativeFunction],
372        ops_list: list[OperatorName] = full_codegen,
373    ) -> Iterator[str]:
374        """
375        We code-gen for the functional variant, which is all we need for IR classes/lowerings/shape inferences, but we
376        only code-gen additional entries for the inplace variant for the native functions.
377        """
378
379        for x in xs:
380            fs = list(x.functions()) if isinstance(x, NativeFunctionsGroup) else [x]
381            for f in fs:
382                if f.func.name in ops_list:
383                    yield from func(f)
384
385    selector = SelectiveBuilder.get_nop_selector()
386
387    assert backend_key is not None
388    class_name = backend_indices[backend_key].native_function_class_name()
389
390    if impl_path is not None:
391        error_on_missing_kernels(
392            native_functions,
393            backend_indices,
394            backend_key,
395            autograd_key,
396            class_name,
397            impl_path,
398            full_codegen,
399        )
400
401    """ Validate Shape Inference Definitions
402
403    Generated lazy native functions all perform shape inference, by first using a meta:: kernel
404    if available for that op, and otherwise using a 'compute_shape_{op}' function instead.  The generator
405    knows the call signature for compute_shape_{op} because it matches the nativefunction (and meta::) signature,
406    so it just has to check whether the op is structured and generate a call for one or the other.  It's up to the dev
407    to supply the missing compute_shape_{op} function, but the codegen at least warns you about this and provides
408    the expected signature which can be copy-pasted into shape_inference.h.
409
410    compute_shape_{op} functions are handwritten and should be replaced over time as ops get ported
411    to structured kernels.
412
413    See torch/csrc/lazy/core/shape_inference.cpp #READ THIS! for more information.
414    """
415    if shape_inference_hdr is not None:
416        expected_shape_infr_decls = list(
417            concat_map_codegen(
418                dest.GenLazyShapeInferenceDefinition(
419                    backend_indices[backend_key], tensor_class
420                ),
421                grouped_native_functions,
422            )
423        )
424
425        validate_shape_inference_header(shape_inference_hdr, expected_shape_infr_decls)
426    assert class_name is not None
427
428    # Generate nativefunction declarations
429    # Note, eager registrations is set to False for the lazy TS backend as another LTC backend
430    # may want to register their own lazy kernels instead of registering the TS ones.
431    # The registration will lazily happen when init_ts_backend is called.
432    gen_dispatchkey_nativefunc_headers(
433        fm,
434        class_name,
435        cpp_namespace,
436        backend_indices,
437        grouped_native_functions,
438        backend_key,
439        autograd_key,
440        backend_name,
441    )
442
443    # Generate Dispatcher registrations which hook up the nativefunctions
444    for dispatch_key in (
445        [backend_key] if autograd_key is None else [backend_key, autograd_key]
446    ):
447        gen_dispatcher_registrations(
448            fm,
449            output_dir,
450            class_name,
451            backend_indices,
452            grouped_native_functions,
453            backend_key,
454            dispatch_key,
455            selector,
456            build_in_tree=build_in_tree,
457            per_operator_headers=per_operator_headers,
458            backend_name=backend_name,
459            eager_registration=False,
460        )
461
462    # Generate native function impls that build IR nodes
463    ns_helper = NamespaceHelper(cpp_namespace)
464    fm.write_with_template(
465        f"{backend_key}NativeFunctions.cpp",
466        "DispatchKeyNativeFunctions.cpp",
467        lambda: {
468            "includes": [
469                f"#include <{path}>"
470                for path in [
471                    tensor_class_hdr,
472                    shape_inference_hdr,
473                    "ATen/Functions.h",
474                    "ATen/native/TensorConversions.h",
475                    "ATen/NativeFunctions.h",
476                    "ATen/CompositeExplicitAutogradNonFunctionalFunctions.h",
477                    "ATen/MetaFunctions.h",
478                    "ATen/Operators.h",
479                    "ATen/native/CPUFallback.h",
480                    "torch/csrc/lazy/core/ir_builder.h",
481                    "torch/csrc/lazy/core/lazy_graph_executor.h",
482                    "torch/csrc/lazy/core/metrics.h",
483                    "torch/csrc/lazy/core/shape.h",
484                    f"{output_dir}/{backend_key}NativeFunctions.h",
485                    f"{output_dir}/LazyIr.h",
486                ]
487                + (
488                    ["torch/csrc/lazy/ts_backend/ts_eager_fallback.h"]
489                    if gen_forced_fallback_code
490                    else []
491                )
492            ],
493            "helper_fns": get_ltc_helper_fns(),
494            "native_functions_include": "",
495            "namespace_prologue": ns_helper.prologue,
496            "namespace_epilogue": ns_helper.epilogue,
497            "native_function_definitions": list(
498                concat_map_codegen(
499                    native_func_definition_generator(
500                        f"{backend_key}NativeFunctions",
501                        backend_indices[backend_key],
502                        tensor_class,
503                        gen_forced_fallback_code,
504                        backend_namespace,
505                        get_tensorlist,
506                        get_tensor_or_wrap_number,
507                        try_get_tensor,
508                        metrics_counter,
509                        create_tensor,
510                        create_from_first_tensor,
511                        create_aten_from_ltc_tensor,
512                        tuple_aten_from_ltc_tensors,
513                        lazy_tensor_ptr,
514                        get_device_fn,
515                    ),
516                    grouped_native_functions,
517                )
518            ),
519        },
520    )
521    # Generate IR node classes
522    lazy_ir_obj = lazy_ir_generator(
523        backend_indices[backend_key], backend_name, node_base, use_lazy_shape
524    )
525
526    fm.write_with_template(
527        "LazyIr.h",
528        "LazyIr.h",
529        lambda: {
530            "lazy_ir_sysinc": [
531                f"#include <{path}>"
532                for path in [
533                    "ATen/core/Formatting.h",
534                    "c10/core/ScalarType.h",
535                    "torch/csrc/lazy/core/hash.h",
536                    "torch/csrc/lazy/core/ir.h",
537                    "torch/csrc/lazy/core/shape.h",
538                    "optional",
539                    "vector",
540                ]
541            ],
542            "lazy_ir_inc": [f'#include "{node_base_hdr}"']
543            if node_base_hdr is not None
544            else [],
545            "ir_declarations": list(
546                concat_map_codegen(
547                    lazy_ir_obj, grouped_native_functions, full_codegen + ir_gen
548                )
549            ),
550            "namespace_prologue": ns_helper.prologue,
551            "namespace_epilogue": ns_helper.epilogue,
552        },
553    )
554
555    # Generate Non Native IR Node classes
556    fm.write_with_template(
557        "LazyNonNativeIr.h",
558        "LazyNonNativeIr.h",
559        lambda: {
560            "lazy_non_native_ir_inc": [
561                f"#include <{path}>"
562                for path in [
563                    "torch/csrc/lazy/core/ir.h",
564                    "torch/csrc/lazy/core/ir_builder.h",
565                    "torch/csrc/lazy/core/internal_ops/ltc_ops.h",
566                    "torch/csrc/lazy/core/shape_inference.h",
567                ]
568                + ([node_base_hdr] if node_base_hdr else [])
569                if path
570            ],
571            "non_native_ir_nodes": dest.generate_non_native_lazy_ir_nodes(
572                non_native, lazy_ir_obj
573            ),
574            "namespace_prologue": ns_helper.prologue,
575            "namespace_epilogue": ns_helper.epilogue,
576        },
577    )
578
579
580if __name__ == "__main__":
581    main()
582