xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tpu/tpu_ops_c_api.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_KERNELS_C_API_H_
16 #define TENSORFLOW_CORE_TPU_KERNELS_TPU_KERNELS_C_API_H_
17 
18 #include <stddef.h>
19 
20 #include <cstdint>
21 
22 #include "absl/types/optional.h"
23 #include "tensorflow/c/tf_tensor.h"
24 #include "tensorflow/compiler/xla/stream_executor/tpu/c_api_decl.h"
25 #include "tensorflow/compiler/xla/stream_executor/tpu/proto_helper.h"
26 #include "tensorflow/core/tpu/libtftpu.h"
27 
28 typedef struct TpuSerializedProto TpuSerializedProto;
29 
30 namespace tensorflow {
31 
32 class TpuMeshCommonState;
33 class TpuEmbeddingEngineState;
34 class ResourceMgr;
35 
36 }  // namespace tensorflow
37 
38 extern "C" {
39 
40 typedef struct XLA_TpuProgram XLA_TpuProgram;
41 
42 // Enum for choosing sharding/unsharding program from a `XLA_TpuProgram` obj.
43 enum TpuProgramShardingType { kInvalid = 0, kMain, kSharding, kUnsharding };
44 
45 struct TpuProgramFingerprint {
46   const char* bytes;
47   size_t size;
48 };
49 
50 struct TpuExecutableSerializedProto {
51   const char* bytes;
52   size_t size;
53 };
54 
55 struct CompilerMetadataSerializedProto {
56   const char* bytes;
57   size_t size;
58 };
59 
60 struct HostComputeMetadataSerializedProto {
61   const char* bytes;
62   size_t size;
63 };
64 
65 typedef struct XLA_TpuMeshState XLA_TpuMeshState;
66 
67 typedef struct XLA_TpuEmbeddingEngineState XLA_TpuEmbeddingEngineState;
68 
69 typedef struct TpuEmbedding_TensorBatchFixedState
70     TpuEmbedding_TensorBatchFixedState;
71 
72 typedef struct TpuProfiler TpuProfiler;
73 
74 typedef struct XLA_DeviceAssignment {
75   const char* bytes;
76   size_t size;
77 } XLA_DeviceAssignment;
78 
79 // Property for creating compilation cache key.
80 struct CompilationCacheKeyProperty {
81   const char* config_prefix;
82   const char* shapes_prefix;
83   const char* function_name;
84   uint64_t mlir_module_fingerprint;
85   const int32_t* device_ids;
86   size_t device_ids_size;
87   int32_t guaranteed_constants_size;
88   uint64_t function_library_fingerprint;
89   int32_t num_cores_per_replica;
90   int32_t num_replicas;
91   const XLA_TpuMeshState* mesh_state;
92   uint64_t session_id;
93   tensorflow::ResourceMgr* resource_mgr;
94 };
95 
96 // Compilation cache key result returning both the key and a more verbose debug
97 // version.
98 struct CompilationCacheKeyResult {
99   const char* key;
100   const char* debug_string;
101 };
102 
103 typedef struct XLA_TpuNodeContext XLA_TpuNodeContext;
104 
105 typedef struct TfTpu_OrdinalSelector TfTpuOrdinalSelector;
106 
107 struct TpuPartitionedCall_Params {
108   bool input_shape_opt;
109   bool group_tensors_for_packing;
110   int32_t minimum_input_tensors_packing;
111   int32_t minimum_output_tensors_packing;
112 
113   // Whether to attempt to automatically shard inputs by adding an
114   // XlaSharding op after each input.
115   bool enable_auto_xla_input_sharding;
116 
117   // The dimension of each input to shard if
118   // enable_auto_xla_input_sharding is set to true. Negative numbers are
119   // allowed and refers to dimensions starting from the end.
120   int32_t auto_xla_input_sharding_dim;
121 
122   // If true, only create one variable on the TPU for each variable on the CPU.
123   bool enable_variable_deduplication;
124 };
125 
126 // Compiles Mlir or TF function computation by lowering into HLO IR and returns
127 // `count` number of TPU programs ready for execution.
128 // The API allocates the `XLA_TpuProgram*[]` array `tpu_programs` and creates
129 // `XLA_TpuProgram` object(s) using the `TpuProgram_New` API. The caller is
130 // responsible to deallocate both the `XLA_TpuProgram*[]` array and the
131 // `XLA_TpuProgram` object(s) using `TpuProgram_FreeArray` and `TpuProgram_Free`
132 // API respectively.
133 TFTPU_CAPI_EXPORT void TpuCompile_CompileAndBuild(
134     TpuSerializedProto compilation_request, const XLA_TpuMeshState* mesh_state,
135     XLA_TpuProgram** tpu_programs[], size_t* count, TF_Status* status);
136 
137 // Compiles a HLO IR and returns `count` number of TPU programs ready for
138 // execution. The API allocates the `XLA_TpuProgram*[]` array `tpu_programs` and
139 // creates `XLA_TpuProgram` object(s) using the `TpuProgram_New` API. The caller
140 // is responsible to deallocate both the `XLA_TpuProgram*[]` array and the
141 // `XLA_TpuProgram` object(s) using `TpuProgram_FreeArray` and `TpuProgram_Free`
142 // API respectively.
143 TFTPU_CAPI_EXPORT void TpuCompile_XrtCompileAndBuild(
144     TpuSerializedProto xrt_computation, const XLA_TpuMeshState* mesh_state,
145     XLA_TpuProgram** tpu_programs[], size_t* count, TF_Status* status);
146 
147 // Creates a TPU profiler that is ready to start profiling.
148 TFTPU_CAPI_EXPORT void TpuProfiler_Create(TpuProfiler** tpu_profiler,
149                                           TF_Status* status);
150 // Destroys the given TPU profiler.
151 TFTPU_CAPI_EXPORT void TpuProfiler_Destroy(TpuProfiler* tpu_profiler);
152 // Starts profiling if not already started, returns an error otherwise.
153 TFTPU_CAPI_EXPORT void TpuProfiler_Start(TpuProfiler* tpu_profiler,
154                                          TF_Status* status);
155 // Stops profiling if not already stopped, returns an error otherwise.
156 TFTPU_CAPI_EXPORT void TpuProfiler_Stop(TpuProfiler* tpu_profiler,
157                                         TF_Status* status);
158 // Serializes profiled data into `buffer` and returns the size of `buffer`. The
159 // profile data held by the TPU driver will be cleared after retrieval.
160 //
161 // Step 1. Query the size of buffer required into `size_in_bytes`.
162 //
163 //   size_t size_in_bytes;
164 //   TpuProfiler_CollectData(profiler, status, nullptr, &size_in_bytes);
165 //
166 // Step 2. Retrieve the data into a `buffer` of size `size_in_bytes`.
167 //         Subsequently,The TPU driver clears its copy of the profile data.
168 //
169 //   uint8_t buffer = new uint8_t[size_in_bytes];
170 //   TpuProfiler_CollectData(profiler, status, buffer, size_in_bytes);
171 //
172 // Step 3. Unpack the data into an XSpace.
173 //
174 //   tensorflow::profiler::XSpace space;
175 //   space.ParseFromArray(buffer, size_in_bytes);
176 //
177 TFTPU_CAPI_EXPORT void TpuProfiler_CollectData(TpuProfiler* tpu_profiler,
178                                                TF_Status* status,
179                                                uint8_t* buffer,
180                                                size_t* size_in_bytes);
181 
182 // Creates a new TPU mesh state object.
183 TFTPU_CAPI_EXPORT XLA_TpuMeshState* TpuMeshState_Create();
184 
185 // Deletes the given TPU `mesh_state` object. Once deleted the object is
186 // unusable.
187 TFTPU_CAPI_EXPORT void TpuMeshState_Free(XLA_TpuMeshState* mesh_state);
188 
189 // Returns a pointer to an opaque mesh data structure used internally.
190 TFTPU_CAPI_EXPORT void* TpuMeshState_MeshCommonState(
191     XLA_TpuMeshState* mesh_state);
192 
193 // Creates a new TPU embedding engine state object.
194 TFTPU_CAPI_EXPORT XLA_TpuEmbeddingEngineState* TpuEmbeddingEngineState_Create();
195 
196 // Delete the given TPU embedding engine state object. Once deleted the object
197 // is unusable.
198 TFTPU_CAPI_EXPORT void TpuEmbeddingEngineState_Free(
199     XLA_TpuEmbeddingEngineState* engine_state);
200 
201 // Returns a pointer to an opaque embedding engine state data structure used
202 // internally.
203 TFTPU_CAPI_EXPORT void* TpuEmbeddingEngineState_GetState(
204     XLA_TpuEmbeddingEngineState* engine_state);
205 
206 TFTPU_CAPI_EXPORT void TfTpuOrdinalSelector_Create(
207     TfTpuOrdinalSelector** ordinal_selector, int num_cores_per_replica);
208 
209 TFTPU_CAPI_EXPORT void TfTpuOrdinalSelector_Destroy(
210     TfTpuOrdinalSelector* ordinal_selector);
211 
212 TFTPU_CAPI_EXPORT void TfTpuOrdinalSelector_GetOrdinal(
213     TfTpuOrdinalSelector* ordinal_selector, std::optional<uint64_t> key,
214     int64_t* req_id, int64_t* ordinal);
215 
216 TFTPU_CAPI_EXPORT void TfTpuOrdinalSelector_DequeueFromCoreSelector(
217     TfTpuOrdinalSelector* ordinal_selector, int32_t device_ordinal,
218     int64_t req_id);
219 
220 TFTPU_CAPI_EXPORT void TfTpu_GetTpuPartitionedCallParams(
221     TpuPartitionedCall_Params* params);
222 
223 typedef struct TpuExecutable_LoadProgramAndEnqueueToStream_Params {
224   int32_t struct_size;
225   void* priv;
226 
227   const XLA_TpuProgram* program;
228   SE_DeviceMemoryBase* arguments;
229   size_t arguments_len;
230   SE_DeviceMemoryBase* result;
231   bool has_cross_program_prefetch_addr;
232   SE_DeviceMemoryBase* cross_program_prefetch_addr;
233   int32_t rng_seed;
234   XLA_DeviceAssignment* device_assignment;
235   SE_Stream* stream;
236 
237   TF_Status* status;  // out
238 } TpuExecutable_LoadProgramAndEnqueueToStream_Params;
239 
240 #define TpuExecutable_LoadProgramAndEnqueueToStream_Params_SIZE \
241   (sizeof(struct TpuExecutable_LoadProgramAndEnqueueToStream_Params))
242 
243 TFTPU_CAPI_EXPORT void TpuExecutable_LoadProgramAndEnqueueToStream(
244     TpuExecutable_LoadProgramAndEnqueueToStream_Params* params);
245 
246 TFTPU_CAPI_EXPORT void HardwareLayout_HostShapeToDeviceShape(
247     XLA_Shape* host_shape, XLA_Shape* device_shape);
248 TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSize(XLA_Shape* shape);
249 TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompact(XLA_Shape* shape);
250 TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompactRaw(XLA_Shape* shape);
251 
252 typedef struct TpuExecute_RuntimeInputToPaddedData_Params {
253   int32_t struct_size;
254   void* priv;
255 
256   uint32_t* runtime_input_ptr;
257   size_t runtime_input_size;
258   int8_t* padded_data_ptr;
259   size_t padded_data_size;
260   XLA_Shape* runtime_shape;
261   XLA_Shape* compile_time_shape;
262 
263   TF_Status* status;  // out
264 } TpuExecute_RuntimeInputToPaddedData_Params;
265 
266 #define TpuExecute_RuntimeInputToPaddedData_Params_SIZE \
267   (sizeof(struct TpuExecute_RuntimeInputToPaddedData_Params))
268 
269 TFTPU_CAPI_EXPORT void TpuExecute_RuntimeInputToPaddedData(
270     TpuExecute_RuntimeInputToPaddedData_Params* params);
271 
272 typedef struct ConfigureDistributedTpuOp_DoWork_Params {
273   int32_t struct_size;
274   void* priv;
275 
276   size_t num_cores_per_host_size;
277   const int32_t* num_cores_per_host;
278   size_t server_address_size;
279   const char* server_address;
280 
281   size_t* host_config_output_size;  // out
282   char** host_config_output;        // out
283   TF_Status* status;                // out
284 } ConfigureDistributedTpuOp_DoWork_Params;
285 
286 #define ConfigureDistributedTpuOp_DoWork_Params_SIZE \
287   (sizeof(struct ConfigureDistributedTpuOp_DoWork_Params))
288 
289 TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork(
290     ConfigureDistributedTpuOp_DoWork_Params* params);
291 
292 typedef struct WaitForDistributedTpuOp_DoWork_Params {
293   int32_t struct_size;
294   void* priv;
295 
296   size_t num_hosts;
297   size_t num_cores_per_host;
298   const int32_t** host_ordinal_to_global_core_id_map;
299   tensorflow::TpuMeshCommonState* tpu_mesh_common_state;
300 
301   size_t* tpu_topology_output_size;  // out
302   char** tpu_topology_output;        // out
303   TF_Status* status;                 // out
304 } WaitForDistributedTpuOp_DoWork_Params;
305 
306 #define WaitForDistributedTpuOp_DoWork_Params_SIZE \
307   (sizeof(struct WaitForDistributedTpuOp_DoWork_Params))
308 
309 TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork(
310     WaitForDistributedTpuOp_DoWork_Params* params);
311 
312 typedef struct InitializeHostForDistributedTpuOp_DoWork_Params {
313   int32_t struct_size;
314   void* priv;
315 
316   size_t tpu_host_config_size;
317   const char* tpu_host_config;
318   bool enable_whole_mesh_compilations;
319   bool is_master_worker;
320 
321   size_t* core_id_output_size;  // out
322   int32_t** core_id_output;     // out
323   TF_Status* status;            // out
324 } InitializeHostForDistributedTpuOp_DoWork_Params;
325 
326 #define InitializeHostForDistributedTpuOp_DoWork_Params_SIZE \
327   (sizeof(struct InitializeHostForDistributedTpuOp_DoWork_Params))
328 
329 TFTPU_CAPI_EXPORT void InitializeHostForDistributedTpuOp_DoWork(
330     InitializeHostForDistributedTpuOp_DoWork_Params* params);
331 
332 TFTPU_CAPI_EXPORT void SetGlobalTPUArrayOp_DoWork(
333     const size_t tpu_topology_size, const char* tpu_topology,
334     TF_Status* status);
335 
336 TFTPU_CAPI_EXPORT void DisconnectDistributedTpuChipsOp_DoWork(
337     int32_t* number_of_chips_output, TF_Status* status);
338 
339 TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeCharArray(char* output);
340 TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeInt32Array(int32_t* output);
341 
342 TFTPU_CAPI_EXPORT bool TpuConfigurationApi_HasTPUPodState();
343 
344 TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpusPerHost(int32_t* tpus,
345                                                        TF_Status* status);
346 TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpuMemoryLimit(int64_t* memory_limit,
347                                                           TF_Status* status);
348 TFTPU_CAPI_EXPORT void TpuConfigurationApi_RemoteCompilationCacheSizeInBytes(
349     int64_t* cache_size_in_bytes);
350 
351 typedef struct TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params {
352   int32_t struct_size;
353   void* priv;
354 
355   size_t tpu_host_config_size;
356   const char* tpu_host_config;
357 
358   size_t* server_address_output_size;  // out
359   char** server_address_output;        // out
360   TF_Status* status;                   // out
361 } TpuConfigurationApi_CompilationCacheServerAddressFromConfig_Params;
362 
363 #define TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params_SIZE \
364   (sizeof(                                                                   \
365       struct TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params))
366 
367 TFTPU_CAPI_EXPORT
368 void TpuConfigurationApi_CompilationCacheServerAddressFromConfig(
369     TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params* params);
370 
371 typedef struct TpuConfigurationApi_GetServerAddressAndPort_Params {
372   int32_t struct_size;
373   void* priv;
374 
375   size_t* server_address_output_size;  // out
376   char** server_address_output;        // out
377   int* port_output;                    // out
378   TF_Status* status;                   // out
379 } TpuConfigurationApi_GetServerAddressAndPort_Params;
380 
381 #define TpuConfigurationApi_GetServerAddressAndPort_Params_SIZE \
382   (sizeof(struct TpuConfigurationApi_GetServerAddressAndPort_Params))
383 
384 TFTPU_CAPI_EXPORT void TpuConfigurationApi_GetServerAddressAndPort(
385     TpuConfigurationApi_GetServerAddressAndPort_Params* params);
386 
387 // Creates a new TPU program.
388 TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_New();
389 
390 // Destroys the `tpu_program`.
391 TFTPU_CAPI_EXPORT void TpuProgram_Free(XLA_TpuProgram* tpu_program);
392 
393 // Creates an array of `XLA_TpuProgram*`.
394 TFTPU_CAPI_EXPORT XLA_TpuProgram** TpuProgram_NewArray(size_t count);
395 
396 // Destroys an array of `XLA_TpuProgram*`.
397 TFTPU_CAPI_EXPORT void TpuProgram_FreeArray(XLA_TpuProgram* tpu_program[]);
398 
399 // Unloads and destroys the `tpu_program`. Once the TPU program is unloaded and
400 // destroyed, it is in an unusable state.
401 TFTPU_CAPI_EXPORT void TpuProgram_UnloadAndDestroy(XLA_TpuProgram* tpu_program,
402                                                    TF_Status* status);
403 
404 // Gets TPU program size in bytes from the `tpu_program`.
405 TFTPU_CAPI_EXPORT int64_t
406 TpuProgram_GetProgramSize(const XLA_TpuProgram* tpu_program);
407 
408 // Logs the summary of current memory state snapshot of the `tpu_program`.
409 TFTPU_CAPI_EXPORT bool TpuProgram_LogProgramMemorySummary(
410     const XLA_TpuProgram* tpu_program);
411 
412 // Gets TPU program executable info from the `tpu_program`.
413 TFTPU_CAPI_EXPORT void TpuProgram_GetExecutableInfo(
414     const XLA_TpuProgram* tpu_program, TpuSerializedProto* executable_info,
415     TF_Status* status);
416 
417 // Gets host transfer info proto.
418 TFTPU_CAPI_EXPORT void TpuProgram_GetHostTransferInfo(
419     const XLA_TpuProgram* tpu_program, TpuSerializedProto* host_transfer_info,
420     TF_Status* status);
421 
422 // Gets HLO metadata proto.
423 TFTPU_CAPI_EXPORT void TpuProgram_GetHloMetadata(
424     const XLA_TpuProgram* tpu_program, TpuSerializedProto* hlo_metadata,
425     TF_Status* status);
426 
427 // Gets may modify variables boolean value.
428 TFTPU_CAPI_EXPORT void TpuProgram_GetMayModifyVariables(
429     const XLA_TpuProgram* tpu_program, bool* may_modify_variables);
430 
431 // Checks if TPU program has sharding.
432 TFTPU_CAPI_EXPORT bool TpuProgram_HasSharding(
433     const XLA_TpuProgram* tpu_program);
434 
435 // Gets TPU program by sharding type. Return value is valid only when the
436 // `status.status()` returns `OK`.
437 TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_GetTpuProgram(
438     XLA_TpuProgram* tpu_program, TpuProgramShardingType type);
439 
440 // Gets TPU executable proto from a `tpu_program`.
441 TFTPU_CAPI_EXPORT void TpuProgram_SerializeTpuExecutable(
442     const XLA_TpuProgram* tpu_program, TpuExecutableSerializedProto* executable,
443     TF_Status* status);
444 
445 // Gets compilation metadata proto from a `tpu_program`.
446 TFTPU_CAPI_EXPORT void TpuProgram_SerializeCompilerMetadata(
447     const XLA_TpuProgram* tpu_program,
448     CompilerMetadataSerializedProto* compiler_metadata, TF_Status* status);
449 
450 // Deserializes the `GetTpuProgramResponse` proto into an `XLA_TpuProgram`.
451 TFTPU_CAPI_EXPORT void TpuProgram_DeserializeFromGetTpuProgramResponseProto(
452     TpuSerializedProto get_tpu_program_response, XLA_TpuProgram* tpu_program,
453     TF_Status* status);
454 
455 TFTPU_CAPI_EXPORT TpuProgramFingerprint
456 TpuProgram_GetFingerprint(const XLA_TpuProgram* tpu_program);
457 
458 TFTPU_CAPI_EXPORT void TpuProgram_DestroyFingerprint(
459     TpuProgramFingerprint fingerprint);
460 
461 // Checks if whether a TPU compilation is enabled.
462 TFTPU_CAPI_EXPORT bool TpuCompile_IsTpuCompilationEnabled();
463 
464 // XLA compilation cannot be cancelled. To avoid hanging the TF worker will exit
465 // when cancellation is requested for an XLA compile op. Some tests require this
466 // behavior to be disabled, and we test for this condition with the following
467 // flag function.
468 TFTPU_CAPI_EXPORT bool TpuCompile_ShouldTpuCompileOpIgnoreCancellation();
469 
470 // Returns the number of available TPU core count.
471 TFTPU_CAPI_EXPORT int TpuTopology_AvailableCoreCount(
472     const XLA_TpuMeshState* mesh_state, TpuCoreTypeEnum tpu_core_type);
473 
474 // Recycle unused service port.
475 TFTPU_CAPI_EXPORT void TpuNetUtil_RecycleUnusedPort(int port);
476 
477 // Creates a unique compilation cache `key` used for `put` and `get` operations.
478 // Returned buffers are heap-allocated and must be owned.
479 TFTPU_CAPI_EXPORT CompilationCacheKeyResult
480 TpuCompile_CreateCompilationCacheKey(CompilationCacheKeyProperty property);
481 
482 // Destroys the CompilationCacheKeyResult returned by calling the
483 // `TpuCompile_CreateCompilationCacheKey` API.
484 TFTPU_CAPI_EXPORT void TpuCompile_DestroyCompilationCacheKey(
485     CompilationCacheKeyResult result);
486 
487 // Creates a guaranteed const fingerprint. Guarantee const is normally used in
488 // TPU inference to avoid re-copying unchanged variables onto the TPU device.
489 // It promises the value is identical for every execution in the same session
490 // even if the actual value changes in later executions.
491 TFTPU_CAPI_EXPORT uint64_t TpuCompile_CreateGuaranteedConstFingerprint(
492     uint64_t fingerprint, const char* data, size_t size);
493 
494 XLA_TpuNodeContext* TpuNodeContext_Create(int device_ordinal,
495                                           TF_Status* status);
496 void TpuNodeContext_Free(XLA_TpuNodeContext* node_context);
497 
498 void TpuNodeContext_StopChipHeartbeats(TF_Status* status);
499 
500 void TpuNodeContext_CloseTpuHost(TF_Status* status);
501 
502 void TpuNodeContext_Initialize(int device_ordinal, TF_Status* status);
503 
504 bool TpuNodeContext_CompactionSupported(int device_ordinal);
505 
506 // Globally initialize the TPU system for inference.
507 TFTPU_CAPI_EXPORT void TfTpu_InitializeTpuModelServer();
508 
509 typedef struct TpuEmbeddingEngine_ExecutePartitioner_Params {
510   int32_t struct_size;
511   void* priv;
512   TpuSerializedProto tpu_embedding_config;
513 
514   // out
515   size_t* common_config_size;
516   char** common_config;
517   TF_Status* status;
518 } TpuEmbeddingEngine_ExecutePartitioner_Params;
519 
520 TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_ExecutePartitioner(
521     TpuEmbeddingEngine_ExecutePartitioner_Params* params);
522 
523 typedef struct TpuEmbeddingEngine_ConfigureMemory_Params {
524   int32_t struct_size;
525   void* priv;
526 
527   int num_inputs;
528   size_t common_config_size;
529   const char* common_config;
530 
531   // out
532   size_t* memory_config_size;
533   char** memory_config;
534   TF_Status* status;
535 } TpuEmbeddingEngine_ConfigureMemory_Params;
536 
537 TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_ConfigureMemory(
538     TpuEmbeddingEngine_ConfigureMemory_Params* params);
539 
540 typedef struct TpuEmbeddingEngine_CollateMemory_Params {
541   int32_t struct_size;
542   void* priv;
543 
544   size_t memory_configs_size;
545   const TpuSerializedProto* memory_configs;
546 
547   // out
548   size_t* merged_memory_config_size;
549   char** merged_memory_config;
550   TF_Status* status;
551 } TpuEmbeddingEngine_CollateMemory_Params;
552 
553 TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_CollateMemory(
554     TpuEmbeddingEngine_CollateMemory_Params* params);
555 
556 typedef struct TpuEmbeddingEngine_ConfigureHost_Params {
557   int32_t struct_size;
558   void* priv;
559 
560   int num_inputs;
561   size_t common_config_size;
562   const char* common_config;
563   size_t memory_config_size;
564   const char* memory_config;
565   TpuSerializedProto tpu_embedding_config;
566 
567   // out
568   size_t* network_config_size;
569   char** network_config;
570   TF_Status* status;
571 } TpuEmbeddingEngine_ConfigureHost_Params;
572 
573 TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_ConfigureHost(
574     TpuEmbeddingEngine_ConfigureHost_Params* params);
575 
576 typedef struct TpuEmbeddingEngine_ConnectHosts_Params {
577   int32_t struct_size;
578   void* priv;
579 
580   size_t network_configs_size;
581   const TpuSerializedProto* network_configs;
582 
583   // out
584   TF_Status* status;
585 } TpuEmbeddingEngine_ConnectHosts_Params;
586 
587 TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_ConnectHosts(
588     TpuEmbeddingEngine_ConnectHosts_Params* params);
589 
590 typedef struct TpuEmbeddingEngine_Finalize_Params {
591   int32_t struct_size;
592   void* priv;
593   const XLA_TpuMeshState* tpu_mesh_state;
594 
595   size_t common_config_size;
596   const char* common_config;
597   size_t memory_config_size;
598   const char* memory_config;
599 
600   // out
601   TF_Status* status;
602 } TpuEmbeddingEngine_Finalize_Params;
603 
604 TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_Finalize(
605     TpuEmbeddingEngine_Finalize_Params* params);
606 
607 typedef struct TpuEmbeddingEngine_IsInitialized_Params {
608   int32_t struct_size;
609   void* priv;
610 
611   size_t config_string_size;
612   const char* config_string;
613 
614   // out
615   bool* is_tpu_embedding_initialized;
616   TF_Status* status;
617 } TpuEmbeddingEngine_IsInitialized_Params;
618 
619 TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_IsInitialized(
620     TpuEmbeddingEngine_IsInitialized_Params* params);
621 
622 TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_WriteParameters(
623     TpuEmbeddingEngineParameters* params, TF_Status* status);
624 
625 TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_ReadParameters(
626     TpuEmbeddingEngineParameters* params, TF_Status* status);
627 
628 typedef struct TpuEmbeddingEngine_EnqueueTensorBatch_Params {
629   int32_t struct_size;
630   void* priv;
631 
632   int32_t mode;
633   int32_t local_device_ordinal;
634   TpuEmbedding_TensorBatchFixedState* fixed_state;
635 
636   TF_Tensor** sample_indices_tensors;
637   size_t sample_indices_tensors_size;
638   TF_Tensor** embedding_indices_tensors;
639   size_t embedding_indices_tensors_size;
640   TF_Tensor** aggregation_weights_tensors;
641   size_t aggregation_weights_tensors_size;
642   TF_Status* status;
643 } TpuEmbeddingEngine_EnqueueTensorBatch_Params;
644 
645 TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_EnqueueTensorBatch(
646     TpuEmbeddingEngine_EnqueueTensorBatch_Params* params);
647 
648 typedef struct TpuEmbedding_TensorBatchFixedState_Create_Params {
649   int32_t struct_size;
650   void* priv;
651 
652   size_t combiners_size;
653   char** combiners;
654 
655   // out
656   TF_Status* status;
657 } TpuEmbedding_TensorBatchFixedState_Create_Params;
658 
659 TFTPU_CAPI_EXPORT TpuEmbedding_TensorBatchFixedState*
660 TpuEmbeddingTensorBatchFixedState_Create(
661     TpuEmbedding_TensorBatchFixedState_Create_Params* params);
662 TFTPU_CAPI_EXPORT void TpuEmbeddingTensorBatchFixedState_Destroy(
663     TpuEmbedding_TensorBatchFixedState* fixed_state);
664 
665 typedef struct TpuEmbeddingEngine_RecvActivationsComputation_Params {
666   int32_t struct_size;
667   void* priv;
668 
669   size_t config_string_size;
670   XLA_Shape* deduplication_data_shape;
671   const XLA_TpuMeshState* tpu_mesh_state;
672 
673   // out
674   TpuSerializedProto* xla_computation;
675   TF_Status* status;
676 } TpuEmbeddingEngine_RecvActivationsComputation_Params;
677 
678 TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_RecvActivationsComputation(
679     TpuEmbeddingEngine_RecvActivationsComputation_Params* params);
680 
681 typedef struct
682     TpuEmbeddingEngine_RecvTPUEmbeddingDeduplicationDataComputation_Params {
683   int32_t struct_size;
684   void* priv;
685 
686   const XLA_TpuMeshState* tpu_mesh_state;
687   // out
688   TpuSerializedProto* xla_computation;
689   TF_Status* status;
690 } TpuEmbeddingEngine_RecvTPUEmbeddingDeduplicationDataComputation_Params;
691 
692 TFTPU_CAPI_EXPORT void
693 TpuEmbeddingEngine_RecvTPUEmbeddingDeduplicationDataComputation(
694     TpuEmbeddingEngine_RecvTPUEmbeddingDeduplicationDataComputation_Params*
695         params);
696 
697 typedef struct TpuEmbeddingEngine_SendTPUEmbeddingGradientsComputation_Params {
698   int32_t struct_size;
699   void* priv;
700 
701   int32_t num_inputs;
702   const XLA_TpuMeshState* tpu_mesh_state;
703   XLA_Shape* learning_rate_tuple_shape;
704   XLA_Shape* deduplication_data_shape;
705   XLA_Shape* gradient_tuple_shape;
706   // out
707   TpuSerializedProto* xla_computation;
708   TF_Status* status;
709 } TpuEmbeddingEngine_SendTPUEmbeddingGradientsComputation_Params;
710 
711 TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_SendTPUEmbeddingGradientsComputation(
712     TpuEmbeddingEngine_SendTPUEmbeddingGradientsComputation_Params* params);
713 
714 struct TfTpu_OpsApiFn {
715   TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CompileAndBuild);
716   TFTPU_ADD_FN_IN_STRUCT(TpuCompile_XrtCompileAndBuild);
717 
718   TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Create);
719   TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Free);
720   TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_MeshCommonState);
721 
722   TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngineState_Create);
723   TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngineState_Free);
724   TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngineState_GetState);
725 
726   TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Create);
727   TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Destroy);
728   TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Start);
729   TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Stop);
730   TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_CollectData);
731 
732   TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_LoadProgramAndEnqueueToStream);
733   TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_HostShapeToDeviceShape);
734   TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSize);
735   TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSizeCompact);
736   TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSizeCompactRaw);
737 
738   TFTPU_ADD_FN_IN_STRUCT(TpuExecute_RuntimeInputToPaddedData);
739 
740   TFTPU_ADD_FN_IN_STRUCT(ConfigureDistributedTpuOp_DoWork);
741   TFTPU_ADD_FN_IN_STRUCT(WaitForDistributedTpuOp_DoWork);
742   TFTPU_ADD_FN_IN_STRUCT(InitializeHostForDistributedTpuOp_DoWork);
743   TFTPU_ADD_FN_IN_STRUCT(SetGlobalTPUArrayOp_DoWork);
744   TFTPU_ADD_FN_IN_STRUCT(DisconnectDistributedTpuChipsOp_DoWork);
745   TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeCharArray);
746   TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeInt32Array);
747   TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_HasTPUPodState);
748   TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpusPerHost);
749   TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpuMemoryLimit);
750   TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_RemoteCompilationCacheSizeInBytes);
751   TFTPU_ADD_FN_IN_STRUCT(
752       TpuConfigurationApi_CompilationCacheServerAddressFromConfig);
753   TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_GetServerAddressAndPort);
754 
755   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_New);
756   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_Free);
757   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_NewArray);
758   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_FreeArray);
759   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_UnloadAndDestroy);
760   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetProgramSize);
761   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_LogProgramMemorySummary);
762   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetExecutableInfo);
763   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHostTransferInfo);
764   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHloMetadata);
765   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetMayModifyVariables);
766   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_HasSharding);
767   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetTpuProgram);
768   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeTpuExecutable);
769   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeCompilerMetadata);
770   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_DeserializeFromGetTpuProgramResponseProto);
771   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetFingerprint);
772   TFTPU_ADD_FN_IN_STRUCT(TpuProgram_DestroyFingerprint);
773 
774   TFTPU_ADD_FN_IN_STRUCT(TpuCompile_IsTpuCompilationEnabled);
775   TFTPU_ADD_FN_IN_STRUCT(TpuCompile_ShouldTpuCompileOpIgnoreCancellation);
776   TFTPU_ADD_FN_IN_STRUCT(TpuTopology_AvailableCoreCount);
777   TFTPU_ADD_FN_IN_STRUCT(TpuNetUtil_RecycleUnusedPort);
778   TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateCompilationCacheKey);
779   TFTPU_ADD_FN_IN_STRUCT(TpuCompile_DestroyCompilationCacheKey);
780   TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateGuaranteedConstFingerprint);
781 
782   TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Create);
783   TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Free);
784   TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_StopChipHeartbeats);
785   TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_CloseTpuHost);
786   TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Initialize);
787   TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_CompactionSupported);
788 
789   TFTPU_ADD_FN_IN_STRUCT(TfTpu_InitializeTpuModelServer);
790 
791   TFTPU_ADD_FN_IN_STRUCT(TfTpuOrdinalSelector_Create);
792   TFTPU_ADD_FN_IN_STRUCT(TfTpuOrdinalSelector_Destroy);
793   TFTPU_ADD_FN_IN_STRUCT(TfTpuOrdinalSelector_GetOrdinal);
794   TFTPU_ADD_FN_IN_STRUCT(TfTpuOrdinalSelector_DequeueFromCoreSelector);
795   TFTPU_ADD_FN_IN_STRUCT(TfTpu_GetTpuPartitionedCallParams);
796 
797   TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_ExecutePartitioner);
798   TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_ConfigureMemory);
799   TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_CollateMemory);
800   TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_ConfigureHost);
801   TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_ConnectHosts);
802   TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_Finalize);
803   TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_IsInitialized);
804   TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_WriteParameters);
805   TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_ReadParameters);
806   TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingTensorBatchFixedState_Create);
807   TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingTensorBatchFixedState_Destroy);
808   TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_EnqueueTensorBatch);
809   TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_RecvActivationsComputation);
810   TFTPU_ADD_FN_IN_STRUCT(
811       TpuEmbeddingEngine_RecvTPUEmbeddingDeduplicationDataComputation);
812   TFTPU_ADD_FN_IN_STRUCT(
813       TpuEmbeddingEngine_SendTPUEmbeddingGradientsComputation);
814 };
815 
816 }  // extern "C"
817 
818 #endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_KERNELS_C_API_H_
819