1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import List, Optional 8 9import torch 10 11""" 12Set of simple utilities for translating between torch.memory_format and dim_order 13""" 14 15 16def _get_contiguous_dim_order(ndim: int) -> List[int]: 17 if ndim < 0: 18 raise AssertionError( 19 f"Unsupported rank for contiguous dim order. Only supports ndim greater than or equal to 0, but got {ndim}" 20 ) 21 22 return list(range(ndim)) 23 24 25def _get_channels_last_dim_order(ndim: int) -> List[int]: 26 if ndim == 4: 27 return [0, 2, 3, 1] 28 29 raise AssertionError( 30 f"Unsupported rank for channels last dim order. Only support ndim equal to 4, but got {ndim}" 31 ) 32 33 34def get_memory_format(dim_order: Optional[List[int]]) -> torch.memory_format: 35 """ 36 Given a dim_order try to map it to torch.memory_format 37 """ 38 if dim_order is None: 39 return torch.preserve_format 40 elif dim_order == _get_contiguous_dim_order(len(dim_order)): 41 return torch.contiguous_format 42 elif len(dim_order) == 4 and dim_order == _get_channels_last_dim_order( 43 len(dim_order) 44 ): 45 return torch.channels_last 46 47 raise AssertionError( 48 f"Failed to map a given dim_order: {dim_order} to a torch.memory_format" 49 ) 50 51 52def get_dim_order( 53 memory_format: Optional[torch.memory_format], ndim: int 54) -> Optional[List[int]]: 55 """ 56 Given a memory_format and a tensor rank, generate a dim_order 57 """ 58 if memory_format in [None, torch.preserve_format]: 59 return None 60 elif memory_format == torch.contiguous_format: 61 return _get_contiguous_dim_order(ndim) 62 elif memory_format == torch.channels_last: 63 return _get_channels_last_dim_order(ndim) 64 65 raise AssertionError( 66 f"Failed to generate dim_order for a given memory format: {memory_format}" 67 ) 68 69 70def is_channel_last_dim_order(tensor: torch.Tensor) -> bool: 71 """ 72 Check if a tensor has channels last dim order 73 """ 74 if tensor.dim() != 4: 75 # Only support 4D tensors for channel list memory format. 76 return False 77 78 return tensor.dim_order() == tuple(_get_channels_last_dim_order(tensor.dim())) 79 80 81def is_contiguous_dim_order(tensor: torch.Tensor) -> bool: 82 """ 83 Check if a tensor has contiguous dim order 84 """ 85 return tensor.dim_order() == tuple(_get_contiguous_dim_order(tensor.dim())) 86