1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/portable/cpu/scalar_utils.h>
10 #include <executorch/kernels/portable/cpu/util/elementwise_util.h>
11 #include <executorch/kernels/portable/cpu/util/math_util.h>
12 #include <executorch/runtime/kernel/kernel_includes.h>
13
14 namespace torch {
15 namespace executor {
16 namespace native {
17
maximum_out(KernelRuntimeContext & ctx,const Tensor & a,const Tensor & b,Tensor & out)18 Tensor& maximum_out(
19 KernelRuntimeContext& ctx,
20 const Tensor& a,
21 const Tensor& b,
22 Tensor& out) {
23 // Common Dtype
24 ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
25
26 // Check Common Dtype
27 ET_KERNEL_CHECK(
28 ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out);
29
30 // Check Dim Order
31 ET_KERNEL_CHECK(
32 ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
33
34 // Resize
35 ET_KERNEL_CHECK(
36 ctx,
37 resize_to_broadcast_target_size(a, b, out) == Error::Ok,
38 InvalidArgument,
39 out);
40
41 // Compute Dtype
42 ScalarType compute_type = utils::get_compute_type(common_type);
43
44 // @lint-ignore CLANGTIDY facebook-hte-CArray
45 static constexpr const char op_name[] = "maximum.out";
46
47 ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
48 utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
49 [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
50 return utils::max_override(val_a, val_b);
51 },
52 ctx,
53 a,
54 utils::SupportedTensorDtypes::REALHBBF16,
55 b,
56 utils::SupportedTensorDtypes::REALHBBF16,
57 out,
58 utils::SupportedTensorDtypes::REALHBBF16);
59 });
60
61 return out;
62 }
63
64 } // namespace native
65 } // namespace executor
66 } // namespace torch
67