1 #pragma once
2
3 #include <torch/nn/options/embedding.h>
4
5 namespace torch {
6 namespace nn {
7 namespace functional {
8
9 inline Tensor one_hot(const Tensor& tensor, int64_t num_classes = -1) {
10 return torch::one_hot(tensor, num_classes);
11 }
12
13 #ifndef DOXYGEN_SHOULD_SKIP_THIS
14 namespace detail {
_no_grad_embedding_renorm_(Tensor weight,const Tensor & input,float max_norm,float norm_type)15 inline void _no_grad_embedding_renorm_(
16 Tensor weight,
17 const Tensor& input,
18 float max_norm,
19 float norm_type) {
20 torch::NoGradGuard no_grad;
21 torch::embedding_renorm_(weight, input, max_norm, norm_type);
22 }
23
embedding(const Tensor & input,const Tensor & weight,std::optional<int64_t> padding_idx,std::optional<double> max_norm,double norm_type,bool scale_grad_by_freq,bool sparse)24 inline Tensor embedding(
25 const Tensor& input,
26 const Tensor& weight,
27 std::optional<int64_t> padding_idx,
28 std::optional<double> max_norm,
29 double norm_type,
30 bool scale_grad_by_freq,
31 bool sparse) {
32 auto input_ = input;
33
34 if (padding_idx != std::nullopt) {
35 if (*padding_idx > 0) {
36 TORCH_CHECK(
37 *padding_idx < weight.size(0),
38 "Padding_idx must be within num_embeddings");
39 } else if (*padding_idx < 0) {
40 TORCH_CHECK(
41 *padding_idx >= -weight.size(0),
42 "Padding_idx must be within num_embedding");
43 padding_idx = weight.size(0) + *padding_idx;
44 }
45 } else {
46 padding_idx = -1;
47 }
48
49 if (max_norm != std::nullopt) {
50 input_ = input_.contiguous();
51 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
52 _no_grad_embedding_renorm_(weight, input_, *max_norm, norm_type);
53 }
54 return torch::embedding(
55 weight, input_, *padding_idx, scale_grad_by_freq, sparse);
56 }
57 } // namespace detail
58 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
59
60 /// See
61 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.embedding
62 /// about the exact behavior of this functional.
63 ///
64 /// See the documentation for `torch::nn::functional::EmbeddingFuncOptions`
65 /// class to learn what optional arguments are supported for this functional.
66 ///
67 /// Example:
68 /// ```
69 /// namespace F = torch::nn::functional;
70 /// F::embedding(input, weight,
71 /// F::EmbeddingFuncOptions().norm_type(2.5).scale_grad_by_freq(true).sparse(true));
72 /// ```
73 inline Tensor embedding(
74 const Tensor& input,
75 const Tensor& weight,
76 const EmbeddingFuncOptions& options = {}) {
77 return detail::embedding(
78 input,
79 weight,
80 options.padding_idx(),
81 options.max_norm(),
82 options.norm_type(),
83 options.scale_grad_by_freq(),
84 options.sparse());
85 }
86
87 #ifndef DOXYGEN_SHOULD_SKIP_THIS
88 namespace detail {
embedding_bag(const Tensor & input,const Tensor & weight,const Tensor & offsets,std::optional<double> max_norm,double norm_type,bool scale_grad_by_freq,EmbeddingBagMode mode,bool sparse,const Tensor & per_sample_weights,bool include_last_offset,std::optional<int64_t> padding_idx)89 inline Tensor embedding_bag(
90 const Tensor& input,
91 const Tensor& weight,
92 const Tensor& offsets,
93 std::optional<double> max_norm,
94 double norm_type,
95 bool scale_grad_by_freq,
96 EmbeddingBagMode mode,
97 bool sparse,
98 const Tensor& per_sample_weights,
99 bool include_last_offset,
100 std::optional<int64_t> padding_idx) {
101 auto input_ = input;
102 auto offsets_ = offsets;
103 auto per_sample_weights_ = per_sample_weights;
104 TORCH_CHECK(
105 !per_sample_weights_.defined() ||
106 input_.sizes() == per_sample_weights_.sizes(),
107 "embedding_bag: If per_sample_weights (",
108 per_sample_weights_.sizes(),
109 ") is not null, then it must have the same shape as the input (",
110 input_.sizes(),
111 ")");
112 if (input_.dim() == 2) {
113 TORCH_CHECK(
114 !offsets_.defined(),
115 "If input is 2D, then offsets has to be null, as input is treated is a mini-batch of fixed length sequences. However, found offsets of type Tensor");
116 offsets_ = torch::arange(
117 0,
118 input_.numel(),
119 input_.size(1),
120 torch::TensorOptions().dtype(torch::kLong).device(input_.device()));
121 input_ = input_.reshape(-1);
122 if (per_sample_weights_.defined()) {
123 per_sample_weights_ = per_sample_weights_.reshape(-1);
124 }
125 } else if (input_.dim() == 1) {
126 TORCH_CHECK(
127 offsets_.defined(), "offsets has to be a 1D Tensor but got null");
128 TORCH_CHECK(offsets_.dim() == 1, "offsets has to be a 1D Tensor");
129 } else {
130 TORCH_CHECK(
131 false,
132 "input has to be 1D or 2D Tensor, but got Tensor of dimension ",
133 input_.dim());
134 }
135
136 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
137 int mode_enum;
138 if (std::holds_alternative<enumtype::kSum>(mode)) {
139 mode_enum = 0;
140 } else if (std::holds_alternative<enumtype::kMean>(mode)) {
141 mode_enum = 1;
142 } else if (std::holds_alternative<enumtype::kMax>(mode)) {
143 mode_enum = 2;
144 TORCH_CHECK(
145 !scale_grad_by_freq,
146 "max mode does not support scaling the gradient by the frequency");
147 TORCH_CHECK(!sparse, "max mode does not support sparse weights");
148 } else {
149 TORCH_CHECK(false, "mode has to be one of sum, mean or max");
150 }
151
152 if (max_norm != std::nullopt) {
153 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
154 _no_grad_embedding_renorm_(weight, input_, *max_norm, norm_type);
155 }
156
157 TORCH_CHECK(
158 !per_sample_weights_.defined() || std::get_if<enumtype::kSum>(&mode),
159 "embedding_bag: per_sample_weights was not null. ",
160 "per_sample_weights is only supported for mode='kSum' (got mode='",
161 torch::enumtype::get_enum_name(mode),
162 "').Please open a feature request on GitHub.");
163
164 return std::get<0>(torch::embedding_bag(
165 weight,
166 input_,
167 offsets_,
168 scale_grad_by_freq,
169 mode_enum,
170 sparse,
171 per_sample_weights_,
172 include_last_offset,
173 padding_idx));
174 }
175 } // namespace detail
176 #endif /* DOXYGEN_SHOULD_SKIP_THIS */
177
178 /// See
179 /// https://pytorch.org/docs/main/nn.functional.html#torch.nn.functional.embedding_bag
180 /// about the exact behavior of this functional.
181 ///
182 /// See the documentation for `torch::nn::functional::EmbeddingBagFuncOptions`
183 /// class to learn what optional arguments are supported for this functional.
184 ///
185 /// Example:
186 /// ```
187 /// namespace F = torch::nn::functional;
188 /// F::embedding_bag(input, weight,
189 /// F::EmbeddingBagFuncOptions().mode(torch::kSum).offsets(offsets));
190 /// ```
191 inline Tensor embedding_bag(
192 const Tensor& input,
193 const Tensor& weight,
194 const EmbeddingBagFuncOptions& options = {}) {
195 return detail::embedding_bag(
196 input,
197 weight,
198 options.offsets(),
199 options.max_norm(),
200 options.norm_type(),
201 options.scale_grad_by_freq(),
202 options.mode(),
203 options.sparse(),
204 options.per_sample_weights(),
205 options.include_last_offset(),
206 options.padding_idx());
207 }
208
209 } // namespace functional
210 } // namespace nn
211 } // namespace torch
212