xref: /aosp_15_r20/external/tensorflow/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/profiler/convert/xplane_to_memory_profile.h"
17 
18 #include <algorithm>
19 #include <string>
20 #include <tuple>
21 #include <type_traits>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/string_view.h"
30 #include "absl/types/optional.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/framework/types.pb.h"
33 #include "tensorflow/core/lib/gtl/map_util.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/protobuf.h"
36 #include "tensorflow/core/profiler/protobuf/memory_profile.pb.h"
37 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
38 #include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
39 #include "tensorflow/core/profiler/utils/xplane_schema.h"
40 #include "tensorflow/core/profiler/utils/xplane_utils.h"
41 #include "tensorflow/core/profiler/utils/xplane_visitor.h"
42 
43 namespace tensorflow {
44 namespace profiler {
45 
46 namespace {
47 
48 constexpr int64_t kInvalidStepId = -1;
49 
50 // Index of the time-sorted memory_profile_snapshots list, and the
51 // MemoryActivityMetadata proto it contains.
52 using IndexMetaPair =
53     std::pair<int64_t /*index*/, const MemoryActivityMetadata*>;
54 
IsMemoryAllocation(int64_t event_type)55 bool IsMemoryAllocation(int64_t event_type) {
56   return event_type == HostEventType::kMemoryAllocation;
57 }
58 
IsMemoryDeallocation(int64_t event_type)59 bool IsMemoryDeallocation(int64_t event_type) {
60   return event_type == HostEventType::kMemoryDeallocation;
61 }
62 
UpdateProfileSummary(const MemoryAggregationStats & stats,int64_t time_offset_ps,MemoryProfileSummary * summary)63 void UpdateProfileSummary(const MemoryAggregationStats& stats,
64                           int64_t time_offset_ps,
65                           MemoryProfileSummary* summary) {
66   // Update the peak memory usage over allocator's lifetime.
67   summary->set_peak_bytes_usage_lifetime(stats.peak_bytes_in_use());
68   MemoryAggregationStats* peak_stats = summary->mutable_peak_stats();
69   // If we reach (or stay at) peak memory usage within the profiling window,
70   // update memory profile summary.
71   if (stats.stack_reserved_bytes() + stats.heap_allocated_bytes() >=
72       peak_stats->peak_bytes_in_use()) {
73     *peak_stats = stats;
74     peak_stats->set_peak_bytes_in_use(stats.stack_reserved_bytes() +
75                                       stats.heap_allocated_bytes());
76     summary->set_peak_stats_time_ps(time_offset_ps);
77     summary->set_memory_capacity(stats.stack_reserved_bytes() +
78                                  stats.heap_allocated_bytes() +
79                                  stats.free_memory_bytes());
80   }
81 }
82 
83 // Generate memory profile proto by processing host trace XPlane.
GenerateMemoryProfile(const XPlane * host_trace)84 MemoryProfile GenerateMemoryProfile(const XPlane* host_trace) {
85   XPlaneVisitor plane = CreateTfXPlaneVisitor(host_trace);
86   MemoryProfile memory_profile;
87   // Iterate over all XEvents in the XPlane, and add the XStats to a new
88   // MemoryProfileSnapshot if the EventType is kMemoryAllocation or
89   // kMemoryDeallocation.
90   plane.ForEachLine([&](const XLineVisitor& line) {
91     line.ForEachEvent([&](const XEventVisitor& event) {
92       int64_t event_type = event.Type().value_or(kUnknownHostEventType);
93       if (!(IsMemoryAllocation(event_type) ||
94             IsMemoryDeallocation(event_type))) {
95         return;
96       }
97 
98       MemoryAggregationStats stats;
99       MemoryActivityMetadata metadata;
100       if (IsMemoryAllocation(event_type)) {
101         metadata.set_memory_activity(ALLOCATION);
102       } else if (IsMemoryDeallocation(event_type)) {
103         metadata.set_memory_activity(DEALLOCATION);
104       }
105       metadata.set_step_id(kInvalidStepId);
106 
107       std::string memory_id;
108       event.ForEachStat([&](const XStatVisitor& stat) {
109         if (!stat.Type().has_value()) return;
110         switch (stat.Type().value()) {
111           case StatType::kIndexOnHost:
112           case StatType::kDeviceOrdinal:
113             memory_id = absl::StrCat(stat.IntValue());
114             break;
115           case StatType::kAllocatorName:
116             memory_id = std::string(stat.StrOrRefValue());
117             break;
118           case StatType::kBytesReserved:
119             stats.set_stack_reserved_bytes(stat.IntValue());
120             break;
121           case StatType::kBytesAllocated:
122             stats.set_heap_allocated_bytes(stat.IntValue());
123             break;
124           case StatType::kBytesAvailable:
125             stats.set_free_memory_bytes(stat.IntValue());
126             break;
127           case StatType::kFragmentation:
128             stats.set_fragmentation(stat.DoubleValue());
129             break;
130           case StatType::kPeakBytesInUse:
131             stats.set_peak_bytes_in_use(stat.IntValue());
132             break;
133           case StatType::kRequestedBytes:
134             metadata.set_requested_bytes(stat.IntValue());
135             break;
136           case StatType::kAllocationBytes:
137             metadata.set_allocation_bytes(stat.IntValue());
138             break;
139           case StatType::kAddress:
140             metadata.set_address(stat.IntValue());
141             break;
142           case StatType::kTfOp:
143             metadata.set_tf_op_name(std::string(stat.StrOrRefValue()));
144             break;
145           case StatType::kGroupId:
146             metadata.set_step_id(stat.IntValue());
147             break;
148           case StatType::kRegionType:
149             metadata.set_region_type(std::string(stat.StrOrRefValue()));
150             break;
151           case StatType::kDataType:
152             metadata.set_data_type(tensorflow::DataTypeString(
153                 static_cast<tensorflow::DataType>(stat.IntValue())));
154             break;
155           case StatType::kTensorShapes:
156             metadata.set_tensor_shape(std::string(stat.StrOrRefValue()));
157             break;
158         }
159       });
160 
161       MemoryProfileSummary* summary =
162           (*memory_profile.mutable_memory_profile_per_allocator())[memory_id]
163               .mutable_profile_summary();
164       UpdateProfileSummary(stats, event.OffsetPs(), summary);
165 
166       MemoryProfileSnapshot* snapshot =
167           (*memory_profile.mutable_memory_profile_per_allocator())[memory_id]
168               .add_memory_profile_snapshots();
169       snapshot->set_time_offset_ps(event.OffsetPs());
170       *snapshot->mutable_aggregation_stats() = std::move(stats);
171       *snapshot->mutable_activity_metadata() = std::move(metadata);
172     });
173   });
174   return memory_profile;
175 }
176 
177 // Fix invalid step ids of snapshots at the beginning/end of the profile or at
178 // the step boundaries. The snapshots with invalid step ids at the beginning get
179 // 0 for their step ids. Those at the step boundaries or at the end get the
180 // previous snapshot's step id + 1.
UpdateStepId(PerAllocatorMemoryProfile * memory_profile)181 void UpdateStepId(PerAllocatorMemoryProfile* memory_profile) {
182   int64_t last_valid_step_id = -1;
183   // Snapshots are already sorted in time.
184   for (auto& snapshot : *memory_profile->mutable_memory_profile_snapshots()) {
185     DCHECK(snapshot.has_activity_metadata());
186     if (snapshot.mutable_activity_metadata()->step_id() == kInvalidStepId) {
187       snapshot.mutable_activity_metadata()->set_step_id(last_valid_step_id + 1);
188     } else {
189       last_valid_step_id = snapshot.mutable_activity_metadata()->step_id();
190     }
191   }
192 }
193 
194 // Update the MemoryActivityMetadata for each deallocation event by copying from
195 // matching allocation.
UpdateDeallocation(PerAllocatorMemoryProfile * memory_profile)196 void UpdateDeallocation(PerAllocatorMemoryProfile* memory_profile) {
197   absl::flat_hash_map<uint64 /*address*/, const MemoryActivityMetadata*>
198       addr_metadata_map;
199   for (auto& snapshot : *memory_profile->mutable_memory_profile_snapshots()) {
200     // Match the deallocation with previous allocation based on address.
201     uint64 address = snapshot.activity_metadata().address();
202     if (snapshot.activity_metadata().memory_activity() == DEALLOCATION) {
203       if (addr_metadata_map.contains(address)) {
204         const MemoryActivityMetadata* alloc_meta = addr_metadata_map[address];
205         snapshot.mutable_activity_metadata()->set_tf_op_name(
206             alloc_meta->tf_op_name());
207         snapshot.mutable_activity_metadata()->set_region_type(
208             alloc_meta->region_type());
209         snapshot.mutable_activity_metadata()->set_data_type(
210             alloc_meta->data_type());
211         snapshot.mutable_activity_metadata()->set_tensor_shape(
212             alloc_meta->tensor_shape());
213         // In case of following (unexpected) deallocations to the same chunk
214         // address, leave the metadata as it is (empty or already captured).
215         addr_metadata_map.erase(address);
216       } else {
217         VLOG(2)
218             << "Can't find matching memory allocation for this deallocation: "
219             << snapshot.DebugString();
220       }
221     } else if (!addr_metadata_map.contains(address)) {  // Allocation.
222       addr_metadata_map[address] = &snapshot.activity_metadata();
223     } else {
224       VLOG(2) << "There are two allocations recorded for the same address: "
225               << address
226               << ". The later allocation event is: " << snapshot.DebugString();
227     }
228   }
229   VLOG(2) << "Number of allocations that cannot find matching dealloctions: "
230           << addr_metadata_map.size();
231 }
232 
233 // Return the step id for the peak memory usage data point.
GetPeakMemoryStep(int64_t peak_bytes_profile,const PerAllocatorMemoryProfile * memory_profile)234 int64_t GetPeakMemoryStep(int64_t peak_bytes_profile,
235                           const PerAllocatorMemoryProfile* memory_profile) {
236   int64_t peak_bytes_profile_step_id = 0;
237   for (const auto& snapshot : memory_profile->memory_profile_snapshots()) {
238     // Get the step id of the peak memory usage.
239     if (peak_bytes_profile ==
240         snapshot.aggregation_stats().heap_allocated_bytes() +
241             snapshot.aggregation_stats().stack_reserved_bytes()) {
242       DCHECK(snapshot.has_activity_metadata());
243       peak_bytes_profile_step_id = snapshot.activity_metadata().step_id();
244     }
245   }
246   return peak_bytes_profile_step_id;
247 }
248 
249 // Functor that compares (index, metadata) pair to sort in the order of
250 // allocation bytes and requested bytes (descending), as well as TF Op name,
251 // region type, data type, and tensor shape (ascending).
252 struct MetadataComparator {
operator ()tensorflow::profiler::__anonc580e20b0111::MetadataComparator253   bool operator()(const IndexMetaPair& a, const IndexMetaPair& b) const {
254     const MemoryActivityMetadata* a_meta = a.second;
255     const MemoryActivityMetadata* b_meta = b.second;
256     DCHECK_NE(a_meta, nullptr);
257     DCHECK_NE(b_meta, nullptr);
258 
259     auto lhs =
260         std::make_tuple(-a_meta->allocation_bytes(), -a_meta->requested_bytes(),
261                         a_meta->tf_op_name(), a_meta->region_type(),
262                         a_meta->data_type(), a_meta->tensor_shape());
263     auto rhs =
264         std::make_tuple(-b_meta->allocation_bytes(), -b_meta->requested_bytes(),
265                         b_meta->tf_op_name(), b_meta->region_type(),
266                         b_meta->data_type(), b_meta->tensor_shape());
267     return lhs < rhs;
268   }
269 };
270 
271 // If applicable, add items into active_allocs vector and special_allocations
272 // proto for the unmapped memory usage (in heap) and stack reservation at peak.
InsertSpecialAllocations(int64_t unmapped_allocation_bytes,int64_t step_id,PerAllocatorMemoryProfile * memory_profile,std::vector<IndexMetaPair> * active_allocs)273 void InsertSpecialAllocations(int64_t unmapped_allocation_bytes,
274                               int64_t step_id,
275                               PerAllocatorMemoryProfile* memory_profile,
276                               std::vector<IndexMetaPair>* active_allocs) {
277   int index = 0;
278   if (unmapped_allocation_bytes > 0) {
279     MemoryActivityMetadata* special_allocation =
280         memory_profile->add_special_allocations();
281     special_allocation->set_memory_activity(ALLOCATION);
282     special_allocation->set_requested_bytes(unmapped_allocation_bytes);
283     special_allocation->set_allocation_bytes(unmapped_allocation_bytes);
284     special_allocation->set_address(0);
285     special_allocation->set_tf_op_name("unused preallocated device memory");
286     special_allocation->set_step_id(step_id);
287     special_allocation->set_region_type("persist/dynamic");
288     special_allocation->set_data_type(
289         tensorflow::DataTypeString(static_cast<tensorflow::DataType>(0)));
290     special_allocation->set_tensor_shape("unknown");
291     active_allocs->push_back({--index, special_allocation});
292   }
293   int64_t stack_bytes =
294       memory_profile->profile_summary().peak_stats().stack_reserved_bytes();
295   if (stack_bytes > 0) {
296     MemoryActivityMetadata* special_allocation =
297         memory_profile->add_special_allocations();
298     special_allocation->set_memory_activity(ALLOCATION);
299     special_allocation->set_requested_bytes(stack_bytes);
300     special_allocation->set_allocation_bytes(stack_bytes);
301     special_allocation->set_address(0);
302     special_allocation->set_tf_op_name("stack");
303     special_allocation->set_step_id(step_id);
304     special_allocation->set_region_type("stack");
305     special_allocation->set_data_type(
306         tensorflow::DataTypeString(static_cast<tensorflow::DataType>(0)));
307     special_allocation->set_tensor_shape("unknown");
308     active_allocs->push_back({--index, special_allocation});
309   }
310 }
311 
operator ==(const IndexMetaPair & a,const IndexMetaPair & b)312 bool operator==(const IndexMetaPair& a, const IndexMetaPair& b) {
313   const MemoryActivityMetadata* a_meta = a.second;
314   const MemoryActivityMetadata* b_meta = b.second;
315   return a_meta->allocation_bytes() == b_meta->allocation_bytes() &&
316          a_meta->requested_bytes() == b_meta->requested_bytes() &&
317          a_meta->tf_op_name() == b_meta->tf_op_name() &&
318          a_meta->region_type() == b_meta->region_type() &&
319          a_meta->data_type() == b_meta->data_type() &&
320          a_meta->tensor_shape() == b_meta->tensor_shape();
321 }
322 
323 // Generate the memory breakdown table of active allocations at the peak usage
324 // (within profiling window) and fill each ActiveAllocation proto (i.e. a row).
ProcessActiveAllocations(int64_t peak_bytes_profile_step_id,PerAllocatorMemoryProfile * memory_profile)325 void ProcessActiveAllocations(int64_t peak_bytes_profile_step_id,
326                               PerAllocatorMemoryProfile* memory_profile) {
327   int64_t unmapped_allocation_bytes =
328       memory_profile->profile_summary().peak_stats().heap_allocated_bytes();
329   int64_t unmapped_deallocation_bytes = 0;
330   absl::flat_hash_map<int64_t /*address*/, IndexMetaPair> active_alloc_map;
331   // Only account for the memory activities in the step that includes peak
332   // memory usage.
333   for (int i = 0; i < memory_profile->memory_profile_snapshots_size(); i++) {
334     const auto& snapshot = memory_profile->memory_profile_snapshots().at(i);
335     DCHECK(snapshot.has_activity_metadata());
336     const MemoryActivityMetadata& metadata = snapshot.activity_metadata();
337     if (snapshot.time_offset_ps() >
338         memory_profile->profile_summary().peak_stats_time_ps())
339       break;
340     if (metadata.step_id() != peak_bytes_profile_step_id) continue;
341 
342     if (metadata.memory_activity() == ALLOCATION) {
343       active_alloc_map[metadata.address()] = {i, &metadata};
344       unmapped_allocation_bytes -= metadata.allocation_bytes();
345     } else {
346       DCHECK_EQ(metadata.memory_activity(), DEALLOCATION);
347       if (active_alloc_map.contains(metadata.address())) {
348         active_alloc_map.erase(metadata.address());
349       } else {
350         unmapped_deallocation_bytes += metadata.allocation_bytes();
351       }
352       unmapped_allocation_bytes += metadata.allocation_bytes();
353     }
354   }
355   // This separates the persistent memory from the freed memory from last step's
356   // allocations.
357   unmapped_allocation_bytes -= unmapped_deallocation_bytes;
358 
359   VLOG(2) << "unmapped_allocation_bytes=" << unmapped_allocation_bytes
360           << ", unmapped_deallocation_bytes=" << unmapped_deallocation_bytes;
361 
362   // Using pair of (index, MemoryActivityMetadata*) so that we can sort by the
363   // metadata, and fetch metadata by indexing the time-sorted snapshots at
364   // frontend.
365   std::vector<IndexMetaPair> active_allocs;
366   for (const auto& address_and_index_meta : active_alloc_map) {
367     active_allocs.push_back(address_and_index_meta.second);
368   }
369 
370   InsertSpecialAllocations(unmapped_allocation_bytes,
371                            peak_bytes_profile_step_id, memory_profile,
372                            &active_allocs);
373 
374   std::sort(active_allocs.begin(), active_allocs.end(), MetadataComparator());
375 
376   // Fill the sorted active_allocations proto messages at peak memory usage.
377   // Merge identical allocations and show occurrences.
378   for (int i = 0, end = active_allocs.size(); i < end; i++) {
379     ActiveAllocation* allocation = memory_profile->add_active_allocations();
380     allocation->set_snapshot_index(active_allocs[i].first);
381     if (active_allocs[i].first < 0) {
382       allocation->set_special_index(-active_allocs[i].first - 1);
383     } else {
384       allocation->set_special_index(-1);
385     }
386     allocation->set_num_occurrences(1);
387     const int last_alloc = active_allocs.size() - 1;
388     while (i < last_alloc && active_allocs[i] == active_allocs[i + 1]) {
389       allocation->set_num_occurrences(allocation->num_occurrences() + 1);
390       i++;
391     }
392   }
393 
394   VLOG(2) << "Distinctive active allocation count="
395           << memory_profile->active_allocations_size();
396 }
397 
398 // This function saves the MemoryProfileSnapshots referenced by
399 // <active_allocations> max_num_snapshots.
SaveActiveAllocationSnapshots(protobuf::RepeatedPtrField<MemoryProfileSnapshot> * snapshots,protobuf::RepeatedPtrField<ActiveAllocation> * active_allocations)400 void SaveActiveAllocationSnapshots(
401     protobuf::RepeatedPtrField<MemoryProfileSnapshot>* snapshots,
402     protobuf::RepeatedPtrField<ActiveAllocation>* active_allocations) {
403   std::vector<MemoryProfileSnapshot*> samples;
404   // Puts the snapshots referenced by active_allocations in <samples>.
405   for (const auto& allocation : *active_allocations) {
406     auto orig_index = allocation.snapshot_index();
407     if (orig_index < 0) continue;
408     samples.push_back(&(*snapshots)[orig_index]);
409   }
410 
411   // Change the reference index in <active_allocations>.
412   int new_index = 0;
413   for (auto& allocation : *active_allocations) {
414     int64_t origin_index = allocation.snapshot_index();
415     if (origin_index < 0) continue;
416     allocation.set_snapshot_index(new_index);
417     new_index++;
418   }
419 
420   protobuf::RepeatedPtrField<MemoryProfileSnapshot> new_snapshots;
421   new_snapshots.Reserve(samples.size());
422   for (const auto& sample : samples) {
423     *new_snapshots.Add() = std::move(*sample);
424   }
425   *snapshots = std::move(new_snapshots);
426 }
427 
428 // Sample <max_num_snapshots> memory profile snapshots from the original memory
429 // profile data.
SampleMemoryProfileTimeline(int64_t max_num_snapshots,PerAllocatorMemoryProfile * memory_profile)430 void SampleMemoryProfileTimeline(int64_t max_num_snapshots,
431                                  PerAllocatorMemoryProfile* memory_profile) {
432   const protobuf::RepeatedPtrField<MemoryProfileSnapshot>& original_snapshots =
433       memory_profile->memory_profile_snapshots();
434   protobuf::RepeatedPtrField<MemoryProfileSnapshot>* timeline_snapshots =
435       memory_profile->mutable_sampled_timeline_snapshots();
436   int64_t snapshot_count = original_snapshots.size();
437   if (snapshot_count > max_num_snapshots) {
438     // When there are more memory profile data than <max_num_snapshots>, we
439     // sample the origin data using a max box filter. Filter width is
440     // <filter_width>, collect <count> samples starting from the <start> index
441     // in the original snapshots.
442     auto max_box_filter = [&](int filter_width, int count, int start) {
443       for (int i = 0; i < count; i++) {
444         // Use a max function to get the MemoryProfileSnapshot with the largest
445         // memory usage in the box filter.
446         const MemoryProfileSnapshot* max_snapshot =
447             &original_snapshots[start + filter_width * i];
448         int64_t max_bytes =
449             max_snapshot->aggregation_stats().heap_allocated_bytes() +
450             max_snapshot->aggregation_stats().stack_reserved_bytes();
451         for (int index = start + filter_width * i + 1;
452              index < start + filter_width * (i + 1); index++) {
453           int64_t bytes = original_snapshots[index]
454                               .aggregation_stats()
455                               .heap_allocated_bytes() +
456                           original_snapshots[index]
457                               .aggregation_stats()
458                               .stack_reserved_bytes();
459           if (bytes > max_bytes) {
460             max_snapshot = &original_snapshots[index];
461             max_bytes = bytes;
462           }
463         }
464         *timeline_snapshots->Add() = *max_snapshot;
465       }
466     };
467 
468     int width = snapshot_count / max_num_snapshots;
469     int count1 = max_num_snapshots * (width + 1) - snapshot_count;
470     int count2 = max_num_snapshots - count1;
471 
472     // Collect <count1> samples with box filter width <width>, then collect
473     // <count2> samples with box filter width <width+1>, the total number of
474     // samples collected will be <max_num_snapshot>.
475     max_box_filter(width, count1, 0);
476     max_box_filter(width + 1, count2, width * count1);
477   } else {
478     // When the number of original snapshots are smaller than
479     // <max_num_snapshots>, just copy all the data points to the timeline.
480     *timeline_snapshots = original_snapshots;
481   }
482 }
483 
484 // Post-process the memory profile to correctly update proto fields, and break
485 // down peak memory usage for each allocator.
ProcessMemoryProfileProto(int64_t max_num_snapshots,MemoryProfile * memory_profile)486 void ProcessMemoryProfileProto(int64_t max_num_snapshots,
487                                MemoryProfile* memory_profile) {
488   memory_profile->set_num_hosts(1);
489   // Add sorted memory ids within memory profile data to the selection list.
490   for (const auto& id_and_allocator_profile :
491        memory_profile->memory_profile_per_allocator()) {
492     if (!id_and_allocator_profile.second.memory_profile_snapshots().empty()) {
493       memory_profile->add_memory_ids(id_and_allocator_profile.first);
494     }
495   }
496   absl::c_sort(*memory_profile->mutable_memory_ids());
497 
498   for (auto& id_and_allocator_profile :
499        *memory_profile->mutable_memory_profile_per_allocator()) {
500     PerAllocatorMemoryProfile* allocator_memory_profile =
501         &id_and_allocator_profile.second;
502     protobuf::RepeatedPtrField<MemoryProfileSnapshot>* snapshots =
503         allocator_memory_profile->mutable_memory_profile_snapshots();
504     // Sort the memory_profile_snapshots by time_offset_ps (ascending) in proto.
505     absl::c_sort(*snapshots, [](const MemoryProfileSnapshot& a,
506                                 const MemoryProfileSnapshot& b) {
507       return a.time_offset_ps() < b.time_offset_ps();
508     });
509 
510     UpdateStepId(allocator_memory_profile);
511     UpdateDeallocation(allocator_memory_profile);
512 
513     // Sample a subset of MemoryProfileSnapshots to display in the frontend
514     // memory timeline graph.
515     SampleMemoryProfileTimeline(max_num_snapshots, allocator_memory_profile);
516 
517     int64_t peak_step_id =
518         GetPeakMemoryStep(allocator_memory_profile->profile_summary()
519                               .peak_stats()
520                               .peak_bytes_in_use(),
521                           allocator_memory_profile);
522     ProcessActiveAllocations(peak_step_id, allocator_memory_profile);
523     SaveActiveAllocationSnapshots(
524         snapshots, allocator_memory_profile->mutable_active_allocations());
525   }
526 }
527 
528 template <typename Proto>
ConvertProtoToJson(const Proto & proto_output,std::string * json_output)529 Status ConvertProtoToJson(const Proto& proto_output, std::string* json_output) {
530   protobuf::util::JsonPrintOptions json_options;
531   json_options.always_print_primitive_fields = true;
532   auto status = protobuf::util::MessageToJsonString(proto_output, json_output,
533                                                     json_options);
534   if (!status.ok()) {
535     // Convert error_msg google::protobuf::StringPiece (or absl::string_view) to
536     // tensorflow::StringPiece.
537     auto error_msg = status.message();
538     return errors::Internal(
539         "Could not convert proto to JSON string: ",
540         absl::string_view(error_msg.data(), error_msg.length()));
541   }
542   return OkStatus();
543 }
544 
545 }  // namespace
546 
ConvertXPlaneToMemoryProfile(const XPlane & host_plane,int64_t max_num_snapshots)547 MemoryProfile ConvertXPlaneToMemoryProfile(const XPlane& host_plane,
548                                            int64_t max_num_snapshots) {
549   MemoryProfile memory_profile = GenerateMemoryProfile(&host_plane);
550   ProcessMemoryProfileProto(max_num_snapshots, &memory_profile);
551   // Default version number is 0, set version number to 1 here due to the new
552   // memory profile sampling algorithm.
553   memory_profile.set_version(1);
554   return memory_profile;
555 }
556 
ConvertXSpaceToMemoryProfileJson(const XSpace & xspace,std::string * json_output)557 Status ConvertXSpaceToMemoryProfileJson(const XSpace& xspace,
558                                         std::string* json_output) {
559   if (const XPlane* host_plane =
560           FindPlaneWithName(xspace, kHostThreadsPlaneName)) {
561     MemoryProfile memory_profile = ConvertXPlaneToMemoryProfile(*host_plane);
562     TF_RETURN_IF_ERROR(ConvertProtoToJson(memory_profile, json_output));
563   }
564   return OkStatus();
565 }
566 
567 }  // namespace profiler
568 }  // namespace tensorflow
569