xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tpu/kernels/tpu_embedding_load_retrieve_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/tpu/kernels/tpu_embedding_load_retrieve_ops.h"
17 
18 #include <stddef.h>
19 
20 #include <array>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 
25 #include "absl/cleanup/cleanup.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_shape.pb.h"
30 #include "tensorflow/core/platform/types.h"
31 #include "tensorflow/core/protobuf/tpu/optimization_parameters.pb.h"
32 #include "tensorflow/core/tpu/ops/tpu_embedding_shape_util.h"
33 #include "tensorflow/core/tpu/tpu_api.h"
34 #include "tensorflow/core/tpu/tpu_embedding_configuration_proto_rewrite.h"
35 #include "tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h"
36 #include "tensorflow/stream_executor/tpu/c_api_conversions.h"
37 #include "tensorflow/stream_executor/tpu/c_api_decl.h"
38 #include "tensorflow/stream_executor/tpu/status_helper.h"
39 
40 
41 using tensorflow::tpu::TPUEmbeddingConfiguration;
42 using tensorflow::tpu::TpuEmbeddingShapeUtil;
43 
44 
45 namespace tensorflow {
46 
47 // Computes (and VLOGs) the expected shapes of the embedding table shards.
ComputeExpectedTableShardShapes(const TPUEmbeddingConfiguration & config,int shard_id,int num_shards,const string & op_name,std::vector<TensorShape> * table_shapes)48 Status ComputeExpectedTableShardShapes(const TPUEmbeddingConfiguration& config,
49                                        int shard_id, int num_shards,
50                                        const string& op_name,
51                                        std::vector<TensorShape>* table_shapes) {
52   std::vector<TensorShapeProto> shape_protos;
53   const int num_tables = config.table_descriptor_size();
54   TF_RETURN_IF_ERROR(TpuEmbeddingShapeUtil::ComputeTableShapes(
55       config, shard_id, num_shards, &shape_protos));
56   if (num_tables != shape_protos.size()) {
57     return errors::InvalidArgument(
58         op_name, ": The size of the shape_protos vector ", shape_protos.size(),
59         " must be the same as the number of tables ", num_tables);
60   }
61   for (int table_id = 0; table_id < num_tables; ++table_id) {
62     const TensorShape& shape = TensorShape(shape_protos[table_id]);
63     table_shapes->push_back(shape);
64 
65     const auto& table_descriptor = config.table_descriptor(table_id);
66     VLOG(1) << "Table " << table_id << " (name " << table_descriptor.name()
67             << ") has shape: " << shape.DebugString()
68             << " on shard: " << shard_id << " (of " << num_shards << ").";
69   }
70 
71   return OkStatus();
72 }
73 
74 // Logs min/max/avg for the specified state_variable array.
LogRangeStatistics(int32 table_id,int32 state_variable_index,absl::Span<const float> state_variable)75 void LogRangeStatistics(int32 table_id, int32 state_variable_index,
76                         absl::Span<const float> state_variable) {
77   if (VLOG_IS_ON(5)) {
78     float min = std::numeric_limits<float>::infinity();
79     float max = -std::numeric_limits<float>::infinity();
80     double avg = 0.0;
81     for (int elt = 0; elt < state_variable.size(); ++elt) {
82       if (state_variable[elt] < min) min = state_variable[elt];
83       if (state_variable[elt] > max) max = state_variable[elt];
84       avg += state_variable[elt];
85     }
86     LOG(INFO) << "Table " << table_id << " state_variable "
87               << state_variable_index << " min " << min << " max " << max
88               << " avg " << avg / state_variable.size() << " total elts "
89               << state_variable.size();
90   }
91 }
92 
93 
LoadAllTPUEmbeddingParametersOp(OpKernelConstruction * ctx)94 LoadAllTPUEmbeddingParametersOp::LoadAllTPUEmbeddingParametersOp(
95     OpKernelConstruction* ctx)
96     : OpKernel(ctx) {
97   string config_string;
98   OP_REQUIRES_OK(ctx, ctx->GetAttr("config", &config_string));
99 
100   OP_REQUIRES(
101       ctx, config_.ParseFromString(config_string),
102       errors::InvalidArgument("LoadAllTPUEmbeddingParametersOp: Failed to "
103                               "parse TPUEmbeddingConfiguration "
104                               "proto from config attr"));
105   // Auto populate the feature descriptor
106   // TODO (b/201806244): remove this logic after the change to the
107   // initialization to the config proto.
108   OP_REQUIRES_OK(ctx, PopulateMissingFieldsInTPUEmbeddingConfig(&config_));
109 
110   int num_shards;
111   OP_REQUIRES_OK(ctx, ctx->GetAttr("num_shards", &num_shards));
112   int shard_id;
113   OP_REQUIRES_OK(ctx, ctx->GetAttr("shard_id", &shard_id));
114 
115   OP_REQUIRES_OK(ctx, ComputeExpectedTableShardShapes(
116                           config_, shard_id, num_shards,
117                           "LoadAllTPUEmbeddingParametersOp", &table_shapes_));
118 }
119 
GetStateVariables(OpKernelContext * ctx,std::array<std::vector<absl::Span<const float>>,tpu::kMaxAuxiliaryParameterCount+1> & state_variable_vector)120 void LoadAllTPUEmbeddingParametersOp::GetStateVariables(
121     OpKernelContext* ctx,
122     std::array<std::vector<absl::Span<const float>>,
123                tpu::kMaxAuxiliaryParameterCount + 1> &state_variable_vector) {
124     std::array<OpInputList, tpu::kMaxAuxiliaryParameterCount + 1>
125       state_variable;
126   OP_REQUIRES_OK(ctx, ctx->input_list("parameters", &state_variable[0]));
127   for (int i = 1; i <= tpu::kMaxAuxiliaryParameterCount; ++i) {
128     OP_REQUIRES_OK(ctx, ctx->input_list(absl::StrCat("auxiliary", i),
129                                         &state_variable[i]));
130   }
131   const int num_tables = state_variable[0].size();
132   // This should be enforced by Tensorflow's type system.
133   for (int i = 1; i <= tpu::kMaxAuxiliaryParameterCount; ++i) {
134     CHECK_EQ(num_tables, state_variable[i].size());
135   }
136 
137   OP_REQUIRES(ctx, num_tables == table_shapes_.size(),
138               errors::InvalidArgument(
139                   "LoadAllTPUEmbeddingParametersOp has ", num_tables,
140                   " inputs in lists but config specifies ",
141                   table_shapes_.size(), " embedding tables."));
142 
143   CHECK_EQ(num_tables, config_.table_descriptor_size());
144   for (int table_id = 0; table_id < num_tables; ++table_id) {
145     const auto& table_descriptor = config_.table_descriptor(table_id);
146     std::vector<tpu::StateVariableSpecification> state_variable_specs;
147     Status status = tpu::GetOptimizationAlgorithmStateVariables(
148         table_descriptor.optimization_parameters(), &state_variable_specs);
149     OP_REQUIRES(ctx, status.ok(),
150                 errors::InvalidArgument(
151                     "LoadAllTPUEmbeddingParametersOp: No optimization "
152                     "algorithm specified for table ",
153                     table_id, " (named ", table_descriptor.name(), ")"));
154     OP_REQUIRES(
155         ctx, state_variable[0][table_id].shape() == table_shapes_[table_id],
156         errors::InvalidArgument(
157             "LoadAllTPUEmbeddingParametersOp: Embeddings for table ",
158             table_id, " (named ", table_descriptor.name(), ") has shape ",
159             state_variable[0][table_id].shape().DebugString(),
160             " but config specifies table shape ",
161             table_shapes_[table_id].DebugString()));
162     for (int i = 1; i < state_variable_specs.size(); ++i) {
163       OP_REQUIRES(
164           ctx, state_variable[i][table_id].shape() == table_shapes_[table_id],
165           errors::InvalidArgument(
166               "LoadAllTPUEmbeddingParametersOp: Auxiliary ", i - 1,
167               " for table ", table_id, " has shape ",
168               state_variable[i][table_id].shape().DebugString(),
169               " but config specifies table shape ",
170               table_shapes_[table_id].DebugString()));
171     }
172     const int64 num_elements = state_variable[0][table_id].NumElements();
173     VLOG(1) << "Table " << table_id << " (name " << table_descriptor.name()
174             << ") has shape: " << table_shapes_[table_id].DebugString()
175             << ", number of elements: " << num_elements;
176     for (int i = 0; i < state_variable_specs.size(); ++i) {
177       OP_REQUIRES(
178           ctx, state_variable[i][table_id].NumElements() == num_elements,
179           errors::InvalidArgument(
180               "LoadAllTPUEmbeddingParametersOp: Embeddings/auxiliary ", i,
181               " for table ", table_id, " has element count ",
182               state_variable[i][table_id].NumElements(),
183               " but config requires count ", num_elements));
184       const float* state_variable_i_ptr =
185           state_variable[i][table_id].flat<float>().data();
186       state_variable_vector[i].push_back(absl::MakeConstSpan(
187           state_variable_i_ptr, static_cast<size_t>(num_elements)));
188       LogRangeStatistics(
189           table_id, i,
190           absl::MakeConstSpan(state_variable_i_ptr, num_elements));
191     }
192     for (int i = state_variable_specs.size();
193          i <= tpu::kMaxAuxiliaryParameterCount; ++i) {
194       OP_REQUIRES(ctx, state_variable[i][table_id].NumElements() == 0,
195                   errors::InvalidArgument(
196                       "LoadAllTPUEmbeddingParametersOp: Auxiliary ", i,
197                       " for table ", table_id, " has element count ",
198                       state_variable[i][table_id].NumElements(),
199                       " but config requires empty tensor"));
200       state_variable_vector[i].push_back(absl::Span<const float>());
201     }
202   }
203 }
204 
Compute(OpKernelContext * ctx)205 void LoadAllTPUEmbeddingParametersOp::Compute(OpKernelContext* ctx) {
206     VLOG(1) << "LoadAllTPUEmbeddingParameters::Compute";
207 
208   std::array<std::vector<absl::Span<const float>>,
209              tpu::kMaxAuxiliaryParameterCount + 1> state_variable_vector;
210 
211   GetStateVariables(ctx, state_variable_vector);
212   const int num_tables = state_variable_vector[0].size();
213 
214   std::unique_ptr<ApiConverter::TpuEmbeddingEngineParametersData> params =
215     ApiConverter::Create(num_tables);
216   std::array<std::vector<FloatListRef>,
217              tpu::kMaxAuxiliaryParameterCount + 1> params_data;
218   for (size_t i = 0; i < tpu::kMaxAuxiliaryParameterCount + 1; i++) {
219     params_data[i] = std::vector<FloatListRef>(num_tables);
220     for (size_t table_id = 0; table_id < num_tables; table_id++) {
221       params->c_params.parameters[i][table_id] = &(params_data[i][table_id]);
222       params->c_params.parameters[i][table_id]->size =
223           state_variable_vector[i][table_id].size();
224       params->c_params.parameters[i][table_id]->ptr =
225           const_cast<float*>(state_variable_vector[i][table_id].data());
226     }
227   }
228   StatusHelper status;
229   tpu::OpsApiFn()->TpuEmbeddingEngine_WriteParametersFn(&(params->c_params),
230                                                         status.c_status);
231   OP_REQUIRES_OK(ctx, status.status());
232 
233   VLOG(1) << "LoadAllTPUEmbeddingParameters::Compute done";
234 }
235 
236 
RetrieveAllTPUEmbeddingParametersOp(OpKernelConstruction * ctx)237 RetrieveAllTPUEmbeddingParametersOp::RetrieveAllTPUEmbeddingParametersOp(
238     OpKernelConstruction* ctx)
239       : OpKernel(ctx) {
240   string config_string;
241   OP_REQUIRES_OK(ctx, ctx->GetAttr("config", &config_string));
242 
243   OP_REQUIRES(
244       ctx, config_.ParseFromString(config_string),
245       errors::InvalidArgument("Failed to parse TPUEmbeddingConfiguration "
246                               "proto from config attr"));
247 
248   // Auto populate the feature descriptor
249   // TODO (b/201806244): remove this logic after the change to the
250   // initialization to the config proto.
251   OP_REQUIRES_OK(ctx, PopulateMissingFieldsInTPUEmbeddingConfig(&config_));
252 
253   int num_shards;
254   OP_REQUIRES_OK(ctx, ctx->GetAttr("num_shards", &num_shards));
255   int shard_id;
256   OP_REQUIRES_OK(ctx, ctx->GetAttr("shard_id", &shard_id));
257 
258   OP_REQUIRES_OK(ctx,
259                  ComputeExpectedTableShardShapes(
260                      config_, shard_id, num_shards,
261                      "RetrieveAllTPUEmbeddingParametersOp", &table_shapes_));
262 }
263 
GetStateVariables(OpKernelContext * ctx,std::array<std::vector<absl::Span<float>>,tpu::kMaxAuxiliaryParameterCount+1> & state_variable_vector,std::vector<int> & num_state_variables)264 void RetrieveAllTPUEmbeddingParametersOp::GetStateVariables(
265     OpKernelContext* ctx,
266     std::array<std::vector<absl::Span<float>>,
267                tpu::kMaxAuxiliaryParameterCount + 1> &state_variable_vector,
268     std::vector<int> & num_state_variables) {
269   std::array<OpOutputList, tpu::kMaxAuxiliaryParameterCount + 1>
270       state_variable;
271   OP_REQUIRES_OK(ctx, ctx->output_list("parameters", &state_variable[0]));
272   for (int i = 1; i <= tpu::kMaxAuxiliaryParameterCount; ++i) {
273     OP_REQUIRES_OK(ctx, ctx->output_list(absl::StrCat("auxiliary", i),
274                                          &state_variable[i]));
275   }
276   const int num_tables = state_variable[0].size();
277   // This should be enforced by Tensorflow's type system.
278   for (int i = 1; i <= tpu::kMaxAuxiliaryParameterCount; ++i) {
279     CHECK_EQ(num_tables, state_variable[i].size());
280   }
281 
282   OP_REQUIRES(ctx, num_tables == table_shapes_.size(),
283               errors::InvalidArgument(
284                   "RetrieveAllTPUEmbeddingParametersOp has ", num_tables,
285                   " outputs in lists but config specifies ",
286                   table_shapes_.size(), " embedding tables."));
287 
288   for (auto& v : state_variable_vector) {
289     v.resize(num_tables);
290   }
291   num_state_variables.resize(num_tables);
292 
293   // Get locations to write returned data
294   for (int table_id = 0; table_id < num_tables; ++table_id) {
295     const auto& table_descriptor = config_.table_descriptor(table_id);
296 
297     std::vector<tpu::StateVariableSpecification> state_variable_specs;
298     Status status = tpu::GetOptimizationAlgorithmStateVariables(
299         table_descriptor.optimization_parameters(), &state_variable_specs);
300     OP_REQUIRES(
301         ctx, status.ok(),
302         errors::InvalidArgument("RetrieveAllTPUEmbeddingParametersOp: No "
303                                 "optimization algorithm specified for table ",
304                                 table_id));
305     num_state_variables[table_id] = state_variable_specs.size();
306     const int64 num_elements = table_shapes_[table_id].num_elements();
307     for (int i = 0; i < state_variable_specs.size(); ++i) {
308       Tensor* state_variable_tensor;
309       OP_REQUIRES_OK(
310           ctx, state_variable[i].allocate(table_id, table_shapes_[table_id],
311                                           &state_variable_tensor));
312       float* state_variable_ptr = state_variable_tensor->flat<float>().data();
313       state_variable_vector[i][table_id] =
314           absl::MakeSpan(state_variable_ptr, num_elements);
315     }
316     // Fill in auxiliary values after the number actually used for table_id
317     // with empty 2-D tensors.
318     for (int i = state_variable_specs.size();
319          i <= tpu::kMaxAuxiliaryParameterCount; ++i) {
320       Tensor* auxiliary_tensor;
321       TensorShape shape;
322       std::array<int32, 2> dims = {{0, 0}};
323       OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(dims, &shape));
324       OP_REQUIRES_OK(ctx, state_variable[i].allocate(table_id, shape,
325                                                      &auxiliary_tensor));
326       state_variable_vector[i][table_id] = absl::Span<float>();
327     }
328   }
329 }
330 
Compute(OpKernelContext * ctx)331 void RetrieveAllTPUEmbeddingParametersOp::Compute(OpKernelContext* ctx) {
332   VLOG(1) << "RetrieveAllTPUEmbeddingParameters::Compute";
333 
334   std::array<std::vector<absl::Span<float>>,
335              tpu::kMaxAuxiliaryParameterCount + 1> state_variable_vector;
336   std::vector<int> num_state_variables;
337 
338   GetStateVariables(ctx, state_variable_vector, num_state_variables);
339   const int num_tables = state_variable_vector[0].size();
340 
341 
342   std::unique_ptr<ApiConverter::TpuEmbeddingEngineParametersData> params =
343       ApiConverter::Create(num_tables);
344   std::array<std::vector<FloatListRef>,
345              tpu::kMaxAuxiliaryParameterCount + 1> params_data;
346   for (size_t i = 0; i < tpu::kMaxAuxiliaryParameterCount + 1; i++) {
347     params_data[i] = std::vector<FloatListRef>(num_tables);
348     for (size_t table_id = 0; table_id < num_tables; table_id++) {
349       params->c_params.parameters[i][table_id] =
350           &(params_data[i][table_id]);
351       params->c_params.parameters[i][table_id]->size =
352           state_variable_vector[i][table_id].size(),
353       params->c_params.parameters[i][table_id]->ptr =
354           state_variable_vector[i][table_id].data();
355     }
356   }
357   StatusHelper status;
358   tpu::OpsApiFn()->TpuEmbeddingEngine_ReadParametersFn(&(params->c_params),
359                                                        status.c_status);
360   OP_REQUIRES_OK(ctx, status.status());
361 
362   if (VLOG_IS_ON(5)) {
363     for (int table_id = 0; table_id < num_tables; ++table_id) {
364       for (int i = 0; i < num_state_variables[table_id]; ++i) {
365         LogRangeStatistics(table_id, i, state_variable_vector[i][table_id]);
366       }
367     }
368   }
369 }
370 
371 #ifdef LIBTPU_ON_GCE
372 
373 REGISTER_KERNEL_BUILDER(
374     Name("LoadAllTPUEmbeddingParameters").Device(DEVICE_CPU),
375     LoadAllTPUEmbeddingParametersOp);
376 REGISTER_KERNEL_BUILDER(
377     Name("RetrieveAllTPUEmbeddingParameters").Device(DEVICE_CPU),
378     RetrieveAllTPUEmbeddingParametersOp);
379 
380 #endif  // LIBTPU_ON_GCE
381 }  // namespace tensorflow
382