# Copyright 2023-2024 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-unsafe # # PyTorch to Tosa mapping - simple mapping functions and multi-type extraction # of key information. These are used by the initial compile stage which captures # the standardised TOSA representation. # import serializer.tosa_serializer as ts import torch UNSUPPORTED_DTYPES = ( torch.float64, torch.double, torch.complex64, torch.cfloat, torch.complex128, torch.cdouble, torch.uint8, torch.int64, torch.long, ) DTYPE_MAP = { torch.float32: ts.DType.FP32, torch.float: ts.DType.FP32, torch.float16: ts.DType.FP16, torch.half: ts.DType.FP16, torch.bfloat16: ts.DType.BF16, torch.int8: ts.DType.INT8, torch.int16: ts.DType.INT16, torch.short: ts.DType.INT16, torch.int32: ts.DType.INT32, torch.int: ts.DType.INT32, torch.bool: ts.DType.BOOL, } def map_dtype(data_type): assert data_type not in UNSUPPORTED_DTYPES, f"Unsupported type: {data_type}" assert data_type in DTYPE_MAP, f"Unknown type: {data_type}" return DTYPE_MAP[data_type] # Returns the shape and type of a node # TODO: other types, can be # SymInt, FakeTensor, a List[Union[FakeTensor, SymInt]], or None def extract_tensor_meta(meta): assert meta.get("val") is not None val = meta["val"] if type(val) is tuple: # TODO: should use first concrete representation val = val[0] assert torch._subclasses.fake_tensor.FakeTensor == type(val) dtype = map_dtype(val.dtype) shape = tuple(val.size()) if meta.get("tosa_dim_order") is not None: dim_order = meta["tosa_dim_order"] else: dim_order = tuple(range(len(shape))) return (dtype, shape, dim_order) # Class to capture arguments and turn into tensor references for TOSA OPs class TosaArg: def __process_node(self, argument): assert isinstance(argument, torch.fx.node.Node) self.name = argument.name self.dtype, self.shape, self.dim_order = extract_tensor_meta(argument.meta) def __process_list(self, argument): self.special = list(argument) def __process_number(self, argument): self.number = argument def __init__(self, argument) -> None: self.name = None self.dtype = None self.shape = None self.dim_order = None self.special = None if argument is None: return if isinstance(argument, torch.fx.node.Node): self.__process_node(argument) return if isinstance(argument, list): self.__process_list(argument) return if isinstance(argument, int): self.__process_number(argument) return if isinstance(argument, float): self.__process_number(argument) return RuntimeError( f"Unhandled node input argument: {argument}, of type {type(argument)}" )