xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/python/python_list.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Dict.h>
4 #include <ATen/core/List.h>
5 #include <ATen/core/ivalue.h>
6 #include <ATen/core/jit_type.h>
7 #include <pybind11/detail/common.h>
8 #include <torch/csrc/utils/pybind.h>
9 #include <cstddef>
10 #include <optional>
11 #include <stdexcept>
12 
13 namespace torch::jit {
14 
15 void initScriptListBindings(PyObject* module);
16 
17 /// An iterator over the elements of ScriptList. This is used to support
18 /// __iter__(), .
19 class ScriptListIterator final {
20  public:
ScriptListIterator(c10::impl::GenericList::iterator iter,c10::impl::GenericList::iterator end)21   ScriptListIterator(
22       c10::impl::GenericList::iterator iter,
23       c10::impl::GenericList::iterator end)
24       : iter_(iter), end_(end) {}
25   at::IValue next();
26   bool done() const;
27 
28  private:
29   c10::impl::GenericList::iterator iter_;
30   c10::impl::GenericList::iterator end_;
31 };
32 
33 /// A wrapper around c10::List that can be exposed in Python via pybind
34 /// with an API identical to the Python list class. This allows
35 /// lists to have reference semantics across the Python/TorchScript
36 /// boundary.
37 class ScriptList final {
38  public:
39   // TODO: Do these make sense?
40   using size_type = size_t;
41   using diff_type = ptrdiff_t;
42   using ssize_t = Py_ssize_t;
43 
44   // Constructor for empty lists created during slicing, extending, etc.
ScriptList(const at::TypePtr & type)45   ScriptList(const at::TypePtr& type) : list_(at::AnyType::get()) {
46     auto list_type = type->expect<at::ListType>();
47     list_ = c10::impl::GenericList(list_type);
48   }
49 
50   // Constructor for instances based on existing lists (e.g. a
51   // Python instance or a list nested inside another).
ScriptList(const at::IValue & data)52   ScriptList(const at::IValue& data) : list_(at::AnyType::get()) {
53     TORCH_INTERNAL_ASSERT(data.isList());
54     list_ = data.toList();
55   }
56 
type()57   at::ListTypePtr type() const {
58     return at::ListType::create(list_.elementType());
59   }
60 
61   // Return a string representation that can be used
62   // to reconstruct the instance.
repr()63   std::string repr() const {
64     std::ostringstream s;
65     s << '[';
66     bool f = false;
67     for (auto const& elem : list_) {
68       if (f) {
69         s << ", ";
70       }
71       s << at::IValue(elem);
72       f = true;
73     }
74     s << ']';
75     return s.str();
76   }
77 
78   // Return an iterator over the elements of the list.
iter()79   ScriptListIterator iter() const {
80     auto begin = list_.begin();
81     auto end = list_.end();
82     return ScriptListIterator(begin, end);
83   }
84 
85   // Interpret the list as a boolean; empty means false, non-empty means
86   // true.
toBool()87   bool toBool() const {
88     return !(list_.empty());
89   }
90 
91   // Get the value for the given index.
getItem(diff_type idx)92   at::IValue getItem(diff_type idx) {
93     idx = wrap_index(idx);
94     return list_.get(idx);
95   };
96 
97   // Set the value corresponding to the given index.
setItem(diff_type idx,const at::IValue & value)98   void setItem(diff_type idx, const at::IValue& value) {
99     idx = wrap_index(idx);
100     return list_.set(idx, value);
101   }
102 
103   // Check whether the list contains the given value.
contains(const at::IValue & value)104   bool contains(const at::IValue& value) {
105     for (const auto& elem : list_) {
106       if (elem == value) {
107         return true;
108       }
109     }
110 
111     return false;
112   }
113 
114   // Delete the item at the given index from the list.
delItem(diff_type idx)115   void delItem(diff_type idx) {
116     idx = wrap_index(idx);
117     auto iter = list_.begin() + idx;
118     list_.erase(iter);
119   }
120 
121   // Get the size of the list.
len()122   ssize_t len() const {
123     return list_.size();
124   }
125 
126   // Count the number of times a value appears in the list.
count(const at::IValue & value)127   ssize_t count(const at::IValue& value) const {
128     ssize_t total = 0;
129 
130     for (const auto& elem : list_) {
131       if (elem == value) {
132         ++total;
133       }
134     }
135 
136     return total;
137   }
138 
139   // Remove the first occurrence of a value from the list.
remove(const at::IValue & value)140   void remove(const at::IValue& value) {
141     auto list = list_;
142 
143     int64_t idx = -1, i = 0;
144 
145     for (const auto& elem : list) {
146       if (elem == value) {
147         idx = i;
148         break;
149       }
150 
151       ++i;
152     }
153 
154     if (idx == -1) {
155       throw py::value_error();
156     }
157 
158     list.erase(list.begin() + idx);
159   }
160 
161   // Append a value to the end of the list.
append(const at::IValue & value)162   void append(const at::IValue& value) {
163     list_.emplace_back(value);
164   }
165 
166   // Clear the contents of the list.
clear()167   void clear() {
168     list_.clear();
169   }
170 
171   // Append the contents of an iterable to the list.
extend(const at::IValue & iterable)172   void extend(const at::IValue& iterable) {
173     list_.append(iterable.toList());
174   }
175 
176   // Remove and return the element at the specified index from the list. If no
177   // index is passed, the last element is removed and returned.
178   at::IValue pop(std::optional<size_type> idx = std::nullopt) {
179     at::IValue ret;
180 
181     if (idx) {
182       idx = wrap_index(*idx);
183       ret = list_.get(*idx);
184       list_.erase(list_.begin() + *idx);
185     } else {
186       ret = list_.get(list_.size() - 1);
187       list_.pop_back();
188     }
189 
190     return ret;
191   }
192 
193   // Insert a value before the given index.
insert(const at::IValue & value,diff_type idx)194   void insert(const at::IValue& value, diff_type idx) {
195     // wrap_index cannot be used; idx == len() is allowed
196     if (idx < 0) {
197       idx += len();
198     }
199 
200     if (idx < 0 || idx > len()) {
201       throw std::out_of_range("list index out of range");
202     }
203 
204     list_.insert(list_.begin() + idx, value);
205   }
206 
207   // A c10::List instance that holds the actual data.
208   c10::impl::GenericList list_;
209 
210  private:
211   // Wrap an index so that it can safely be used to access
212   // the list. For list of size sz, this function can successfully
213   // wrap indices in the range [-sz, sz-1]
wrap_index(diff_type idx)214   diff_type wrap_index(diff_type idx) {
215     auto sz = len();
216     if (idx < 0) {
217       idx += sz;
218     }
219 
220     if (idx < 0 || idx >= sz) {
221       throw std::out_of_range("list index out of range");
222     }
223 
224     return idx;
225   }
226 };
227 
228 } // namespace torch::jit
229