Home
last modified time | relevance | path

Searched full:plain_indices (Results 1 – 25 of 28) sorted by relevance

12

/aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/
H A DSparseCsrTensor.cpp128 …ed_tensor_args_worker(const Tensor& compressed_indices, const Tensor& plain_indices, const Tensor&… in _validate_sparse_compressed_tensor_args_worker() argument
149 TORCH_CHECK(plain_indices.layout() == kStrided, in _validate_sparse_compressed_tensor_args_worker()
150 …"expected ", plain_indices_name, " to be a strided tensor but got ", plain_indices.layout(), " ten… in _validate_sparse_compressed_tensor_args_worker()
180 compressed_indices.dim() == plain_indices.dim(), in _validate_sparse_compressed_tensor_args_worker()
182 compressed_indices.dim(), " and ", plain_indices.dim(), ", respectively"); in _validate_sparse_compressed_tensor_args_worker()
191 TORCH_CHECK(plain_indices.stride(-1) == 1, in _validate_sparse_compressed_tensor_args_worker()
217 DimVector plain_indices_batchsize = DimVector(plain_indices.sizes().slice(0, batch_ndim)); in _validate_sparse_compressed_tensor_args_worker()
247 plain_indices.size(-1) == values_nnz, in _validate_sparse_compressed_tensor_args_worker()
249 ") as defined by values.shape[", batch_ndim, "], but got ", plain_indices.size(-1)); in _validate_sparse_compressed_tensor_args_worker()
252 auto plain_indices_type = plain_indices.scalar_type(); in _validate_sparse_compressed_tensor_args_worker()
[all …]
H A DSparseCsrTensorMath.cpp226 auto [compressed_indices, plain_indices] = getCompressedPlainIndices(sparse); in intersection_binary_op_with_wrapped_scalar()
229 plain_indices.clone(), in intersection_binary_op_with_wrapped_scalar()
302 auto plain_indices = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(input.layout(), in get_result_tensor_for_unary_op() local
309 plain_indices.clone(), in get_result_tensor_for_unary_op()
351 auto [compressed_indices, plain_indices] = at::sparse_csr::getCompressedPlainIndices(mask); in sparse_mask_sparse_compressed()
355 plain_indices, in sparse_mask_sparse_compressed()
H A DSparseBlasImpl.cpp178 auto [compressed_indices, plain_indices] = at::sparse_csr::getCompressedPlainIndices(compressed); in _compressed_row_strided_mm_out()
181 auto strided_tiled_selected_rows = strided_tiled.index_select(-4, plain_indices); in _compressed_row_strided_mm_out()
199 plain_indices, in _compressed_row_strided_mm_out()
H A DValidateCompressedIndicesCommon.h39 // use `cidx/idx` to refer to `compressed_indices/plain_indices` respectively.
117 // plain_indices[..., compressed_indices[..., i - 1]:compressed_indices[..., i]]
/aosp_15_r20/external/pytorch/test/
H A Dtest_sparse_csr.py249 … for (compressed_indices, plain_indices, values), kwargs in self.generate_simple_inputs(
260 # max(plain_indices) is undefined if
261 # plain_indices has no values
264 plain_indices_expect = plain_indices
269 plain_indices = plain_indices.tolist()
276 … compressed_indices, plain_indices, values, requires_grad=requires_grad)
279 compressed_indices, plain_indices, values, size,
284 compressed_indices, plain_indices, values,
288 compressed_indices, plain_indices, values, size,
350 compressed_indices, plain_indices = s.crow_indices(), s.col_indices()
[all …]
H A Dtest_sparse.py4280 plain_indices = torch.empty((*batchsize, nnz), device='meta', dtype=index_dtype)
4285 plain_indices,
4342 plain_indices = torch.empty((*batchsize, nnz), device='meta', dtype=index_dtype)
4345 plain_indices,
4911 compressed_indices, plain_indices = r.crow_indices(), r.col_indices()
4913 compressed_indices, plain_indices = r.ccol_indices(), r.row_indices()
4914 … torch._validate_sparse_compressed_tensor_args(compressed_indices, plain_indices, r.values(),
4918 self.assertEqual(plain_indices.dtype, torch.int64)
4921 self.assertEqual(plain_indices.dtype, index_dtype)
4971 plain_indices = torch.tensor([0, 1, 2, 3] * 5, dtype=index_dtype, device=device)
[all …]
/aosp_15_r20/external/pytorch/aten/src/ATen/native/
H A DTensorConversions.cpp178 // plain_indices it can be done with reshape/unflatten.
183 Tensor& plain_indices, in reshape_2d_sparse_compressed_members_to_nd_batched() argument
200 plain_indices = plain_indices.reshape(batchsize_infer_last); in reshape_2d_sparse_compressed_members_to_nd_batched()
301 auto [compressed_indices, plain_indices] = at::sparse_csr::getCompressedPlainIndices(self); in _to_copy()
324 plain_indices, in _to_copy()
325 plain_indices.scalar_type(), in _to_copy()
648 auto [compressed_indices, plain_indices] = in sparse_compressed_to_dense()
658 plain_indices.unsqueeze_(0); in sparse_compressed_to_dense()
665 plain_indices = plain_indices.flatten(0, batch_ndim - 1); in sparse_compressed_to_dense()
707 plain_indices.flatten(), in sparse_compressed_to_dense()
[all …]
H A DTensorFactories.cpp1293 auto plain_indices = at::empty(plain_indices_and_values_size, options); in zeros_sparse_compressed_symint() local
1297 plain_indices, in zeros_sparse_compressed_symint()
1406 auto [compressed_indices, plain_indices] = at::sparse_csr::getCompressedPlainIndices(res); in zeros_like()
/aosp_15_r20/external/pytorch/aten/src/ATen/
H A DSparseCsrTensorImpl.cpp173 auto [compressed_indices, plain_indices] = in resize_as_sparse_compressed_tensor_()
179 if (col_indices_.sizes() != plain_indices.sizes()) { in resize_as_sparse_compressed_tensor_()
180 col_indices_.resize_as_(plain_indices); in resize_as_sparse_compressed_tensor_()
185 col_indices_.copy_(plain_indices); in resize_as_sparse_compressed_tensor_()
H A DSparseCsrTensorUtils.h367 auto [compressed_indices, plain_indices] = in only_sparse_compressed_binary_op_trivial_cases()
372 plain_indices, in only_sparse_compressed_binary_op_trivial_cases()
399 auto [compressed_indices, plain_indices] = in to_type()
403 plain_indices, in to_type()
H A DSparseCsrTensorImpl.h59 const Tensor& plain_indices() const { in plain_indices() function
201 dest_sparse_impl->col_indices_ = src_sparse_impl->plain_indices(); in copy_tensor_metadata()
/aosp_15_r20/external/pytorch/torch/multiprocessing/
H A Dreductions.py471 plain_indices = rebuild_plain_indices_func(*rebuild_plain_indices_args)
474 compressed_indices, plain_indices, values, shape, layout=layout
496 plain_indices = sparse.col_indices()
499 plain_indices = sparse.row_indices()
507 plain_indices
/aosp_15_r20/external/pytorch/torch/csrc/utils/
H A Dtensor_new.cpp976 Tensor plain_indices = internal_new_from_data( in sparse_compressed_tensor_ctor_worker() local
999 plain_indices, in sparse_compressed_tensor_ctor_worker()
1032 Tensor plain_indices = internal_new_from_data( in sparse_compressed_tensor_ctor_worker() local
1055 plain_indices, in sparse_compressed_tensor_ctor_worker()
1309 …"_validate_sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObjec… in _validate_sparse_compressed_tensor_args()
1331 Tensor plain_indices = internal_new_from_data( in _validate_sparse_compressed_tensor_args() local
1341 plain_indices, in _validate_sparse_compressed_tensor_args()
1402 Tensor plain_indices = internal_new_from_data( in _validate_sparse_compressed_tensor_args_template() local
1412 compressed_indices, plain_indices, values, r.intlist(3), required_layout); in _validate_sparse_compressed_tensor_args_template()
/aosp_15_r20/external/pytorch/torch/
H A D_utils.py266 compressed_indices, plain_indices = (
271 compressed_indices, plain_indices = (
276 compressed_indices, plain_indices, t.values(), t.size(), t.layout
314 compressed_indices, plain_indices, values, size = data
317 plain_indices,
H A D_tensor_str.py525 plain_indices = plain_indices_method(self).detach()
530 plain_indices, indent + len(plain_indices_prefix)
532 if plain_indices.numel() == 0 or is_meta:
533 plain_indices_str += ", size=" + str(tuple(plain_indices.shape))
H A D_tensor.py381 compressed_indices, plain_indices = (
386 compressed_indices, plain_indices = (
394 plain_indices,
H A D_torch_docs.py9901 r"""sparse_compressed_tensor(compressed_indices, plain_indices, values, size=None, """
9906 the given :attr:`compressed_indices` and :attr:`plain_indices`. Sparse
9918 tensor encodes the index in ``values`` and ``plain_indices``
9923 plain_indices (array_like): Plain dimension (column or row)
9959 >>> plain_indices = [0, 1, 0, 1]
9962 ... torch.tensor(plain_indices, dtype=torch.int64),
/aosp_15_r20/external/pytorch/torch/sparse/
H A D__init__.py626 plain_indices=obj.col_indices(),
632 plain_indices=obj.row_indices(),
660 d["plain_indices"],
/aosp_15_r20/external/pytorch/torch/csrc/autograd/
H A Dpython_torch_functions_manual.cpp195 …({"sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* value…
196 …"sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values,…
H A Dinput_buffer.cpp51 impl->plain_indices().storage().data_ptr(), stream); in record_stream_any_impl()
/aosp_15_r20/external/pytorch/docs/source/
H A Dsparse.rst1095 >>> plain_indices = torch.tensor([0, 1, 0, 1])
1097 …>>> csr = torch.sparse_compressed_tensor(compressed_indices, plain_indices, values, layout=torch.s…
1103 …>>> csc = torch.sparse_compressed_tensor(compressed_indices, plain_indices, values, layout=torch.s…
/aosp_15_r20/external/pytorch/torch/_dynamo/
H A Dutils.py823 plain_indices = x.col_indices()
826 plain_indices = x.row_indices()
829 torch_clone(plain_indices),
/aosp_15_r20/external/pytorch/torch/autograd/
H A Dgradcheck.py1678 compressed_indices, plain_indices = x.crow_indices(), x.col_indices()
1680 compressed_indices, plain_indices = x.ccol_indices(), x.row_indices()
1691 plain_indices,
/aosp_15_r20/external/pytorch/torch/testing/_internal/
H A Dcommon_utils.py3317 plain_indices = torch.zeros(nnz, dtype=index_dtype, device=device)
3320 plain_indices[compressed_indices[i]:compressed_indices[i + 1]], _ = torch.sort(
3325 return values, compressed_indices, plain_indices
3340 plain_indices = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1)
3341 return torch.sparse_compressed_tensor(compressed_indices, plain_indices,
3419 …(compressed_indices, plain_indices, values), dict(size=expected_size_from_shape_inference, device=…
/aosp_15_r20/external/pytorch/tools/pyi/
H A Dgen_pyi.py767 "plain_indices: Union[Tensor, List]",

12