xref: /aosp_15_r20/external/tensorflow/tensorflow/core/profiler/lib/scoped_memory_debug_annotation.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_PROFILER_LIB_SCOPED_MEMORY_DEBUG_ANNOTATION_H_
16 #define TENSORFLOW_CORE_PROFILER_LIB_SCOPED_MEMORY_DEBUG_ANNOTATION_H_
17 
18 #include <cstdint>
19 #include <functional>
20 #include <string>
21 #include <utility>
22 
23 namespace tensorflow {
24 namespace profiler {
25 
26 // Annotations for memory profiling and debugging purpose.
27 // ScopedMemoryDebugAnnotation will cache the annotations in thread-local
28 // memory, and some allocators will try to tag allocations with the annotations.
29 struct MemoryDebugAnnotation {
30   const char* pending_op_name = nullptr;
31   int64_t pending_step_id = 0;
32   const char* pending_region_type = nullptr;
33   int32_t pending_data_type = 0;
34   // A lambda function, when invoked, it will generate the string that describe
35   // the shape of the pending tensor. By default, the TensorShape string is an
36   // empty string.
37   std::function<std::string()> pending_shape_func = []() { return ""; };
38 };
39 
40 // Wrapper class of MemoryDebugAnnotation for RAII.
41 class ScopedMemoryDebugAnnotation {
42  public:
CurrentAnnotation()43   static const MemoryDebugAnnotation& CurrentAnnotation() {
44     return *ThreadMemoryDebugAnnotation();
45   }
46 
ScopedMemoryDebugAnnotation(const char * op_name)47   explicit ScopedMemoryDebugAnnotation(const char* op_name) {
48     MemoryDebugAnnotation* thread_local_annotation =
49         ThreadMemoryDebugAnnotation();
50     last_annotation_ = *thread_local_annotation;
51     *thread_local_annotation = MemoryDebugAnnotation();
52     thread_local_annotation->pending_op_name = op_name;
53   }
54 
ScopedMemoryDebugAnnotation(const char * op_name,int64_t step_id)55   explicit ScopedMemoryDebugAnnotation(const char* op_name, int64_t step_id) {
56     MemoryDebugAnnotation* thread_local_annotation =
57         ThreadMemoryDebugAnnotation();
58     last_annotation_ = *thread_local_annotation;
59     *thread_local_annotation = MemoryDebugAnnotation();
60     thread_local_annotation->pending_op_name = op_name;
61     thread_local_annotation->pending_step_id = step_id;
62   }
63 
64   // This constructor keeps the pending_op_name and pending_step_id from parent
65   // (if any).  Otherwise it overwrites with op_name.
ScopedMemoryDebugAnnotation(const char * op_name,const char * region_type,int32_t data_type,std::function<std::string ()> && pending_shape_func)66   explicit ScopedMemoryDebugAnnotation(
67       const char* op_name, const char* region_type, int32_t data_type,
68       std::function<std::string()>&& pending_shape_func) {
69     MemoryDebugAnnotation* thread_local_annotation =
70         ThreadMemoryDebugAnnotation();
71     last_annotation_ = *thread_local_annotation;
72     if (!thread_local_annotation->pending_op_name) {
73       thread_local_annotation->pending_op_name = op_name;
74     }
75     thread_local_annotation->pending_region_type = region_type;
76     thread_local_annotation->pending_data_type = data_type;
77     thread_local_annotation->pending_shape_func = std::move(pending_shape_func);
78   }
79 
ScopedMemoryDebugAnnotation(const char * op_name,int64_t step_id,const char * region_type,int32_t data_type,std::function<std::string ()> && pending_shape_func)80   explicit ScopedMemoryDebugAnnotation(
81       const char* op_name, int64_t step_id, const char* region_type,
82       int32_t data_type, std::function<std::string()>&& pending_shape_func) {
83     MemoryDebugAnnotation* thread_local_annotation =
84         ThreadMemoryDebugAnnotation();
85     last_annotation_ = *thread_local_annotation;
86     thread_local_annotation->pending_op_name = op_name;
87     thread_local_annotation->pending_step_id = step_id;
88     thread_local_annotation->pending_region_type = region_type;
89     thread_local_annotation->pending_data_type = data_type;
90     thread_local_annotation->pending_shape_func = std::move(pending_shape_func);
91   }
92 
~ScopedMemoryDebugAnnotation()93   ~ScopedMemoryDebugAnnotation() {
94     *ThreadMemoryDebugAnnotation() = last_annotation_;
95   }
96 
97  private:
98   // Returns a pointer to the MemoryDebugAnnotation for the current thread.
99   static MemoryDebugAnnotation* ThreadMemoryDebugAnnotation();
100 
101   // Stores the previous values in case the annotations are nested.
102   MemoryDebugAnnotation last_annotation_;
103 
104   ScopedMemoryDebugAnnotation(const ScopedMemoryDebugAnnotation&) = delete;
105   ScopedMemoryDebugAnnotation& operator=(const ScopedMemoryDebugAnnotation&) =
106       delete;
107 };
108 
109 }  // namespace profiler
110 }  // namespace tensorflow
111 
112 #endif  // TENSORFLOW_CORE_PROFILER_LIB_SCOPED_MEMORY_DEBUG_ANNOTATION_H_
113