1 #pragma once 2 3 #include <torch/data/example.h> 4 #include <torch/data/transforms/lambda.h> 5 6 #include <vector> 7 8 namespace torch { 9 namespace data { 10 namespace transforms { 11 12 /// A `Collation` is a transform that reduces a batch into a single value. 13 /// The result is a `BatchDataset` that has the type of the single value as its 14 /// `BatchType`. 15 template <typename T, typename BatchType = std::vector<T>> 16 using Collation = BatchTransform<BatchType, T>; 17 18 /// A `Collate` allows passing a custom function to reduce/collate a batch 19 /// into a single value. It's effectively the lambda version of `Collation`, 20 /// which you could subclass and override `operator()` to achieve the same. 21 /// 22 /// \rst 23 /// .. code-block:: cpp 24 /// using namespace torch::data; 25 /// 26 /// auto dataset = datasets::MNIST("path/to/mnist") 27 /// .map(transforms::Collate<Example<>>([](std::vector<Example<>> e) { 28 /// return std::move(e.front()); 29 /// })); 30 /// \endrst 31 template <typename T, typename BatchType = std::vector<T>> 32 using Collate = BatchLambda<BatchType, T>; 33 } // namespace transforms 34 } // namespace data 35 } // namespace torch 36