1 #pragma once 2 3 #include <torch/data/example.h> 4 #include <torch/data/transforms/collate.h> 5 #include <torch/types.h> 6 7 #include <utility> 8 #include <vector> 9 10 namespace torch { 11 namespace data { 12 namespace transforms { 13 14 template <typename T = Example<>> 15 struct Stack; 16 17 /// A `Collation` for `Example<Tensor, Tensor>` types that stacks all data 18 /// tensors into one tensor, and all target (label) tensors into one tensor. 19 template <> 20 struct Stack<Example<>> : public Collation<Example<>> { 21 Example<> apply_batch(std::vector<Example<>> examples) override { 22 std::vector<torch::Tensor> data, targets; 23 data.reserve(examples.size()); 24 targets.reserve(examples.size()); 25 for (auto& example : examples) { 26 data.push_back(std::move(example.data)); 27 targets.push_back(std::move(example.target)); 28 } 29 return {torch::stack(data), torch::stack(targets)}; 30 } 31 }; 32 33 /// A `Collation` for `Example<Tensor, NoTarget>` types that stacks all data 34 /// tensors into one tensor. 35 template <> 36 struct Stack<TensorExample> 37 : public Collation<Example<Tensor, example::NoTarget>> { 38 TensorExample apply_batch(std::vector<TensorExample> examples) override { 39 std::vector<torch::Tensor> data; 40 data.reserve(examples.size()); 41 for (auto& example : examples) { 42 data.push_back(std::move(example.data)); 43 } 44 return torch::stack(data); 45 } 46 }; 47 } // namespace transforms 48 } // namespace data 49 } // namespace torch 50