1 #pragma once 2 3 #include <torch/data/example.h> 4 #include <torch/data/transforms/base.h> 5 #include <torch/types.h> 6 7 #include <functional> 8 #include <utility> 9 10 namespace torch { 11 namespace data { 12 namespace transforms { 13 14 /// A `Transform` that is specialized for the typical `Example<Tensor, Tensor>` 15 /// combination. It exposes a single `operator()` interface hook (for 16 /// subclasses), and calls this function on input `Example` objects. 17 template <typename Target = Tensor> 18 class TensorTransform 19 : public Transform<Example<Tensor, Target>, Example<Tensor, Target>> { 20 public: 21 using E = Example<Tensor, Target>; 22 using typename Transform<E, E>::InputType; 23 using typename Transform<E, E>::OutputType; 24 25 /// Transforms a single input tensor to an output tensor. 26 virtual Tensor operator()(Tensor input) = 0; 27 28 /// Implementation of `Transform::apply` that calls `operator()`. apply(InputType input)29 OutputType apply(InputType input) override { 30 input.data = (*this)(std::move(input.data)); 31 return input; 32 } 33 }; 34 35 /// A `Lambda` specialized for the typical `Example<Tensor, Tensor>` input type. 36 template <typename Target = Tensor> 37 class TensorLambda : public TensorTransform<Target> { 38 public: 39 using FunctionType = std::function<Tensor(Tensor)>; 40 41 /// Creates a `TensorLambda` from the given `function`. TensorLambda(FunctionType function)42 explicit TensorLambda(FunctionType function) 43 : function_(std::move(function)) {} 44 45 /// Applies the user-provided functor to the input tensor. operator()46 Tensor operator()(Tensor input) override { 47 return function_(std::move(input)); 48 } 49 50 private: 51 FunctionType function_; 52 }; 53 54 /// Normalizes input tensors by subtracting the supplied mean and dividing by 55 /// the given standard deviation. 56 template <typename Target = Tensor> 57 struct Normalize : public TensorTransform<Target> { 58 /// Constructs a `Normalize` transform. The mean and standard deviation can be 59 /// anything that is broadcastable over the input tensors (like single 60 /// scalars). NormalizeNormalize61 Normalize(ArrayRef<double> mean, ArrayRef<double> stddev) 62 : mean(torch::tensor(mean, torch::kFloat32) 63 .unsqueeze(/*dim=*/1) 64 .unsqueeze(/*dim=*/2)), 65 stddev(torch::tensor(stddev, torch::kFloat32) 66 .unsqueeze(/*dim=*/1) 67 .unsqueeze(/*dim=*/2)) {} 68 operatorNormalize69 torch::Tensor operator()(Tensor input) override { 70 return input.sub(mean).div(stddev); 71 } 72 73 torch::Tensor mean, stddev; 74 }; 75 } // namespace transforms 76 } // namespace data 77 } // namespace torch 78