xref: /aosp_15_r20/external/executorch/exir/tensor.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict
8*523fa7a6SAndroid Build Coastguard Worker# pyre-ignore-all-errors[6]
9*523fa7a6SAndroid Build Coastguard Worker# pyre-ignore-all-errors[16]
10*523fa7a6SAndroid Build Coastguard Workerfrom __future__ import annotations
11*523fa7a6SAndroid Build Coastguard Worker
12*523fa7a6SAndroid Build Coastguard Workerimport copy
13*523fa7a6SAndroid Build Coastguard Worker
14*523fa7a6SAndroid Build Coastguard Workerimport math
15*523fa7a6SAndroid Build Coastguard Workerimport typing
16*523fa7a6SAndroid Build Coastguard Workerfrom typing import Dict, List, Optional, Tuple, Union
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir.schema as schema
19*523fa7a6SAndroid Build Coastguard Workerimport torch
20*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.error import internal_assert
21*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.schema import ScalarType, TensorShapeDynamism
22*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.sym_util import eval_shape
23*523fa7a6SAndroid Build Coastguard Worker
24*523fa7a6SAndroid Build Coastguard Worker
25*523fa7a6SAndroid Build Coastguard Workerclass AddressSpaceOverflowException(Exception):
26*523fa7a6SAndroid Build Coastguard Worker    pass
27*523fa7a6SAndroid Build Coastguard Worker
28*523fa7a6SAndroid Build Coastguard Worker
29*523fa7a6SAndroid Build Coastguard Workerdef num_bytes_from_shape_and_dtype(shape: torch.Size, dtype: torch.dtype) -> int:
30*523fa7a6SAndroid Build Coastguard Worker    """
31*523fa7a6SAndroid Build Coastguard Worker    Assume the tensor is a contiguous one.
32*523fa7a6SAndroid Build Coastguard Worker    """
33*523fa7a6SAndroid Build Coastguard Worker
34*523fa7a6SAndroid Build Coastguard Worker    return math.prod(shape) * torch._utils._element_size(dtype)
35*523fa7a6SAndroid Build Coastguard Worker
36*523fa7a6SAndroid Build Coastguard Worker
37*523fa7a6SAndroid Build Coastguard Workerdef contiguous_stride_from_shape(shape: torch.Size) -> Tuple[int]:
38*523fa7a6SAndroid Build Coastguard Worker    strides = []
39*523fa7a6SAndroid Build Coastguard Worker    accum = 1
40*523fa7a6SAndroid Build Coastguard Worker    for sz in reversed(shape):
41*523fa7a6SAndroid Build Coastguard Worker        strides.append(accum)
42*523fa7a6SAndroid Build Coastguard Worker        # For sizes[i] == 0, treat it as 1 to be consistent with core Pytorch
43*523fa7a6SAndroid Build Coastguard Worker        # This preserves the PT equivalent behavior for dims with 0 elements
44*523fa7a6SAndroid Build Coastguard Worker        if isinstance(sz, int):
45*523fa7a6SAndroid Build Coastguard Worker            if sz != 0:
46*523fa7a6SAndroid Build Coastguard Worker                accum *= sz
47*523fa7a6SAndroid Build Coastguard Worker        else:
48*523fa7a6SAndroid Build Coastguard Worker            # Unbacked symints may error on the != 0 check
49*523fa7a6SAndroid Build Coastguard Worker            accum *= sz
50*523fa7a6SAndroid Build Coastguard Worker    return tuple(reversed(strides))
51*523fa7a6SAndroid Build Coastguard Worker
52*523fa7a6SAndroid Build Coastguard Worker
53*523fa7a6SAndroid Build Coastguard Workerdef dim_order_from_stride(stride: Tuple[int]) -> Tuple[bytes]:
54*523fa7a6SAndroid Build Coastguard Worker    """
55*523fa7a6SAndroid Build Coastguard Worker    Dimension order represents how dimensions are laid out in memory,
56*523fa7a6SAndroid Build Coastguard Worker    starting from the outer-most to the inner-most dimension.
57*523fa7a6SAndroid Build Coastguard Worker    Thus, the conversion from strides is done by sorting the strides
58*523fa7a6SAndroid Build Coastguard Worker    from larger to smaller since the dimension with the largest stride
59*523fa7a6SAndroid Build Coastguard Worker    is the outer-most and the dimension with the smallest stride is the inner-most.
60*523fa7a6SAndroid Build Coastguard Worker    For example, tensor with sizes = (3, 5, 2) and strides = (5, 1, 15), implies
61*523fa7a6SAndroid Build Coastguard Worker    dimension order of (2, 0, 1). Dimension order of (2, 0, 1) can be obtained
62*523fa7a6SAndroid Build Coastguard Worker    by sorting strides from large to smaller.
63*523fa7a6SAndroid Build Coastguard Worker
64*523fa7a6SAndroid Build Coastguard Worker    When strides do not convey dimension order unambiguously, dimension order
65*523fa7a6SAndroid Build Coastguard Worker    returned is dependent on stability of sort. In python same key elements are kept
66*523fa7a6SAndroid Build Coastguard Worker    in original order. Thus when strides = (4, 3, 1, 1) returned value is (0, 1, 2, 3)
67*523fa7a6SAndroid Build Coastguard Worker    Another example is: sizes = (1, 3, 1, 1) with strides = (3, 1, 3, 3), returned
68*523fa7a6SAndroid Build Coastguard Worker    value is (0, 2, 3, 1)
69*523fa7a6SAndroid Build Coastguard Worker    """
70*523fa7a6SAndroid Build Coastguard Worker    for _, s in enumerate(stride):
71*523fa7a6SAndroid Build Coastguard Worker        if s == 0:
72*523fa7a6SAndroid Build Coastguard Worker            raise ValueError("0 in strides is not supported for ExecuTorch.")
73*523fa7a6SAndroid Build Coastguard Worker    sorted_dims = [
74*523fa7a6SAndroid Build Coastguard Worker        i[0] for i in sorted(enumerate(stride), key=lambda x: x[1], reverse=True)
75*523fa7a6SAndroid Build Coastguard Worker    ]
76*523fa7a6SAndroid Build Coastguard Worker    return tuple(typing.cast(Tuple[bytes], sorted_dims))
77*523fa7a6SAndroid Build Coastguard Worker
78*523fa7a6SAndroid Build Coastguard Worker
79*523fa7a6SAndroid Build Coastguard Workerdef stride_from_dim_order(sizes: List[int], dim_order: List[bytes]) -> List[int]:
80*523fa7a6SAndroid Build Coastguard Worker    """
81*523fa7a6SAndroid Build Coastguard Worker    Converts dim order to stride using sizes
82*523fa7a6SAndroid Build Coastguard Worker    e.g. if sizes = (2, 3, 4) and dim_order = (0, 1, 2) then strides = (12, 4, 1)
83*523fa7a6SAndroid Build Coastguard Worker    while for the same size if dim_order = (0, 2, 1) then strides = (12, 1, 3)
84*523fa7a6SAndroid Build Coastguard Worker    See executorch/runtime/core/exec_aten/util/dim_order_util.h for details
85*523fa7a6SAndroid Build Coastguard Worker    Args:
86*523fa7a6SAndroid Build Coastguard Worker        sizes (Tuple[int]): sizes of the tensor
87*523fa7a6SAndroid Build Coastguard Worker        dim_order (Tuple[bytes]): dim order of the tensor
88*523fa7a6SAndroid Build Coastguard Worker    Returns:
89*523fa7a6SAndroid Build Coastguard Worker        Tuple[int]: stride
90*523fa7a6SAndroid Build Coastguard Worker    """
91*523fa7a6SAndroid Build Coastguard Worker    if len(sizes) == 0:
92*523fa7a6SAndroid Build Coastguard Worker        return []
93*523fa7a6SAndroid Build Coastguard Worker    strides = copy.deepcopy(sizes)
94*523fa7a6SAndroid Build Coastguard Worker    ndim = len(sizes)
95*523fa7a6SAndroid Build Coastguard Worker    strides[dim_order[ndim - 1]] = 1
96*523fa7a6SAndroid Build Coastguard Worker    for i in range(ndim - 2, -1, -1):
97*523fa7a6SAndroid Build Coastguard Worker        if sizes[dim_order[i + 1]] == 0:
98*523fa7a6SAndroid Build Coastguard Worker            strides[dim_order[i]] = strides[dim_order[i + 1]]
99*523fa7a6SAndroid Build Coastguard Worker        else:
100*523fa7a6SAndroid Build Coastguard Worker            strides[dim_order[i]] = sizes[dim_order[i + 1]] * strides[dim_order[i + 1]]
101*523fa7a6SAndroid Build Coastguard Worker    return strides
102*523fa7a6SAndroid Build Coastguard Worker
103*523fa7a6SAndroid Build Coastguard Worker
104*523fa7a6SAndroid Build Coastguard Workerdef calculate_aligned_num_bytes(num: int, alignment: int) -> int:
105*523fa7a6SAndroid Build Coastguard Worker    return math.ceil(num / alignment) * alignment
106*523fa7a6SAndroid Build Coastguard Worker
107*523fa7a6SAndroid Build Coastguard Worker
108*523fa7a6SAndroid Build Coastguard Workerdef determine_tensor_dynanism(shape: torch.Size) -> TensorShapeDynamism:
109*523fa7a6SAndroid Build Coastguard Worker    if all(isinstance(s, int) for s in shape):
110*523fa7a6SAndroid Build Coastguard Worker        return TensorShapeDynamism.STATIC
111*523fa7a6SAndroid Build Coastguard Worker    else:
112*523fa7a6SAndroid Build Coastguard Worker        try:
113*523fa7a6SAndroid Build Coastguard Worker            _ = eval_shape(shape)
114*523fa7a6SAndroid Build Coastguard Worker            return TensorShapeDynamism.DYNAMIC_BOUND
115*523fa7a6SAndroid Build Coastguard Worker        except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
116*523fa7a6SAndroid Build Coastguard Worker            return TensorShapeDynamism.DYNAMIC_UNBOUND
117*523fa7a6SAndroid Build Coastguard Worker
118*523fa7a6SAndroid Build Coastguard Worker
119*523fa7a6SAndroid Build Coastguard WorkerALIGNMENT = 16
120*523fa7a6SAndroid Build Coastguard Worker
121*523fa7a6SAndroid Build Coastguard Worker
122*523fa7a6SAndroid Build Coastguard Workerclass TensorSpec:
123*523fa7a6SAndroid Build Coastguard Worker    """
124*523fa7a6SAndroid Build Coastguard Worker    Captures the metadata for a given Tensor (ex. scalar type, storage, etc.).
125*523fa7a6SAndroid Build Coastguard Worker    """
126*523fa7a6SAndroid Build Coastguard Worker
127*523fa7a6SAndroid Build Coastguard Worker    def __init__(
128*523fa7a6SAndroid Build Coastguard Worker        self,
129*523fa7a6SAndroid Build Coastguard Worker        dtype: torch.dtype,
130*523fa7a6SAndroid Build Coastguard Worker        shape: torch.Size,
131*523fa7a6SAndroid Build Coastguard Worker        layout: torch.layout = torch.strided,
132*523fa7a6SAndroid Build Coastguard Worker        is_sparse: bool = False,
133*523fa7a6SAndroid Build Coastguard Worker        const: bool = False,
134*523fa7a6SAndroid Build Coastguard Worker        requires_grad: bool = False,
135*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
136*523fa7a6SAndroid Build Coastguard Worker        self.scalar_type = dtype
137*523fa7a6SAndroid Build Coastguard Worker        self.const = const
138*523fa7a6SAndroid Build Coastguard Worker        self.alignment: int = ALIGNMENT
139*523fa7a6SAndroid Build Coastguard Worker        self.storage: Optional[torch.UntypedStorage] = None
140*523fa7a6SAndroid Build Coastguard Worker        # convert to list making it easier to handle type checking
141*523fa7a6SAndroid Build Coastguard Worker        self.shape: List[int] = list(shape)
142*523fa7a6SAndroid Build Coastguard Worker        self.stride: Tuple[int] = contiguous_stride_from_shape(shape)
143*523fa7a6SAndroid Build Coastguard Worker        self.dim_order: Tuple[bytes] = dim_order_from_stride(self.stride)
144*523fa7a6SAndroid Build Coastguard Worker        self.requires_grad = requires_grad
145*523fa7a6SAndroid Build Coastguard Worker        self.layout = layout
146*523fa7a6SAndroid Build Coastguard Worker        self.is_sparse = is_sparse
147*523fa7a6SAndroid Build Coastguard Worker        self.init_mem_planning_fields()
148*523fa7a6SAndroid Build Coastguard Worker        self.shape_dynamism: TensorShapeDynamism = determine_tensor_dynanism(self.shape)
149*523fa7a6SAndroid Build Coastguard Worker
150*523fa7a6SAndroid Build Coastguard Worker    @property
151*523fa7a6SAndroid Build Coastguard Worker    def allocated_memory(self) -> int:
152*523fa7a6SAndroid Build Coastguard Worker        nbytes = num_bytes_from_shape_and_dtype(self.shape, self.dtype)
153*523fa7a6SAndroid Build Coastguard Worker        return calculate_aligned_num_bytes(nbytes, self.alignment)
154*523fa7a6SAndroid Build Coastguard Worker
155*523fa7a6SAndroid Build Coastguard Worker    def realign(self, new_alignment: int) -> int:
156*523fa7a6SAndroid Build Coastguard Worker        self.alignment = new_alignment
157*523fa7a6SAndroid Build Coastguard Worker        return self.allocated_memory
158*523fa7a6SAndroid Build Coastguard Worker
159*523fa7a6SAndroid Build Coastguard Worker    def nbytes(self) -> int:
160*523fa7a6SAndroid Build Coastguard Worker        return num_bytes_from_shape_and_dtype(self.shape, self.dtype)
161*523fa7a6SAndroid Build Coastguard Worker
162*523fa7a6SAndroid Build Coastguard Worker    @classmethod
163*523fa7a6SAndroid Build Coastguard Worker    def from_tensor(cls, tensor: torch.Tensor, const: bool = False) -> TensorSpec:
164*523fa7a6SAndroid Build Coastguard Worker        if const:
165*523fa7a6SAndroid Build Coastguard Worker            # for non-contigous tensors, convert to a contiguous one
166*523fa7a6SAndroid Build Coastguard Worker            tensor = tensor.contiguous()
167*523fa7a6SAndroid Build Coastguard Worker            # Weights cannot be views during emission or serialization
168*523fa7a6SAndroid Build Coastguard Worker            if tensor.nbytes != tensor.untyped_storage().nbytes():
169*523fa7a6SAndroid Build Coastguard Worker                tensor = tensor.clone()
170*523fa7a6SAndroid Build Coastguard Worker
171*523fa7a6SAndroid Build Coastguard Worker        spec = cls(
172*523fa7a6SAndroid Build Coastguard Worker            dtype=tensor.dtype,
173*523fa7a6SAndroid Build Coastguard Worker            shape=tensor.shape,
174*523fa7a6SAndroid Build Coastguard Worker            layout=tensor.layout,
175*523fa7a6SAndroid Build Coastguard Worker            const=const,
176*523fa7a6SAndroid Build Coastguard Worker            is_sparse=tensor.is_sparse,
177*523fa7a6SAndroid Build Coastguard Worker        )
178*523fa7a6SAndroid Build Coastguard Worker        spec.stride = tensor.stride()
179*523fa7a6SAndroid Build Coastguard Worker        spec.dim_order = dim_order_from_stride(spec.stride)
180*523fa7a6SAndroid Build Coastguard Worker        spec.requires_grad = tensor.requires_grad
181*523fa7a6SAndroid Build Coastguard Worker        spec.storage = tensor.untyped_storage() if const else None
182*523fa7a6SAndroid Build Coastguard Worker
183*523fa7a6SAndroid Build Coastguard Worker        return spec
184*523fa7a6SAndroid Build Coastguard Worker
185*523fa7a6SAndroid Build Coastguard Worker    def init_mem_planning_fields(self) -> None:
186*523fa7a6SAndroid Build Coastguard Worker        self.lifetime = [None, None]
187*523fa7a6SAndroid Build Coastguard Worker        self.mem_id = None
188*523fa7a6SAndroid Build Coastguard Worker        self.mem_obj_id = None
189*523fa7a6SAndroid Build Coastguard Worker        self.mem_offset = None
190*523fa7a6SAndroid Build Coastguard Worker
191*523fa7a6SAndroid Build Coastguard Worker    @property
192*523fa7a6SAndroid Build Coastguard Worker    def dtype(self) -> torch.dtype:
193*523fa7a6SAndroid Build Coastguard Worker        return self.scalar_type
194*523fa7a6SAndroid Build Coastguard Worker
195*523fa7a6SAndroid Build Coastguard Worker    @property
196*523fa7a6SAndroid Build Coastguard Worker    def is_dynamic_shape_tensor(self) -> bool:
197*523fa7a6SAndroid Build Coastguard Worker        return self.shape_dynamism != schema.TensorShapeDynamism.STATIC
198*523fa7a6SAndroid Build Coastguard Worker
199*523fa7a6SAndroid Build Coastguard Worker    @property
200*523fa7a6SAndroid Build Coastguard Worker    def is_static_shape_tensor(self) -> bool:
201*523fa7a6SAndroid Build Coastguard Worker        return self.shape_dynamism == TensorShapeDynamism.STATIC
202*523fa7a6SAndroid Build Coastguard Worker
203*523fa7a6SAndroid Build Coastguard Worker    @property
204*523fa7a6SAndroid Build Coastguard Worker    def is_upper_bound_tensor(self) -> bool:
205*523fa7a6SAndroid Build Coastguard Worker        return self.shape_dynamism == TensorShapeDynamism.DYNAMIC_BOUND
206*523fa7a6SAndroid Build Coastguard Worker
207*523fa7a6SAndroid Build Coastguard Worker    @property
208*523fa7a6SAndroid Build Coastguard Worker    def is_dynamic_unbound_tensor(self) -> bool:
209*523fa7a6SAndroid Build Coastguard Worker        return self.shape_dynamism == TensorShapeDynamism.DYNAMIC_UNBOUND
210*523fa7a6SAndroid Build Coastguard Worker
211*523fa7a6SAndroid Build Coastguard Worker    def debug(self) -> str:
212*523fa7a6SAndroid Build Coastguard Worker        return (
213*523fa7a6SAndroid Build Coastguard Worker            f"TensorSpec(id={id(self)}, const={self.const}, scalar_type={self.scalar_type}"
214*523fa7a6SAndroid Build Coastguard Worker            + f", allocated_memory={self.allocated_memory}, mem_id={self.mem_id}"
215*523fa7a6SAndroid Build Coastguard Worker            + f", mem_offset={self.mem_offset}, lifetime={self.lifetime}"
216*523fa7a6SAndroid Build Coastguard Worker            + f", shape_dynamism={self.shape_dynamism}"
217*523fa7a6SAndroid Build Coastguard Worker            + (f", shape={self.shape}")
218*523fa7a6SAndroid Build Coastguard Worker            + ")"
219*523fa7a6SAndroid Build Coastguard Worker        )
220*523fa7a6SAndroid Build Coastguard Worker
221*523fa7a6SAndroid Build Coastguard Worker    def __repr__(self) -> str:
222*523fa7a6SAndroid Build Coastguard Worker        """
223*523fa7a6SAndroid Build Coastguard Worker        Round-trippable printing function
224*523fa7a6SAndroid Build Coastguard Worker        """
225*523fa7a6SAndroid Build Coastguard Worker        return (
226*523fa7a6SAndroid Build Coastguard Worker            f"TensorSpec(dtype={self.scalar_type}, shape={self.shape}"
227*523fa7a6SAndroid Build Coastguard Worker            + f", layout={self.layout}"
228*523fa7a6SAndroid Build Coastguard Worker            + f", is_sparse={self.is_sparse}"
229*523fa7a6SAndroid Build Coastguard Worker            + f", shape_dynamism={self.shape_dynamism}"
230*523fa7a6SAndroid Build Coastguard Worker            + f", const={self.const}, requires_grad={self.requires_grad}"
231*523fa7a6SAndroid Build Coastguard Worker            + ")"
232*523fa7a6SAndroid Build Coastguard Worker        )
233*523fa7a6SAndroid Build Coastguard Worker
234*523fa7a6SAndroid Build Coastguard Worker
235*523fa7a6SAndroid Build Coastguard Workerdef memory_format_enum(memory_format: torch.memory_format) -> int:
236*523fa7a6SAndroid Build Coastguard Worker    internal_assert(
237*523fa7a6SAndroid Build Coastguard Worker        isinstance(memory_format, torch.memory_format),
238*523fa7a6SAndroid Build Coastguard Worker        "We only support torch.memory_format",
239*523fa7a6SAndroid Build Coastguard Worker    )
240*523fa7a6SAndroid Build Coastguard Worker    table = {
241*523fa7a6SAndroid Build Coastguard Worker        torch.contiguous_format: 0,
242*523fa7a6SAndroid Build Coastguard Worker        torch.preserve_format: 1,
243*523fa7a6SAndroid Build Coastguard Worker    }
244*523fa7a6SAndroid Build Coastguard Worker    return table[memory_format]
245*523fa7a6SAndroid Build Coastguard Worker
246*523fa7a6SAndroid Build Coastguard Worker
247*523fa7a6SAndroid Build Coastguard Workerscalar_type_table: Dict[torch.dtype, ScalarType] = {
248*523fa7a6SAndroid Build Coastguard Worker    torch.uint8: ScalarType.BYTE,
249*523fa7a6SAndroid Build Coastguard Worker    torch.int8: ScalarType.CHAR,
250*523fa7a6SAndroid Build Coastguard Worker    torch.int16: ScalarType.SHORT,
251*523fa7a6SAndroid Build Coastguard Worker    torch.int32: ScalarType.INT,
252*523fa7a6SAndroid Build Coastguard Worker    torch.int64: ScalarType.LONG,
253*523fa7a6SAndroid Build Coastguard Worker    torch.half: ScalarType.HALF,
254*523fa7a6SAndroid Build Coastguard Worker    torch.float: ScalarType.FLOAT,
255*523fa7a6SAndroid Build Coastguard Worker    torch.double: ScalarType.DOUBLE,
256*523fa7a6SAndroid Build Coastguard Worker    torch.complex32: ScalarType.COMPLEX32,
257*523fa7a6SAndroid Build Coastguard Worker    torch.complex64: ScalarType.COMPLEX64,
258*523fa7a6SAndroid Build Coastguard Worker    torch.complex128: ScalarType.COMPLEX128,
259*523fa7a6SAndroid Build Coastguard Worker    torch.bool: ScalarType.BOOL,
260*523fa7a6SAndroid Build Coastguard Worker    torch.qint8: ScalarType.QINT8,
261*523fa7a6SAndroid Build Coastguard Worker    torch.quint8: ScalarType.QUINT8,
262*523fa7a6SAndroid Build Coastguard Worker    torch.qint32: ScalarType.QINT32,
263*523fa7a6SAndroid Build Coastguard Worker    torch.bfloat16: ScalarType.BFLOAT16,
264*523fa7a6SAndroid Build Coastguard Worker    torch.quint4x2: ScalarType.QUINT4x2,
265*523fa7a6SAndroid Build Coastguard Worker    torch.uint16: ScalarType.UINT16,
266*523fa7a6SAndroid Build Coastguard Worker}
267*523fa7a6SAndroid Build Coastguard Worker
268*523fa7a6SAndroid Build Coastguard Worker
269*523fa7a6SAndroid Build Coastguard Workerenum_to_scalar_map: Dict[ScalarType, torch.dtype] = {
270*523fa7a6SAndroid Build Coastguard Worker    scalar_type_table[key]: key for key in scalar_type_table
271*523fa7a6SAndroid Build Coastguard Worker}
272*523fa7a6SAndroid Build Coastguard Worker
273*523fa7a6SAndroid Build Coastguard Worker
274*523fa7a6SAndroid Build Coastguard Workerdef scalar_type_enum(dtype: torch.dtype) -> ScalarType:
275*523fa7a6SAndroid Build Coastguard Worker    # TODO (zhengxu) single source of truth from c10/core/ScalarType.h.
276*523fa7a6SAndroid Build Coastguard Worker    internal_assert(
277*523fa7a6SAndroid Build Coastguard Worker        isinstance(dtype, torch.dtype), "We only support dtypes defined in Pytorch Core"
278*523fa7a6SAndroid Build Coastguard Worker    )
279*523fa7a6SAndroid Build Coastguard Worker    return scalar_type_table[dtype]
280*523fa7a6SAndroid Build Coastguard Worker
281*523fa7a6SAndroid Build Coastguard Worker
282*523fa7a6SAndroid Build Coastguard Workerdef get_scalar_type(enum: ScalarType) -> torch.dtype:
283*523fa7a6SAndroid Build Coastguard Worker    return enum_to_scalar_map[enum]
284*523fa7a6SAndroid Build Coastguard Worker
285*523fa7a6SAndroid Build Coastguard Worker
286*523fa7a6SAndroid Build Coastguard Workerdef layout_enum(layout: torch.layout) -> int:
287*523fa7a6SAndroid Build Coastguard Worker    # TODO single source of truth.
288*523fa7a6SAndroid Build Coastguard Worker    table = {
289*523fa7a6SAndroid Build Coastguard Worker        torch.strided: 0,
290*523fa7a6SAndroid Build Coastguard Worker        torch.sparse_coo: 1,
291*523fa7a6SAndroid Build Coastguard Worker    }
292*523fa7a6SAndroid Build Coastguard Worker    return table[layout]
293*523fa7a6SAndroid Build Coastguard Worker
294*523fa7a6SAndroid Build Coastguard Worker
295*523fa7a6SAndroid Build Coastguard Workerdef make_allocation_info(mem_id: int, mem_offset: int) -> schema.AllocationDetails:
296*523fa7a6SAndroid Build Coastguard Worker    """
297*523fa7a6SAndroid Build Coastguard Worker    Creates the allocation_details object for creating tensors
298*523fa7a6SAndroid Build Coastguard Worker    """
299*523fa7a6SAndroid Build Coastguard Worker    if mem_offset < 0:
300*523fa7a6SAndroid Build Coastguard Worker        raise ValueError(f"mem_offset {mem_offset} must not be negative")
301*523fa7a6SAndroid Build Coastguard Worker    memory_offset_low = mem_offset & ((1 << 32) - 1)
302*523fa7a6SAndroid Build Coastguard Worker    memory_offset_high = mem_offset >> 32
303*523fa7a6SAndroid Build Coastguard Worker    if memory_offset_high >= 1 << 32:
304*523fa7a6SAndroid Build Coastguard Worker        raise AddressSpaceOverflowException(
305*523fa7a6SAndroid Build Coastguard Worker            f"mem_offset {mem_offset} does not fit in 64 bits"
306*523fa7a6SAndroid Build Coastguard Worker        )
307*523fa7a6SAndroid Build Coastguard Worker
308*523fa7a6SAndroid Build Coastguard Worker    allocation_info = schema.AllocationDetails(
309*523fa7a6SAndroid Build Coastguard Worker        memory_id=mem_id,
310*523fa7a6SAndroid Build Coastguard Worker        memory_offset_low=memory_offset_low,
311*523fa7a6SAndroid Build Coastguard Worker        memory_offset_high=memory_offset_high,
312*523fa7a6SAndroid Build Coastguard Worker    )
313*523fa7a6SAndroid Build Coastguard Worker    return allocation_info
314*523fa7a6SAndroid Build Coastguard Worker
315*523fa7a6SAndroid Build Coastguard Worker
316*523fa7a6SAndroid Build Coastguard Workerdef make_tensor_value(
317*523fa7a6SAndroid Build Coastguard Worker    data_buffer_idx: int,
318*523fa7a6SAndroid Build Coastguard Worker    allocation_info: Optional[schema.AllocationDetails],
319*523fa7a6SAndroid Build Coastguard Worker    spec: TensorSpec,
320*523fa7a6SAndroid Build Coastguard Worker) -> schema.Tensor:
321*523fa7a6SAndroid Build Coastguard Worker    """
322*523fa7a6SAndroid Build Coastguard Worker    Converts the normal torch tensor to a flatbuffer tensor.
323*523fa7a6SAndroid Build Coastguard Worker    """
324*523fa7a6SAndroid Build Coastguard Worker
325*523fa7a6SAndroid Build Coastguard Worker    def to_list(
326*523fa7a6SAndroid Build Coastguard Worker        x: Union[torch.Size, int, List[int], Tuple[int]]
327*523fa7a6SAndroid Build Coastguard Worker    ) -> Union[List[int], List[torch.Size]]:
328*523fa7a6SAndroid Build Coastguard Worker        if isinstance(x, torch.Size) or isinstance(x, tuple):
329*523fa7a6SAndroid Build Coastguard Worker            return list(x)
330*523fa7a6SAndroid Build Coastguard Worker        elif isinstance(x, int):
331*523fa7a6SAndroid Build Coastguard Worker            return [x]
332*523fa7a6SAndroid Build Coastguard Worker        else:
333*523fa7a6SAndroid Build Coastguard Worker            return x
334*523fa7a6SAndroid Build Coastguard Worker
335*523fa7a6SAndroid Build Coastguard Worker    tensor_size = to_list(spec.shape)
336*523fa7a6SAndroid Build Coastguard Worker    tensor_dim_order = to_list(spec.dim_order)
337*523fa7a6SAndroid Build Coastguard Worker
338*523fa7a6SAndroid Build Coastguard Worker    flatbuffer_tensor = schema.Tensor(
339*523fa7a6SAndroid Build Coastguard Worker        scalar_type=scalar_type_enum(spec.scalar_type),
340*523fa7a6SAndroid Build Coastguard Worker        # The runtime currently only supports tensors with offsets of zero.
341*523fa7a6SAndroid Build Coastguard Worker        storage_offset=0,
342*523fa7a6SAndroid Build Coastguard Worker        sizes=tensor_size,
343*523fa7a6SAndroid Build Coastguard Worker        dim_order=tensor_dim_order,
344*523fa7a6SAndroid Build Coastguard Worker        requires_grad=spec.requires_grad,
345*523fa7a6SAndroid Build Coastguard Worker        data_buffer_idx=data_buffer_idx,
346*523fa7a6SAndroid Build Coastguard Worker        allocation_info=allocation_info,
347*523fa7a6SAndroid Build Coastguard Worker        layout=layout_enum(spec.layout),
348*523fa7a6SAndroid Build Coastguard Worker        shape_dynamism=spec.shape_dynamism,
349*523fa7a6SAndroid Build Coastguard Worker    )
350*523fa7a6SAndroid Build Coastguard Worker    return flatbuffer_tensor
351*523fa7a6SAndroid Build Coastguard Worker
352*523fa7a6SAndroid Build Coastguard Worker
353*523fa7a6SAndroid Build Coastguard Workerdef check_spec(tensor: torch.Tensor, spec: TensorSpec) -> None:
354*523fa7a6SAndroid Build Coastguard Worker    internal_assert(
355*523fa7a6SAndroid Build Coastguard Worker        tensor.is_sparse == spec.is_sparse,
356*523fa7a6SAndroid Build Coastguard Worker        f"Tensor attribute 'is_sparse' is expected to be equal to '{spec.is_sparse}', actually got: '{tensor.is_sparse}'",
357*523fa7a6SAndroid Build Coastguard Worker    )
358*523fa7a6SAndroid Build Coastguard Worker    internal_assert(
359*523fa7a6SAndroid Build Coastguard Worker        tensor.shape == spec.shape,
360*523fa7a6SAndroid Build Coastguard Worker        f"Tensor attribute 'shape' is expected to be equal to '{spec.shape}', actually got: '{tensor.shape}'",
361*523fa7a6SAndroid Build Coastguard Worker    )
362*523fa7a6SAndroid Build Coastguard Worker    internal_assert(
363*523fa7a6SAndroid Build Coastguard Worker        tensor.dtype == spec.dtype,
364*523fa7a6SAndroid Build Coastguard Worker        f"Tensor attribute 'dtype' is expected to be equal to '{spec.dtype}', actually got: '{tensor.dtype}'",
365*523fa7a6SAndroid Build Coastguard Worker    )
366