Searched refs:tensor2_expand_size (Results 1 – 4 of 4) sorted by relevance
/aosp_15_r20/external/pytorch/aten/src/ATen/native/ |
H A D | Distance.cpp | 119 std::vector<SymInt> tensor2_expand_size(expand_batch_portion); in cdist_impl() local 120 tensor2_expand_size.insert(tensor2_expand_size.end(), {r2, c2}); in cdist_impl() 127 …Tensor tensor2_expanded = x2.expand_symint(tensor2_expand_size).contiguous().view_symint(tensor2_v… in cdist_impl() 204 std::vector<int64_t> tensor2_expand_size(expand_batch_portion); in _cdist_backward() local 205 tensor2_expand_size.insert(tensor2_expand_size.end(), {r2, c2}); in _cdist_backward() 220 if (tensor2_expand_size != x2.sizes()) { in _cdist_backward() 221 x2 = x2.expand(tensor2_expand_size); in _cdist_backward()
|
H A D | LinearAlgebra.cpp | 2134 const auto tensor2_expand_size = [&output_shape, m2, p, vector_rhs]{ in _matmul_impl() local 2143 auto tensor2_expanded = tensor2.expand(tensor2_expand_size); in _matmul_impl()
|
/aosp_15_r20/external/pytorch/aten/src/ATen/native/mps/operations/ |
H A D | ReduceOps.mm | 1097 std::vector<int64_t> tensor2_expand_size(expand_batch_portion); 1098 tensor2_expand_size.insert(tensor2_expand_size.end(), {r2, c2}); 1119 toShape:getMPSShape(tensor2_expand_size)
|
/aosp_15_r20/external/pytorch/torch/_decomp/ |
H A D | decompositions.py | 4411 tensor2_expand_size = expand_batch_portion + [m2] 4413 tensor2.expand(tensor2_expand_size) 4418 tensor2_expand_size = expand_batch_portion + [m2, p] 4419 tensor2_expanded = tensor2.expand(tensor2_expand_size).reshape(
|