xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_copy.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cstring>
10 
11 #include <executorch/kernels/portable/cpu/util/broadcast_util.h>
12 #include <executorch/kernels/portable/cpu/util/elementwise_util.h>
13 #include <executorch/runtime/kernel/kernel_includes.h>
14 
15 namespace torch {
16 namespace executor {
17 namespace native {
18 
19 using Tensor = exec_aten::Tensor;
20 
21 // copy.out(const Tensor& in, const Tensor& src, bool non_blocking, Tensor(a!)
22 // out) -> Tensor(a!), see caffe2/aten/src/ATen/native/Copy.cpp
23 // TODO: We actually shouldn't see this op with the proper functionalization,
24 // and this op needs to be deleted
copy_out(KernelRuntimeContext & ctx,const Tensor & in,const Tensor & src,bool non_blocking,Tensor & out)25 Tensor& copy_out(
26     KernelRuntimeContext& ctx,
27     const Tensor& in,
28     const Tensor& src,
29     bool non_blocking,
30     Tensor& out) {
31   (void)ctx;
32   // Right now we only support blocking data transfer
33   ET_KERNEL_CHECK(ctx, non_blocking == false, InvalidArgument, out);
34 
35   ET_KERNEL_CHECK(ctx, tensors_have_same_dtype(in, out), InvalidArgument, out);
36 
37   ET_KERNEL_CHECK(
38       ctx, tensor_is_broadcastable_to(src, in), InvalidArgument, out);
39 
40   ET_KERNEL_CHECK(
41       ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
42 
43   ET_KERNEL_CHECK(
44       ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
45 
46   // @lint-ignore CLANGTIDY facebook-hte-CArray
47   static constexpr const char op_name[] = "copy.out";
48 
49   ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy.out", CTYPE, [&]() {
50     utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
51         [](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
52         ctx,
53         in,
54         utils::SupportedTensorDtypes::REALHBBF16,
55         src,
56         utils::SupportedTensorDtypes::REALHBBF16,
57         out,
58         utils::SupportedTensorDtypes::REALHBBF16);
59   });
60 
61   return out;
62 }
63 
copy_(KernelRuntimeContext & ctx,Tensor & in,const Tensor & src,bool non_blocking)64 Tensor& copy_(
65     KernelRuntimeContext& ctx,
66     Tensor& in,
67     const Tensor& src,
68     bool non_blocking) {
69   (void)ctx;
70   // Right now we only support blocking data transfer
71   ET_KERNEL_CHECK(ctx, non_blocking == false, InvalidArgument, in);
72 
73   ET_KERNEL_CHECK(
74       ctx, tensor_is_broadcastable_to(src, in), InvalidArgument, in);
75 
76   ET_KERNEL_CHECK(
77       ctx, tensors_have_same_dim_order(in, src), InvalidArgument, in);
78 
79   // @lint-ignore CLANGTIDY facebook-hte-CArray
80   static constexpr const char op_name[] = "copy_";
81 
82   ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy_", CTYPE, [&]() {
83     utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
84         [](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
85         ctx,
86         in,
87         utils::SupportedTensorDtypes::REALHBBF16,
88         src,
89         utils::SupportedTensorDtypes::REALHBBF16,
90         in,
91         utils::SupportedTensorDtypes::REALHBBF16);
92   });
93 
94   return in;
95 }
96 
97 } // namespace native
98 } // namespace executor
99 } // namespace torch
100