xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/data/transforms/base.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/types.h>
4 
5 #include <utility>
6 #include <vector>
7 
8 namespace torch {
9 namespace data {
10 namespace transforms {
11 
12 /// A transformation of a batch to a new batch.
13 template <typename InputBatch, typename OutputBatch>
14 class BatchTransform {
15  public:
16   using InputBatchType = InputBatch;
17   using OutputBatchType = OutputBatch;
18 
19   virtual ~BatchTransform() = default;
20 
21   /// Applies the transformation to the given `input_batch`.
22   virtual OutputBatch apply_batch(InputBatch input_batch) = 0;
23 };
24 
25 /// A transformation of individual input examples to individual output examples.
26 ///
27 /// Just like a `Dataset` is a `BatchDataset`, a `Transform` is a
28 /// `BatchTransform` that can operate on the level of individual examples rather
29 /// than entire batches. The batch-level transform is implemented (by default)
30 /// in terms of the example-level transform, though this can be customized.
31 template <typename Input, typename Output>
32 class Transform
33     : public BatchTransform<std::vector<Input>, std::vector<Output>> {
34  public:
35   using InputType = Input;
36   using OutputType = Output;
37 
38   /// Applies the transformation to the given `input`.
39   virtual OutputType apply(InputType input) = 0;
40 
41   /// Applies the `transformation` over the entire `input_batch`.
apply_batch(std::vector<Input> input_batch)42   std::vector<Output> apply_batch(std::vector<Input> input_batch) override {
43     std::vector<Output> output_batch;
44     output_batch.reserve(input_batch.size());
45     for (auto&& input : input_batch) {
46       output_batch.push_back(apply(std::move(input)));
47     }
48     return output_batch;
49   }
50 };
51 } // namespace transforms
52 } // namespace data
53 } // namespace torch
54