1#!/usr/bin/env python3 2# Owner(s): ["module: internals"] 3 4import torch 5from torch.testing._internal.common_utils import run_tests, TestCase 6 7 8class TestComparisonUtils(TestCase): 9 def test_all_equal_no_assert(self): 10 t = torch.tensor([0.5]) 11 torch._assert_tensor_metadata(t, [1], [1], torch.float) 12 13 def test_all_equal_no_assert_nones(self): 14 t = torch.tensor([0.5]) 15 torch._assert_tensor_metadata(t, None, None, None) 16 17 def test_assert_dtype(self): 18 t = torch.tensor([0.5]) 19 20 with self.assertRaises(RuntimeError): 21 torch._assert_tensor_metadata(t, None, None, torch.int32) 22 23 def test_assert_strides(self): 24 t = torch.tensor([0.5]) 25 26 with self.assertRaises(RuntimeError): 27 torch._assert_tensor_metadata(t, None, [3], torch.float) 28 29 def test_assert_sizes(self): 30 t = torch.tensor([0.5]) 31 32 with self.assertRaises(RuntimeError): 33 torch._assert_tensor_metadata(t, [3], [1], torch.float) 34 35 36if __name__ == "__main__": 37 run_tests() 38