1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9""" 10This module contains APIs to manipulate ops. 11""" 12from dataclasses import dataclass 13from typing import Callable, Dict 14 15import torch 16from executorch.exir.tensor import TensorSpec 17 18 19@dataclass 20class ScratchTensorMetadata: 21 dtype: torch.dtype 22 shape: torch.Size 23 layout: torch.layout = torch.strided 24 device: torch.device = torch.device("cpu") 25 is_sparse: bool = False 26 27 28ScratchCallableType = Callable[..., Dict[str, ScratchTensorMetadata]] 29 30 31def attach_get_scratch_metas_fn( 32 out_variant: torch._ops.OpOverload, 33) -> Callable[[ScratchCallableType], ScratchCallableType]: 34 """ 35 Apply this decorator to the get_scratch_metas methods for `out_variant` op. 36 The decorator will do the job to attach the get_scratch_metas method 37 to the out variant op. 38 39 The get_scratch_metas method has the same signature as the out variant op. 40 There are 2 difference though: 41 - the Tensor input arguments are all replaced with TensorSpec 42 - the output is a dictionary of ScratchTensorMetadata 43 """ 44 45 def to_tensor_spec(meta: ScratchTensorMetadata) -> TensorSpec: 46 return TensorSpec( 47 const=False, 48 requires_grad=False, 49 # fields copy from ScratchTensorMetadata 50 dtype=meta.dtype, 51 shape=meta.shape, 52 layout=meta.layout, 53 is_sparse=meta.is_sparse, 54 ) 55 56 def adapt_return_value( 57 get_scratch_metas_fn: ScratchCallableType, 58 ) -> Callable[..., Dict[str, TensorSpec]]: 59 """ 60 Adapt return value from a ScratchTensorMetadata to a TensorSpec 61 """ 62 63 def wrapper(*args: TensorSpec, **kwargs: TensorSpec) -> Dict[str, TensorSpec]: 64 meta_dict = get_scratch_metas_fn(*args, **kwargs) 65 return {k: to_tensor_spec(v) for k, v in meta_dict.items()} 66 67 return wrapper 68 69 def wrapper(get_scratch_metas_fn: ScratchCallableType) -> ScratchCallableType: 70 # pyre-fixme[16]: `OpOverload` has no attribute `get_scratch_metas`. 71 out_variant.get_scratch_metas = adapt_return_value(get_scratch_metas_fn) 72 return get_scratch_metas_fn 73 74 return wrapper 75 76 77# pyre-ignore 78def attach_calculate_upper_bound_shape_fn(func_op: torch._ops.OpOverload): 79 """ 80 The input is the OpOverload for the functional op. 81 """ 82 83 # pyre-ignore 84 def wrapper(calculate_upper_bound_shape_fn): 85 # pyre-fixme[16]: `OpOverload` has no attribute `calculate_upper_bound_shape`. 86 func_op.calculate_upper_bound_shape = calculate_upper_bound_shape_fn 87 return calculate_upper_bound_shape_fn 88 89 return wrapper 90