xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/instruction.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/irange.h>
2 #include <torch/csrc/jit/runtime/instruction.h>
3 #include <cstring>
4 #include <iostream>
5 
6 namespace torch::jit {
operator <<(std::ostream & out,OpCode op)7 static std::ostream& operator<<(std::ostream& out, OpCode op) {
8   switch (op) {
9 #define OP_STRING(x, _) \
10   case x:               \
11     return out << #x;
12     FORALL_OPCODES(OP_STRING)
13 #undef OP_STRING
14   }
15   return out;
16 }
17 
toString(OpCode op)18 char const* toString(OpCode op) {
19   switch (op) {
20 #define OP_STRING(x, _) \
21   case x:               \
22     return #x;
23     FORALL_OPCODES(OP_STRING)
24 #undef OP_STRING
25   }
26   return nullptr;
27 }
28 
OpInfo(OpCode op)29 static const char* OpInfo(OpCode op) {
30   switch (op) {
31 #define OP_INFO(x, info) \
32   case x:                \
33     return info;
34     // NOLINTNEXTLINE(bugprone-branch-clone)
35     FORALL_OPCODES(OP_INFO)
36 #undef OP_INFO
37   }
38   return nullptr;
39 }
40 
41 static constexpr size_t instruction_size = 8;
42 static_assert(
43     sizeof(Instruction) == instruction_size,
44     "Instructions should be 8 bytes");
operator <<(std::ostream & out,Instruction inst)45 std::ostream& operator<<(std::ostream& out, Instruction inst) {
46   // TODO: use op info to print out the op in a more user-friendly way
47   int nargs = std::strlen(OpInfo(inst.op));
48   out << inst.op;
49   if (nargs > 0) {
50     out << " " << inst.X;
51   }
52   if (nargs > 1) {
53     out << " " << inst.N;
54   }
55   return out;
56 }
57 
58 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
59 static constexpr const char* strOpCode[] = {
60 #define STR_OP(x, _) #x,
61     FORALL_OPCODES(STR_OP)
62 #undef STR_OP
63 };
64 
parseOpCode(const char * str)65 OpCode parseOpCode(const char* str) {
66   const int n = sizeof(strOpCode) / sizeof(strOpCode[0]);
67   for (const auto i : c10::irange(n)) {
68     if (strcmp(strOpCode[i], str) == 0)
69       return (OpCode)i;
70   }
71   return OP;
72 }
73 
isOpSupportedInMobile(OpCode op)74 bool isOpSupportedInMobile(OpCode op) {
75   // clang-format off
76   // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
77   static constexpr OpCode supported_ops_in_mobile[] {
78       OP, OPN, LOAD, MOVE, STOREN, STORE, DROP, DROPR, LOADC, JF, JMP, LOOP,
79       RET, GET_ATTR, SET_ATTR, LIST_CONSTRUCT, TUPLE_CONSTRUCT, WARN,
80       INTERFACE_CALL, LIST_UNPACK, TUPLE_SLICE, DICT_CONSTRUCT,
81       NAMED_TUPLE_CONSTRUCT, CREATE_OBJECT, ISINSTANCE, CALL,
82       RAISE_EXCEPTION, UNCHECKED_CAST, __IS__, UN_INITIALIZED,
83       __ISNOT__, FORMAT, DEVICE, DICT_INDEX,
84       DTYPE, TUPLE_INDEX, DIM, __NOT__,
85       TO_LIST, NUM_TO_TENSOR, IS_CUDA};
86   // clang-format on
87 
88   for (auto sop : supported_ops_in_mobile) {
89     if (op == sop)
90       return true;
91   }
92   return false;
93 }
94 
95 } // namespace torch::jit
96