Home
last modified time | relevance | path

Searched refs:tensor_batch_dim (Results 1 – 3 of 3) sorted by relevance

/aosp_15_r20/external/pytorch/aten/src/ATen/functorch/
H A DBatchRulesBinaryOps.cpp18 const Tensor& tensor, std::optional<int64_t> tensor_batch_dim, in _binary_pointwise_batch_rule() argument
23 tensor, tensor_batch_dim, other, other_batch_dim); in _binary_pointwise_batch_rule()
37 const Tensor& tensor, std::optional<int64_t> tensor_batch_dim, in apply()
41 tensor, tensor_batch_dim, other, other_batch_dim, in apply()
97 Tensor& tensor, std::optional<int64_t> tensor_batch_dim, in binary_pointwise_inplace_batch_rule() argument
100 if (!tensor_batch_dim && other_batch_dim) { in binary_pointwise_inplace_batch_rule()
105 auto tensor_logical_rank = rankWithoutBatchDim(tensor, tensor_batch_dim); in binary_pointwise_inplace_batch_rule()
109 auto tensor_ = moveBatchDimToFront(tensor, tensor_batch_dim); in binary_pointwise_inplace_batch_rule()
116 tensor_ = maybePadToLogicalRank(tensor_, tensor_batch_dim, max_logical_rank); in binary_pointwise_inplace_batch_rule()
124 const Tensor& tensor, std::optional<int64_t> tensor_batch_dim, in comparison_pointwise_batch_rule() argument
[all …]
H A DBatchRulesHelper.cpp172 const Tensor& tensor, std::optional<int64_t> tensor_batch_dim, in _binary_pointwise_helper() argument
176 auto tensor_logical_rank = rankWithoutBatchDim(tensor, tensor_batch_dim); in _binary_pointwise_helper()
180 auto tensor_ = moveBatchDimToFront(tensor, tensor_batch_dim); in _binary_pointwise_helper()
185 auto tensor_is_logical_scalar = (tensor_logical_rank == 0 && tensor_batch_dim.has_value()); in _binary_pointwise_helper()
199 tensor_ = maybePadToLogicalRank(tensor_, tensor_batch_dim, max_logical_rank); in _binary_pointwise_helper()
H A DBatchRulesHelper.h472 …const Tensor& tensor, std::optional<int64_t> tensor_batch_dim, const Tensor& other, std::optional<…