1 /**
2 * Copyright (c) 2016-present, Facebook, Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <torch/csrc/jit/api/module.h>
18 #include <torch/csrc/jit/mobile/module.h>
19 #include <torch/csrc/jit/serialization/import.h>
20 #include <torch/csrc/jit/runtime/instruction.h>
21 #include <c10/util/Flags.h>
22
23 #include <fstream>
24
25 namespace torch {
26 namespace jit {
dump_opnames(const Module & m,std::unordered_set<std::string> & opnames)27 void dump_opnames(const Module& m, std::unordered_set<std::string>& opnames) {
28 auto methods = m.get_methods();
29 for (const auto& method : methods) {
30 const auto& func = method.function();
31 std::cout << "function name: " << func.name() << std::endl;
32 auto graph = toGraphFunction(func).graph()->copy();
33 torch::jit::Code code(graph, "");
34 for (size_t i = 0; i < code.instructions().size(); ++i) {
35 auto ins = code.instructions()[i];
36 auto node = code.instructions_source()[i];
37 if (ins.op == OpCode::OP) {
38 auto opname = node->schema().operator_name();
39 std::string namestr = opname.name;
40 if (!opname.overload_name.empty())
41 namestr += "." + opname.overload_name;
42 std::cout << " " << namestr << std::endl;
43 opnames.emplace(namestr);
44 }
45 }
46 }
47 for (const auto& sub_m : m.children()) {
48 std::cout << "sub module name: " << sub_m.type()->name()->qualifiedName() << std::endl;
49 dump_opnames(sub_m, opnames);
50 }
51 }
52 }
53 }
54
55 C10_DEFINE_string(model, "", "The given torch script model.");
56 C10_DEFINE_string(output, "", "The output yaml file of operator list.");
57
main(int argc,char ** argv)58 int main(int argc, char** argv) {
59 c10::SetUsageMessage(
60 "Dump operators in a script module and its sub modules.\n"
61 "Example usage:\n"
62 "./dump_operator_names"
63 " --model=<model_file>"
64 " --output=<output.yaml>");
65
66 if (!c10::ParseCommandLineFlags(&argc, &argv)) {
67 std::cerr << "Failed to parse command line flags!" << std::endl;
68 return 1;
69 }
70
71 CAFFE_ENFORCE_GE(FLAGS_model.size(), 0, "Model file must be specified.");
72 CAFFE_ENFORCE_GE(FLAGS_output.size(), 0, "Output yaml file must be specified.");
73
74 auto m = torch::jit::load(FLAGS_model);
75 std::unordered_set<std::string> opnames;
76 torch::jit::dump_opnames(m, opnames);
77 std::ofstream ofile(FLAGS_output);
78 std::cout << "-- Final List --" << std::endl;
79 for (const auto& name : opnames) {
80 std::cout << name << std::endl;
81 ofile << "- " << name << std::endl;
82 }
83 ofile.close();
84 }
85