# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-strict from __future__ import annotations try: # noqa: C901 from torch._higher_order_ops.executorch_call_delegate import ( executorch_call_delegate as executorch_call_delegate, get_lowered_module_name as get_lowered_module_name, is_lowered_module as is_lowered_module, ) except ImportError: # TODO: Delete this code once pytorch pin advances from typing import Any, cast import torch import torch.utils._pytree as pytree from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensorMode from torch.fx.experimental.proxy_tensor import ( disable_proxy_modes_tracing, get_proxy_slot, ProxyTorchDispatchMode, track_tensor_tree, ) from torch.utils._pytree import tree_flatten executorch_call_delegate = HigherOrderOperator("executorch_call_delegate") executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonDispatcher) executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonTLSSnapshot) executorch_call_delegate.fallthrough(torch._C.DispatchKey.ADInplaceOrView) executorch_call_delegate.fallthrough(torch._C.DispatchKey.AutocastCPU) LOWERED_BACKEND_MODULE_TYPE = "LoweredBackendModule" # pyre-ignore def trace_call_delegate(proxy_mode, func_overload, lowered_module, *args): # pyre-ignore def _unwrap_proxy(e): if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)): return e return get_proxy_slot( cast(torch.Tensor, e), proxy_mode.tracer, e, lambda e: e.proxy ) if not is_lowered_module(lowered_module): raise ValueError( "executorch_call_delegate()'s first argument must be a LoweredBackendModule" ) with disable_proxy_modes_tracing(): out = call_delegate_cpu(lowered_module, *args) get_lowered_module_name(proxy_mode.tracer.root, lowered_module) node_args = (lowered_module, *args) proxy_args = pytree.tree_map(_unwrap_proxy, node_args) out_proxy = proxy_mode.tracer.create_proxy( "call_function", func_overload, proxy_args, {}, name="executorch_call_delegate", ) return track_tensor_tree( out, out_proxy, constant=None, tracer=proxy_mode.tracer ) @executorch_call_delegate.py_impl(torch._C.DispatchKey.CompositeExplicitAutograd) # pyre-ignore def call_delegate_cpu(lowered_module, *args): # FX creates this immutable_dict/list concept. Get rid of this. map_types = { torch.fx.immutable_collections.immutable_dict: dict, torch.fx.immutable_collections.immutable_list: list, } new_args = pytree.tree_map_only( tuple(map_types.keys()), lambda a: map_types[type(a)](a), args, lambda a: isinstance(a, tuple(map_types.keys())), ) return lowered_module.original_module.module()(*new_args) @executorch_call_delegate.py_impl(torch._C.DispatchKey.Autograd) # pyre-ignore def call_delegate_autograd(lowered_module, *args): # TODO: support autograd flat_operands, _ = tree_flatten([lowered_module, *args]) requires_grad = any( f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor) ) with torch._C._ExcludeDispatchKeyGuard( torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU) ): res = executorch_call_delegate(lowered_module, *args) if requires_grad: # Create aliases of the output that has requires_grad=True. We need # at least one of the inputs to err_fn to require grad so that the # output will have a grad_fn. # pyre-ignore def fake_requires_grad(var): if var is not None: var = var.detach() if torch.is_floating_point(var) or torch.is_complex(var): var.requires_grad = True return var return pytree.tree_map_only(torch.Tensor, fake_requires_grad, res) return res @executorch_call_delegate.py_impl(ProxyTorchDispatchMode) # pyre-ignore def call_delegate_proxy_torch_dispatch_mode(mode, lowered_module, *args): res = trace_call_delegate(mode, executorch_call_delegate, lowered_module, *args) return res @executorch_call_delegate.py_impl(FakeTensorMode) # pyre-ignore def call_delegate_fake_tensor_mode(mode, lowered_module, *args): with mode: return call_delegate_cpu(lowered_module, *args) @executorch_call_delegate.py_functionalize_impl # pyre-ignore def call_delegate_functionalize(ctx, lowered_module, *args): unwrapped_args = tuple(ctx.unwrap_tensors(arg) for arg in args) with ctx.redispatch_to_next(): res = executorch_call_delegate(lowered_module, *unwrapped_args) return ctx.wrap_tensors(res) # pyre-ignore: Missing parameter annotation [2]: Parameter `obj` must have a type other than `Any`.Pyre def is_lowered_module(obj: Any) -> bool: """ This function is added to avoid using isinstance(obj, LoweredBackendModule) as it will import LoweredBackendModule, which may cause a circular import. """ return type(obj).__name__ == LOWERED_BACKEND_MODULE_TYPE def get_lowered_module_name( root: torch.nn.Module, # pyre-ignore: Undefined or invalid type [11]: Annotation `LoweredBackendModule` is not defined as a type. lowered_module: LOWERED_BACKEND_MODULE_TYPE, # noqa ) -> str: """ Adds the given lowered_module into the given root module and returns the name of the module added. """ # Find a qualifying name for the lowered submodule qualname = None i = 0 while True: qualname = f"lowered_module_{i}" if not hasattr(root, qualname): break i += 1 assert qualname is not None root.add_module(qualname, lowered_module) return qualname