xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/metrics/types_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/lite/metrics/types_util.h"
16 
17 #include <string>
18 
19 #include "llvm/ADT/TypeSwitch.h"
20 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
21 
22 namespace mlir {
23 namespace TFL {
24 namespace {
25 
26 // Extracts information from mlir::FileLineColLoc to the proto message
27 // tflite::metrics::ConverterErrorData::FileLoc.
ExtractFileLine(const FileLineColLoc & loc,tflite::metrics::ConverterErrorData::FileLoc * fileline)28 void ExtractFileLine(const FileLineColLoc& loc,
29                      tflite::metrics::ConverterErrorData::FileLoc* fileline) {
30   fileline->set_filename(loc.getFilename().str());
31   fileline->set_line(loc.getLine());
32   fileline->set_column(loc.getColumn());
33 }
34 
35 // Defines a child class of Location to access its protected members.
36 class LocationExtractor : public Location {
37  public:
LocationExtractor(const Location & loc)38   explicit LocationExtractor(const Location& loc) : Location(loc) {}
39 
Extract(tflite::metrics::ConverterErrorData * error_data)40   void Extract(tflite::metrics::ConverterErrorData* error_data) {
41     using tflite::metrics::ConverterErrorData;
42     auto mutable_location = error_data->mutable_location();
43 
44     llvm::TypeSwitch<LocationAttr>(impl)
45         .Case<OpaqueLoc>([&](OpaqueLoc loc) {
46           LocationExtractor(loc.getFallbackLocation()).Extract(error_data);
47         })
48         .Case<UnknownLoc>([&](UnknownLoc loc) {
49           mutable_location->set_type(ConverterErrorData::UNKNOWNLOC);
50         })
51         .Case<FileLineColLoc>([&](FileLineColLoc loc) {
52           if (!mutable_location->has_type()) {
53             mutable_location->set_type(ConverterErrorData::CALLSITELOC);
54           }
55           auto new_call = mutable_location->mutable_call()->Add();
56           ExtractFileLine(loc, new_call->mutable_source());
57         })
58         .Case<NameLoc>([&](NameLoc loc) {
59           if (!mutable_location->has_type()) {
60             mutable_location->set_type(ConverterErrorData::NAMELOC);
61           }
62 
63           auto new_call = mutable_location->mutable_call()->Add();
64           new_call->set_name(loc.getName().str());
65           // Add child as the source location.
66           auto child_loc = loc.getChildLoc();
67           if (child_loc.isa<FileLineColLoc>()) {
68             auto typed_child_loc = child_loc.dyn_cast<FileLineColLoc>();
69             ExtractFileLine(typed_child_loc, new_call->mutable_source());
70           }
71         })
72         .Case<CallSiteLoc>([&](CallSiteLoc loc) {
73           mutable_location->set_type(ConverterErrorData::CALLSITELOC);
74           LocationExtractor(loc.getCallee()).Extract(error_data);
75           LocationExtractor(loc.getCaller()).Extract(error_data);
76         })
77         .Case<FusedLoc>([&](FusedLoc loc) {
78           auto locations = loc.getLocations();
79           size_t num_locs = locations.size();
80           // Skip the first location if it stores information for propagating
81           // op_type metadata.
82           if (num_locs > 0) {
83             if (auto name_loc = locations[0].dyn_cast<mlir::NameLoc>()) {
84               if (name_loc.getName().strref().endswith(":")) {
85                 if (num_locs == 2) {
86                   return LocationExtractor(locations[1]).Extract(error_data);
87                 } else if (num_locs > 2) {
88                   locations = {locations.begin() + 1, locations.end()};
89                 }
90               }
91             }
92           }
93 
94           mutable_location->set_type(ConverterErrorData::FUSEDLOC);
95           llvm::interleave(
96               locations,
97               [&](Location l) { LocationExtractor(l).Extract(error_data); },
98               [&]() {});
99         });
100   }
101 };
102 }  // namespace
103 
NewConverterErrorData(const std::string & pass_name,const std::string & error_message,tflite::metrics::ConverterErrorData::ErrorCode error_code,const std::string & op_name,const Location & location)104 tflite::metrics::ConverterErrorData NewConverterErrorData(
105     const std ::string& pass_name, const std::string& error_message,
106     tflite::metrics::ConverterErrorData::ErrorCode error_code,
107     const std::string& op_name, const Location& location) {
108   using tflite::metrics::ConverterErrorData;
109   ConverterErrorData error;
110   if (!pass_name.empty()) {
111     error.set_subcomponent(pass_name);
112   }
113 
114   if (!error_message.empty()) {
115     error.set_error_message(error_message);
116   }
117 
118   if (!op_name.empty()) {
119     error.mutable_operator_()->set_name(op_name);
120   }
121 
122   error.set_error_code(error_code);
123   LocationExtractor(location).Extract(&error);
124   return error;
125 }
126 
127 }  // namespace TFL
128 }  // namespace mlir
129