xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/data/transforms/collate.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/data/example.h>
4 #include <torch/data/transforms/lambda.h>
5 
6 #include <vector>
7 
8 namespace torch {
9 namespace data {
10 namespace transforms {
11 
12 /// A `Collation` is a transform that reduces a batch into a single value.
13 /// The result is a `BatchDataset` that has the type of the single value as its
14 /// `BatchType`.
15 template <typename T, typename BatchType = std::vector<T>>
16 using Collation = BatchTransform<BatchType, T>;
17 
18 /// A `Collate` allows passing a custom function to reduce/collate a batch
19 /// into a single value. It's effectively the lambda version of `Collation`,
20 /// which you could subclass and override `operator()` to achieve the same.
21 ///
22 /// \rst
23 /// .. code-block:: cpp
24 ///   using namespace torch::data;
25 ///
26 ///   auto dataset = datasets::MNIST("path/to/mnist")
27 ///     .map(transforms::Collate<Example<>>([](std::vector<Example<>> e) {
28 ///       return std::move(e.front());
29 ///     }));
30 /// \endrst
31 template <typename T, typename BatchType = std::vector<T>>
32 using Collate = BatchLambda<BatchType, T>;
33 } // namespace transforms
34 } // namespace data
35 } // namespace torch
36