1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 #include <executorch/runtime/core/exec_aten/exec_aten.h> 12 #include <executorch/runtime/core/exec_aten/util/tensor_util.h> 13 14 namespace torch { 15 namespace executor { 16 17 bool check_gather_args( 18 const Tensor& in, 19 int64_t dim, 20 const Tensor& index, 21 bool sparse_grad, 22 Tensor& output); 23 24 bool check_index_select_args( 25 const Tensor& in, 26 int64_t dim, 27 const Tensor& index, 28 Tensor& out); 29 30 void get_index_select_out_target_size( 31 const Tensor& in, 32 int64_t dim, 33 const Tensor& index, 34 exec_aten::SizesType* out_sizes, 35 size_t* out_ndim); 36 37 bool check_nonzero_args(const Tensor& in, const Tensor& out); 38 39 bool check_scatter_add_args( 40 const Tensor& self, 41 int64_t dim, 42 const Tensor& index, 43 const Tensor& src, 44 Tensor& out); 45 46 bool check_scatter_src_args( 47 const Tensor& self, 48 int64_t dim, 49 const Tensor& index, 50 const Tensor& src, 51 Tensor& out); 52 53 bool check_scatter_value_args( 54 const Tensor& self, 55 int64_t dim, 56 const Tensor& index, 57 const Scalar& value, 58 Tensor& out); 59 60 bool check_select_scatter_args( 61 const Tensor& in, 62 const Tensor& src, 63 int64_t dim, 64 int64_t index, 65 Tensor& output); 66 67 } // namespace executor 68 } // namespace torch 69