Searched refs:tensor_batch_dim (Results 1 – 3 of 3) sorted by relevance
/aosp_15_r20/external/pytorch/aten/src/ATen/functorch/ |
H A D | BatchRulesBinaryOps.cpp | 18 const Tensor& tensor, std::optional<int64_t> tensor_batch_dim, in _binary_pointwise_batch_rule() argument 23 tensor, tensor_batch_dim, other, other_batch_dim); in _binary_pointwise_batch_rule() 37 const Tensor& tensor, std::optional<int64_t> tensor_batch_dim, in apply() 41 tensor, tensor_batch_dim, other, other_batch_dim, in apply() 97 Tensor& tensor, std::optional<int64_t> tensor_batch_dim, in binary_pointwise_inplace_batch_rule() argument 100 if (!tensor_batch_dim && other_batch_dim) { in binary_pointwise_inplace_batch_rule() 105 auto tensor_logical_rank = rankWithoutBatchDim(tensor, tensor_batch_dim); in binary_pointwise_inplace_batch_rule() 109 auto tensor_ = moveBatchDimToFront(tensor, tensor_batch_dim); in binary_pointwise_inplace_batch_rule() 116 tensor_ = maybePadToLogicalRank(tensor_, tensor_batch_dim, max_logical_rank); in binary_pointwise_inplace_batch_rule() 124 const Tensor& tensor, std::optional<int64_t> tensor_batch_dim, in comparison_pointwise_batch_rule() argument [all …]
|
H A D | BatchRulesHelper.cpp | 172 const Tensor& tensor, std::optional<int64_t> tensor_batch_dim, in _binary_pointwise_helper() argument 176 auto tensor_logical_rank = rankWithoutBatchDim(tensor, tensor_batch_dim); in _binary_pointwise_helper() 180 auto tensor_ = moveBatchDimToFront(tensor, tensor_batch_dim); in _binary_pointwise_helper() 185 auto tensor_is_logical_scalar = (tensor_logical_rank == 0 && tensor_batch_dim.has_value()); in _binary_pointwise_helper() 199 tensor_ = maybePadToLogicalRank(tensor_, tensor_batch_dim, max_logical_rank); in _binary_pointwise_helper()
|
H A D | BatchRulesHelper.h | 472 …const Tensor& tensor, std::optional<int64_t> tensor_batch_dim, const Tensor& other, std::optional<…
|