xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/transfer_guard_lib.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_TRANSFER_GUARD_LIB_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_TRANSFER_GUARD_LIB_H_
18 
19 #include <optional>
20 #include <string>
21 
22 #include "pybind11/pybind11.h"
23 #include "tensorflow/compiler/xla/status.h"
24 
25 namespace jax {
26 
27 // Transfer guard level chosen by the user code.
28 enum class TransferGuardLevel {
29   // Explicit transfers: allow
30   // Implicit transfers: allow
31   kAllow,
32   // Explicit transfers: allow
33   // Implicit transfers: log
34   kLog,
35   // Explicit transfers: allow
36   // Implicit transfers: disallow
37   kDisallow,
38   // Explicit transfers: log
39   // Implicit transfers: log
40   kLogExplicit,
41   // Explicit transfers: disallow
42   // Implicit transfers: disallow
43   kDisallowExplicit,
44 };
45 
46 // Flags for transfer guard levels are controlled by:
47 // - a global flag value,
48 //   e.g., associated to --jax_transfer_guard_device_to_host
49 //   which defaults to TransferGuardLevel::kAllow.
50 // - possibly a thread-local value, which initially is std::nullopt and
51 //   overrides the global value if set. The thread-local state is used to
52 //   implement context managers that locally override the global state.
53 //
54 // Explicit device_put/device_get contexts are tracked by context managers.
55 struct TransferGuardState {
56   std::optional<TransferGuardLevel> host_to_device;
57   std::optional<TransferGuardLevel> device_to_device;
58   std::optional<TransferGuardLevel> device_to_host;
59   bool explicit_device_put = false;
60   bool explicit_device_get = false;
61 };
62 
63 // Resulting action for a transfer given the transfer guard level and the
64 // transfer type.
65 enum class TransferGuardAction {
66   // Silently allow the transfer.
67   kAllow,
68   // Log and allow the transfer.
69   kLog,
70   // Disallow the transfer.
71   kDisallow,
72 };
73 
74 // Guards a host-to-device transfer. formatter is called to describe the
75 // transfer in a log message or error status.
76 // REQUIRES: Python GIL.
77 xla::Status ApplyTransferGuardToHostToDevice(
78     absl::FunctionRef<std::string()> formatter);
79 
80 // Guards a device-to-device transfer. formatter is called to describe the
81 // transfer in a log message or error status.
82 // REQUIRES: Python GIL.
83 xla::Status ApplyTransferGuardToDeviceToDevice(
84     absl::FunctionRef<std::string()> formatter);
85 
86 // Guards a device-to-host transfer. formatter is called to describe the
87 // transfer in a log message or error status.
88 // REQUIRES: Python GIL.
89 xla::Status ApplyTransferGuardToDeviceToHost(
90     absl::FunctionRef<std::string()> formatter);
91 
92 // The function to call in `xla.cc` to add the bindings for this module.
93 void BuildTransferGuardSubmodule(pybind11::module& m);
94 
95 }  // namespace jax
96 
97 #endif  // TENSORFLOW_COMPILER_XLA_PYTHON_TRANSFER_GUARD_LIB_H_
98