Searched refs:batch_tensor2 (Results 1 – 6 of 6) sorted by relevance
/aosp_15_r20/external/pytorch/aten/src/ATen/native/ |
H A D | Distance.cpp | 115 SymIntArrayRef batch_tensor2(x2.sym_sizes().data(), dim2 - 2); in cdist_impl() local 116 std::vector<SymInt> expand_batch_portion = infer_size_symint(batch_tensor1, batch_tensor2); in cdist_impl() 200 IntArrayRef batch_tensor2(_x2.sizes().data(), dim2 - 2); in _cdist_backward() local 201 std::vector<int64_t> expand_batch_portion = infer_size(batch_tensor1, batch_tensor2); in _cdist_backward()
|
H A D | LinearAlgebra.cpp | 2108 const IntArrayRef batch_tensor2(tensor2.sizes().data(), in _matmul_impl() local 2113 if (dim_tensor1 == 3 && dim_tensor2 == 3 && batch_tensor1[0] != batch_tensor2[0]) { in _matmul_impl() 2117 if (batch_tensor2[0] == 1 && (tensor2.requires_grad() || isTensorSubclassLike(tensor2))) { in _matmul_impl() 2122 auto output_shape = infer_size_dimvector(batch_tensor1, batch_tensor2); in _matmul_impl()
|
/aosp_15_r20/external/pytorch/torch/jit/ |
H A D | _shape_functions.py | 618 batch_tensor2: List[int] = [] 621 batch_tensor2.append(tensor2[i]) 624 expand_batch_portion = broadcast(batch_tensor1, batch_tensor2)
|
/aosp_15_r20/external/pytorch/torch/_decomp/ |
H A D | decompositions.py | 4378 batch_tensor2: List[int] = [] 4381 batch_tensor2.append(tensor2.size(i)) 4388 and batch_tensor1[0] != batch_tensor2[0] 4392 if batch_tensor2[0] == 1 and tensor2.requires_grad: 4397 torch.broadcast_shapes(batch_tensor1, batch_tensor2)
|
/aosp_15_r20/external/pytorch/aten/src/ATen/native/mps/operations/ |
H A D | ReduceOps.mm | 1093 IntArrayRef batch_tensor2(x2.sizes().data(), dim2 - 2); 1094 std::vector<int64_t> expand_batch_portion = infer_size(batch_tensor1, batch_tensor2);
|
/aosp_15_r20/external/pytorch/torch/ |
H A D | _meta_registrations.py | 3319 batch_tensor2 = x2.shape[:-2] 3320 output_shape = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2)) 3332 batch_tensor2 = x2.shape[:-2] 3333 expand_batch_portion = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2))
|