# Owner(s): ["oncall: jit"] import unittest import numpy as np import torch from torch.testing import FileCheck from torch.testing._internal.common_utils import IS_MACOS from torch.testing._internal.jit_utils import JitTestCase if __name__ == "__main__": raise RuntimeError( "This test file is not meant to be run directly, use:\n\n" "\tpython test/test_jit.py TESTNAME\n\n" "instead." ) class TestPythonIr(JitTestCase): def test_param_strides(self): def trace_me(arg): return arg t = torch.zeros(1, 3, 16, 16) traced = torch.jit.trace(trace_me, t) value = list(traced.graph.param_node().outputs())[0] real_strides = list(t.stride()) type_strides = value.type().strides() self.assertEqual(real_strides, type_strides) def test_permute_inputs_binding(self): @torch.jit.script def foo(i, j, k): pass g = foo.graph idxs = [] for i, inp in enumerate(g.inputs()): inp.setDebugName(f"inp{i}") idxs.append(i) permuted_idxs = list(np.random.permutation(idxs)) g.permuteInputs(permuted_idxs) for i, inp in enumerate(g.inputs()): self.assertEqual(f"inp{permuted_idxs[i]}", inp.debugName()) @unittest.skipIf(IS_MACOS, "Failing on MacOS only") def test_python_ir_utils(self): @torch.jit.script def foo(inp): x = inp + 1 y = x / 2 z = y * y return z add_node = foo.graph.findNode("aten::add") div_node = foo.graph.findNode("aten::div") with foo.graph.insert_point_guard(add_node): with foo.graph.insert_point_guard(div_node): foo.graph.insertConstant("goodbye") foo.graph.insertConstant("hello") with foo.graph.insert_point_guard(foo.graph.findNode("aten::mul")): foo.graph.insertConstant("hello") FileCheck().check("hello").check("goodbye").check("hello").run(foo.graph) self.assertTrue(add_node.matches(add_node.schema())) self.assertFalse(add_node.matches(div_node.schema())) def test_python_ir_utils_graph(self): @torch.jit.script def unrolled_mul(x: torch.Tensor, y: int): out = x for _ in range(y - 1): out = out + x return out @torch.jit.script def foo(x): return x * 4 g = foo.graph muls = g.findAllNodes("aten::mul") scalar_muls = filter( lambda x: x.matches("aten::mul(Tensor self, Scalar other) -> Tensor"), muls ) mul_constant_int = filter( lambda x: isinstance(list(x.inputs())[1].toIValue(), int), scalar_muls ) for mul in mul_constant_int: with g.insert_point_guard(mul): outputs = g.insertGraph(unrolled_mul.graph, list(mul.inputs())) assert len(outputs) == len(list(mul.outputs())) for new_out, old_out in zip(outputs, g.outputs()): old_out.replaceAllUsesWith(new_out) mul.destroy() FileCheck().check_not("aten::mul").check("aten::add").run(foo.graph) self.assertEqual(foo(torch.ones([2, 2])), torch.ones([2, 2]) * 4)