xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/sharded_device_array.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_SHARDED_DEVICE_ARRAY_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_SHARDED_DEVICE_ARRAY_H_
18 
19 #include <optional>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/types/variant.h"
24 #include "pybind11/cast.h"
25 #include "pybind11/numpy.h"
26 #include "pybind11/pybind11.h"
27 #include "pybind11/pytypes.h"
28 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
29 #include "tensorflow/compiler/xla/python/py_buffer.h"
30 #include "tensorflow/compiler/xla/python/types.h"
31 
32 // TODO(jblespiau): The current implementation moves the Python logic to C++,
33 // as a preliminary step to executing the `pmap` execution path from C++.
34 // It implements the current Python behavior (thus, it may not be optimal, and
35 // we will be able to modify it later).
36 
37 namespace jax {
38 
39 // High level introduction.
40 //
41 // pmap and other parallel computation functions distribute some computation on
42 // several devices. On December 2020, the devices mesh (i.e. N-dimentional array
43 // of devices on which we map the computation) is defined by the user.
44 //
45 // We describe how to shard the inputs, and how to map it to the mesh of devices
46 // using `ShardingSpec`. It's mainly based on 2 components:
47 // - `sharding`, which specifies how to shard the inputs.
48 // - `mesh_mapping`, which specifies how to map shards to devices.
49 //
50 // The 3 following structs define how to shard one dimension of an ndarry.
51 //
52 // `NoSharding` (`None` in Python) means no sharding.
53 struct NoSharding {
54   bool operator==(const NoSharding& other) const { return true; }
55   bool operator!=(const NoSharding& other) const { return false; }
56 };
57 
58 template <typename H>
AbslHashValue(H h,const NoSharding & key)59 H AbslHashValue(H h, const NoSharding& key) {
60   return h;
61 }
62 
63 // `Chunked` means that the dimension is split into np.prod(chunks) chunks
64 // and the split dimension itself is preserved inside the map.
65 // Those chunks are distributed over `len(chunks)` ShardedAxes axes
66 // (major-to-minor).
67 // For example, for a tensor `t` of shape [N] sharded using [Chunked([p])] (with
68 // p  dividing N, let S = N // p) the tensor will be split into p chunks of
69 // shape [S], such sharded_t[k] = t[k * S: (k+1)*S] (left included, right
70 // excluded) for k in {0, ... p-1}.
71 struct Chunked {
72  public:
ChunkedChunked73   explicit Chunked(std::vector<int> chunks_) : chunks(std::move(chunks_)) {}
74   // The number of chunks per axis.
75   std::vector<int> chunks;
76 
77   bool operator==(const Chunked& other) const { return chunks == other.chunks; }
78   bool operator!=(const Chunked& other) const { return chunks != other.chunks; }
79 };
80 
81 template <typename H>
AbslHashValue(H h,const Chunked & key)82 H AbslHashValue(H h, const Chunked& key) {
83   h = H::combine(std::move(h), key.chunks);
84   return h;
85 }
86 
87 // `Unstacked` means that the dimension is split into chunks of size 1, and
88 // doesn't appear inside the map. `size` is always the dimension size.
89 // For example, a Tensor t of shape [N] will be sharded into N tensors of shape
90 // [], when using `Unstacked(N)`.
91 struct Unstacked {
92  public:
UnstackedUnstacked93   explicit Unstacked(int sz) : size(sz) {}
94   int size;
95 
96   bool operator==(const Unstacked& other) const { return size == other.size; }
97   bool operator!=(const Unstacked& other) const { return size != other.size; }
98 };
99 
100 template <typename H>
AbslHashValue(H h,const Unstacked & key)101 H AbslHashValue(H h, const Unstacked& key) {
102   h = H::combine(std::move(h), key.size);
103   return h;
104 }
105 
106 using AvalDimSharding = std::variant<NoSharding, Chunked, Unstacked>;
107 
108 // Assigns sharded axes to mesh dimensions.
109 //
110 // The devices will be for each dimension which has a sharded `AvalDimSharding`
111 // When no axis is assigned, the data is replicated.
112 // As indices are 0-indexed, `ShardedAxis(1)` refers to the second actually
113 // sharded axis (i.e. counting as if the None dimensions of sharding were
114 // filtered out).
115 // For example, given the sharding `[Unstacked(n), None, Chunked(m)]`, an entry
116 // of `ShardedAxis(1)` refers to the `Chunked(m)` axis, not the `None`.
117 
118 struct ShardedAxis {
119   int axis;
120   bool operator==(const ShardedAxis& other) const { return axis == other.axis; }
121   bool operator!=(const ShardedAxis& other) const { return axis != other.axis; }
122 };
123 
124 template <typename H>
AbslHashValue(H h,const ShardedAxis & key)125 H AbslHashValue(H h, const ShardedAxis& key) {
126   h = H::combine(std::move(h), key.axis);
127   return h;
128 }
129 
130 struct Replicated {
131   int replicas;
132   bool operator==(const Replicated& other) const {
133     return replicas == other.replicas;
134   }
135   bool operator!=(const Replicated& other) const {
136     return replicas != other.replicas;
137   }
138 };
139 
140 template <typename H>
AbslHashValue(H h,const Replicated & key)141 H AbslHashValue(H h, const Replicated& key) {
142   h = H::combine(std::move(h), key.replicas);
143   return h;
144 }
145 
146 using MeshDimAssignment = std::variant<ShardedAxis, Replicated>;
147 
148 // Describes how each axis is sharded (if it is), and how it's mapped to the
149 // devices mesh. See Jax pxla.py for the documentation.
150 //
151 // ShardingSpec is shared across pmap, pjit and xpmap. For pmap, an input
152 // `sharding`  is composed of `NoSharding` and at most one `Unstacked`.
153 // If `axis_size=None`, at least one the inputs has a dimension associated to
154 // `Unstacked`.
155 //
156 // Examples:
157 //
158 // 1. For pmap, with a tensor of shape [8, 2, 2], to unstack along the first
159 //    dimension into [8] devices:
160 //
161 //    sharding = [Unstacked(8), NoSharding, NoSharding]
162 //    mesh_mapping = [ShardedAxis(0)]
163 //
164 // 2. With an input array of shape [6], that we want to chunk into [2, 3]
165 //    Assuming an device mesh [3, 4, 2] of devices, we will have:
166 //
167 //    sharding = [Chunked([2, 3])]
168 //    mesh_mapping = [ShardedAxis(1), Replicated, ShardedAxis(0)]
169 //
170 //    In particular, in the above example, the ShardedAxis refers to indices
171 //    of the sharded shape [2, 3]. (only the `Chunked` sharding can produce more
172 //    than one dimension).
173 class ShardingSpec {
174  public:
ShardingSpec(std::vector<AvalDimSharding> sharding,std::vector<MeshDimAssignment> mesh_mapping)175   ShardingSpec(std::vector<AvalDimSharding> sharding,
176                std::vector<MeshDimAssignment> mesh_mapping)
177       : sharding_(std::move(sharding)),
178         mesh_mapping_(std::move(mesh_mapping)) {}
ShardingSpec(pybind11::iterable py_sharding,pybind11::iterable py_mesh_mapping)179   ShardingSpec(pybind11::iterable py_sharding,
180                pybind11::iterable py_mesh_mapping)
181       : sharding_(xla::IterableToVector<AvalDimSharding>(py_sharding)),
182         mesh_mapping_(
183             xla::IterableToVector<MeshDimAssignment>(py_mesh_mapping)) {}
184 
GetSharding()185   const std::vector<AvalDimSharding>& GetSharding() const { return sharding_; }
GetMeshMapping()186   const std::vector<MeshDimAssignment>& GetMeshMapping() const {
187     return mesh_mapping_;
188   }
189 
190   bool operator==(const ShardingSpec& other) const {
191     return sharding_ == other.sharding_ && mesh_mapping_ == other.mesh_mapping_;
192   }
193 
194   bool operator!=(const ShardingSpec& other) const { return !(*this == other); }
195 
196   template <typename H>
197   friend H AbslHashValue(H h, const ShardingSpec& key);
198 
199  private:
200   //  `sharding` specifies how the array is supposed to get partitioned into
201   //  chunks. Its length matchs the rank of the array. See the docstring
202   //  of `AvalDimSharding` for the supported partitioning schemes.
203   std::vector<AvalDimSharding> sharding_;
204   //  `mesh_mapping` describes an assignments of the array chunks created by
205   //  `sharding` to a logical device mesh. The length of the tuple is equal to
206   //  the rank of the mesh. Each mesh dimension can either get partitions of
207   //  data varying along one of the sharded dimensions, or the data can be
208   //  replicated.
209   std::vector<MeshDimAssignment> mesh_mapping_;
210 };
211 
212 template <typename H>
AbslHashValue(H h,const ShardingSpec & key)213 H AbslHashValue(H h, const ShardingSpec& key) {
214   h = H::combine(std::move(h), key.sharding_);
215   h = H::combine(std::move(h), key.mesh_mapping_);
216   return h;
217 }
218 
219 // A ShardedDeviceArray is an ndarray sharded across devices.
220 //
221 // The purpose of a ShardedDeviceArray is to reduce the number of transfers when
222 // executing replicated computations, by allowing results to persist on the
223 // devices that produced them. That way dispatching a similarly replicated
224 // computation that consumes the same sharded memory layout does not incur any
225 // transfers.
226 
227 // A ShardedDeviceArray represents one logical ndarray value, and simulates the
228 // behavior of an ndarray so that it can be treated by user code as an ndarray;
229 // that is, it is only an optimization to reduce transfers.
230 
231 // Design note: We move to C++, only what will need to be accessed by C++ to
232 // execute a pmap computation. A large part of the logic is still in Python.
233 class ShardedDeviceArray {
234  public:
235   ShardedDeviceArray(const ShardedDeviceArray&) = delete;
236   ShardedDeviceArray& operator=(const ShardedDeviceArray&) = delete;
237   ShardedDeviceArray(ShardedDeviceArray&&) = default;
238   ShardedDeviceArray& operator=(ShardedDeviceArray&&) = default;
239 
240   // Delete all the underlying buffers (freeing memory on device).
241   // The Numpy value on the host, if it exists, will also be deleted.
242   void Delete();
GetShardingSpec()243   const ShardingSpec& GetShardingSpec() const { return sharding_spec_; }
244   // Returns an error status iff the object has been deleted.
245   xla::StatusOr<absl::Span<xla::PjRtBuffer* const>> GetPjRtBuffers();
246 
is_deleted()247   bool is_deleted() const { return is_deleted_; }
weak_type()248   bool weak_type() const { return weak_type_; }
device_buffers()249   std::optional<pybind11::list> device_buffers() const {
250     return device_buffers_;
251   }
aval()252   pybind11::object aval() const { return aval_; }
indices()253   pybind11::object indices() const { return indices_; }
254 
npy_value()255   std::optional<pybind11::object> npy_value() const { return npy_value_; }
set_npy_value(pybind11::object npy_value)256   void set_npy_value(pybind11::object npy_value) { npy_value_ = npy_value; }
257 
one_replica_buffer_indices()258   std::optional<pybind11::object> one_replica_buffer_indices() const {
259     return one_replica_buffer_indices_;
260   }
set_one_replica_buffer_indices(pybind11::object obj)261   void set_one_replica_buffer_indices(pybind11::object obj) {
262     one_replica_buffer_indices_ = obj;
263   }
264 
265   // Python-wrapper definitions.
266 
267   // pybind11::object typed subclass for PyBuffer objects.
268   class pyobject : public pybind11::object {
269    public:
270     PYBIND11_OBJECT(pyobject,  // NOLINT
271                     pybind11::object, ShardedDeviceArray::IsShardedDeviceArray);
272     pyobject() = default;
sda()273     ShardedDeviceArray* sda() const {
274       return ShardedDeviceArray::AsShardedDeviceArrayUnchecked(*this);
275     }
276   };
277   using object = pyobject;
278 
279   // Returns true if `handle` is a IsShardedDeviceArray.
280   static bool IsShardedDeviceArray(pybind11::handle handle);
281   // Converts `handle` to a PyBuffer*. Does not do any checking.
282   static ShardedDeviceArray* AsShardedDeviceArrayUnchecked(
283       pybind11::handle handle);
284   // Converts `handle` to a PyBuffer*. Returns an error status if
285   // !IsPyBuffer(handle)
286   static xla::StatusOr<ShardedDeviceArray*> AsShardedDeviceArray(
287       pybind11::handle handle);
288 
289   // Gets a Python handle to an existing ShardedDeviceArray. Assumes the
290   // PyObject was allocated on the Python heap, which is the case if Make() was
291   // used.
292   pybind11::handle AsHandle();
293 
294   static object Make(pybind11::object aval, ShardingSpec sharding_spec,
295                      pybind11::list device_buffers, pybind11::object indices,
296                      bool weak_type);
297 
298   static xla::Status RegisterTypes(pybind11::module& m);
base_type()299   static PyObject* base_type() { return base_type_; }
type()300   static PyObject* type() { return type_; }
301 
302  private:
303   // Buffers are expected to be xla::PyBuffer objects, but as there are
304   // alternative backend implementations, this may not be guaranteed.
305   // TODO(jblespiau): As soon as PjRtBuffer is supported by all
306   // implementations, we should be able to store this with the C++ objects.
ShardedDeviceArray(pybind11::object aval,ShardingSpec sharding_spec,pybind11::list device_buffers,pybind11::object indices,bool weak_type)307   ShardedDeviceArray(pybind11::object aval, ShardingSpec sharding_spec,
308                      pybind11::list device_buffers, pybind11::object indices,
309                      bool weak_type)
310       : aval_(std::move(aval)),
311         sharding_spec_(std::move(sharding_spec)),
312         indices_(std::move(indices)),
313         device_buffers_(std::move(device_buffers)),
314         weak_type_(weak_type) {}
315   static PyObject* base_type_;
316   static PyObject* type_;
317 
318   // A ShapedArray indicating the shape and dtype of this array.
319   pybind11::object aval_;
320   // Describes how this array is sharded across `device_buffers`.
321   ShardingSpec sharding_spec_;
322   // The `indices` used to slice numpy array into the underlying list of
323   // buffers. See the Python pxla.py:spec_to_indices function.
324   pybind11::object indices_;
325   // The buffers containing the data for this array. Each buffer is the same
326   // shape and on a different device. Buffers are in row-major order, with
327   // replication treated as an extra innermost dimension.
328   std::optional<pybind11::list> device_buffers_;
329 
330   std::optional<pybind11::object> npy_value_ = std::nullopt;
331   std::optional<pybind11::object> one_replica_buffer_indices_ = std::nullopt;
332 
333   // The device_buffers as a C++ object. As this is what we consume from C++
334   // and this is also what we generate from C++, cache the result so that
335   // we don't have to perform casts.
336   // TODO(jblespiau): Make this the default, and have `device_buffers_` the
337   // the optional Python value if it's accessed from Python.
338   std::optional<std::vector<xla::PjRtBuffer*>> cpp_device_buffers_ =
339       std::nullopt;
340 
341   // The weak_type to prevent accessing the "aval_.weak_type" attribute which
342   // is significantly slower.
343   bool weak_type_;
344   bool is_deleted_ = false;
345 };
346 
347 }  // namespace jax
348 
349 #endif  // TENSORFLOW_COMPILER_XLA_PYTHON_SHARDED_DEVICE_ARRAY_H_
350