xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/data/transforms/tensor.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/data/example.h>
4 #include <torch/data/transforms/base.h>
5 #include <torch/types.h>
6 
7 #include <functional>
8 #include <utility>
9 
10 namespace torch {
11 namespace data {
12 namespace transforms {
13 
14 /// A `Transform` that is specialized for the typical `Example<Tensor, Tensor>`
15 /// combination. It exposes a single `operator()` interface hook (for
16 /// subclasses), and calls this function on input `Example` objects.
17 template <typename Target = Tensor>
18 class TensorTransform
19     : public Transform<Example<Tensor, Target>, Example<Tensor, Target>> {
20  public:
21   using E = Example<Tensor, Target>;
22   using typename Transform<E, E>::InputType;
23   using typename Transform<E, E>::OutputType;
24 
25   /// Transforms a single input tensor to an output tensor.
26   virtual Tensor operator()(Tensor input) = 0;
27 
28   /// Implementation of `Transform::apply` that calls `operator()`.
apply(InputType input)29   OutputType apply(InputType input) override {
30     input.data = (*this)(std::move(input.data));
31     return input;
32   }
33 };
34 
35 /// A `Lambda` specialized for the typical `Example<Tensor, Tensor>` input type.
36 template <typename Target = Tensor>
37 class TensorLambda : public TensorTransform<Target> {
38  public:
39   using FunctionType = std::function<Tensor(Tensor)>;
40 
41   /// Creates a `TensorLambda` from the given `function`.
TensorLambda(FunctionType function)42   explicit TensorLambda(FunctionType function)
43       : function_(std::move(function)) {}
44 
45   /// Applies the user-provided functor to the input tensor.
operator()46   Tensor operator()(Tensor input) override {
47     return function_(std::move(input));
48   }
49 
50  private:
51   FunctionType function_;
52 };
53 
54 /// Normalizes input tensors by subtracting the supplied mean and dividing by
55 /// the given standard deviation.
56 template <typename Target = Tensor>
57 struct Normalize : public TensorTransform<Target> {
58   /// Constructs a `Normalize` transform. The mean and standard deviation can be
59   /// anything that is broadcastable over the input tensors (like single
60   /// scalars).
NormalizeNormalize61   Normalize(ArrayRef<double> mean, ArrayRef<double> stddev)
62       : mean(torch::tensor(mean, torch::kFloat32)
63                  .unsqueeze(/*dim=*/1)
64                  .unsqueeze(/*dim=*/2)),
65         stddev(torch::tensor(stddev, torch::kFloat32)
66                    .unsqueeze(/*dim=*/1)
67                    .unsqueeze(/*dim=*/2)) {}
68 
operatorNormalize69   torch::Tensor operator()(Tensor input) override {
70     return input.sub(mean).div(stddev);
71   }
72 
73   torch::Tensor mean, stddev;
74 };
75 } // namespace transforms
76 } // namespace data
77 } // namespace torch
78