xref: /aosp_15_r20/external/executorch/exir/dim_order_utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from typing import List, Optional
8
9import torch
10
11"""
12Set of simple utilities for translating between torch.memory_format and dim_order
13"""
14
15
16def _get_contiguous_dim_order(ndim: int) -> List[int]:
17    if ndim < 0:
18        raise AssertionError(
19            f"Unsupported rank for contiguous dim order. Only supports ndim greater than or equal to 0, but got {ndim}"
20        )
21
22    return list(range(ndim))
23
24
25def _get_channels_last_dim_order(ndim: int) -> List[int]:
26    if ndim == 4:
27        return [0, 2, 3, 1]
28
29    raise AssertionError(
30        f"Unsupported rank for channels last dim order. Only support ndim equal to 4, but got {ndim}"
31    )
32
33
34def get_memory_format(dim_order: Optional[List[int]]) -> torch.memory_format:
35    """
36    Given a dim_order try to map it to torch.memory_format
37    """
38    if dim_order is None:
39        return torch.preserve_format
40    elif dim_order == _get_contiguous_dim_order(len(dim_order)):
41        return torch.contiguous_format
42    elif len(dim_order) == 4 and dim_order == _get_channels_last_dim_order(
43        len(dim_order)
44    ):
45        return torch.channels_last
46
47    raise AssertionError(
48        f"Failed to map a given dim_order: {dim_order} to a torch.memory_format"
49    )
50
51
52def get_dim_order(
53    memory_format: Optional[torch.memory_format], ndim: int
54) -> Optional[List[int]]:
55    """
56    Given a memory_format and a tensor rank, generate a dim_order
57    """
58    if memory_format in [None, torch.preserve_format]:
59        return None
60    elif memory_format == torch.contiguous_format:
61        return _get_contiguous_dim_order(ndim)
62    elif memory_format == torch.channels_last:
63        return _get_channels_last_dim_order(ndim)
64
65    raise AssertionError(
66        f"Failed to generate dim_order for a given memory format: {memory_format}"
67    )
68
69
70def is_channel_last_dim_order(tensor: torch.Tensor) -> bool:
71    """
72    Check if a tensor has channels last dim order
73    """
74    if tensor.dim() != 4:
75        # Only support 4D tensors for channel list memory format.
76        return False
77
78    return tensor.dim_order() == tuple(_get_channels_last_dim_order(tensor.dim()))
79
80
81def is_contiguous_dim_order(tensor: torch.Tensor) -> bool:
82    """
83    Check if a tensor has contiguous dim order
84    """
85    return tensor.dim_order() == tuple(_get_contiguous_dim_order(tensor.dim()))
86