1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict 8*523fa7a6SAndroid Build Coastguard Worker# pyre-ignore-all-errors[6] 9*523fa7a6SAndroid Build Coastguard Worker# pyre-ignore-all-errors[16] 10*523fa7a6SAndroid Build Coastguard Workerfrom __future__ import annotations 11*523fa7a6SAndroid Build Coastguard Worker 12*523fa7a6SAndroid Build Coastguard Workerimport copy 13*523fa7a6SAndroid Build Coastguard Worker 14*523fa7a6SAndroid Build Coastguard Workerimport math 15*523fa7a6SAndroid Build Coastguard Workerimport typing 16*523fa7a6SAndroid Build Coastguard Workerfrom typing import Dict, List, Optional, Tuple, Union 17*523fa7a6SAndroid Build Coastguard Worker 18*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir.schema as schema 19*523fa7a6SAndroid Build Coastguard Workerimport torch 20*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.error import internal_assert 21*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.schema import ScalarType, TensorShapeDynamism 22*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.sym_util import eval_shape 23*523fa7a6SAndroid Build Coastguard Worker 24*523fa7a6SAndroid Build Coastguard Worker 25*523fa7a6SAndroid Build Coastguard Workerclass AddressSpaceOverflowException(Exception): 26*523fa7a6SAndroid Build Coastguard Worker pass 27*523fa7a6SAndroid Build Coastguard Worker 28*523fa7a6SAndroid Build Coastguard Worker 29*523fa7a6SAndroid Build Coastguard Workerdef num_bytes_from_shape_and_dtype(shape: torch.Size, dtype: torch.dtype) -> int: 30*523fa7a6SAndroid Build Coastguard Worker """ 31*523fa7a6SAndroid Build Coastguard Worker Assume the tensor is a contiguous one. 32*523fa7a6SAndroid Build Coastguard Worker """ 33*523fa7a6SAndroid Build Coastguard Worker 34*523fa7a6SAndroid Build Coastguard Worker return math.prod(shape) * torch._utils._element_size(dtype) 35*523fa7a6SAndroid Build Coastguard Worker 36*523fa7a6SAndroid Build Coastguard Worker 37*523fa7a6SAndroid Build Coastguard Workerdef contiguous_stride_from_shape(shape: torch.Size) -> Tuple[int]: 38*523fa7a6SAndroid Build Coastguard Worker strides = [] 39*523fa7a6SAndroid Build Coastguard Worker accum = 1 40*523fa7a6SAndroid Build Coastguard Worker for sz in reversed(shape): 41*523fa7a6SAndroid Build Coastguard Worker strides.append(accum) 42*523fa7a6SAndroid Build Coastguard Worker # For sizes[i] == 0, treat it as 1 to be consistent with core Pytorch 43*523fa7a6SAndroid Build Coastguard Worker # This preserves the PT equivalent behavior for dims with 0 elements 44*523fa7a6SAndroid Build Coastguard Worker if isinstance(sz, int): 45*523fa7a6SAndroid Build Coastguard Worker if sz != 0: 46*523fa7a6SAndroid Build Coastguard Worker accum *= sz 47*523fa7a6SAndroid Build Coastguard Worker else: 48*523fa7a6SAndroid Build Coastguard Worker # Unbacked symints may error on the != 0 check 49*523fa7a6SAndroid Build Coastguard Worker accum *= sz 50*523fa7a6SAndroid Build Coastguard Worker return tuple(reversed(strides)) 51*523fa7a6SAndroid Build Coastguard Worker 52*523fa7a6SAndroid Build Coastguard Worker 53*523fa7a6SAndroid Build Coastguard Workerdef dim_order_from_stride(stride: Tuple[int]) -> Tuple[bytes]: 54*523fa7a6SAndroid Build Coastguard Worker """ 55*523fa7a6SAndroid Build Coastguard Worker Dimension order represents how dimensions are laid out in memory, 56*523fa7a6SAndroid Build Coastguard Worker starting from the outer-most to the inner-most dimension. 57*523fa7a6SAndroid Build Coastguard Worker Thus, the conversion from strides is done by sorting the strides 58*523fa7a6SAndroid Build Coastguard Worker from larger to smaller since the dimension with the largest stride 59*523fa7a6SAndroid Build Coastguard Worker is the outer-most and the dimension with the smallest stride is the inner-most. 60*523fa7a6SAndroid Build Coastguard Worker For example, tensor with sizes = (3, 5, 2) and strides = (5, 1, 15), implies 61*523fa7a6SAndroid Build Coastguard Worker dimension order of (2, 0, 1). Dimension order of (2, 0, 1) can be obtained 62*523fa7a6SAndroid Build Coastguard Worker by sorting strides from large to smaller. 63*523fa7a6SAndroid Build Coastguard Worker 64*523fa7a6SAndroid Build Coastguard Worker When strides do not convey dimension order unambiguously, dimension order 65*523fa7a6SAndroid Build Coastguard Worker returned is dependent on stability of sort. In python same key elements are kept 66*523fa7a6SAndroid Build Coastguard Worker in original order. Thus when strides = (4, 3, 1, 1) returned value is (0, 1, 2, 3) 67*523fa7a6SAndroid Build Coastguard Worker Another example is: sizes = (1, 3, 1, 1) with strides = (3, 1, 3, 3), returned 68*523fa7a6SAndroid Build Coastguard Worker value is (0, 2, 3, 1) 69*523fa7a6SAndroid Build Coastguard Worker """ 70*523fa7a6SAndroid Build Coastguard Worker for _, s in enumerate(stride): 71*523fa7a6SAndroid Build Coastguard Worker if s == 0: 72*523fa7a6SAndroid Build Coastguard Worker raise ValueError("0 in strides is not supported for ExecuTorch.") 73*523fa7a6SAndroid Build Coastguard Worker sorted_dims = [ 74*523fa7a6SAndroid Build Coastguard Worker i[0] for i in sorted(enumerate(stride), key=lambda x: x[1], reverse=True) 75*523fa7a6SAndroid Build Coastguard Worker ] 76*523fa7a6SAndroid Build Coastguard Worker return tuple(typing.cast(Tuple[bytes], sorted_dims)) 77*523fa7a6SAndroid Build Coastguard Worker 78*523fa7a6SAndroid Build Coastguard Worker 79*523fa7a6SAndroid Build Coastguard Workerdef stride_from_dim_order(sizes: List[int], dim_order: List[bytes]) -> List[int]: 80*523fa7a6SAndroid Build Coastguard Worker """ 81*523fa7a6SAndroid Build Coastguard Worker Converts dim order to stride using sizes 82*523fa7a6SAndroid Build Coastguard Worker e.g. if sizes = (2, 3, 4) and dim_order = (0, 1, 2) then strides = (12, 4, 1) 83*523fa7a6SAndroid Build Coastguard Worker while for the same size if dim_order = (0, 2, 1) then strides = (12, 1, 3) 84*523fa7a6SAndroid Build Coastguard Worker See executorch/runtime/core/exec_aten/util/dim_order_util.h for details 85*523fa7a6SAndroid Build Coastguard Worker Args: 86*523fa7a6SAndroid Build Coastguard Worker sizes (Tuple[int]): sizes of the tensor 87*523fa7a6SAndroid Build Coastguard Worker dim_order (Tuple[bytes]): dim order of the tensor 88*523fa7a6SAndroid Build Coastguard Worker Returns: 89*523fa7a6SAndroid Build Coastguard Worker Tuple[int]: stride 90*523fa7a6SAndroid Build Coastguard Worker """ 91*523fa7a6SAndroid Build Coastguard Worker if len(sizes) == 0: 92*523fa7a6SAndroid Build Coastguard Worker return [] 93*523fa7a6SAndroid Build Coastguard Worker strides = copy.deepcopy(sizes) 94*523fa7a6SAndroid Build Coastguard Worker ndim = len(sizes) 95*523fa7a6SAndroid Build Coastguard Worker strides[dim_order[ndim - 1]] = 1 96*523fa7a6SAndroid Build Coastguard Worker for i in range(ndim - 2, -1, -1): 97*523fa7a6SAndroid Build Coastguard Worker if sizes[dim_order[i + 1]] == 0: 98*523fa7a6SAndroid Build Coastguard Worker strides[dim_order[i]] = strides[dim_order[i + 1]] 99*523fa7a6SAndroid Build Coastguard Worker else: 100*523fa7a6SAndroid Build Coastguard Worker strides[dim_order[i]] = sizes[dim_order[i + 1]] * strides[dim_order[i + 1]] 101*523fa7a6SAndroid Build Coastguard Worker return strides 102*523fa7a6SAndroid Build Coastguard Worker 103*523fa7a6SAndroid Build Coastguard Worker 104*523fa7a6SAndroid Build Coastguard Workerdef calculate_aligned_num_bytes(num: int, alignment: int) -> int: 105*523fa7a6SAndroid Build Coastguard Worker return math.ceil(num / alignment) * alignment 106*523fa7a6SAndroid Build Coastguard Worker 107*523fa7a6SAndroid Build Coastguard Worker 108*523fa7a6SAndroid Build Coastguard Workerdef determine_tensor_dynanism(shape: torch.Size) -> TensorShapeDynamism: 109*523fa7a6SAndroid Build Coastguard Worker if all(isinstance(s, int) for s in shape): 110*523fa7a6SAndroid Build Coastguard Worker return TensorShapeDynamism.STATIC 111*523fa7a6SAndroid Build Coastguard Worker else: 112*523fa7a6SAndroid Build Coastguard Worker try: 113*523fa7a6SAndroid Build Coastguard Worker _ = eval_shape(shape) 114*523fa7a6SAndroid Build Coastguard Worker return TensorShapeDynamism.DYNAMIC_BOUND 115*523fa7a6SAndroid Build Coastguard Worker except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: 116*523fa7a6SAndroid Build Coastguard Worker return TensorShapeDynamism.DYNAMIC_UNBOUND 117*523fa7a6SAndroid Build Coastguard Worker 118*523fa7a6SAndroid Build Coastguard Worker 119*523fa7a6SAndroid Build Coastguard WorkerALIGNMENT = 16 120*523fa7a6SAndroid Build Coastguard Worker 121*523fa7a6SAndroid Build Coastguard Worker 122*523fa7a6SAndroid Build Coastguard Workerclass TensorSpec: 123*523fa7a6SAndroid Build Coastguard Worker """ 124*523fa7a6SAndroid Build Coastguard Worker Captures the metadata for a given Tensor (ex. scalar type, storage, etc.). 125*523fa7a6SAndroid Build Coastguard Worker """ 126*523fa7a6SAndroid Build Coastguard Worker 127*523fa7a6SAndroid Build Coastguard Worker def __init__( 128*523fa7a6SAndroid Build Coastguard Worker self, 129*523fa7a6SAndroid Build Coastguard Worker dtype: torch.dtype, 130*523fa7a6SAndroid Build Coastguard Worker shape: torch.Size, 131*523fa7a6SAndroid Build Coastguard Worker layout: torch.layout = torch.strided, 132*523fa7a6SAndroid Build Coastguard Worker is_sparse: bool = False, 133*523fa7a6SAndroid Build Coastguard Worker const: bool = False, 134*523fa7a6SAndroid Build Coastguard Worker requires_grad: bool = False, 135*523fa7a6SAndroid Build Coastguard Worker ) -> None: 136*523fa7a6SAndroid Build Coastguard Worker self.scalar_type = dtype 137*523fa7a6SAndroid Build Coastguard Worker self.const = const 138*523fa7a6SAndroid Build Coastguard Worker self.alignment: int = ALIGNMENT 139*523fa7a6SAndroid Build Coastguard Worker self.storage: Optional[torch.UntypedStorage] = None 140*523fa7a6SAndroid Build Coastguard Worker # convert to list making it easier to handle type checking 141*523fa7a6SAndroid Build Coastguard Worker self.shape: List[int] = list(shape) 142*523fa7a6SAndroid Build Coastguard Worker self.stride: Tuple[int] = contiguous_stride_from_shape(shape) 143*523fa7a6SAndroid Build Coastguard Worker self.dim_order: Tuple[bytes] = dim_order_from_stride(self.stride) 144*523fa7a6SAndroid Build Coastguard Worker self.requires_grad = requires_grad 145*523fa7a6SAndroid Build Coastguard Worker self.layout = layout 146*523fa7a6SAndroid Build Coastguard Worker self.is_sparse = is_sparse 147*523fa7a6SAndroid Build Coastguard Worker self.init_mem_planning_fields() 148*523fa7a6SAndroid Build Coastguard Worker self.shape_dynamism: TensorShapeDynamism = determine_tensor_dynanism(self.shape) 149*523fa7a6SAndroid Build Coastguard Worker 150*523fa7a6SAndroid Build Coastguard Worker @property 151*523fa7a6SAndroid Build Coastguard Worker def allocated_memory(self) -> int: 152*523fa7a6SAndroid Build Coastguard Worker nbytes = num_bytes_from_shape_and_dtype(self.shape, self.dtype) 153*523fa7a6SAndroid Build Coastguard Worker return calculate_aligned_num_bytes(nbytes, self.alignment) 154*523fa7a6SAndroid Build Coastguard Worker 155*523fa7a6SAndroid Build Coastguard Worker def realign(self, new_alignment: int) -> int: 156*523fa7a6SAndroid Build Coastguard Worker self.alignment = new_alignment 157*523fa7a6SAndroid Build Coastguard Worker return self.allocated_memory 158*523fa7a6SAndroid Build Coastguard Worker 159*523fa7a6SAndroid Build Coastguard Worker def nbytes(self) -> int: 160*523fa7a6SAndroid Build Coastguard Worker return num_bytes_from_shape_and_dtype(self.shape, self.dtype) 161*523fa7a6SAndroid Build Coastguard Worker 162*523fa7a6SAndroid Build Coastguard Worker @classmethod 163*523fa7a6SAndroid Build Coastguard Worker def from_tensor(cls, tensor: torch.Tensor, const: bool = False) -> TensorSpec: 164*523fa7a6SAndroid Build Coastguard Worker if const: 165*523fa7a6SAndroid Build Coastguard Worker # for non-contigous tensors, convert to a contiguous one 166*523fa7a6SAndroid Build Coastguard Worker tensor = tensor.contiguous() 167*523fa7a6SAndroid Build Coastguard Worker # Weights cannot be views during emission or serialization 168*523fa7a6SAndroid Build Coastguard Worker if tensor.nbytes != tensor.untyped_storage().nbytes(): 169*523fa7a6SAndroid Build Coastguard Worker tensor = tensor.clone() 170*523fa7a6SAndroid Build Coastguard Worker 171*523fa7a6SAndroid Build Coastguard Worker spec = cls( 172*523fa7a6SAndroid Build Coastguard Worker dtype=tensor.dtype, 173*523fa7a6SAndroid Build Coastguard Worker shape=tensor.shape, 174*523fa7a6SAndroid Build Coastguard Worker layout=tensor.layout, 175*523fa7a6SAndroid Build Coastguard Worker const=const, 176*523fa7a6SAndroid Build Coastguard Worker is_sparse=tensor.is_sparse, 177*523fa7a6SAndroid Build Coastguard Worker ) 178*523fa7a6SAndroid Build Coastguard Worker spec.stride = tensor.stride() 179*523fa7a6SAndroid Build Coastguard Worker spec.dim_order = dim_order_from_stride(spec.stride) 180*523fa7a6SAndroid Build Coastguard Worker spec.requires_grad = tensor.requires_grad 181*523fa7a6SAndroid Build Coastguard Worker spec.storage = tensor.untyped_storage() if const else None 182*523fa7a6SAndroid Build Coastguard Worker 183*523fa7a6SAndroid Build Coastguard Worker return spec 184*523fa7a6SAndroid Build Coastguard Worker 185*523fa7a6SAndroid Build Coastguard Worker def init_mem_planning_fields(self) -> None: 186*523fa7a6SAndroid Build Coastguard Worker self.lifetime = [None, None] 187*523fa7a6SAndroid Build Coastguard Worker self.mem_id = None 188*523fa7a6SAndroid Build Coastguard Worker self.mem_obj_id = None 189*523fa7a6SAndroid Build Coastguard Worker self.mem_offset = None 190*523fa7a6SAndroid Build Coastguard Worker 191*523fa7a6SAndroid Build Coastguard Worker @property 192*523fa7a6SAndroid Build Coastguard Worker def dtype(self) -> torch.dtype: 193*523fa7a6SAndroid Build Coastguard Worker return self.scalar_type 194*523fa7a6SAndroid Build Coastguard Worker 195*523fa7a6SAndroid Build Coastguard Worker @property 196*523fa7a6SAndroid Build Coastguard Worker def is_dynamic_shape_tensor(self) -> bool: 197*523fa7a6SAndroid Build Coastguard Worker return self.shape_dynamism != schema.TensorShapeDynamism.STATIC 198*523fa7a6SAndroid Build Coastguard Worker 199*523fa7a6SAndroid Build Coastguard Worker @property 200*523fa7a6SAndroid Build Coastguard Worker def is_static_shape_tensor(self) -> bool: 201*523fa7a6SAndroid Build Coastguard Worker return self.shape_dynamism == TensorShapeDynamism.STATIC 202*523fa7a6SAndroid Build Coastguard Worker 203*523fa7a6SAndroid Build Coastguard Worker @property 204*523fa7a6SAndroid Build Coastguard Worker def is_upper_bound_tensor(self) -> bool: 205*523fa7a6SAndroid Build Coastguard Worker return self.shape_dynamism == TensorShapeDynamism.DYNAMIC_BOUND 206*523fa7a6SAndroid Build Coastguard Worker 207*523fa7a6SAndroid Build Coastguard Worker @property 208*523fa7a6SAndroid Build Coastguard Worker def is_dynamic_unbound_tensor(self) -> bool: 209*523fa7a6SAndroid Build Coastguard Worker return self.shape_dynamism == TensorShapeDynamism.DYNAMIC_UNBOUND 210*523fa7a6SAndroid Build Coastguard Worker 211*523fa7a6SAndroid Build Coastguard Worker def debug(self) -> str: 212*523fa7a6SAndroid Build Coastguard Worker return ( 213*523fa7a6SAndroid Build Coastguard Worker f"TensorSpec(id={id(self)}, const={self.const}, scalar_type={self.scalar_type}" 214*523fa7a6SAndroid Build Coastguard Worker + f", allocated_memory={self.allocated_memory}, mem_id={self.mem_id}" 215*523fa7a6SAndroid Build Coastguard Worker + f", mem_offset={self.mem_offset}, lifetime={self.lifetime}" 216*523fa7a6SAndroid Build Coastguard Worker + f", shape_dynamism={self.shape_dynamism}" 217*523fa7a6SAndroid Build Coastguard Worker + (f", shape={self.shape}") 218*523fa7a6SAndroid Build Coastguard Worker + ")" 219*523fa7a6SAndroid Build Coastguard Worker ) 220*523fa7a6SAndroid Build Coastguard Worker 221*523fa7a6SAndroid Build Coastguard Worker def __repr__(self) -> str: 222*523fa7a6SAndroid Build Coastguard Worker """ 223*523fa7a6SAndroid Build Coastguard Worker Round-trippable printing function 224*523fa7a6SAndroid Build Coastguard Worker """ 225*523fa7a6SAndroid Build Coastguard Worker return ( 226*523fa7a6SAndroid Build Coastguard Worker f"TensorSpec(dtype={self.scalar_type}, shape={self.shape}" 227*523fa7a6SAndroid Build Coastguard Worker + f", layout={self.layout}" 228*523fa7a6SAndroid Build Coastguard Worker + f", is_sparse={self.is_sparse}" 229*523fa7a6SAndroid Build Coastguard Worker + f", shape_dynamism={self.shape_dynamism}" 230*523fa7a6SAndroid Build Coastguard Worker + f", const={self.const}, requires_grad={self.requires_grad}" 231*523fa7a6SAndroid Build Coastguard Worker + ")" 232*523fa7a6SAndroid Build Coastguard Worker ) 233*523fa7a6SAndroid Build Coastguard Worker 234*523fa7a6SAndroid Build Coastguard Worker 235*523fa7a6SAndroid Build Coastguard Workerdef memory_format_enum(memory_format: torch.memory_format) -> int: 236*523fa7a6SAndroid Build Coastguard Worker internal_assert( 237*523fa7a6SAndroid Build Coastguard Worker isinstance(memory_format, torch.memory_format), 238*523fa7a6SAndroid Build Coastguard Worker "We only support torch.memory_format", 239*523fa7a6SAndroid Build Coastguard Worker ) 240*523fa7a6SAndroid Build Coastguard Worker table = { 241*523fa7a6SAndroid Build Coastguard Worker torch.contiguous_format: 0, 242*523fa7a6SAndroid Build Coastguard Worker torch.preserve_format: 1, 243*523fa7a6SAndroid Build Coastguard Worker } 244*523fa7a6SAndroid Build Coastguard Worker return table[memory_format] 245*523fa7a6SAndroid Build Coastguard Worker 246*523fa7a6SAndroid Build Coastguard Worker 247*523fa7a6SAndroid Build Coastguard Workerscalar_type_table: Dict[torch.dtype, ScalarType] = { 248*523fa7a6SAndroid Build Coastguard Worker torch.uint8: ScalarType.BYTE, 249*523fa7a6SAndroid Build Coastguard Worker torch.int8: ScalarType.CHAR, 250*523fa7a6SAndroid Build Coastguard Worker torch.int16: ScalarType.SHORT, 251*523fa7a6SAndroid Build Coastguard Worker torch.int32: ScalarType.INT, 252*523fa7a6SAndroid Build Coastguard Worker torch.int64: ScalarType.LONG, 253*523fa7a6SAndroid Build Coastguard Worker torch.half: ScalarType.HALF, 254*523fa7a6SAndroid Build Coastguard Worker torch.float: ScalarType.FLOAT, 255*523fa7a6SAndroid Build Coastguard Worker torch.double: ScalarType.DOUBLE, 256*523fa7a6SAndroid Build Coastguard Worker torch.complex32: ScalarType.COMPLEX32, 257*523fa7a6SAndroid Build Coastguard Worker torch.complex64: ScalarType.COMPLEX64, 258*523fa7a6SAndroid Build Coastguard Worker torch.complex128: ScalarType.COMPLEX128, 259*523fa7a6SAndroid Build Coastguard Worker torch.bool: ScalarType.BOOL, 260*523fa7a6SAndroid Build Coastguard Worker torch.qint8: ScalarType.QINT8, 261*523fa7a6SAndroid Build Coastguard Worker torch.quint8: ScalarType.QUINT8, 262*523fa7a6SAndroid Build Coastguard Worker torch.qint32: ScalarType.QINT32, 263*523fa7a6SAndroid Build Coastguard Worker torch.bfloat16: ScalarType.BFLOAT16, 264*523fa7a6SAndroid Build Coastguard Worker torch.quint4x2: ScalarType.QUINT4x2, 265*523fa7a6SAndroid Build Coastguard Worker torch.uint16: ScalarType.UINT16, 266*523fa7a6SAndroid Build Coastguard Worker} 267*523fa7a6SAndroid Build Coastguard Worker 268*523fa7a6SAndroid Build Coastguard Worker 269*523fa7a6SAndroid Build Coastguard Workerenum_to_scalar_map: Dict[ScalarType, torch.dtype] = { 270*523fa7a6SAndroid Build Coastguard Worker scalar_type_table[key]: key for key in scalar_type_table 271*523fa7a6SAndroid Build Coastguard Worker} 272*523fa7a6SAndroid Build Coastguard Worker 273*523fa7a6SAndroid Build Coastguard Worker 274*523fa7a6SAndroid Build Coastguard Workerdef scalar_type_enum(dtype: torch.dtype) -> ScalarType: 275*523fa7a6SAndroid Build Coastguard Worker # TODO (zhengxu) single source of truth from c10/core/ScalarType.h. 276*523fa7a6SAndroid Build Coastguard Worker internal_assert( 277*523fa7a6SAndroid Build Coastguard Worker isinstance(dtype, torch.dtype), "We only support dtypes defined in Pytorch Core" 278*523fa7a6SAndroid Build Coastguard Worker ) 279*523fa7a6SAndroid Build Coastguard Worker return scalar_type_table[dtype] 280*523fa7a6SAndroid Build Coastguard Worker 281*523fa7a6SAndroid Build Coastguard Worker 282*523fa7a6SAndroid Build Coastguard Workerdef get_scalar_type(enum: ScalarType) -> torch.dtype: 283*523fa7a6SAndroid Build Coastguard Worker return enum_to_scalar_map[enum] 284*523fa7a6SAndroid Build Coastguard Worker 285*523fa7a6SAndroid Build Coastguard Worker 286*523fa7a6SAndroid Build Coastguard Workerdef layout_enum(layout: torch.layout) -> int: 287*523fa7a6SAndroid Build Coastguard Worker # TODO single source of truth. 288*523fa7a6SAndroid Build Coastguard Worker table = { 289*523fa7a6SAndroid Build Coastguard Worker torch.strided: 0, 290*523fa7a6SAndroid Build Coastguard Worker torch.sparse_coo: 1, 291*523fa7a6SAndroid Build Coastguard Worker } 292*523fa7a6SAndroid Build Coastguard Worker return table[layout] 293*523fa7a6SAndroid Build Coastguard Worker 294*523fa7a6SAndroid Build Coastguard Worker 295*523fa7a6SAndroid Build Coastguard Workerdef make_allocation_info(mem_id: int, mem_offset: int) -> schema.AllocationDetails: 296*523fa7a6SAndroid Build Coastguard Worker """ 297*523fa7a6SAndroid Build Coastguard Worker Creates the allocation_details object for creating tensors 298*523fa7a6SAndroid Build Coastguard Worker """ 299*523fa7a6SAndroid Build Coastguard Worker if mem_offset < 0: 300*523fa7a6SAndroid Build Coastguard Worker raise ValueError(f"mem_offset {mem_offset} must not be negative") 301*523fa7a6SAndroid Build Coastguard Worker memory_offset_low = mem_offset & ((1 << 32) - 1) 302*523fa7a6SAndroid Build Coastguard Worker memory_offset_high = mem_offset >> 32 303*523fa7a6SAndroid Build Coastguard Worker if memory_offset_high >= 1 << 32: 304*523fa7a6SAndroid Build Coastguard Worker raise AddressSpaceOverflowException( 305*523fa7a6SAndroid Build Coastguard Worker f"mem_offset {mem_offset} does not fit in 64 bits" 306*523fa7a6SAndroid Build Coastguard Worker ) 307*523fa7a6SAndroid Build Coastguard Worker 308*523fa7a6SAndroid Build Coastguard Worker allocation_info = schema.AllocationDetails( 309*523fa7a6SAndroid Build Coastguard Worker memory_id=mem_id, 310*523fa7a6SAndroid Build Coastguard Worker memory_offset_low=memory_offset_low, 311*523fa7a6SAndroid Build Coastguard Worker memory_offset_high=memory_offset_high, 312*523fa7a6SAndroid Build Coastguard Worker ) 313*523fa7a6SAndroid Build Coastguard Worker return allocation_info 314*523fa7a6SAndroid Build Coastguard Worker 315*523fa7a6SAndroid Build Coastguard Worker 316*523fa7a6SAndroid Build Coastguard Workerdef make_tensor_value( 317*523fa7a6SAndroid Build Coastguard Worker data_buffer_idx: int, 318*523fa7a6SAndroid Build Coastguard Worker allocation_info: Optional[schema.AllocationDetails], 319*523fa7a6SAndroid Build Coastguard Worker spec: TensorSpec, 320*523fa7a6SAndroid Build Coastguard Worker) -> schema.Tensor: 321*523fa7a6SAndroid Build Coastguard Worker """ 322*523fa7a6SAndroid Build Coastguard Worker Converts the normal torch tensor to a flatbuffer tensor. 323*523fa7a6SAndroid Build Coastguard Worker """ 324*523fa7a6SAndroid Build Coastguard Worker 325*523fa7a6SAndroid Build Coastguard Worker def to_list( 326*523fa7a6SAndroid Build Coastguard Worker x: Union[torch.Size, int, List[int], Tuple[int]] 327*523fa7a6SAndroid Build Coastguard Worker ) -> Union[List[int], List[torch.Size]]: 328*523fa7a6SAndroid Build Coastguard Worker if isinstance(x, torch.Size) or isinstance(x, tuple): 329*523fa7a6SAndroid Build Coastguard Worker return list(x) 330*523fa7a6SAndroid Build Coastguard Worker elif isinstance(x, int): 331*523fa7a6SAndroid Build Coastguard Worker return [x] 332*523fa7a6SAndroid Build Coastguard Worker else: 333*523fa7a6SAndroid Build Coastguard Worker return x 334*523fa7a6SAndroid Build Coastguard Worker 335*523fa7a6SAndroid Build Coastguard Worker tensor_size = to_list(spec.shape) 336*523fa7a6SAndroid Build Coastguard Worker tensor_dim_order = to_list(spec.dim_order) 337*523fa7a6SAndroid Build Coastguard Worker 338*523fa7a6SAndroid Build Coastguard Worker flatbuffer_tensor = schema.Tensor( 339*523fa7a6SAndroid Build Coastguard Worker scalar_type=scalar_type_enum(spec.scalar_type), 340*523fa7a6SAndroid Build Coastguard Worker # The runtime currently only supports tensors with offsets of zero. 341*523fa7a6SAndroid Build Coastguard Worker storage_offset=0, 342*523fa7a6SAndroid Build Coastguard Worker sizes=tensor_size, 343*523fa7a6SAndroid Build Coastguard Worker dim_order=tensor_dim_order, 344*523fa7a6SAndroid Build Coastguard Worker requires_grad=spec.requires_grad, 345*523fa7a6SAndroid Build Coastguard Worker data_buffer_idx=data_buffer_idx, 346*523fa7a6SAndroid Build Coastguard Worker allocation_info=allocation_info, 347*523fa7a6SAndroid Build Coastguard Worker layout=layout_enum(spec.layout), 348*523fa7a6SAndroid Build Coastguard Worker shape_dynamism=spec.shape_dynamism, 349*523fa7a6SAndroid Build Coastguard Worker ) 350*523fa7a6SAndroid Build Coastguard Worker return flatbuffer_tensor 351*523fa7a6SAndroid Build Coastguard Worker 352*523fa7a6SAndroid Build Coastguard Worker 353*523fa7a6SAndroid Build Coastguard Workerdef check_spec(tensor: torch.Tensor, spec: TensorSpec) -> None: 354*523fa7a6SAndroid Build Coastguard Worker internal_assert( 355*523fa7a6SAndroid Build Coastguard Worker tensor.is_sparse == spec.is_sparse, 356*523fa7a6SAndroid Build Coastguard Worker f"Tensor attribute 'is_sparse' is expected to be equal to '{spec.is_sparse}', actually got: '{tensor.is_sparse}'", 357*523fa7a6SAndroid Build Coastguard Worker ) 358*523fa7a6SAndroid Build Coastguard Worker internal_assert( 359*523fa7a6SAndroid Build Coastguard Worker tensor.shape == spec.shape, 360*523fa7a6SAndroid Build Coastguard Worker f"Tensor attribute 'shape' is expected to be equal to '{spec.shape}', actually got: '{tensor.shape}'", 361*523fa7a6SAndroid Build Coastguard Worker ) 362*523fa7a6SAndroid Build Coastguard Worker internal_assert( 363*523fa7a6SAndroid Build Coastguard Worker tensor.dtype == spec.dtype, 364*523fa7a6SAndroid Build Coastguard Worker f"Tensor attribute 'dtype' is expected to be equal to '{spec.dtype}', actually got: '{tensor.dtype}'", 365*523fa7a6SAndroid Build Coastguard Worker ) 366