xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/data/transforms/stack.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/data/example.h>
4 #include <torch/data/transforms/collate.h>
5 #include <torch/types.h>
6 
7 #include <utility>
8 #include <vector>
9 
10 namespace torch {
11 namespace data {
12 namespace transforms {
13 
14 template <typename T = Example<>>
15 struct Stack;
16 
17 /// A `Collation` for `Example<Tensor, Tensor>` types that stacks all data
18 /// tensors into one tensor, and all target (label) tensors into one tensor.
19 template <>
20 struct Stack<Example<>> : public Collation<Example<>> {
21   Example<> apply_batch(std::vector<Example<>> examples) override {
22     std::vector<torch::Tensor> data, targets;
23     data.reserve(examples.size());
24     targets.reserve(examples.size());
25     for (auto& example : examples) {
26       data.push_back(std::move(example.data));
27       targets.push_back(std::move(example.target));
28     }
29     return {torch::stack(data), torch::stack(targets)};
30   }
31 };
32 
33 /// A `Collation` for `Example<Tensor, NoTarget>` types that stacks all data
34 /// tensors into one tensor.
35 template <>
36 struct Stack<TensorExample>
37     : public Collation<Example<Tensor, example::NoTarget>> {
38   TensorExample apply_batch(std::vector<TensorExample> examples) override {
39     std::vector<torch::Tensor> data;
40     data.reserve(examples.size());
41     for (auto& example : examples) {
42       data.push_back(std::move(example.data));
43     }
44     return torch::stack(data);
45   }
46 };
47 } // namespace transforms
48 } // namespace data
49 } // namespace torch
50