xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/stream_executor/scratch_allocator.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_SCRATCH_ALLOCATOR_H_
17 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_SCRATCH_ALLOCATOR_H_
18 
19 #include <memory>
20 #include <utility>
21 
22 #include "absl/container/inlined_vector.h"
23 #include "tensorflow/compiler/xla/stream_executor/device_memory.h"
24 #include "tensorflow/compiler/xla/stream_executor/device_memory_allocator.h"
25 #include "tensorflow/compiler/xla/stream_executor/lib/statusor.h"
26 #include "tensorflow/compiler/xla/stream_executor/platform/port.h"
27 #include "tensorflow/compiler/xla/stream_executor/temporary_device_memory.h"
28 
29 namespace stream_executor {
30 
31 class Stream;
32 
33 // Interface for "scratch" allocator for device memory, which deallocates all
34 // buffers it has allocated at destruction. Returned memory pointers are not
35 // owning.
36 //
37 // Used by stream operations (e.g. Stream::ThenConvolveWithScratch) to
38 // optionally request scratch space to speed up the operation.
39 class ScratchAllocator {
40  public:
~ScratchAllocator()41   virtual ~ScratchAllocator() {}
42 
43   // Returns a limit of memory this scratch allocator wants to produce, in
44   // bytes. This information may be used to help select an algorithm.
45   //
46   // Returns values < 0 to indicate that there is no recommended limit.
47   virtual int64_t GetMemoryLimitInBytes() = 0;
48 
49   // Returns an allocation on byte_size bytes for use in an operation on stream.
50   //
51   // This is a temporary allocation, and the caller is responsible for
52   // deallocating at some known-safe point. See the class comment above.
53   virtual port::StatusOr<DeviceMemory<uint8>> AllocateBytes(
54       int64_t byte_size) = 0;
55 };
56 
57 // Allocates a single temporary memory allocation -- this memory is deallocated
58 // at the next stream synchronization point after this object has gone out of
59 // scope. This satisfies the lifetime and deallocation properties given in the
60 // class comment above.
61 //
62 // Thread-compatible, but not thread-safe (use in scenarios where only one
63 // thread will request the scratch allocation).
64 class OneTimeScratchAllocator : public ScratchAllocator {
65  public:
OneTimeScratchAllocator(Stream * stream)66   explicit OneTimeScratchAllocator(Stream* stream) : stream_(stream) {}
67 
GetMemoryLimitInBytes()68   int64_t GetMemoryLimitInBytes() override { return -1; }
69 
70   port::StatusOr<DeviceMemory<uint8>> AllocateBytes(int64_t byte_size) override;
71 
72  private:
73   std::unique_ptr<TemporaryDeviceMemory<uint8>> temporary_;
74   Stream* stream_;
75 
76   SE_DISALLOW_COPY_AND_ASSIGN(OneTimeScratchAllocator);
77 };
78 
79 // Can allocate several times -- this memory is deallocated when the scratch
80 // allocator is destroyed.
81 //
82 // Thread-compatible, but not thread-safe (use in scenarios where only one
83 // thread will request the scratch allocation).
84 template <size_t N = 1>
85 class OwningScratchAllocator : public ScratchAllocator {
86  public:
OwningScratchAllocator(int device_ordinal,DeviceMemoryAllocator * allocator)87   OwningScratchAllocator(int device_ordinal, DeviceMemoryAllocator* allocator)
88       : device_ordinal_(device_ordinal), allocator_(allocator) {}
89 
GetMemoryLimitInBytes()90   int64_t GetMemoryLimitInBytes() override { return -1; }
91 
AllocateBytes(int64_t byte_size)92   port::StatusOr<DeviceMemory<uint8>> AllocateBytes(
93       int64_t byte_size) override {
94     TF_ASSIGN_OR_RETURN(OwningDeviceMemory buffer,
95                         allocator_->Allocate(device_ordinal_, byte_size,
96                                              /*retry_on_failure=*/false));
97     buffers_.push_back(std::move(buffer));
98     return *buffers_.back();
99   }
100 
101  private:
102   int device_ordinal_;
103   DeviceMemoryAllocator* allocator_;
104   absl::InlinedVector<OwningDeviceMemory, N> buffers_;
105 
106   SE_DISALLOW_COPY_AND_ASSIGN(OwningScratchAllocator);
107 };
108 
109 }  // namespace stream_executor
110 
111 #endif  // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_SCRATCH_ALLOCATOR_H_
112