1# mypy: ignore-errors 2 3# Torch 4import torch 5import torch.cuda 6import torch.jit 7import torch.jit._logging 8import torch.jit.frontend 9import torch.jit.quantized 10 11# Testing utils 12from torch.testing._internal.common_dtype import floating_and_complex_types_and 13from torch.testing._internal.common_utils import TestCase, \ 14 freeze_rng_state, TemporaryFileName, enable_profiling_mode_for_profiling_tests, is_iterable_of_tensors 15from torch.testing._internal.common_utils import enable_profiling_mode # noqa: F401 16 17# Standard library 18from itertools import chain 19from typing import List, Union 20from torch._C import TensorType 21 22import io 23 24def check_output_types(self, func, ref_outputs, args, kwargs): 25 graph = getattr(func, 'last_graph', None) 26 types = [o.type() for o in graph.outputs()] 27 self.assertTrue(len(types) == 1) 28 t = types[0] 29 torch._C._jit_assert_is_instance(ref_outputs, t) 30 31# Test names in this set are only checked for a single derivative 32nn_functional_single_grad = frozenset('test_nn_' + name for name in [ 33 'pdist', 34 'multilabel_margin_loss', 35 'max_unpool3d', 36 'multi_margin_loss', 37 'binary_cross_entropy', 38 'binary_cross_entropy_size_average', 39 'ctc_loss', 40 'grid_sample', 41]) 42 43def check_against_reference(self, func, reference_func, output_func, args, kwargs=None, 44 allow_unused=True, check_types=True, no_grad=False, no_gradgrad=False): 45 """Verifies a function performs identically to some reference implementation. 46 47 Commonly, this is used to verify that a JIT implementation 48 (output_func) matches the behavior of the eager implementation 49 (reference_func). 50 """ 51 kwargs = kwargs if kwargs else {} 52 53 def allSum(vs): 54 if isinstance(vs, torch.Tensor): 55 vs = (vs,) 56 return sum((i + 1) * v.sum().abs() if v.dtype.is_complex else (i + 1) * v.sum() 57 for i, v in enumerate(vs) 58 if v is not None and v.dtype in floating_and_complex_types_and(torch.half, torch.bfloat16)) 59 60 def clone_tensor(t, preserve_requires_grad): 61 require_grad = preserve_requires_grad and t.requires_grad 62 return t.detach().clone().requires_grad_(require_grad) 63 64 def clone_inputs(preserve_requires_grad: bool): 65 inputs: List[Union[torch.Tensor, List[torch.Tensor]]] = [] 66 67 for arg in args: 68 if isinstance(arg, torch.Tensor): 69 inputs.append(clone_tensor(arg, preserve_requires_grad)) 70 elif is_iterable_of_tensors(arg): 71 inputs.append([clone_tensor(t, preserve_requires_grad) for t in arg]) 72 else: 73 inputs.append(arg) 74 75 return inputs 76 77 # Returns tensors in args that requires_grad, including tensors in TensorList args 78 def get_recording_tensors(args): 79 recording_tensors: List[torch.Tensor] = [] 80 81 for arg in args: 82 if isinstance(arg, torch.Tensor) and arg.requires_grad: 83 recording_tensors.append(arg) 84 elif is_iterable_of_tensors(arg): 85 recording_tensors.extend(filter(lambda t: t.requires_grad, arg)) 86 87 return recording_tensors 88 89 # test no gradients case 90 nograd_inputs = clone_inputs(preserve_requires_grad=False) 91 outputs = self.runAndSaveRNG(reference_func, nograd_inputs, kwargs) 92 with enable_profiling_mode_for_profiling_tests(): 93 outputs_test = self.runAndSaveRNG(func, nograd_inputs, kwargs) 94 self.assertEqual(outputs, outputs_test) 95 96 if check_types: 97 check_output_types(self, func, outputs_test, nograd_inputs, kwargs) 98 99 if no_grad: 100 # skip grad tests 101 return 102 103 with enable_profiling_mode_for_profiling_tests(): 104 # test single grad case 105 recording_inputs = clone_inputs(preserve_requires_grad=True) 106 recording_tensors = get_recording_tensors(recording_inputs) 107 outputs = output_func(self.runAndSaveRNG(reference_func, recording_inputs, kwargs)) 108 grads = torch.autograd.grad(allSum(outputs), recording_tensors, 109 allow_unused=allow_unused) 110 outputs_test = output_func(self.runAndSaveRNG(func, recording_inputs, kwargs)) 111 grads_test = torch.autograd.grad(allSum(outputs_test), recording_tensors, 112 allow_unused=allow_unused) 113 self.assertEqual(outputs, outputs_test) 114 self.assertEqual(grads, grads_test) 115 # test the grad grad case 116 if self._testMethodName in nn_functional_single_grad or no_gradgrad: 117 return 118 119 outputs = output_func(self.runAndSaveRNG(reference_func, recording_inputs, kwargs)) 120 l1 = allSum(outputs) 121 grads = torch.autograd.grad(l1, recording_tensors, create_graph=True, 122 allow_unused=allow_unused) 123 124 l2 = (allSum(grads) * l1) 125 grads2 = torch.autograd.grad(l2, recording_tensors, allow_unused=allow_unused) 126 recording_inputs = clone_inputs(preserve_requires_grad=True) 127 recording_tensors = get_recording_tensors(recording_inputs) 128 outputs_test = output_func(self.runAndSaveRNG(func, recording_inputs, kwargs)) 129 l1_test = allSum(outputs_test) 130 grads_test = torch.autograd.grad( 131 l1_test, recording_tensors, create_graph=True, allow_unused=allow_unused) 132 133 l2_test = (allSum(grads_test) * l1_test) 134 grads2_test = torch.autograd.grad(l2_test, recording_tensors, allow_unused=allow_unused) 135 136 self.assertEqual(outputs, outputs_test) 137 self.assertEqual(grads, grads_test) 138 for g2, g2_test in zip(grads2, grads2_test): 139 if g2 is None and g2_test is None: 140 continue 141 self.assertEqual(g2, g2_test, atol=5e-4, rtol=1e-4) 142 143class JitCommonTestCase(TestCase): 144 def createFunctionFromGraph(self, trace): 145 graph = trace if isinstance(trace, torch._C.Graph) else trace.graph() 146 return torch._C._create_function_from_graph("forward", graph) 147 148 def assertExportImport(self, trace, inputs): 149 m = self.createFunctionFromGraph(trace) 150 self.assertExportImportModule(m, inputs) 151 152 def assertExportImportModule(self, m, inputs): 153 m_import = self.getExportImportCopy(m) 154 a = self.runAndSaveRNG(m, inputs) 155 b = self.runAndSaveRNG(m_import, inputs) 156 self.assertEqual(a, b, "Results of original model and " 157 "exported/imported version of model differed") 158 159 def runAndSaveRNG(self, func, inputs, kwargs=None): 160 kwargs = kwargs if kwargs else {} 161 with freeze_rng_state(): 162 results = func(*inputs, **kwargs) 163 return results 164 165 def getExportImportCopy(self, m, also_test_file=True, map_location=None): 166 buffer = io.BytesIO() 167 torch.jit.save(m, buffer) 168 buffer.seek(0) 169 imported = torch.jit.load(buffer, map_location=map_location) 170 171 if not also_test_file: 172 return imported 173 174 with TemporaryFileName() as fname: 175 torch.jit.save(imported, fname) 176 return torch.jit.load(fname, map_location=map_location) 177 178 def autoDiffErrorMessage(self, should_autodiff_node, nodes_not_in_diff_graph, 179 fusion_nodes_not_found, non_fusible_nodes_being_fused, 180 fusion_nodes_found, nodes_in_diff_graph): 181 err_msg = "\nFailure in testing nodes' autodifferentiation. " 182 if should_autodiff_node: 183 err_msg += "One or more nodes were expected to be autodiffed, " \ 184 "but were not found in specified fusible/nonfusible " \ 185 "DifferentiableGraph groups. \nSpecifically:" 186 # The node is intended to appear in a differentiable graph but doesn't 187 diff_nodes_missing = [] 188 # The node is intended to appear in a differentiable graph 189 # outside of a fusion group but instead is in a fusion group 190 diff_nodes_in_fusion = [] 191 # The node is intended to appear in a fusion group but doesn't 192 fusion_nodes_missing = [] 193 # The node is intended to appear in a fusion group but instead 194 # is just in an outer differentiable graph 195 fusion_nodes_in_diff = [] 196 for node in nodes_not_in_diff_graph: 197 if node in non_fusible_nodes_being_fused: 198 diff_nodes_in_fusion.append(node) 199 else: 200 diff_nodes_missing.append(node) 201 for node in fusion_nodes_not_found: 202 if node in nodes_in_diff_graph: 203 fusion_nodes_in_diff.append(node) 204 else: 205 fusion_nodes_missing.append(node) 206 if len(diff_nodes_missing) > 0: 207 err_msg += f"\n {diff_nodes_missing} were not in one of the " \ 208 "DifferentiableGraphs when they were expected to be. " \ 209 "Did you intend for these nodes to be autodiffed? " \ 210 "If not, remove them from the list of nonfusible nodes." 211 if len(diff_nodes_in_fusion) > 0: 212 err_msg += f"\n {diff_nodes_in_fusion} were found in one of the FusionGroups " \ 213 "when they were expected to be just in a DifferentiableGraph. If it was " \ 214 "intended for these nodes to be in FusionGroups, reclassify these nodes as " \ 215 "fusible nodes. If these nodes were not intended to be fused, your " \ 216 "autodifferentiation logic might be wrong." 217 if len(fusion_nodes_missing) > 0: 218 err_msg += f"\n {fusion_nodes_missing} were not in one of the FusionGroups " \ 219 "of the DifferentiableGraphs when they were expected to be. " \ 220 "They were also not found in an outer DifferentiableGraph. Did you " \ 221 "intend for these nodes to be autodifferentiated? If not, you should " \ 222 "remove these nodes from the test's fusible nodes. Otherwise your " \ 223 "autodifferentiation logic might be wrong." 224 if len(fusion_nodes_in_diff) > 0: 225 err_msg += f"\n {fusion_nodes_in_diff} were not in one of the FusionGroups " \ 226 "of the DifferentiableGraphs when they were expected to be, " \ 227 "instead they were found just in an outer DifferentiableGraph. " \ 228 "Did you intend for these nodes to be fused? If not, you should " \ 229 "move these nodes into the test's nonfusible nodes. Otherwise your " \ 230 "autodifferentiation logic might be wrong." 231 else: 232 err_msg += "One or more nodes were not expected to be autodiffed " \ 233 "but were found in a DifferentiableGraph or in a FusionGroup " \ 234 "of a DifferentiableGraph. Did you intend for these nodes to be " \ 235 "autodiffed? If so, change this test to expect autodifferentiation. " \ 236 "\nSpecifically:" 237 if len(fusion_nodes_found) > 0: 238 err_msg += f"\n {fusion_nodes_found} were not expected to be in " \ 239 "one of the DifferentiableGraphs, but appeared in a FusionGroup " \ 240 "of a DifferentiableGraph. " 241 if len(nodes_in_diff_graph) > 0: 242 err_msg += f"\n {nodes_in_diff_graph} were not expected to " \ 243 "be in one of the DifferentiableGraphs but were." 244 return err_msg 245 246 def assertAutodiffNode(self, graph, should_autodiff_node, nonfusible_nodes, fusible_nodes): 247 diff_nodes = graph.findAllNodes('prim::DifferentiableGraph') 248 diff_subgraphs = [node.g('Subgraph') for node in diff_nodes] 249 250 # Note: currently no tests have fusible_nodes 251 fusion_nodes = list(chain.from_iterable([g.findAllNodes('prim::FusionGroup') for g in diff_subgraphs])) 252 fusion_subgraphs = [node.g('Subgraph') for node in fusion_nodes] 253 254 # For any non-fusible node, it must show up in one of the DifferentiableGraphs. 255 nodes_in_diff_graph = [] 256 nodes_not_in_diff_graph = [] 257 non_fusible_nodes_being_fused = [] 258 for node in nonfusible_nodes: 259 if any(g.findNode(node) is not None for g in diff_subgraphs): 260 nodes_in_diff_graph.append(node) 261 else: 262 nodes_not_in_diff_graph.append(node) 263 if any(g.findNode(node) is not None for g in fusion_subgraphs): 264 non_fusible_nodes_being_fused.append(node) 265 found_all_nonfusible_nodes = len(nodes_in_diff_graph) == len(nonfusible_nodes) 266 267 # For any fusible node, it must show up in one of the FusionGroups in one of the DifferentiableGraphs. 268 fusion_nodes_found = [] 269 fusion_nodes_not_found = [] 270 for node in fusible_nodes: 271 if any(g.findNode(node) is not None for g in fusion_subgraphs): 272 fusion_nodes_found.append(node) 273 else: 274 fusion_nodes_not_found.append(node) 275 found_all_fusible_nodes = len(fusion_nodes_found) == len(fusible_nodes) 276 277 if should_autodiff_node is not None: 278 err_msg = self.autoDiffErrorMessage(should_autodiff_node, 279 nodes_not_in_diff_graph, 280 fusion_nodes_not_found, 281 non_fusible_nodes_being_fused, 282 fusion_nodes_found, 283 nodes_in_diff_graph) 284 self.assertEqual(should_autodiff_node, 285 found_all_nonfusible_nodes and found_all_fusible_nodes, err_msg) 286 287 def checkShapeAnalysis(self, out_sizes: Union[List[int], List[List[int]]], 288 traced_graph, assert_propagation, constant_prop=True): 289 # repropagte input shapes provided by tracing, 290 prev_symbolic_shapes_test_enabled = torch._C._jit_symbolic_shapes_test_mode_enabled() 291 for enable_test_mode in [True, False]: 292 # here we are testing allowing/disallowing substituting in complete shapes as constants, 293 # disallowing constants helps stress test partial eval and substitution pipeline 294 torch._C._jit_set_symbolic_shapes_test_mode(enable_test_mode) 295 torch._C._jit_erase_non_input_shape_information(traced_graph) 296 if constant_prop: 297 torch._C._jit_pass_constant_propagation(traced_graph) 298 torch._C._jit_pass_propagate_shapes_on_graph(traced_graph) 299 # Add sizes to default tensor type to avoid checking something out of scope 300 # and difficulties with tracer leaving in other parts of tensor type 301 output = next(traced_graph.outputs()).type() 302 303 def test_type(type, actual_size): 304 sizes = type.symbolic_sizes() 305 out_type = TensorType.get().with_sizes(sizes) 306 actual_type = TensorType.get().with_sizes(actual_size) 307 308 # always check actual shape is a subtype of the output 309 self.assertTrue(actual_type.isSubtypeOf(out_type)) 310 311 # and then if assertion flag is provided, check shape analysis 312 # is successful 313 if assert_propagation: 314 self.assertEqual(out_type.sizes(), actual_size) 315 316 if output.isSubtypeOf(torch._C.TensorType.get()): 317 test_type(output, out_sizes) 318 else: 319 tuple_elements = output.elements() 320 for i in range(len(tuple_elements)): 321 test_type(tuple_elements[i], out_sizes[i]) 322 323 torch._C._jit_set_symbolic_shapes_test_mode(prev_symbolic_shapes_test_enabled) 324