1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_SHARDED_DEVICE_ARRAY_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_SHARDED_DEVICE_ARRAY_H_
18
19 #include <optional>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/types/variant.h"
24 #include "pybind11/cast.h"
25 #include "pybind11/numpy.h"
26 #include "pybind11/pybind11.h"
27 #include "pybind11/pytypes.h"
28 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
29 #include "tensorflow/compiler/xla/python/py_buffer.h"
30 #include "tensorflow/compiler/xla/python/types.h"
31
32 // TODO(jblespiau): The current implementation moves the Python logic to C++,
33 // as a preliminary step to executing the `pmap` execution path from C++.
34 // It implements the current Python behavior (thus, it may not be optimal, and
35 // we will be able to modify it later).
36
37 namespace jax {
38
39 // High level introduction.
40 //
41 // pmap and other parallel computation functions distribute some computation on
42 // several devices. On December 2020, the devices mesh (i.e. N-dimentional array
43 // of devices on which we map the computation) is defined by the user.
44 //
45 // We describe how to shard the inputs, and how to map it to the mesh of devices
46 // using `ShardingSpec`. It's mainly based on 2 components:
47 // - `sharding`, which specifies how to shard the inputs.
48 // - `mesh_mapping`, which specifies how to map shards to devices.
49 //
50 // The 3 following structs define how to shard one dimension of an ndarry.
51 //
52 // `NoSharding` (`None` in Python) means no sharding.
53 struct NoSharding {
54 bool operator==(const NoSharding& other) const { return true; }
55 bool operator!=(const NoSharding& other) const { return false; }
56 };
57
58 template <typename H>
AbslHashValue(H h,const NoSharding & key)59 H AbslHashValue(H h, const NoSharding& key) {
60 return h;
61 }
62
63 // `Chunked` means that the dimension is split into np.prod(chunks) chunks
64 // and the split dimension itself is preserved inside the map.
65 // Those chunks are distributed over `len(chunks)` ShardedAxes axes
66 // (major-to-minor).
67 // For example, for a tensor `t` of shape [N] sharded using [Chunked([p])] (with
68 // p dividing N, let S = N // p) the tensor will be split into p chunks of
69 // shape [S], such sharded_t[k] = t[k * S: (k+1)*S] (left included, right
70 // excluded) for k in {0, ... p-1}.
71 struct Chunked {
72 public:
ChunkedChunked73 explicit Chunked(std::vector<int> chunks_) : chunks(std::move(chunks_)) {}
74 // The number of chunks per axis.
75 std::vector<int> chunks;
76
77 bool operator==(const Chunked& other) const { return chunks == other.chunks; }
78 bool operator!=(const Chunked& other) const { return chunks != other.chunks; }
79 };
80
81 template <typename H>
AbslHashValue(H h,const Chunked & key)82 H AbslHashValue(H h, const Chunked& key) {
83 h = H::combine(std::move(h), key.chunks);
84 return h;
85 }
86
87 // `Unstacked` means that the dimension is split into chunks of size 1, and
88 // doesn't appear inside the map. `size` is always the dimension size.
89 // For example, a Tensor t of shape [N] will be sharded into N tensors of shape
90 // [], when using `Unstacked(N)`.
91 struct Unstacked {
92 public:
UnstackedUnstacked93 explicit Unstacked(int sz) : size(sz) {}
94 int size;
95
96 bool operator==(const Unstacked& other) const { return size == other.size; }
97 bool operator!=(const Unstacked& other) const { return size != other.size; }
98 };
99
100 template <typename H>
AbslHashValue(H h,const Unstacked & key)101 H AbslHashValue(H h, const Unstacked& key) {
102 h = H::combine(std::move(h), key.size);
103 return h;
104 }
105
106 using AvalDimSharding = std::variant<NoSharding, Chunked, Unstacked>;
107
108 // Assigns sharded axes to mesh dimensions.
109 //
110 // The devices will be for each dimension which has a sharded `AvalDimSharding`
111 // When no axis is assigned, the data is replicated.
112 // As indices are 0-indexed, `ShardedAxis(1)` refers to the second actually
113 // sharded axis (i.e. counting as if the None dimensions of sharding were
114 // filtered out).
115 // For example, given the sharding `[Unstacked(n), None, Chunked(m)]`, an entry
116 // of `ShardedAxis(1)` refers to the `Chunked(m)` axis, not the `None`.
117
118 struct ShardedAxis {
119 int axis;
120 bool operator==(const ShardedAxis& other) const { return axis == other.axis; }
121 bool operator!=(const ShardedAxis& other) const { return axis != other.axis; }
122 };
123
124 template <typename H>
AbslHashValue(H h,const ShardedAxis & key)125 H AbslHashValue(H h, const ShardedAxis& key) {
126 h = H::combine(std::move(h), key.axis);
127 return h;
128 }
129
130 struct Replicated {
131 int replicas;
132 bool operator==(const Replicated& other) const {
133 return replicas == other.replicas;
134 }
135 bool operator!=(const Replicated& other) const {
136 return replicas != other.replicas;
137 }
138 };
139
140 template <typename H>
AbslHashValue(H h,const Replicated & key)141 H AbslHashValue(H h, const Replicated& key) {
142 h = H::combine(std::move(h), key.replicas);
143 return h;
144 }
145
146 using MeshDimAssignment = std::variant<ShardedAxis, Replicated>;
147
148 // Describes how each axis is sharded (if it is), and how it's mapped to the
149 // devices mesh. See Jax pxla.py for the documentation.
150 //
151 // ShardingSpec is shared across pmap, pjit and xpmap. For pmap, an input
152 // `sharding` is composed of `NoSharding` and at most one `Unstacked`.
153 // If `axis_size=None`, at least one the inputs has a dimension associated to
154 // `Unstacked`.
155 //
156 // Examples:
157 //
158 // 1. For pmap, with a tensor of shape [8, 2, 2], to unstack along the first
159 // dimension into [8] devices:
160 //
161 // sharding = [Unstacked(8), NoSharding, NoSharding]
162 // mesh_mapping = [ShardedAxis(0)]
163 //
164 // 2. With an input array of shape [6], that we want to chunk into [2, 3]
165 // Assuming an device mesh [3, 4, 2] of devices, we will have:
166 //
167 // sharding = [Chunked([2, 3])]
168 // mesh_mapping = [ShardedAxis(1), Replicated, ShardedAxis(0)]
169 //
170 // In particular, in the above example, the ShardedAxis refers to indices
171 // of the sharded shape [2, 3]. (only the `Chunked` sharding can produce more
172 // than one dimension).
173 class ShardingSpec {
174 public:
ShardingSpec(std::vector<AvalDimSharding> sharding,std::vector<MeshDimAssignment> mesh_mapping)175 ShardingSpec(std::vector<AvalDimSharding> sharding,
176 std::vector<MeshDimAssignment> mesh_mapping)
177 : sharding_(std::move(sharding)),
178 mesh_mapping_(std::move(mesh_mapping)) {}
ShardingSpec(pybind11::iterable py_sharding,pybind11::iterable py_mesh_mapping)179 ShardingSpec(pybind11::iterable py_sharding,
180 pybind11::iterable py_mesh_mapping)
181 : sharding_(xla::IterableToVector<AvalDimSharding>(py_sharding)),
182 mesh_mapping_(
183 xla::IterableToVector<MeshDimAssignment>(py_mesh_mapping)) {}
184
GetSharding()185 const std::vector<AvalDimSharding>& GetSharding() const { return sharding_; }
GetMeshMapping()186 const std::vector<MeshDimAssignment>& GetMeshMapping() const {
187 return mesh_mapping_;
188 }
189
190 bool operator==(const ShardingSpec& other) const {
191 return sharding_ == other.sharding_ && mesh_mapping_ == other.mesh_mapping_;
192 }
193
194 bool operator!=(const ShardingSpec& other) const { return !(*this == other); }
195
196 template <typename H>
197 friend H AbslHashValue(H h, const ShardingSpec& key);
198
199 private:
200 // `sharding` specifies how the array is supposed to get partitioned into
201 // chunks. Its length matchs the rank of the array. See the docstring
202 // of `AvalDimSharding` for the supported partitioning schemes.
203 std::vector<AvalDimSharding> sharding_;
204 // `mesh_mapping` describes an assignments of the array chunks created by
205 // `sharding` to a logical device mesh. The length of the tuple is equal to
206 // the rank of the mesh. Each mesh dimension can either get partitions of
207 // data varying along one of the sharded dimensions, or the data can be
208 // replicated.
209 std::vector<MeshDimAssignment> mesh_mapping_;
210 };
211
212 template <typename H>
AbslHashValue(H h,const ShardingSpec & key)213 H AbslHashValue(H h, const ShardingSpec& key) {
214 h = H::combine(std::move(h), key.sharding_);
215 h = H::combine(std::move(h), key.mesh_mapping_);
216 return h;
217 }
218
219 // A ShardedDeviceArray is an ndarray sharded across devices.
220 //
221 // The purpose of a ShardedDeviceArray is to reduce the number of transfers when
222 // executing replicated computations, by allowing results to persist on the
223 // devices that produced them. That way dispatching a similarly replicated
224 // computation that consumes the same sharded memory layout does not incur any
225 // transfers.
226
227 // A ShardedDeviceArray represents one logical ndarray value, and simulates the
228 // behavior of an ndarray so that it can be treated by user code as an ndarray;
229 // that is, it is only an optimization to reduce transfers.
230
231 // Design note: We move to C++, only what will need to be accessed by C++ to
232 // execute a pmap computation. A large part of the logic is still in Python.
233 class ShardedDeviceArray {
234 public:
235 ShardedDeviceArray(const ShardedDeviceArray&) = delete;
236 ShardedDeviceArray& operator=(const ShardedDeviceArray&) = delete;
237 ShardedDeviceArray(ShardedDeviceArray&&) = default;
238 ShardedDeviceArray& operator=(ShardedDeviceArray&&) = default;
239
240 // Delete all the underlying buffers (freeing memory on device).
241 // The Numpy value on the host, if it exists, will also be deleted.
242 void Delete();
GetShardingSpec()243 const ShardingSpec& GetShardingSpec() const { return sharding_spec_; }
244 // Returns an error status iff the object has been deleted.
245 xla::StatusOr<absl::Span<xla::PjRtBuffer* const>> GetPjRtBuffers();
246
is_deleted()247 bool is_deleted() const { return is_deleted_; }
weak_type()248 bool weak_type() const { return weak_type_; }
device_buffers()249 std::optional<pybind11::list> device_buffers() const {
250 return device_buffers_;
251 }
aval()252 pybind11::object aval() const { return aval_; }
indices()253 pybind11::object indices() const { return indices_; }
254
npy_value()255 std::optional<pybind11::object> npy_value() const { return npy_value_; }
set_npy_value(pybind11::object npy_value)256 void set_npy_value(pybind11::object npy_value) { npy_value_ = npy_value; }
257
one_replica_buffer_indices()258 std::optional<pybind11::object> one_replica_buffer_indices() const {
259 return one_replica_buffer_indices_;
260 }
set_one_replica_buffer_indices(pybind11::object obj)261 void set_one_replica_buffer_indices(pybind11::object obj) {
262 one_replica_buffer_indices_ = obj;
263 }
264
265 // Python-wrapper definitions.
266
267 // pybind11::object typed subclass for PyBuffer objects.
268 class pyobject : public pybind11::object {
269 public:
270 PYBIND11_OBJECT(pyobject, // NOLINT
271 pybind11::object, ShardedDeviceArray::IsShardedDeviceArray);
272 pyobject() = default;
sda()273 ShardedDeviceArray* sda() const {
274 return ShardedDeviceArray::AsShardedDeviceArrayUnchecked(*this);
275 }
276 };
277 using object = pyobject;
278
279 // Returns true if `handle` is a IsShardedDeviceArray.
280 static bool IsShardedDeviceArray(pybind11::handle handle);
281 // Converts `handle` to a PyBuffer*. Does not do any checking.
282 static ShardedDeviceArray* AsShardedDeviceArrayUnchecked(
283 pybind11::handle handle);
284 // Converts `handle` to a PyBuffer*. Returns an error status if
285 // !IsPyBuffer(handle)
286 static xla::StatusOr<ShardedDeviceArray*> AsShardedDeviceArray(
287 pybind11::handle handle);
288
289 // Gets a Python handle to an existing ShardedDeviceArray. Assumes the
290 // PyObject was allocated on the Python heap, which is the case if Make() was
291 // used.
292 pybind11::handle AsHandle();
293
294 static object Make(pybind11::object aval, ShardingSpec sharding_spec,
295 pybind11::list device_buffers, pybind11::object indices,
296 bool weak_type);
297
298 static xla::Status RegisterTypes(pybind11::module& m);
base_type()299 static PyObject* base_type() { return base_type_; }
type()300 static PyObject* type() { return type_; }
301
302 private:
303 // Buffers are expected to be xla::PyBuffer objects, but as there are
304 // alternative backend implementations, this may not be guaranteed.
305 // TODO(jblespiau): As soon as PjRtBuffer is supported by all
306 // implementations, we should be able to store this with the C++ objects.
ShardedDeviceArray(pybind11::object aval,ShardingSpec sharding_spec,pybind11::list device_buffers,pybind11::object indices,bool weak_type)307 ShardedDeviceArray(pybind11::object aval, ShardingSpec sharding_spec,
308 pybind11::list device_buffers, pybind11::object indices,
309 bool weak_type)
310 : aval_(std::move(aval)),
311 sharding_spec_(std::move(sharding_spec)),
312 indices_(std::move(indices)),
313 device_buffers_(std::move(device_buffers)),
314 weak_type_(weak_type) {}
315 static PyObject* base_type_;
316 static PyObject* type_;
317
318 // A ShapedArray indicating the shape and dtype of this array.
319 pybind11::object aval_;
320 // Describes how this array is sharded across `device_buffers`.
321 ShardingSpec sharding_spec_;
322 // The `indices` used to slice numpy array into the underlying list of
323 // buffers. See the Python pxla.py:spec_to_indices function.
324 pybind11::object indices_;
325 // The buffers containing the data for this array. Each buffer is the same
326 // shape and on a different device. Buffers are in row-major order, with
327 // replication treated as an extra innermost dimension.
328 std::optional<pybind11::list> device_buffers_;
329
330 std::optional<pybind11::object> npy_value_ = std::nullopt;
331 std::optional<pybind11::object> one_replica_buffer_indices_ = std::nullopt;
332
333 // The device_buffers as a C++ object. As this is what we consume from C++
334 // and this is also what we generate from C++, cache the result so that
335 // we don't have to perform casts.
336 // TODO(jblespiau): Make this the default, and have `device_buffers_` the
337 // the optional Python value if it's accessed from Python.
338 std::optional<std::vector<xla::PjRtBuffer*>> cpp_device_buffers_ =
339 std::nullopt;
340
341 // The weak_type to prevent accessing the "aval_.weak_type" attribute which
342 // is significantly slower.
343 bool weak_type_;
344 bool is_deleted_ = false;
345 };
346
347 } // namespace jax
348
349 #endif // TENSORFLOW_COMPILER_XLA_PYTHON_SHARDED_DEVICE_ARRAY_H_
350