1 #pragma once 2 3 #include <torch/csrc/python_headers.h> 4 5 #include <ATen/core/ivalue.h> 6 #include <ATen/core/symbol.h> 7 #include <c10/util/irange.h> 8 #include <torch/csrc/DynamicTypes.h> 9 #include <torch/csrc/THP.h> 10 #include <torch/csrc/autograd/variable.h> 11 #include <torch/csrc/jit/frontend/tracer.h> 12 #include <torch/csrc/jit/python/pybind_utils.h> 13 #include <torch/csrc/utils/pybind.h> 14 15 #include <pybind11/functional.h> 16 #include <pybind11/pybind11.h> 17 #include <pybind11/stl.h> 18 19 namespace py = pybind11; 20 21 namespace torch::jit { 22 23 // This is a variant of shared_ptr that "sees through" a wrapper. 24 // We use it to convert Value, Node, Block and node to "wrapped" Python 25 // values. When we destruct the C++ object, the wrapper's pointer will 26 // be set to 0 and any future dereferencing will throw. We need this 27 // because the Python objects may hang around after the C++ object 28 // has already been destroyed. 29 // This also needs the magic type_caster below, which is from the 30 // workaround offered in https://github.com/pybind/pybind11/issues/2751 31 template <typename T> 32 class unwrapping_shared_ptr { 33 static_assert( 34 std::is_same_v<T, torch::jit::Value> || 35 std::is_same_v<T, torch::jit::Node> || 36 std::is_same_v<T, torch::jit::Block>, 37 "unwrapping type only defined for Graph object types"); 38 39 private: 40 std::shared_ptr<torch::jit::Wrap<T>> impl; 41 42 public: unwrapping_shared_ptr()43 unwrapping_shared_ptr() : impl({}) {} unwrapping_shared_ptr(T * p)44 explicit unwrapping_shared_ptr(T* p) : impl(p->wrap()) { 45 impl->clear_cb = &clear_registered_instances; 46 } get()47 T* get() const { 48 if (!impl->elem) { 49 throw std::logic_error("has been invalidated"); 50 } 51 return impl->elem; 52 } 53 // we need to disable the overloaded & for PyBind11 < 2.3 due. 54 // see https://github.com/pybind/pybind11/pull/1435 55 #if (PYBIND11_VERSION_MAJOR > 2) || \ 56 ((PYBIND11_VERSION_MAJOR == 2) && (PYBIND11_VERSION_MINOR >= 3)) 57 T** operator&() { 58 if (!impl->elem) { 59 throw std::logic_error("has been invalidated"); 60 } 61 return &(impl->elem); 62 } 63 #endif 64 }; 65 66 } // namespace torch::jit 67 68 PYBIND11_DECLARE_HOLDER_TYPE(T, torch::jit::unwrapping_shared_ptr<T>, true); 69 70 namespace pybind11::detail { 71 72 #define CREATE_UNWRAPPING_CASTER(Class) \ 73 template <> \ 74 struct type_caster<Class> : public type_caster_base<Class> { \ 75 public: \ 76 using type = Class; \ 77 using holder_type = torch::jit::unwrapping_shared_ptr<Class>; \ 78 \ 79 bool load(handle src, bool convert) { \ 80 return load_impl<type_caster<Class>>(src, convert); \ 81 } \ 82 \ 83 explicit operator type*() { \ 84 return static_cast<type*>(value); \ 85 } \ 86 explicit operator type&() { \ 87 return *static_cast<type*>(value); \ 88 } \ 89 \ 90 protected: \ 91 friend class type_caster_generic; \ 92 \ 93 bool load_value(const value_and_holder& v_h) { \ 94 if (v_h.holder_constructed()) { \ 95 value = v_h.template holder<holder_type>().get(); \ 96 return true; \ 97 } else { \ 98 throw cast_error( \ 99 "Unable to cast from non-held to held instance (#Class& to Holder<#Class>)"); \ 100 } \ 101 } \ 102 } 103 104 CREATE_UNWRAPPING_CASTER(torch::jit::Node); 105 CREATE_UNWRAPPING_CASTER(torch::jit::Value); 106 CREATE_UNWRAPPING_CASTER(torch::jit::Block); 107 108 #undef CREATE_UNWRAPPING_CASTER 109 110 template <> 111 struct type_caster<torch::jit::IValue> { 112 public: 113 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 114 PYBIND11_TYPE_CASTER(torch::jit::IValue, _("IValue")); 115 116 bool load(handle src, bool) { 117 try { 118 value = torch::jit::toTypeInferredIValue(src); 119 return true; 120 } catch (std::exception& e) { 121 return false; 122 } 123 } 124 125 static handle cast( 126 torch::jit::IValue src, 127 return_value_policy /* policy */, 128 handle /* parent */) { 129 return torch::jit::toPyObject(std::move(src)).release(); 130 } 131 }; 132 133 template <> 134 struct type_caster<torch::jit::Symbol> { 135 public: 136 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 137 PYBIND11_TYPE_CASTER(torch::jit::Symbol, _("Symbol")); 138 139 bool load(handle src, bool) { 140 // TODO: Is there a way to py::cast that doesn't raise an exception on 141 // failure? Can we catch pybind11::cast_error here instead? 142 std::string src_str; 143 try { 144 src_str = py::cast<std::string>(src); 145 } catch (std::exception& e) { 146 return false; 147 } 148 value = torch::jit::Symbol::fromQualString(src_str); 149 return true; 150 } 151 152 static handle cast( 153 torch::jit::Symbol src, 154 return_value_policy /* policy */, 155 handle /* parent */) { 156 return py::cast(std::string(src.toQualString()), return_value_policy::copy) 157 .release(); 158 } 159 }; 160 161 template <> 162 struct type_caster<torch::jit::AttributeKind> { 163 public: 164 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 165 PYBIND11_TYPE_CASTER(torch::jit::AttributeKind, _("AttributeKind")); 166 167 bool load(handle src, bool) { 168 return false; 169 } 170 171 static handle cast( 172 torch::jit::AttributeKind src, 173 return_value_policy /* policy */, 174 handle /* parent */) { 175 return py::cast( 176 std::string(torch::jit::toString(src)), 177 return_value_policy::copy) 178 .release(); 179 } 180 }; 181 182 // See https://github.com/pybind/pybind11/issues/637 183 using ListCasterBase = pybind11::detail:: 184 list_caster<std::vector<torch::jit::Node*>, torch::jit::Node*>; 185 template <> 186 struct type_caster<std::vector<torch::jit::Node*>> : ListCasterBase { 187 static handle cast( 188 const std::vector<torch::jit::Node*>& src, 189 return_value_policy, 190 handle parent) { 191 return ListCasterBase::cast(src, return_value_policy::reference, parent); 192 } 193 static handle cast( 194 const std::vector<torch::jit::Node*>* src, 195 return_value_policy pol, 196 handle parent) { 197 return cast(*src, pol, parent); 198 } 199 }; 200 201 } // namespace pybind11::detail 202 203 namespace torch::jit { 204 205 static inline py::tuple tuple_tail(const py::tuple& tup) { 206 py::tuple r(tup.size() - 1); 207 for (const auto i : c10::irange(1, tup.size())) { 208 r[i - 1] = tup[i]; 209 } 210 return r; 211 } 212 213 } // namespace torch::jit 214