/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include // Declares the operator #include #include #include #include #include #include using namespace ::testing; using exec_aten::Scalar; using exec_aten::ScalarType; using exec_aten::Tensor; using executorch::runtime::KernelRuntimeContext; using torch::executor::testing::TensorFactory; class OpLtScalarOutTest : public OperatorTest { protected: Tensor& op_lt_scalar_out(const Tensor& self, Scalar& other, Tensor& out) { return torch::executor::aten::lt_outf(context_, self, other, out); } template void test_lt_scalar_out() { TensorFactory tf; TensorFactory tf_out; const std::vector sizes = {2, 2}; Tensor out = tf_out.ones(sizes); Scalar other = 2; // Valid input should give the expected output op_lt_scalar_out(tf.make(sizes, /*data=*/{3, 1, 2, 4}), other, out); EXPECT_TENSOR_EQ( out, tf_out.make(sizes, /*data=*/{false, true, false, false})); } }; class OpLtTensorOutTest : public OperatorTest { protected: Tensor& op_lt_tensor_out(const Tensor& self, const Tensor& other, Tensor& out) { return torch::executor::aten::lt_outf(context_, self, other, out); } template void test_dtype() { TensorFactory tf_input; TensorFactory tf_out; Tensor a = tf_input.make(/*sizes=*/{2, 2}, /*data=*/{2, 3, 2, 4}); Tensor b = tf_input.make({2, 2}, {1, 4, 2, 3}); Tensor out = tf_out.zeros({2, 2}); op_lt_tensor_out(a, b, out); EXPECT_TENSOR_EQ(out, tf_out.make({2, 2}, {false, true, false, false})); } }; TEST_F(OpLtScalarOutTest, AllRealInputBoolOutputSupport) { #define TEST_ENTRY(ctype_in, dtype_in, ctype_out, dtype_out) \ test_lt_scalar_out(); #define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \ ET_FORALL_REAL_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \ test_lt_scalar_out(); ET_FORALL_REAL_TYPES(TEST_FORALL_OUT_TYPES) #undef TEST_FORALL_OUT_TYPES #undef TEST_ENTRY } TEST_F(OpLtScalarOutTest, BoolInputDtype) { TensorFactory tf_bool; const std::vector sizes = {2, 2}; Tensor a = tf_bool.make(sizes, /*data=*/{false, true, false, true}); Tensor out = tf_bool.zeros(sizes); Scalar other = 0.5; op_lt_scalar_out(a, other, out); EXPECT_TENSOR_EQ( out, tf_bool.make(sizes, /*data=*/{true, false, true, false})); } // Mismatched shape tests. TEST_F(OpLtScalarOutTest, MismatchedInOutShapesDies) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "ATen kernel can handle mismatched shapes"; } TensorFactory tf_int; TensorFactory tf_bool; Tensor a = tf_int.ones(/*sizes=*/{4}); Tensor out = tf_bool.ones(/*sizes=*/{2, 2}); Scalar other = 3; ET_EXPECT_KERNEL_FAILURE(context_, op_lt_scalar_out(a, other, out)); } TEST_F(OpLtScalarOutTest, DynamicOutShapeTest) { TensorFactory tf; const std::vector sizes = {2, 2}; const std::vector out_sizes = {4, 1}; Tensor out = tf.zeros(out_sizes, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND); Scalar other = 2; // Valid input should give the expected output op_lt_scalar_out(tf.make(sizes, /*data=*/{3, 1, 2, 4}), other, out); EXPECT_TENSOR_EQ(out, tf.make(sizes, /*data=*/{false, true, false, false})); } TEST_F(OpLtTensorOutTest, AllDtypesSupported) { #define TEST_ENTRY(ctype_in, dtype_in, ctype_out, dtype_out) \ test_dtype(); #define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \ ET_FORALL_REAL_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \ test_dtype(); ET_FORALL_REAL_TYPES(TEST_FORALL_OUT_TYPES); #undef TEST_FORALL_OUT_TYPES #undef TEST_ENTRY } TEST_F(OpLtTensorOutTest, MismatchedInShapesDies) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "ATen kernel can handle mismatched shapes"; } TensorFactory tf_int; TensorFactory tf_bool; Tensor a = tf_int.ones(/*sizes=*/{4}); Tensor b = tf_int.ones(/*sizes=*/{2, 2}); Tensor out = tf_bool.ones(/*sizes=*/{4}); ET_EXPECT_KERNEL_FAILURE(context_, op_lt_tensor_out(a, b, out)); } TEST_F(OpLtTensorOutTest, MismatchedInOutShapesDies) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { GTEST_SKIP() << "ATen kernel can handle mismatched shapes"; } TensorFactory tf_int; TensorFactory tf_bool; Tensor a = tf_int.ones(/*sizes=*/{4}); Tensor b = tf_int.ones(/*sizes=*/{4}); Tensor out = tf_bool.ones(/*sizes=*/{2, 2}); ET_EXPECT_KERNEL_FAILURE(context_, op_lt_tensor_out(a, b, out)); } TEST_F(OpLtTensorOutTest, DynamicOutShapeTest) { TensorFactory tf; Tensor a = tf.make(/*sizes=*/{2, 2}, /*data=*/{2, 3, 2, 4}); Tensor b = tf.make({2, 2}, {1, 4, 2, 3}); Tensor out = tf.zeros({1, 4}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND); op_lt_tensor_out(a, b, out); EXPECT_TENSOR_EQ(out, tf.make({2, 2}, {false, true, false, false})); }