1 #pragma once 2 3 #include <torch/types.h> 4 5 #include <utility> 6 #include <vector> 7 8 namespace torch { 9 namespace data { 10 namespace transforms { 11 12 /// A transformation of a batch to a new batch. 13 template <typename InputBatch, typename OutputBatch> 14 class BatchTransform { 15 public: 16 using InputBatchType = InputBatch; 17 using OutputBatchType = OutputBatch; 18 19 virtual ~BatchTransform() = default; 20 21 /// Applies the transformation to the given `input_batch`. 22 virtual OutputBatch apply_batch(InputBatch input_batch) = 0; 23 }; 24 25 /// A transformation of individual input examples to individual output examples. 26 /// 27 /// Just like a `Dataset` is a `BatchDataset`, a `Transform` is a 28 /// `BatchTransform` that can operate on the level of individual examples rather 29 /// than entire batches. The batch-level transform is implemented (by default) 30 /// in terms of the example-level transform, though this can be customized. 31 template <typename Input, typename Output> 32 class Transform 33 : public BatchTransform<std::vector<Input>, std::vector<Output>> { 34 public: 35 using InputType = Input; 36 using OutputType = Output; 37 38 /// Applies the transformation to the given `input`. 39 virtual OutputType apply(InputType input) = 0; 40 41 /// Applies the `transformation` over the entire `input_batch`. apply_batch(std::vector<Input> input_batch)42 std::vector<Output> apply_batch(std::vector<Input> input_batch) override { 43 std::vector<Output> output_batch; 44 output_batch.reserve(input_batch.size()); 45 for (auto&& input : input_batch) { 46 output_batch.push_back(apply(std::move(input))); 47 } 48 return output_batch; 49 } 50 }; 51 } // namespace transforms 52 } // namespace data 53 } // namespace torch 54