1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/lite/metrics/types_util.h"
16
17 #include <string>
18
19 #include "llvm/ADT/TypeSwitch.h"
20 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
21
22 namespace mlir {
23 namespace TFL {
24 namespace {
25
26 // Extracts information from mlir::FileLineColLoc to the proto message
27 // tflite::metrics::ConverterErrorData::FileLoc.
ExtractFileLine(const FileLineColLoc & loc,tflite::metrics::ConverterErrorData::FileLoc * fileline)28 void ExtractFileLine(const FileLineColLoc& loc,
29 tflite::metrics::ConverterErrorData::FileLoc* fileline) {
30 fileline->set_filename(loc.getFilename().str());
31 fileline->set_line(loc.getLine());
32 fileline->set_column(loc.getColumn());
33 }
34
35 // Defines a child class of Location to access its protected members.
36 class LocationExtractor : public Location {
37 public:
LocationExtractor(const Location & loc)38 explicit LocationExtractor(const Location& loc) : Location(loc) {}
39
Extract(tflite::metrics::ConverterErrorData * error_data)40 void Extract(tflite::metrics::ConverterErrorData* error_data) {
41 using tflite::metrics::ConverterErrorData;
42 auto mutable_location = error_data->mutable_location();
43
44 llvm::TypeSwitch<LocationAttr>(impl)
45 .Case<OpaqueLoc>([&](OpaqueLoc loc) {
46 LocationExtractor(loc.getFallbackLocation()).Extract(error_data);
47 })
48 .Case<UnknownLoc>([&](UnknownLoc loc) {
49 mutable_location->set_type(ConverterErrorData::UNKNOWNLOC);
50 })
51 .Case<FileLineColLoc>([&](FileLineColLoc loc) {
52 if (!mutable_location->has_type()) {
53 mutable_location->set_type(ConverterErrorData::CALLSITELOC);
54 }
55 auto new_call = mutable_location->mutable_call()->Add();
56 ExtractFileLine(loc, new_call->mutable_source());
57 })
58 .Case<NameLoc>([&](NameLoc loc) {
59 if (!mutable_location->has_type()) {
60 mutable_location->set_type(ConverterErrorData::NAMELOC);
61 }
62
63 auto new_call = mutable_location->mutable_call()->Add();
64 new_call->set_name(loc.getName().str());
65 // Add child as the source location.
66 auto child_loc = loc.getChildLoc();
67 if (child_loc.isa<FileLineColLoc>()) {
68 auto typed_child_loc = child_loc.dyn_cast<FileLineColLoc>();
69 ExtractFileLine(typed_child_loc, new_call->mutable_source());
70 }
71 })
72 .Case<CallSiteLoc>([&](CallSiteLoc loc) {
73 mutable_location->set_type(ConverterErrorData::CALLSITELOC);
74 LocationExtractor(loc.getCallee()).Extract(error_data);
75 LocationExtractor(loc.getCaller()).Extract(error_data);
76 })
77 .Case<FusedLoc>([&](FusedLoc loc) {
78 auto locations = loc.getLocations();
79 size_t num_locs = locations.size();
80 // Skip the first location if it stores information for propagating
81 // op_type metadata.
82 if (num_locs > 0) {
83 if (auto name_loc = locations[0].dyn_cast<mlir::NameLoc>()) {
84 if (name_loc.getName().strref().endswith(":")) {
85 if (num_locs == 2) {
86 return LocationExtractor(locations[1]).Extract(error_data);
87 } else if (num_locs > 2) {
88 locations = {locations.begin() + 1, locations.end()};
89 }
90 }
91 }
92 }
93
94 mutable_location->set_type(ConverterErrorData::FUSEDLOC);
95 llvm::interleave(
96 locations,
97 [&](Location l) { LocationExtractor(l).Extract(error_data); },
98 [&]() {});
99 });
100 }
101 };
102 } // namespace
103
NewConverterErrorData(const std::string & pass_name,const std::string & error_message,tflite::metrics::ConverterErrorData::ErrorCode error_code,const std::string & op_name,const Location & location)104 tflite::metrics::ConverterErrorData NewConverterErrorData(
105 const std ::string& pass_name, const std::string& error_message,
106 tflite::metrics::ConverterErrorData::ErrorCode error_code,
107 const std::string& op_name, const Location& location) {
108 using tflite::metrics::ConverterErrorData;
109 ConverterErrorData error;
110 if (!pass_name.empty()) {
111 error.set_subcomponent(pass_name);
112 }
113
114 if (!error_message.empty()) {
115 error.set_error_message(error_message);
116 }
117
118 if (!op_name.empty()) {
119 error.mutable_operator_()->set_name(op_name);
120 }
121
122 error.set_error_code(error_code);
123 LocationExtractor(location).Extract(&error);
124 return error;
125 }
126
127 } // namespace TFL
128 } // namespace mlir
129