1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_KERNELS_C_API_H_ 16 #define TENSORFLOW_CORE_TPU_KERNELS_TPU_KERNELS_C_API_H_ 17 18 #include <stddef.h> 19 20 #include <cstdint> 21 22 #include "absl/types/optional.h" 23 #include "tensorflow/c/tf_tensor.h" 24 #include "tensorflow/compiler/xla/stream_executor/tpu/c_api_decl.h" 25 #include "tensorflow/compiler/xla/stream_executor/tpu/proto_helper.h" 26 #include "tensorflow/core/tpu/libtftpu.h" 27 28 typedef struct TpuSerializedProto TpuSerializedProto; 29 30 namespace tensorflow { 31 32 class TpuMeshCommonState; 33 class TpuEmbeddingEngineState; 34 class ResourceMgr; 35 36 } // namespace tensorflow 37 38 extern "C" { 39 40 typedef struct XLA_TpuProgram XLA_TpuProgram; 41 42 // Enum for choosing sharding/unsharding program from a `XLA_TpuProgram` obj. 43 enum TpuProgramShardingType { kInvalid = 0, kMain, kSharding, kUnsharding }; 44 45 struct TpuProgramFingerprint { 46 const char* bytes; 47 size_t size; 48 }; 49 50 struct TpuExecutableSerializedProto { 51 const char* bytes; 52 size_t size; 53 }; 54 55 struct CompilerMetadataSerializedProto { 56 const char* bytes; 57 size_t size; 58 }; 59 60 struct HostComputeMetadataSerializedProto { 61 const char* bytes; 62 size_t size; 63 }; 64 65 typedef struct XLA_TpuMeshState XLA_TpuMeshState; 66 67 typedef struct XLA_TpuEmbeddingEngineState XLA_TpuEmbeddingEngineState; 68 69 typedef struct TpuEmbedding_TensorBatchFixedState 70 TpuEmbedding_TensorBatchFixedState; 71 72 typedef struct TpuProfiler TpuProfiler; 73 74 typedef struct XLA_DeviceAssignment { 75 const char* bytes; 76 size_t size; 77 } XLA_DeviceAssignment; 78 79 // Property for creating compilation cache key. 80 struct CompilationCacheKeyProperty { 81 const char* config_prefix; 82 const char* shapes_prefix; 83 const char* function_name; 84 uint64_t mlir_module_fingerprint; 85 const int32_t* device_ids; 86 size_t device_ids_size; 87 int32_t guaranteed_constants_size; 88 uint64_t function_library_fingerprint; 89 int32_t num_cores_per_replica; 90 int32_t num_replicas; 91 const XLA_TpuMeshState* mesh_state; 92 uint64_t session_id; 93 tensorflow::ResourceMgr* resource_mgr; 94 }; 95 96 // Compilation cache key result returning both the key and a more verbose debug 97 // version. 98 struct CompilationCacheKeyResult { 99 const char* key; 100 const char* debug_string; 101 }; 102 103 typedef struct XLA_TpuNodeContext XLA_TpuNodeContext; 104 105 typedef struct TfTpu_OrdinalSelector TfTpuOrdinalSelector; 106 107 struct TpuPartitionedCall_Params { 108 bool input_shape_opt; 109 bool group_tensors_for_packing; 110 int32_t minimum_input_tensors_packing; 111 int32_t minimum_output_tensors_packing; 112 113 // Whether to attempt to automatically shard inputs by adding an 114 // XlaSharding op after each input. 115 bool enable_auto_xla_input_sharding; 116 117 // The dimension of each input to shard if 118 // enable_auto_xla_input_sharding is set to true. Negative numbers are 119 // allowed and refers to dimensions starting from the end. 120 int32_t auto_xla_input_sharding_dim; 121 122 // If true, only create one variable on the TPU for each variable on the CPU. 123 bool enable_variable_deduplication; 124 }; 125 126 // Compiles Mlir or TF function computation by lowering into HLO IR and returns 127 // `count` number of TPU programs ready for execution. 128 // The API allocates the `XLA_TpuProgram*[]` array `tpu_programs` and creates 129 // `XLA_TpuProgram` object(s) using the `TpuProgram_New` API. The caller is 130 // responsible to deallocate both the `XLA_TpuProgram*[]` array and the 131 // `XLA_TpuProgram` object(s) using `TpuProgram_FreeArray` and `TpuProgram_Free` 132 // API respectively. 133 TFTPU_CAPI_EXPORT void TpuCompile_CompileAndBuild( 134 TpuSerializedProto compilation_request, const XLA_TpuMeshState* mesh_state, 135 XLA_TpuProgram** tpu_programs[], size_t* count, TF_Status* status); 136 137 // Compiles a HLO IR and returns `count` number of TPU programs ready for 138 // execution. The API allocates the `XLA_TpuProgram*[]` array `tpu_programs` and 139 // creates `XLA_TpuProgram` object(s) using the `TpuProgram_New` API. The caller 140 // is responsible to deallocate both the `XLA_TpuProgram*[]` array and the 141 // `XLA_TpuProgram` object(s) using `TpuProgram_FreeArray` and `TpuProgram_Free` 142 // API respectively. 143 TFTPU_CAPI_EXPORT void TpuCompile_XrtCompileAndBuild( 144 TpuSerializedProto xrt_computation, const XLA_TpuMeshState* mesh_state, 145 XLA_TpuProgram** tpu_programs[], size_t* count, TF_Status* status); 146 147 // Creates a TPU profiler that is ready to start profiling. 148 TFTPU_CAPI_EXPORT void TpuProfiler_Create(TpuProfiler** tpu_profiler, 149 TF_Status* status); 150 // Destroys the given TPU profiler. 151 TFTPU_CAPI_EXPORT void TpuProfiler_Destroy(TpuProfiler* tpu_profiler); 152 // Starts profiling if not already started, returns an error otherwise. 153 TFTPU_CAPI_EXPORT void TpuProfiler_Start(TpuProfiler* tpu_profiler, 154 TF_Status* status); 155 // Stops profiling if not already stopped, returns an error otherwise. 156 TFTPU_CAPI_EXPORT void TpuProfiler_Stop(TpuProfiler* tpu_profiler, 157 TF_Status* status); 158 // Serializes profiled data into `buffer` and returns the size of `buffer`. The 159 // profile data held by the TPU driver will be cleared after retrieval. 160 // 161 // Step 1. Query the size of buffer required into `size_in_bytes`. 162 // 163 // size_t size_in_bytes; 164 // TpuProfiler_CollectData(profiler, status, nullptr, &size_in_bytes); 165 // 166 // Step 2. Retrieve the data into a `buffer` of size `size_in_bytes`. 167 // Subsequently,The TPU driver clears its copy of the profile data. 168 // 169 // uint8_t buffer = new uint8_t[size_in_bytes]; 170 // TpuProfiler_CollectData(profiler, status, buffer, size_in_bytes); 171 // 172 // Step 3. Unpack the data into an XSpace. 173 // 174 // tensorflow::profiler::XSpace space; 175 // space.ParseFromArray(buffer, size_in_bytes); 176 // 177 TFTPU_CAPI_EXPORT void TpuProfiler_CollectData(TpuProfiler* tpu_profiler, 178 TF_Status* status, 179 uint8_t* buffer, 180 size_t* size_in_bytes); 181 182 // Creates a new TPU mesh state object. 183 TFTPU_CAPI_EXPORT XLA_TpuMeshState* TpuMeshState_Create(); 184 185 // Deletes the given TPU `mesh_state` object. Once deleted the object is 186 // unusable. 187 TFTPU_CAPI_EXPORT void TpuMeshState_Free(XLA_TpuMeshState* mesh_state); 188 189 // Returns a pointer to an opaque mesh data structure used internally. 190 TFTPU_CAPI_EXPORT void* TpuMeshState_MeshCommonState( 191 XLA_TpuMeshState* mesh_state); 192 193 // Creates a new TPU embedding engine state object. 194 TFTPU_CAPI_EXPORT XLA_TpuEmbeddingEngineState* TpuEmbeddingEngineState_Create(); 195 196 // Delete the given TPU embedding engine state object. Once deleted the object 197 // is unusable. 198 TFTPU_CAPI_EXPORT void TpuEmbeddingEngineState_Free( 199 XLA_TpuEmbeddingEngineState* engine_state); 200 201 // Returns a pointer to an opaque embedding engine state data structure used 202 // internally. 203 TFTPU_CAPI_EXPORT void* TpuEmbeddingEngineState_GetState( 204 XLA_TpuEmbeddingEngineState* engine_state); 205 206 TFTPU_CAPI_EXPORT void TfTpuOrdinalSelector_Create( 207 TfTpuOrdinalSelector** ordinal_selector, int num_cores_per_replica); 208 209 TFTPU_CAPI_EXPORT void TfTpuOrdinalSelector_Destroy( 210 TfTpuOrdinalSelector* ordinal_selector); 211 212 TFTPU_CAPI_EXPORT void TfTpuOrdinalSelector_GetOrdinal( 213 TfTpuOrdinalSelector* ordinal_selector, std::optional<uint64_t> key, 214 int64_t* req_id, int64_t* ordinal); 215 216 TFTPU_CAPI_EXPORT void TfTpuOrdinalSelector_DequeueFromCoreSelector( 217 TfTpuOrdinalSelector* ordinal_selector, int32_t device_ordinal, 218 int64_t req_id); 219 220 TFTPU_CAPI_EXPORT void TfTpu_GetTpuPartitionedCallParams( 221 TpuPartitionedCall_Params* params); 222 223 typedef struct TpuExecutable_LoadProgramAndEnqueueToStream_Params { 224 int32_t struct_size; 225 void* priv; 226 227 const XLA_TpuProgram* program; 228 SE_DeviceMemoryBase* arguments; 229 size_t arguments_len; 230 SE_DeviceMemoryBase* result; 231 bool has_cross_program_prefetch_addr; 232 SE_DeviceMemoryBase* cross_program_prefetch_addr; 233 int32_t rng_seed; 234 XLA_DeviceAssignment* device_assignment; 235 SE_Stream* stream; 236 237 TF_Status* status; // out 238 } TpuExecutable_LoadProgramAndEnqueueToStream_Params; 239 240 #define TpuExecutable_LoadProgramAndEnqueueToStream_Params_SIZE \ 241 (sizeof(struct TpuExecutable_LoadProgramAndEnqueueToStream_Params)) 242 243 TFTPU_CAPI_EXPORT void TpuExecutable_LoadProgramAndEnqueueToStream( 244 TpuExecutable_LoadProgramAndEnqueueToStream_Params* params); 245 246 TFTPU_CAPI_EXPORT void HardwareLayout_HostShapeToDeviceShape( 247 XLA_Shape* host_shape, XLA_Shape* device_shape); 248 TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSize(XLA_Shape* shape); 249 TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompact(XLA_Shape* shape); 250 TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompactRaw(XLA_Shape* shape); 251 252 typedef struct TpuExecute_RuntimeInputToPaddedData_Params { 253 int32_t struct_size; 254 void* priv; 255 256 uint32_t* runtime_input_ptr; 257 size_t runtime_input_size; 258 int8_t* padded_data_ptr; 259 size_t padded_data_size; 260 XLA_Shape* runtime_shape; 261 XLA_Shape* compile_time_shape; 262 263 TF_Status* status; // out 264 } TpuExecute_RuntimeInputToPaddedData_Params; 265 266 #define TpuExecute_RuntimeInputToPaddedData_Params_SIZE \ 267 (sizeof(struct TpuExecute_RuntimeInputToPaddedData_Params)) 268 269 TFTPU_CAPI_EXPORT void TpuExecute_RuntimeInputToPaddedData( 270 TpuExecute_RuntimeInputToPaddedData_Params* params); 271 272 typedef struct ConfigureDistributedTpuOp_DoWork_Params { 273 int32_t struct_size; 274 void* priv; 275 276 size_t num_cores_per_host_size; 277 const int32_t* num_cores_per_host; 278 size_t server_address_size; 279 const char* server_address; 280 281 size_t* host_config_output_size; // out 282 char** host_config_output; // out 283 TF_Status* status; // out 284 } ConfigureDistributedTpuOp_DoWork_Params; 285 286 #define ConfigureDistributedTpuOp_DoWork_Params_SIZE \ 287 (sizeof(struct ConfigureDistributedTpuOp_DoWork_Params)) 288 289 TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork( 290 ConfigureDistributedTpuOp_DoWork_Params* params); 291 292 typedef struct WaitForDistributedTpuOp_DoWork_Params { 293 int32_t struct_size; 294 void* priv; 295 296 size_t num_hosts; 297 size_t num_cores_per_host; 298 const int32_t** host_ordinal_to_global_core_id_map; 299 tensorflow::TpuMeshCommonState* tpu_mesh_common_state; 300 301 size_t* tpu_topology_output_size; // out 302 char** tpu_topology_output; // out 303 TF_Status* status; // out 304 } WaitForDistributedTpuOp_DoWork_Params; 305 306 #define WaitForDistributedTpuOp_DoWork_Params_SIZE \ 307 (sizeof(struct WaitForDistributedTpuOp_DoWork_Params)) 308 309 TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork( 310 WaitForDistributedTpuOp_DoWork_Params* params); 311 312 typedef struct InitializeHostForDistributedTpuOp_DoWork_Params { 313 int32_t struct_size; 314 void* priv; 315 316 size_t tpu_host_config_size; 317 const char* tpu_host_config; 318 bool enable_whole_mesh_compilations; 319 bool is_master_worker; 320 321 size_t* core_id_output_size; // out 322 int32_t** core_id_output; // out 323 TF_Status* status; // out 324 } InitializeHostForDistributedTpuOp_DoWork_Params; 325 326 #define InitializeHostForDistributedTpuOp_DoWork_Params_SIZE \ 327 (sizeof(struct InitializeHostForDistributedTpuOp_DoWork_Params)) 328 329 TFTPU_CAPI_EXPORT void InitializeHostForDistributedTpuOp_DoWork( 330 InitializeHostForDistributedTpuOp_DoWork_Params* params); 331 332 TFTPU_CAPI_EXPORT void SetGlobalTPUArrayOp_DoWork( 333 const size_t tpu_topology_size, const char* tpu_topology, 334 TF_Status* status); 335 336 TFTPU_CAPI_EXPORT void DisconnectDistributedTpuChipsOp_DoWork( 337 int32_t* number_of_chips_output, TF_Status* status); 338 339 TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeCharArray(char* output); 340 TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeInt32Array(int32_t* output); 341 342 TFTPU_CAPI_EXPORT bool TpuConfigurationApi_HasTPUPodState(); 343 344 TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpusPerHost(int32_t* tpus, 345 TF_Status* status); 346 TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpuMemoryLimit(int64_t* memory_limit, 347 TF_Status* status); 348 TFTPU_CAPI_EXPORT void TpuConfigurationApi_RemoteCompilationCacheSizeInBytes( 349 int64_t* cache_size_in_bytes); 350 351 typedef struct TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params { 352 int32_t struct_size; 353 void* priv; 354 355 size_t tpu_host_config_size; 356 const char* tpu_host_config; 357 358 size_t* server_address_output_size; // out 359 char** server_address_output; // out 360 TF_Status* status; // out 361 } TpuConfigurationApi_CompilationCacheServerAddressFromConfig_Params; 362 363 #define TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params_SIZE \ 364 (sizeof( \ 365 struct TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params)) 366 367 TFTPU_CAPI_EXPORT 368 void TpuConfigurationApi_CompilationCacheServerAddressFromConfig( 369 TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params* params); 370 371 typedef struct TpuConfigurationApi_GetServerAddressAndPort_Params { 372 int32_t struct_size; 373 void* priv; 374 375 size_t* server_address_output_size; // out 376 char** server_address_output; // out 377 int* port_output; // out 378 TF_Status* status; // out 379 } TpuConfigurationApi_GetServerAddressAndPort_Params; 380 381 #define TpuConfigurationApi_GetServerAddressAndPort_Params_SIZE \ 382 (sizeof(struct TpuConfigurationApi_GetServerAddressAndPort_Params)) 383 384 TFTPU_CAPI_EXPORT void TpuConfigurationApi_GetServerAddressAndPort( 385 TpuConfigurationApi_GetServerAddressAndPort_Params* params); 386 387 // Creates a new TPU program. 388 TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_New(); 389 390 // Destroys the `tpu_program`. 391 TFTPU_CAPI_EXPORT void TpuProgram_Free(XLA_TpuProgram* tpu_program); 392 393 // Creates an array of `XLA_TpuProgram*`. 394 TFTPU_CAPI_EXPORT XLA_TpuProgram** TpuProgram_NewArray(size_t count); 395 396 // Destroys an array of `XLA_TpuProgram*`. 397 TFTPU_CAPI_EXPORT void TpuProgram_FreeArray(XLA_TpuProgram* tpu_program[]); 398 399 // Unloads and destroys the `tpu_program`. Once the TPU program is unloaded and 400 // destroyed, it is in an unusable state. 401 TFTPU_CAPI_EXPORT void TpuProgram_UnloadAndDestroy(XLA_TpuProgram* tpu_program, 402 TF_Status* status); 403 404 // Gets TPU program size in bytes from the `tpu_program`. 405 TFTPU_CAPI_EXPORT int64_t 406 TpuProgram_GetProgramSize(const XLA_TpuProgram* tpu_program); 407 408 // Logs the summary of current memory state snapshot of the `tpu_program`. 409 TFTPU_CAPI_EXPORT bool TpuProgram_LogProgramMemorySummary( 410 const XLA_TpuProgram* tpu_program); 411 412 // Gets TPU program executable info from the `tpu_program`. 413 TFTPU_CAPI_EXPORT void TpuProgram_GetExecutableInfo( 414 const XLA_TpuProgram* tpu_program, TpuSerializedProto* executable_info, 415 TF_Status* status); 416 417 // Gets host transfer info proto. 418 TFTPU_CAPI_EXPORT void TpuProgram_GetHostTransferInfo( 419 const XLA_TpuProgram* tpu_program, TpuSerializedProto* host_transfer_info, 420 TF_Status* status); 421 422 // Gets HLO metadata proto. 423 TFTPU_CAPI_EXPORT void TpuProgram_GetHloMetadata( 424 const XLA_TpuProgram* tpu_program, TpuSerializedProto* hlo_metadata, 425 TF_Status* status); 426 427 // Gets may modify variables boolean value. 428 TFTPU_CAPI_EXPORT void TpuProgram_GetMayModifyVariables( 429 const XLA_TpuProgram* tpu_program, bool* may_modify_variables); 430 431 // Checks if TPU program has sharding. 432 TFTPU_CAPI_EXPORT bool TpuProgram_HasSharding( 433 const XLA_TpuProgram* tpu_program); 434 435 // Gets TPU program by sharding type. Return value is valid only when the 436 // `status.status()` returns `OK`. 437 TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_GetTpuProgram( 438 XLA_TpuProgram* tpu_program, TpuProgramShardingType type); 439 440 // Gets TPU executable proto from a `tpu_program`. 441 TFTPU_CAPI_EXPORT void TpuProgram_SerializeTpuExecutable( 442 const XLA_TpuProgram* tpu_program, TpuExecutableSerializedProto* executable, 443 TF_Status* status); 444 445 // Gets compilation metadata proto from a `tpu_program`. 446 TFTPU_CAPI_EXPORT void TpuProgram_SerializeCompilerMetadata( 447 const XLA_TpuProgram* tpu_program, 448 CompilerMetadataSerializedProto* compiler_metadata, TF_Status* status); 449 450 // Deserializes the `GetTpuProgramResponse` proto into an `XLA_TpuProgram`. 451 TFTPU_CAPI_EXPORT void TpuProgram_DeserializeFromGetTpuProgramResponseProto( 452 TpuSerializedProto get_tpu_program_response, XLA_TpuProgram* tpu_program, 453 TF_Status* status); 454 455 TFTPU_CAPI_EXPORT TpuProgramFingerprint 456 TpuProgram_GetFingerprint(const XLA_TpuProgram* tpu_program); 457 458 TFTPU_CAPI_EXPORT void TpuProgram_DestroyFingerprint( 459 TpuProgramFingerprint fingerprint); 460 461 // Checks if whether a TPU compilation is enabled. 462 TFTPU_CAPI_EXPORT bool TpuCompile_IsTpuCompilationEnabled(); 463 464 // XLA compilation cannot be cancelled. To avoid hanging the TF worker will exit 465 // when cancellation is requested for an XLA compile op. Some tests require this 466 // behavior to be disabled, and we test for this condition with the following 467 // flag function. 468 TFTPU_CAPI_EXPORT bool TpuCompile_ShouldTpuCompileOpIgnoreCancellation(); 469 470 // Returns the number of available TPU core count. 471 TFTPU_CAPI_EXPORT int TpuTopology_AvailableCoreCount( 472 const XLA_TpuMeshState* mesh_state, TpuCoreTypeEnum tpu_core_type); 473 474 // Recycle unused service port. 475 TFTPU_CAPI_EXPORT void TpuNetUtil_RecycleUnusedPort(int port); 476 477 // Creates a unique compilation cache `key` used for `put` and `get` operations. 478 // Returned buffers are heap-allocated and must be owned. 479 TFTPU_CAPI_EXPORT CompilationCacheKeyResult 480 TpuCompile_CreateCompilationCacheKey(CompilationCacheKeyProperty property); 481 482 // Destroys the CompilationCacheKeyResult returned by calling the 483 // `TpuCompile_CreateCompilationCacheKey` API. 484 TFTPU_CAPI_EXPORT void TpuCompile_DestroyCompilationCacheKey( 485 CompilationCacheKeyResult result); 486 487 // Creates a guaranteed const fingerprint. Guarantee const is normally used in 488 // TPU inference to avoid re-copying unchanged variables onto the TPU device. 489 // It promises the value is identical for every execution in the same session 490 // even if the actual value changes in later executions. 491 TFTPU_CAPI_EXPORT uint64_t TpuCompile_CreateGuaranteedConstFingerprint( 492 uint64_t fingerprint, const char* data, size_t size); 493 494 XLA_TpuNodeContext* TpuNodeContext_Create(int device_ordinal, 495 TF_Status* status); 496 void TpuNodeContext_Free(XLA_TpuNodeContext* node_context); 497 498 void TpuNodeContext_StopChipHeartbeats(TF_Status* status); 499 500 void TpuNodeContext_CloseTpuHost(TF_Status* status); 501 502 void TpuNodeContext_Initialize(int device_ordinal, TF_Status* status); 503 504 bool TpuNodeContext_CompactionSupported(int device_ordinal); 505 506 // Globally initialize the TPU system for inference. 507 TFTPU_CAPI_EXPORT void TfTpu_InitializeTpuModelServer(); 508 509 typedef struct TpuEmbeddingEngine_ExecutePartitioner_Params { 510 int32_t struct_size; 511 void* priv; 512 TpuSerializedProto tpu_embedding_config; 513 514 // out 515 size_t* common_config_size; 516 char** common_config; 517 TF_Status* status; 518 } TpuEmbeddingEngine_ExecutePartitioner_Params; 519 520 TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_ExecutePartitioner( 521 TpuEmbeddingEngine_ExecutePartitioner_Params* params); 522 523 typedef struct TpuEmbeddingEngine_ConfigureMemory_Params { 524 int32_t struct_size; 525 void* priv; 526 527 int num_inputs; 528 size_t common_config_size; 529 const char* common_config; 530 531 // out 532 size_t* memory_config_size; 533 char** memory_config; 534 TF_Status* status; 535 } TpuEmbeddingEngine_ConfigureMemory_Params; 536 537 TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_ConfigureMemory( 538 TpuEmbeddingEngine_ConfigureMemory_Params* params); 539 540 typedef struct TpuEmbeddingEngine_CollateMemory_Params { 541 int32_t struct_size; 542 void* priv; 543 544 size_t memory_configs_size; 545 const TpuSerializedProto* memory_configs; 546 547 // out 548 size_t* merged_memory_config_size; 549 char** merged_memory_config; 550 TF_Status* status; 551 } TpuEmbeddingEngine_CollateMemory_Params; 552 553 TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_CollateMemory( 554 TpuEmbeddingEngine_CollateMemory_Params* params); 555 556 typedef struct TpuEmbeddingEngine_ConfigureHost_Params { 557 int32_t struct_size; 558 void* priv; 559 560 int num_inputs; 561 size_t common_config_size; 562 const char* common_config; 563 size_t memory_config_size; 564 const char* memory_config; 565 TpuSerializedProto tpu_embedding_config; 566 567 // out 568 size_t* network_config_size; 569 char** network_config; 570 TF_Status* status; 571 } TpuEmbeddingEngine_ConfigureHost_Params; 572 573 TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_ConfigureHost( 574 TpuEmbeddingEngine_ConfigureHost_Params* params); 575 576 typedef struct TpuEmbeddingEngine_ConnectHosts_Params { 577 int32_t struct_size; 578 void* priv; 579 580 size_t network_configs_size; 581 const TpuSerializedProto* network_configs; 582 583 // out 584 TF_Status* status; 585 } TpuEmbeddingEngine_ConnectHosts_Params; 586 587 TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_ConnectHosts( 588 TpuEmbeddingEngine_ConnectHosts_Params* params); 589 590 typedef struct TpuEmbeddingEngine_Finalize_Params { 591 int32_t struct_size; 592 void* priv; 593 const XLA_TpuMeshState* tpu_mesh_state; 594 595 size_t common_config_size; 596 const char* common_config; 597 size_t memory_config_size; 598 const char* memory_config; 599 600 // out 601 TF_Status* status; 602 } TpuEmbeddingEngine_Finalize_Params; 603 604 TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_Finalize( 605 TpuEmbeddingEngine_Finalize_Params* params); 606 607 typedef struct TpuEmbeddingEngine_IsInitialized_Params { 608 int32_t struct_size; 609 void* priv; 610 611 size_t config_string_size; 612 const char* config_string; 613 614 // out 615 bool* is_tpu_embedding_initialized; 616 TF_Status* status; 617 } TpuEmbeddingEngine_IsInitialized_Params; 618 619 TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_IsInitialized( 620 TpuEmbeddingEngine_IsInitialized_Params* params); 621 622 TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_WriteParameters( 623 TpuEmbeddingEngineParameters* params, TF_Status* status); 624 625 TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_ReadParameters( 626 TpuEmbeddingEngineParameters* params, TF_Status* status); 627 628 typedef struct TpuEmbeddingEngine_EnqueueTensorBatch_Params { 629 int32_t struct_size; 630 void* priv; 631 632 int32_t mode; 633 int32_t local_device_ordinal; 634 TpuEmbedding_TensorBatchFixedState* fixed_state; 635 636 TF_Tensor** sample_indices_tensors; 637 size_t sample_indices_tensors_size; 638 TF_Tensor** embedding_indices_tensors; 639 size_t embedding_indices_tensors_size; 640 TF_Tensor** aggregation_weights_tensors; 641 size_t aggregation_weights_tensors_size; 642 TF_Status* status; 643 } TpuEmbeddingEngine_EnqueueTensorBatch_Params; 644 645 TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_EnqueueTensorBatch( 646 TpuEmbeddingEngine_EnqueueTensorBatch_Params* params); 647 648 typedef struct TpuEmbedding_TensorBatchFixedState_Create_Params { 649 int32_t struct_size; 650 void* priv; 651 652 size_t combiners_size; 653 char** combiners; 654 655 // out 656 TF_Status* status; 657 } TpuEmbedding_TensorBatchFixedState_Create_Params; 658 659 TFTPU_CAPI_EXPORT TpuEmbedding_TensorBatchFixedState* 660 TpuEmbeddingTensorBatchFixedState_Create( 661 TpuEmbedding_TensorBatchFixedState_Create_Params* params); 662 TFTPU_CAPI_EXPORT void TpuEmbeddingTensorBatchFixedState_Destroy( 663 TpuEmbedding_TensorBatchFixedState* fixed_state); 664 665 typedef struct TpuEmbeddingEngine_RecvActivationsComputation_Params { 666 int32_t struct_size; 667 void* priv; 668 669 size_t config_string_size; 670 XLA_Shape* deduplication_data_shape; 671 const XLA_TpuMeshState* tpu_mesh_state; 672 673 // out 674 TpuSerializedProto* xla_computation; 675 TF_Status* status; 676 } TpuEmbeddingEngine_RecvActivationsComputation_Params; 677 678 TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_RecvActivationsComputation( 679 TpuEmbeddingEngine_RecvActivationsComputation_Params* params); 680 681 typedef struct 682 TpuEmbeddingEngine_RecvTPUEmbeddingDeduplicationDataComputation_Params { 683 int32_t struct_size; 684 void* priv; 685 686 const XLA_TpuMeshState* tpu_mesh_state; 687 // out 688 TpuSerializedProto* xla_computation; 689 TF_Status* status; 690 } TpuEmbeddingEngine_RecvTPUEmbeddingDeduplicationDataComputation_Params; 691 692 TFTPU_CAPI_EXPORT void 693 TpuEmbeddingEngine_RecvTPUEmbeddingDeduplicationDataComputation( 694 TpuEmbeddingEngine_RecvTPUEmbeddingDeduplicationDataComputation_Params* 695 params); 696 697 typedef struct TpuEmbeddingEngine_SendTPUEmbeddingGradientsComputation_Params { 698 int32_t struct_size; 699 void* priv; 700 701 int32_t num_inputs; 702 const XLA_TpuMeshState* tpu_mesh_state; 703 XLA_Shape* learning_rate_tuple_shape; 704 XLA_Shape* deduplication_data_shape; 705 XLA_Shape* gradient_tuple_shape; 706 // out 707 TpuSerializedProto* xla_computation; 708 TF_Status* status; 709 } TpuEmbeddingEngine_SendTPUEmbeddingGradientsComputation_Params; 710 711 TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_SendTPUEmbeddingGradientsComputation( 712 TpuEmbeddingEngine_SendTPUEmbeddingGradientsComputation_Params* params); 713 714 struct TfTpu_OpsApiFn { 715 TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CompileAndBuild); 716 TFTPU_ADD_FN_IN_STRUCT(TpuCompile_XrtCompileAndBuild); 717 718 TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Create); 719 TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Free); 720 TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_MeshCommonState); 721 722 TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngineState_Create); 723 TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngineState_Free); 724 TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngineState_GetState); 725 726 TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Create); 727 TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Destroy); 728 TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Start); 729 TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Stop); 730 TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_CollectData); 731 732 TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_LoadProgramAndEnqueueToStream); 733 TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_HostShapeToDeviceShape); 734 TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSize); 735 TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSizeCompact); 736 TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSizeCompactRaw); 737 738 TFTPU_ADD_FN_IN_STRUCT(TpuExecute_RuntimeInputToPaddedData); 739 740 TFTPU_ADD_FN_IN_STRUCT(ConfigureDistributedTpuOp_DoWork); 741 TFTPU_ADD_FN_IN_STRUCT(WaitForDistributedTpuOp_DoWork); 742 TFTPU_ADD_FN_IN_STRUCT(InitializeHostForDistributedTpuOp_DoWork); 743 TFTPU_ADD_FN_IN_STRUCT(SetGlobalTPUArrayOp_DoWork); 744 TFTPU_ADD_FN_IN_STRUCT(DisconnectDistributedTpuChipsOp_DoWork); 745 TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeCharArray); 746 TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeInt32Array); 747 TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_HasTPUPodState); 748 TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpusPerHost); 749 TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpuMemoryLimit); 750 TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_RemoteCompilationCacheSizeInBytes); 751 TFTPU_ADD_FN_IN_STRUCT( 752 TpuConfigurationApi_CompilationCacheServerAddressFromConfig); 753 TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_GetServerAddressAndPort); 754 755 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_New); 756 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_Free); 757 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_NewArray); 758 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_FreeArray); 759 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_UnloadAndDestroy); 760 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetProgramSize); 761 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_LogProgramMemorySummary); 762 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetExecutableInfo); 763 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHostTransferInfo); 764 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHloMetadata); 765 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetMayModifyVariables); 766 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_HasSharding); 767 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetTpuProgram); 768 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeTpuExecutable); 769 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeCompilerMetadata); 770 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_DeserializeFromGetTpuProgramResponseProto); 771 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetFingerprint); 772 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_DestroyFingerprint); 773 774 TFTPU_ADD_FN_IN_STRUCT(TpuCompile_IsTpuCompilationEnabled); 775 TFTPU_ADD_FN_IN_STRUCT(TpuCompile_ShouldTpuCompileOpIgnoreCancellation); 776 TFTPU_ADD_FN_IN_STRUCT(TpuTopology_AvailableCoreCount); 777 TFTPU_ADD_FN_IN_STRUCT(TpuNetUtil_RecycleUnusedPort); 778 TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateCompilationCacheKey); 779 TFTPU_ADD_FN_IN_STRUCT(TpuCompile_DestroyCompilationCacheKey); 780 TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateGuaranteedConstFingerprint); 781 782 TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Create); 783 TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Free); 784 TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_StopChipHeartbeats); 785 TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_CloseTpuHost); 786 TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Initialize); 787 TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_CompactionSupported); 788 789 TFTPU_ADD_FN_IN_STRUCT(TfTpu_InitializeTpuModelServer); 790 791 TFTPU_ADD_FN_IN_STRUCT(TfTpuOrdinalSelector_Create); 792 TFTPU_ADD_FN_IN_STRUCT(TfTpuOrdinalSelector_Destroy); 793 TFTPU_ADD_FN_IN_STRUCT(TfTpuOrdinalSelector_GetOrdinal); 794 TFTPU_ADD_FN_IN_STRUCT(TfTpuOrdinalSelector_DequeueFromCoreSelector); 795 TFTPU_ADD_FN_IN_STRUCT(TfTpu_GetTpuPartitionedCallParams); 796 797 TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_ExecutePartitioner); 798 TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_ConfigureMemory); 799 TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_CollateMemory); 800 TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_ConfigureHost); 801 TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_ConnectHosts); 802 TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_Finalize); 803 TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_IsInitialized); 804 TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_WriteParameters); 805 TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_ReadParameters); 806 TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingTensorBatchFixedState_Create); 807 TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingTensorBatchFixedState_Destroy); 808 TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_EnqueueTensorBatch); 809 TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_RecvActivationsComputation); 810 TFTPU_ADD_FN_IN_STRUCT( 811 TpuEmbeddingEngine_RecvTPUEmbeddingDeduplicationDataComputation); 812 TFTPU_ADD_FN_IN_STRUCT( 813 TpuEmbeddingEngine_SendTPUEmbeddingGradientsComputation); 814 }; 815 816 } // extern "C" 817 818 #endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_KERNELS_C_API_H_ 819