#include #include #include #include namespace torch { namespace lazy { template void test_hash_repeatable_sensitive(const T& example_a, const T& example_b) { // repeatable EXPECT_EQ(Hash(example_a), Hash(example_a)); EXPECT_EQ(MHash(example_a), MHash(example_a)); EXPECT_EQ(MHash(example_a, example_a), MHash(example_a, example_a)); // sensitive EXPECT_NE(Hash(example_a), Hash(example_b)); EXPECT_NE(MHash(example_a), MHash(example_b)); EXPECT_NE(MHash(example_a, example_a), MHash(example_a, example_b)); } TEST(HashTest, Scalar) { GTEST_SKIP() << "Broken test. See https://github.com/pytorch/pytorch/issues/99883"; c10::Scalar a(0); c10::Scalar b(0); // simulate some garbage in the unused bits of the // the tagged union that is c10::Scalar, which is bigger // than the size of the int64_t we're currently using it with *((uint8_t*)&b) = 1; // actual 'value' of the Scalar as a 64 bit int shouldn't have changed EXPECT_EQ(a.toLong(), b.toLong()); // and hash should ignore this garbage EXPECT_EQ(Hash(a), Hash(b)); EXPECT_EQ(MHash(a), MHash(b)); EXPECT_EQ(MHash(a, a), MHash(a, b)); } TEST(HashTest, Sanity) { // String test_hash_repeatable_sensitive( std::string( "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Ut at suscipit purus."), std::string( "Lorem Jpsum dolor sit amet, consectetur adipiscing elit. Ut at suscipit purus.")); // Number types test_hash_repeatable_sensitive(true, false); test_hash_repeatable_sensitive((int8_t)0xfa, (int8_t)0xfb); test_hash_repeatable_sensitive((int16_t)0xface, (int16_t)0xfade); test_hash_repeatable_sensitive((int32_t)0xfaceb000, (int32_t)0xfadeb000); test_hash_repeatable_sensitive((int64_t)0x1faceb000, (int64_t)0x1fadeb000); test_hash_repeatable_sensitive((uint8_t)0xfa, (uint8_t)0xfb); test_hash_repeatable_sensitive((uint16_t)0xface, (uint16_t)0xfade); test_hash_repeatable_sensitive((uint32_t)0xfaceb000, (uint32_t)0xfadeb000); test_hash_repeatable_sensitive((uint64_t)0x1faceb000, (uint64_t)0x1fadeb000); // c10 types test_hash_repeatable_sensitive(c10::ScalarType::Bool, c10::ScalarType::Byte); test_hash_repeatable_sensitive(c10::Scalar(1.334), c10::Scalar(1.335)); test_hash_repeatable_sensitive(c10::Scalar(true), c10::Scalar(false)); test_hash_repeatable_sensitive(c10::Scalar(12345), c10::Scalar(12354)); // std::optional test_hash_repeatable_sensitive( std::optional("I have value!"), std::optional(std::nullopt)); // Containers auto a = std::vector({0, 1, 1, 2, 3, 5, 8}); auto b = std::vector({1, 1, 2, 3, 5, 8, 12}); test_hash_repeatable_sensitive(a, b); test_hash_repeatable_sensitive( c10::ArrayRef(a), c10::ArrayRef(b)); // vector is a special case bc it is implemented as vector auto bool_a = std::vector({true, false, false, true}); auto bool_b = std::vector({true, true, false, true}); test_hash_repeatable_sensitive(bool_a, bool_b); } } // namespace lazy } // namespace torch