/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include #include #include #include #include namespace torch { namespace executor { namespace native { namespace { ScalarType get_compute_type(ScalarType a_type, ScalarType b_type) { ET_CHECK( !isComplexType(a_type) && !isQIntType(a_type) && !isBitsType(a_type)); ET_CHECK( !isComplexType(b_type) && !isQIntType(b_type) && !isBitsType(b_type)); if (isFloatingType(a_type) && isFloatingType(b_type)) { return promoteTypes(a_type, b_type); } else if (isFloatingType(a_type)) { return a_type; } else if (isFloatingType(b_type)) { return b_type; } return ScalarType::Float; } } // namespace Tensor& opt_div_out( KernelRuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) { (void)ctx; ScalarType a_type = a.scalar_type(); ScalarType b_type = b.scalar_type(); ScalarType out_type = out.scalar_type(); if (a.numel() == 1 || b.numel() == 1) { if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half) { const Tensor* tensor; const Tensor* scalar; ScalarType tensor_type; ScalarType scalar_type; if (a.numel() == 1) { tensor = &b; tensor_type = b_type; scalar = &a; scalar_type = a_type; } else { tensor = &a; tensor_type = a_type; scalar = &b; scalar_type = b_type; } ET_KERNEL_CHECK( ctx, resize_to_broadcast_target_size(a, b, out) == Error::Ok, InvalidArgument, out); ET_SWITCH_REALB_TYPES(tensor_type, ctx, "div.out", CTYPE, [&]() { ET_SWITCH_REALB_TYPES(scalar_type, ctx, "div.out", CTYPE_SCALAR, [&]() { CTYPE_SCALAR scalar_val = *scalar->const_data_ptr(); CTYPE scalar_casted = static_cast(scalar_val); using Vec = executorch::vec::Vectorized; if (a.numel() == 1) { executorch::vec::map( [scalar_casted](Vec x) { return Vec(scalar_casted) / x; }, out.mutable_data_ptr(), tensor->const_data_ptr(), out.numel()); } else { Vec inv_scalar_casted_vec(CTYPE(1) / scalar_casted); executorch::vec::map( [inv_scalar_casted_vec](Vec x) { return x * inv_scalar_casted_vec; }, out.mutable_data_ptr(), tensor->const_data_ptr(), out.numel()); } }); }); return out; } } auto selected_optimized_path = select_optimized_path(a, b, out); if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) { // Resize for dynamic shape auto error = resize_tensor(out, a.sizes()); ET_KERNEL_CHECK_MSG( ctx, error == Error::Ok, InvalidArgument, out, "Failed to resize output tensor."); ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, "div.out", CTYPE, [&]() { using Vec = executorch::vec::Vectorized; executorch::vec::map2( [](Vec x, Vec y) { return x / y; }, out.mutable_data_ptr(), a.const_data_ptr(), b.const_data_ptr(), out.numel()); }); } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) { const Tensor* lhs; const Tensor* rhs; if (selected_optimized_path == ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) { lhs = &b; rhs = &a; } else { // Catch failure to update logic when subing new broadcasting possibility. ET_DCHECK( selected_optimized_path == ElementwiseOptimizedPath::kBroadcast2dBy1d); lhs = &a; rhs = &b; } auto error = resize_tensor(out, lhs->sizes()); ET_KERNEL_CHECK_MSG( ctx, error == Error::Ok, InvalidArgument, out, "Failed to resize output tensor."); ET_SWITCH_REALB_TYPES(out_type, ctx, "sub.out", CTYPE, [&]() { using Vec = executorch::vec::Vectorized; if (selected_optimized_path == ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) { executorch::vec::broadcasting_map_2d_by_1d( [](Vec x, Vec y) { return y / x; }, out.mutable_data_ptr(), lhs->const_data_ptr(), rhs->const_data_ptr(), lhs->sizes()[lhs->dim() - 2], lhs->sizes()[lhs->dim() - 1]); } else { executorch::vec::broadcasting_map_2d_by_1d( [](Vec x, Vec y) { return x / y; }, out.mutable_data_ptr(), lhs->const_data_ptr(), rhs->const_data_ptr(), lhs->sizes()[lhs->dim() - 2], lhs->sizes()[lhs->dim() - 1]); } }); } else { ScalarType common_type = get_compute_type(a_type, b_type); ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); ET_KERNEL_CHECK( ctx, resize_to_broadcast_target_size(a, b, out) == Error::Ok, InvalidArgument, out); ET_SWITCH_REALB_TYPES(a_type, ctx, "div.out", CTYPE_A, [&]() { ET_SWITCH_REALB_TYPES(b_type, ctx, "div.out", CTYPE_B, [&]() { ET_SWITCH_REALB_TYPES(common_type, ctx, "div.out", CTYPE_IN, [&]() { ET_SWITCH_REALB_TYPES(out_type, ctx, "div.out", CTYPE_OUT, [&]() { apply_binary_elementwise_fn( [](const CTYPE_A val_a, const CTYPE_B val_b) { CTYPE_IN a_casted = static_cast(val_a); CTYPE_IN b_casted = static_cast(val_b); CTYPE_IN value = a_casted / b_casted; return static_cast(value); }, a, b, out); }); }); }); }); } return out; } Tensor& opt_div_scalar_out( KernelRuntimeContext& ctx, const Tensor& a, const Scalar& b, Tensor& out) { (void)ctx; ScalarType a_type = a.scalar_type(); ScalarType b_type = utils::get_scalar_dtype(b); ScalarType common_type = isFloatingType(a_type) ? a_type : ScalarType::Float; ScalarType out_type = out.scalar_type(); ET_CHECK(common_type == out_type); // Resize for dynamic shape auto error = resize_tensor(out, a.sizes()); ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor."); if (a_type == common_type && a_type == out_type) { ET_SWITCH_REAL_TYPES(a_type, ctx, "div.Scalar_out", CTYPE, [&]() { ET_SWITCH_REAL_TYPES_AND( Bool, b_type, ctx, "div.Scalar_out", CTYPE_B, [&]() { CTYPE_B b_val; ET_EXTRACT_SCALAR(b, b_val); CTYPE b_casted = static_cast(b_val); using Vec = executorch::vec::Vectorized; Vec inv_b_casted_vec(CTYPE(1) / b_casted); executorch::vec::map( [inv_b_casted_vec](Vec x) { return x * inv_b_casted_vec; }, out.mutable_data_ptr(), a.const_data_ptr(), out.numel()); }); }); } else { ET_SWITCH_REAL_TYPES_AND( Bool, a_type, ctx, "div.Scalar_out", CTYPE_A, [&]() { ET_SWITCH_REAL_TYPES_AND( Bool, b_type, ctx, "div.Scalar_out", CTYPE_B, [&]() { ET_SWITCH_REAL_TYPES( common_type, ctx, "div.Scalar_out", CTYPE_IN, [&]() { ET_SWITCH_REAL_TYPES( out_type, ctx, "div.Scalar_out", CTYPE_OUT, [&]() { CTYPE_B b_val; ET_EXTRACT_SCALAR(b, b_val); CTYPE_IN b_casted = static_cast(b_val); CTYPE_IN inv_b_casted = CTYPE_IN(1) / b_casted; const size_t n = a.numel(); const CTYPE_A* a_data = a.const_data_ptr(); CTYPE_OUT* out_data = out.mutable_data_ptr(); for (auto i = 0; i < n; ++i) { out_data[i] = static_cast( static_cast(a_data[i]) * inv_b_casted); } }); }); }); }); } return out; } } // namespace native } // namespace executor } // namespace torch