Home
last modified time | relevance | path

Searched defs:TensorDescriptorListParams (Results 1 – 2 of 2) sorted by relevance

/aosp_15_r20/external/pytorch/aten/src/ATen/native/miopen/
H A DRNN_miopen.cpp158 struct TensorDescriptorListParams { struct
159 IntArrayRef batch_sizes;
160 int64_t seq_length;
161 int64_t mini_batch;
163 int64_t input_size;
164 int64_t batch_sizes_sum;
166 bool is_input_packed() const { in is_input_packed()
170 void set(IntArrayRef input_sizes, IntArrayRef batch_sizes_, bool batch_first) { in set()
190 std::vector<TensorDescriptor> descriptors(Tensor x) const { in descriptors()
/aosp_15_r20/external/pytorch/aten/src/ATen/native/cudnn/
H A DRNN.cpp423 struct TensorDescriptorListParams { struct
424 IntArrayRef batch_sizes;
425 int64_t seq_length;
426 int64_t mini_batch;
431 int64_t input_size;
433 int64_t batch_sizes_sum; // == sum(batch_sizes)
435 bool is_input_packed() const { in is_input_packed()
439 void set( in set()
467 std::vector<TensorDescriptor> descriptors(Tensor x) const { in descriptors()
476 auto descriptors(Tensor x) const { in descriptors()