xref: /aosp_15_r20/external/executorch/backends/cadence/runtime/utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7
8import logging
9import typing
10from typing import Callable, Union
11
12import numpy as np
13import torch
14
15
16# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
17def distance(fn: Callable[[np.ndarray, np.ndarray], float]) -> Callable[
18    [
19        # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
20        typing.Union[np.ndarray, torch._tensor.Tensor],
21        # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
22        typing.Union[np.ndarray, torch._tensor.Tensor],
23    ],
24    float,
25]:
26    # A distance decorator that performs all the necessary checkes before calculating
27    # the distance between two N-D tensors given a function. This can be a RMS
28    # function, maximum abs diff, or any kind of distance function.
29    def wrapper(
30        # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
31        a: Union[np.ndarray, torch.Tensor],
32        # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
33        b: Union[np.ndarray, torch.Tensor],
34    ) -> float:
35        # convert a and b to np.ndarray type fp64
36        a = to_np_arr_fp64(a)
37        b = to_np_arr_fp64(b)
38
39        # return NaN if shape mismatches
40        if a.shape != b.shape:
41            return np.nan
42
43        # After we make sure shape matches, check if it's empty. If yes, return 0
44        if a.size == 0:
45            return 0
46
47        # np.isinf and np.isnan returns a Boolean mask. Check if Inf or NaN occur at
48        # the same places in a and b. If not, return NaN
49        if np.any(np.isinf(a) != np.isinf(b)) or np.any(np.isnan(a) != np.isnan(b)):
50            return np.nan
51
52        # mask out all the values that are either Inf or NaN
53        mask = np.isinf(a) | np.isnan(a)
54        if np.any(mask):
55            logging.warning("Found inf/nan in tensor when calculating the distance")
56
57        a_masked = a[~mask]
58        b_masked = b[~mask]
59
60        # after masking, the resulting tensor might be empty. If yes, return 0
61        if a_masked.size == 0:
62            return 0
63
64        # only compare the rest (those that are actually numbers) using the metric
65        return fn(a_masked, b_masked)
66
67    return wrapper
68
69
70@distance
71# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
72def rms(a: np.ndarray, b: np.ndarray) -> float:
73    return ((a - b) ** 2).mean() ** 0.5
74
75
76@distance
77# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
78def max_abs_diff(a: np.ndarray, b: np.ndarray) -> float:
79    return np.abs(a - b).max()
80
81
82@distance
83# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
84def max_rel_diff(x: np.ndarray, x_ref: np.ndarray) -> float:
85    return np.abs((x - x_ref) / x_ref).max()
86
87
88# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
89def to_np_arr_fp64(x: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
90    if isinstance(x, torch.Tensor):
91        x = x.detach().cpu().numpy()
92    if isinstance(x, np.ndarray):
93        x = x.astype(np.float64)
94    return x
95
96
97# pyre-fixme[3]: Return type must be annotated.
98def normalized_rms(
99    # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
100    predicted: Union[np.ndarray, torch.Tensor],
101    # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
102    ground_truth: Union[np.ndarray, torch.Tensor],
103):
104    num = rms(predicted, ground_truth)
105    if num == 0:
106        return 0
107    den = np.linalg.norm(to_np_arr_fp64(ground_truth))
108    return np.float64(num) / np.float64(den)
109