1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7 8import logging 9import typing 10from typing import Callable, Union 11 12import numpy as np 13import torch 14 15 16# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. 17def distance(fn: Callable[[np.ndarray, np.ndarray], float]) -> Callable[ 18 [ 19 # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. 20 typing.Union[np.ndarray, torch._tensor.Tensor], 21 # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. 22 typing.Union[np.ndarray, torch._tensor.Tensor], 23 ], 24 float, 25]: 26 # A distance decorator that performs all the necessary checkes before calculating 27 # the distance between two N-D tensors given a function. This can be a RMS 28 # function, maximum abs diff, or any kind of distance function. 29 def wrapper( 30 # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. 31 a: Union[np.ndarray, torch.Tensor], 32 # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. 33 b: Union[np.ndarray, torch.Tensor], 34 ) -> float: 35 # convert a and b to np.ndarray type fp64 36 a = to_np_arr_fp64(a) 37 b = to_np_arr_fp64(b) 38 39 # return NaN if shape mismatches 40 if a.shape != b.shape: 41 return np.nan 42 43 # After we make sure shape matches, check if it's empty. If yes, return 0 44 if a.size == 0: 45 return 0 46 47 # np.isinf and np.isnan returns a Boolean mask. Check if Inf or NaN occur at 48 # the same places in a and b. If not, return NaN 49 if np.any(np.isinf(a) != np.isinf(b)) or np.any(np.isnan(a) != np.isnan(b)): 50 return np.nan 51 52 # mask out all the values that are either Inf or NaN 53 mask = np.isinf(a) | np.isnan(a) 54 if np.any(mask): 55 logging.warning("Found inf/nan in tensor when calculating the distance") 56 57 a_masked = a[~mask] 58 b_masked = b[~mask] 59 60 # after masking, the resulting tensor might be empty. If yes, return 0 61 if a_masked.size == 0: 62 return 0 63 64 # only compare the rest (those that are actually numbers) using the metric 65 return fn(a_masked, b_masked) 66 67 return wrapper 68 69 70@distance 71# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. 72def rms(a: np.ndarray, b: np.ndarray) -> float: 73 return ((a - b) ** 2).mean() ** 0.5 74 75 76@distance 77# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. 78def max_abs_diff(a: np.ndarray, b: np.ndarray) -> float: 79 return np.abs(a - b).max() 80 81 82@distance 83# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. 84def max_rel_diff(x: np.ndarray, x_ref: np.ndarray) -> float: 85 return np.abs((x - x_ref) / x_ref).max() 86 87 88# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. 89def to_np_arr_fp64(x: Union[np.ndarray, torch.Tensor]) -> np.ndarray: 90 if isinstance(x, torch.Tensor): 91 x = x.detach().cpu().numpy() 92 if isinstance(x, np.ndarray): 93 x = x.astype(np.float64) 94 return x 95 96 97# pyre-fixme[3]: Return type must be annotated. 98def normalized_rms( 99 # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. 100 predicted: Union[np.ndarray, torch.Tensor], 101 # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. 102 ground_truth: Union[np.ndarray, torch.Tensor], 103): 104 num = rms(predicted, ground_truth) 105 if num == 0: 106 return 0 107 den = np.linalg.norm(to_np_arr_fp64(ground_truth)) 108 return np.float64(num) / np.float64(den) 109