1import os 2import shutil 3import unittest 4import warnings 5from collections import namedtuple 6 7import torch 8import torch.testing._internal.common_nn as common_nn 9import torch.utils.cpp_extension 10from torch.testing._internal.common_cuda import TEST_CUDA 11 12 13# Note that this namedtuple is for C++ parity test mechanism's internal use. 14# For guidance on how to add a new C++ parity test, please see 15# NOTE [How to check NN module / functional API parity between Python and C++ frontends] 16TorchNNModuleTestParams = namedtuple( 17 "TorchNNModuleTestParams", 18 [ 19 # NN module name (e.g. "BCELoss") 20 "module_name", 21 # Unique identifier for this module config (e.g. "BCELoss_weights_cuda") 22 "module_variant_name", 23 # An instance of an NN test class (e.g. `CriterionTest`) which stores 24 # necessary information (e.g. input / target / extra_args) for running the Python test 25 "test_instance", 26 # Constructor arguments passed to the C++ module constructor, which must be 27 # strictly equivalent to the Python module constructor arguments 28 # (e.g. `torch::nn::BCELossOptions().weight(torch::rand(10))`, 29 # which is strictly equivalent to passing `torch.rand(10)` to `torch.nn.BCELoss` 30 # constructor in Python) 31 "cpp_constructor_args", 32 # All arguments used in NN module's forward pass. 33 # Please see `compute_arg_dict` function for details on how we construct this dict. 34 # (e.g. 35 # ``` 36 # arg_dict = { 37 # 'input': [python_input_tensor], 38 # 'target': [python_target_tensor], 39 # 'extra_args': [], 40 # 'other': [], 41 # } 42 # ``` 43 # ) 44 "arg_dict", 45 # Whether we expect this NN module test to pass the Python/C++ parity test 46 # (e.g. `True`) 47 "has_parity", 48 # Device (e.g. "cuda") 49 "device", 50 # Temporary folder to store C++ outputs (to be compared with Python outputs later) 51 "cpp_tmp_folder", 52 ], 53) 54 55# Note that this namedtuple is for C++ parity test mechanism's internal use. 56# For guidance on how to add a new C++ parity test, please see 57# NOTE [How to check NN module / functional API parity between Python and C++ frontends] 58TorchNNFunctionalTestParams = namedtuple( 59 "TorchNNFunctionalTestParams", 60 [ 61 # NN functional name (e.g. "binary_cross_entropy") 62 "functional_name", 63 # Unique identifier for this functional config (e.g. "BCELoss_no_reduce_cuda") 64 "functional_variant_name", 65 # An instance of an NN test class (e.g. `NewModuleTest`) which stores 66 # necessary information (e.g. input / target / extra_args) for running the Python test 67 "test_instance", 68 # The C++ function call that is strictly equivalent to the Python function call 69 # (e.g. "F::binary_cross_entropy( 70 # i, t.to(i.options()),F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))", 71 # which is strictly equivalent to `F.binary_cross_entropy(i, t.type_as(i), reduction='none')` in Python) 72 "cpp_function_call", 73 # All arguments used in NN functional's function call. 74 # Please see `compute_arg_dict` function for details on how we construct this dict. 75 # (e.g. 76 # ``` 77 # arg_dict = { 78 # 'input': [python_input_tensor], 79 # 'target': [python_target_tensor], 80 # 'extra_args': [], 81 # 'other': [], 82 # } 83 # ``` 84 # ) 85 "arg_dict", 86 # Whether we expect this NN functional test to pass the Python/C++ parity test 87 # (e.g. `True`) 88 "has_parity", 89 # Device (e.g. "cuda") 90 "device", 91 # Temporary folder to store C++ outputs (to be compared with Python outputs later) 92 "cpp_tmp_folder", 93 ], 94) 95 96CppArg = namedtuple("CppArg", ["name", "value"]) 97 98TORCH_NN_COMMON_TEST_HARNESS = """ 99#include <torch/script.h> 100 101void write_ivalue_to_file(const torch::IValue& ivalue, const std::string& file_path) { 102 auto bytes = torch::jit::pickle_save(ivalue); 103 std::ofstream fout(file_path, std::ios::out | std::ios::binary); 104 fout.write(bytes.data(), bytes.size()); 105 fout.close(); 106} 107 108c10::Dict<std::string, torch::Tensor> load_dict_from_file(const std::string& file_path) { 109 c10::Dict<std::string, torch::Tensor> arg_dict; 110 auto arg_dict_module = torch::jit::load(file_path); 111 for (const auto& p : arg_dict_module.named_buffers(/*recurse=*/false)) { 112 arg_dict.insert(p.name, p.value); 113 } 114 return arg_dict; 115} 116 117// Generates rand tensor with non-equal values. This ensures that duplicate 118// values won't be causing test failure for modules like MaxPooling. 119// size should be small, otherwise randperm fails / long overflows. 120torch::Tensor _rand_tensor_non_equal(torch::IntArrayRef size) { 121 int64_t total = 1; 122 for (int64_t elem : size) { 123 total *= elem; 124 } 125 return torch::randperm(total).view(size).to(torch::kDouble); 126} 127""" 128 129 130def compile_cpp_code_inline(name, cpp_sources, functions): 131 cpp_module = torch.utils.cpp_extension.load_inline( 132 name=name, 133 cpp_sources=cpp_sources, 134 extra_cflags=[ 135 "-g" 136 ], # Enable debug symbols by default for debugging test failures. 137 functions=functions, 138 verbose=False, 139 ) 140 return cpp_module 141 142 143def compute_temp_file_path(cpp_tmp_folder, variant_name, file_suffix): 144 return os.path.join(cpp_tmp_folder, f"{variant_name}_{file_suffix}.pt") 145 146 147def is_torch_nn_functional_test(test_params_dict): 148 return "wrap_functional" in str(test_params_dict.get("constructor", "")) 149 150 151def convert_to_list(python_input): 152 if isinstance(python_input, torch.Tensor): 153 return [python_input] 154 else: 155 return list(python_input) 156 157 158def set_python_tensors_requires_grad(python_tensors): 159 return [ 160 tensor.requires_grad_(True) if tensor.dtype != torch.long else tensor 161 for tensor in python_tensors 162 ] 163 164 165def move_python_tensors_to_device(python_tensors, device): 166 return [tensor.to(device) for tensor in python_tensors] 167 168 169def has_test(unit_test_class, test_name): 170 return hasattr(unit_test_class, test_name) 171 172 173def add_test(unit_test_class, test_name, test_fn): 174 if has_test(unit_test_class, test_name): 175 raise RuntimeError("Found two tests with the same name: " + test_name) 176 setattr(unit_test_class, test_name, test_fn) 177 178 179def set_cpp_tensors_requires_grad(cpp_tensor_stmts, python_tensors): 180 assert len(cpp_tensor_stmts) == len(python_tensors) 181 return [ 182 f"{tensor_stmt}.requires_grad_(true)" 183 if tensor.dtype != torch.long 184 else tensor_stmt 185 for tensor_stmt, (_, tensor) in zip(cpp_tensor_stmts, python_tensors) 186 ] 187 188 189def move_cpp_tensors_to_device(cpp_tensor_stmts, device): 190 return [f'{tensor_stmt}.to("{device}")' for tensor_stmt in cpp_tensor_stmts] 191 192 193def is_criterion_test(test_instance): 194 return isinstance(test_instance, common_nn.CriterionTest) 195 196 197# This function computes the following: 198# - What variable declaration statements should show up in the C++ parity test function 199# - What arguments should be passed into the C++ module/functional's forward function 200# 201# For example, for the "L1Loss" test, the return values from this function are: 202# ``` 203# // Note that `arg_dict` stores all tensor values we transfer from Python to C++ 204# cpp_args_construction_stmts = [ 205# "auto i0 = arg_dict.at("i0").to("cpu").requires_grad_(true)", 206# "auto t0 = arg_dict.at("t0").to("cpu")", 207# ], 208# cpp_forward_args_symbols = [ 209# "i0", 210# "t0", 211# ] 212# ``` 213def compute_cpp_args_construction_stmts_and_forward_arg_symbols(test_params): 214 device = test_params.device 215 cpp_forward_args_symbols = [] 216 217 def add_cpp_forward_args(args): 218 args_stmts = [] 219 for arg_name, _ in args: 220 args_stmts.append(f'auto {arg_name} = arg_dict.at("{arg_name}")') 221 cpp_forward_args_symbols.append(arg_name) 222 return args_stmts 223 224 cpp_forward_input_args_stmts = set_cpp_tensors_requires_grad( 225 move_cpp_tensors_to_device( 226 add_cpp_forward_args(test_params.arg_dict["input"]), device 227 ), 228 test_params.arg_dict["input"], 229 ) 230 cpp_forward_target_args_stmts = move_cpp_tensors_to_device( 231 add_cpp_forward_args(test_params.arg_dict["target"]), device 232 ) 233 cpp_forward_extra_args_stmts = move_cpp_tensors_to_device( 234 add_cpp_forward_args(test_params.arg_dict["extra_args"]), device 235 ) 236 237 # Build the list of other arguments needed 238 cpp_other_args_stmts = [] 239 for arg_name, _ in test_params.arg_dict["other"]: 240 cpp_other_args_stmts.append(f'auto {arg_name} = arg_dict.at("{arg_name}")') 241 cpp_other_args_stmts = move_cpp_tensors_to_device(cpp_other_args_stmts, device) 242 243 cpp_args_construction_stmts = ( 244 cpp_forward_input_args_stmts 245 + cpp_forward_target_args_stmts 246 + cpp_forward_extra_args_stmts 247 + cpp_other_args_stmts 248 ) 249 250 return cpp_args_construction_stmts, cpp_forward_args_symbols 251 252 253def serialize_arg_dict_as_script_module(arg_dict): 254 arg_dict_flat = dict( 255 arg_dict["input"] 256 + arg_dict["target"] 257 + arg_dict["extra_args"] 258 + arg_dict["other"] 259 ) 260 arg_dict_module = torch.nn.Module() 261 for arg_name, arg_value in arg_dict_flat.items(): 262 assert isinstance(arg_value, torch.Tensor) 263 arg_dict_module.register_buffer(arg_name, arg_value) 264 265 return torch.jit.script(arg_dict_module) 266 267 268# NOTE: any argument symbol used in `cpp_constructor_args` / `cpp_options_args` / `cpp_function_call` 269# must have a mapping in `cpp_var_map`. 270# 271# The mapping can take one of the following formats: 272# 273# 1. `argument_name` -> Python value 274# 2. `argument_name` -> '_get_input()' (which means `argument_name` in C++ will be bound to `test_instance._get_input()`) 275# 276# For example: 277# ``` 278# def bceloss_weights_no_reduce_test(): 279# t = torch.randn(15, 10).gt(0).double() 280# weights = torch.rand(10) 281# return dict( 282# fullname='BCELoss_weights_no_reduce', 283# constructor=wrap_functional( 284# lambda i: F.binary_cross_entropy(i, t.type_as(i), 285# weight=weights.type_as(i), reduction='none')), 286# cpp_function_call='''F::binary_cross_entropy( 287# i, t.to(i.options()), 288# F::BinaryCrossEntropyFuncOptions() 289# .weight(weights.to(i.options())) 290# .reduction(torch::kNone))''', 291# input_fn=lambda: torch.rand(15, 10).clamp_(2.8e-2, 1 - 2.8e-2), 292# cpp_var_map={'i': '_get_input()', 't': t, 'weights': weights}, 293# reference_fn=lambda i, p, m: -(t * i.log() + (1 - t) * (1 - i).log()) * weights, 294# ) 295# ``` 296def compute_arg_dict(test_params_dict, test_instance): 297 arg_dict = { 298 "input": [], 299 "target": [], 300 "extra_args": [], 301 "other": [], 302 } 303 304 def put_args_into_arg_dict(arg_type, arg_type_prefix, args): 305 for i, arg in enumerate(args): 306 arg_dict[arg_type].append(CppArg(name=arg_type_prefix + str(i), value=arg)) 307 308 put_args_into_arg_dict("input", "i", convert_to_list(test_instance._get_input())) 309 if is_criterion_test(test_instance): 310 put_args_into_arg_dict( 311 "target", "t", convert_to_list(test_instance._get_target()) 312 ) 313 if test_instance.extra_args: 314 put_args_into_arg_dict( 315 "extra_args", "e", convert_to_list(test_instance.extra_args) 316 ) 317 318 cpp_var_map = test_params_dict.get("cpp_var_map", {}) 319 for arg_name, arg_value in cpp_var_map.items(): 320 if isinstance(arg_value, str): 321 if arg_value == "_get_input()": 322 arg_dict["other"].append( 323 CppArg(name=arg_name, value=test_instance._get_input()) 324 ) 325 else: 326 raise RuntimeError( 327 f"`{arg_name}` has unsupported string value: {arg_value}" 328 ) 329 elif isinstance(arg_value, torch.Tensor): 330 arg_dict["other"].append(CppArg(name=arg_name, value=arg_value)) 331 else: 332 raise RuntimeError(f"`{arg_name}` has unsupported value: {arg_value}") 333 334 return arg_dict 335 336 337def decorate_test_fn(test_fn, test_cuda, has_impl_parity, device): 338 if device == "cuda": 339 test_fn = unittest.skipIf(not TEST_CUDA, "CUDA unavailable")(test_fn) 340 test_fn = unittest.skipIf(not test_cuda, "Excluded from CUDA tests")(test_fn) 341 342 # If `Implementation Parity` entry in parity table for this module is `No`, 343 # or `has_parity` entry in test params dict is `False`, we mark the test as 344 # expected failure. 345 if not has_impl_parity: 346 test_fn = unittest.expectedFailure(test_fn) 347 348 return test_fn 349 350 351MESSAGE_HOW_TO_FIX_CPP_PARITY_TEST_FAILURE = """ 352What should I do when C++ API parity test is failing? 353 354- If you are changing the implementation of an existing `torch.nn` module / `torch.nn.functional` function: 355Answer: Ideally you should also change the C++ API implementation for that module / function 356(you can start by searching for the module / function name in `torch/csrc/api/` folder). 357 358- If you are adding a new test for an existing `torch.nn` module / `torch.nn.functional` function: 359Answer: Ideally you should fix the C++ API implementation for that module / function 360to exactly match the Python API implementation (you can start by searching for the module / 361function name in `torch/csrc/api/` folder). 362 363- If you are adding a test for a *new* `torch.nn` module / `torch.nn.functional` function: 364Answer: Ideally you should add the corresponding C++ API implementation for that module / function, 365and it should exactly match the Python API implementation. (We have done a large effort on this 366which is tracked at https://github.com/pytorch/pytorch/issues/25883.) 367 368However, if any of the above is proven to be too complicated, you can just add 369`test_cpp_api_parity=False` to any failing test in `torch/testing/_internal/common_nn.py`, 370and the C++ API parity test will be skipped accordingly. Note that you should 371also file an issue when you do this. 372 373For more details on how to add a C++ API parity test, please see: 374NOTE [How to check NN module / functional API parity between Python and C++ frontends] 375""" 376 377 378def generate_error_msg(name, cpp_value, python_value): 379 return ( 380 f"Parity test failed: {name} in C++ has value: {cpp_value}, " 381 f"which does not match the corresponding value in Python: {python_value}.\n{MESSAGE_HOW_TO_FIX_CPP_PARITY_TEST_FAILURE}" 382 ) 383 384 385def try_remove_folder(folder_path): 386 if os.path.exists(folder_path): 387 # Don't block the process if this fails, but show the error message as warning. 388 try: 389 shutil.rmtree(folder_path) 390 except Exception as e: 391 warnings.warn( 392 f"Non-blocking folder removal fails with the following error:\n{str(e)}" 393 ) 394