1 #pragma once 2 3 #include <ATen/core/Dict.h> 4 #include <ATen/core/List.h> 5 #include <ATen/core/ivalue.h> 6 #include <ATen/core/jit_type.h> 7 #include <pybind11/detail/common.h> 8 #include <torch/csrc/utils/pybind.h> 9 #include <cstddef> 10 #include <optional> 11 #include <stdexcept> 12 13 namespace torch::jit { 14 15 void initScriptListBindings(PyObject* module); 16 17 /// An iterator over the elements of ScriptList. This is used to support 18 /// __iter__(), . 19 class ScriptListIterator final { 20 public: ScriptListIterator(c10::impl::GenericList::iterator iter,c10::impl::GenericList::iterator end)21 ScriptListIterator( 22 c10::impl::GenericList::iterator iter, 23 c10::impl::GenericList::iterator end) 24 : iter_(iter), end_(end) {} 25 at::IValue next(); 26 bool done() const; 27 28 private: 29 c10::impl::GenericList::iterator iter_; 30 c10::impl::GenericList::iterator end_; 31 }; 32 33 /// A wrapper around c10::List that can be exposed in Python via pybind 34 /// with an API identical to the Python list class. This allows 35 /// lists to have reference semantics across the Python/TorchScript 36 /// boundary. 37 class ScriptList final { 38 public: 39 // TODO: Do these make sense? 40 using size_type = size_t; 41 using diff_type = ptrdiff_t; 42 using ssize_t = Py_ssize_t; 43 44 // Constructor for empty lists created during slicing, extending, etc. ScriptList(const at::TypePtr & type)45 ScriptList(const at::TypePtr& type) : list_(at::AnyType::get()) { 46 auto list_type = type->expect<at::ListType>(); 47 list_ = c10::impl::GenericList(list_type); 48 } 49 50 // Constructor for instances based on existing lists (e.g. a 51 // Python instance or a list nested inside another). ScriptList(const at::IValue & data)52 ScriptList(const at::IValue& data) : list_(at::AnyType::get()) { 53 TORCH_INTERNAL_ASSERT(data.isList()); 54 list_ = data.toList(); 55 } 56 type()57 at::ListTypePtr type() const { 58 return at::ListType::create(list_.elementType()); 59 } 60 61 // Return a string representation that can be used 62 // to reconstruct the instance. repr()63 std::string repr() const { 64 std::ostringstream s; 65 s << '['; 66 bool f = false; 67 for (auto const& elem : list_) { 68 if (f) { 69 s << ", "; 70 } 71 s << at::IValue(elem); 72 f = true; 73 } 74 s << ']'; 75 return s.str(); 76 } 77 78 // Return an iterator over the elements of the list. iter()79 ScriptListIterator iter() const { 80 auto begin = list_.begin(); 81 auto end = list_.end(); 82 return ScriptListIterator(begin, end); 83 } 84 85 // Interpret the list as a boolean; empty means false, non-empty means 86 // true. toBool()87 bool toBool() const { 88 return !(list_.empty()); 89 } 90 91 // Get the value for the given index. getItem(diff_type idx)92 at::IValue getItem(diff_type idx) { 93 idx = wrap_index(idx); 94 return list_.get(idx); 95 }; 96 97 // Set the value corresponding to the given index. setItem(diff_type idx,const at::IValue & value)98 void setItem(diff_type idx, const at::IValue& value) { 99 idx = wrap_index(idx); 100 return list_.set(idx, value); 101 } 102 103 // Check whether the list contains the given value. contains(const at::IValue & value)104 bool contains(const at::IValue& value) { 105 for (const auto& elem : list_) { 106 if (elem == value) { 107 return true; 108 } 109 } 110 111 return false; 112 } 113 114 // Delete the item at the given index from the list. delItem(diff_type idx)115 void delItem(diff_type idx) { 116 idx = wrap_index(idx); 117 auto iter = list_.begin() + idx; 118 list_.erase(iter); 119 } 120 121 // Get the size of the list. len()122 ssize_t len() const { 123 return list_.size(); 124 } 125 126 // Count the number of times a value appears in the list. count(const at::IValue & value)127 ssize_t count(const at::IValue& value) const { 128 ssize_t total = 0; 129 130 for (const auto& elem : list_) { 131 if (elem == value) { 132 ++total; 133 } 134 } 135 136 return total; 137 } 138 139 // Remove the first occurrence of a value from the list. remove(const at::IValue & value)140 void remove(const at::IValue& value) { 141 auto list = list_; 142 143 int64_t idx = -1, i = 0; 144 145 for (const auto& elem : list) { 146 if (elem == value) { 147 idx = i; 148 break; 149 } 150 151 ++i; 152 } 153 154 if (idx == -1) { 155 throw py::value_error(); 156 } 157 158 list.erase(list.begin() + idx); 159 } 160 161 // Append a value to the end of the list. append(const at::IValue & value)162 void append(const at::IValue& value) { 163 list_.emplace_back(value); 164 } 165 166 // Clear the contents of the list. clear()167 void clear() { 168 list_.clear(); 169 } 170 171 // Append the contents of an iterable to the list. extend(const at::IValue & iterable)172 void extend(const at::IValue& iterable) { 173 list_.append(iterable.toList()); 174 } 175 176 // Remove and return the element at the specified index from the list. If no 177 // index is passed, the last element is removed and returned. 178 at::IValue pop(std::optional<size_type> idx = std::nullopt) { 179 at::IValue ret; 180 181 if (idx) { 182 idx = wrap_index(*idx); 183 ret = list_.get(*idx); 184 list_.erase(list_.begin() + *idx); 185 } else { 186 ret = list_.get(list_.size() - 1); 187 list_.pop_back(); 188 } 189 190 return ret; 191 } 192 193 // Insert a value before the given index. insert(const at::IValue & value,diff_type idx)194 void insert(const at::IValue& value, diff_type idx) { 195 // wrap_index cannot be used; idx == len() is allowed 196 if (idx < 0) { 197 idx += len(); 198 } 199 200 if (idx < 0 || idx > len()) { 201 throw std::out_of_range("list index out of range"); 202 } 203 204 list_.insert(list_.begin() + idx, value); 205 } 206 207 // A c10::List instance that holds the actual data. 208 c10::impl::GenericList list_; 209 210 private: 211 // Wrap an index so that it can safely be used to access 212 // the list. For list of size sz, this function can successfully 213 // wrap indices in the range [-sz, sz-1] wrap_index(diff_type idx)214 diff_type wrap_index(diff_type idx) { 215 auto sz = len(); 216 if (idx < 0) { 217 idx += sz; 218 } 219 220 if (idx < 0 || idx >= sz) { 221 throw std::out_of_range("list index out of range"); 222 } 223 224 return idx; 225 } 226 }; 227 228 } // namespace torch::jit 229