xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_custom_class_registrations.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <test/cpp/jit/test_custom_class_registrations.h>
2 
3 #include <torch/custom_class.h>
4 #include <torch/script.h>
5 
6 #include <iostream>
7 #include <string>
8 #include <vector>
9 
10 using namespace torch::jit;
11 
12 namespace {
13 
14 struct DefaultArgs : torch::CustomClassHolder {
15   int x;
DefaultArgs__anonb99f12230111::DefaultArgs16   DefaultArgs(int64_t start = 3) : x(start) {}
increment__anonb99f12230111::DefaultArgs17   int64_t increment(int64_t val = 1) {
18     x += val;
19     return x;
20   }
decrement__anonb99f12230111::DefaultArgs21   int64_t decrement(int64_t val = 1) {
22     x += val;
23     return x;
24   }
scale_add__anonb99f12230111::DefaultArgs25   int64_t scale_add(int64_t add, int64_t scale = 1) {
26     // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
27     x = scale * x + add;
28     return x;
29   }
divide__anonb99f12230111::DefaultArgs30   int64_t divide(std::optional<int64_t> factor) {
31     if (factor) {
32       // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
33       x = x / *factor;
34     }
35     return x;
36   }
37 };
38 
39 struct Foo : torch::CustomClassHolder {
40   int x, y;
Foo__anonb99f12230111::Foo41   Foo() : x(0), y(0) {}
Foo__anonb99f12230111::Foo42   Foo(int x_, int y_) : x(x_), y(y_) {}
info__anonb99f12230111::Foo43   int64_t info() {
44     return this->x * this->y;
45   }
add__anonb99f12230111::Foo46   int64_t add(int64_t z) {
47     return (x + y) * z;
48   }
add_tensor__anonb99f12230111::Foo49   at::Tensor add_tensor(at::Tensor z) {
50     return (x + y) * z;
51   }
increment__anonb99f12230111::Foo52   void increment(int64_t z) {
53     this->x += z;
54     this->y += z;
55   }
combine__anonb99f12230111::Foo56   int64_t combine(c10::intrusive_ptr<Foo> b) {
57     return this->info() + b->info();
58   }
eq__anonb99f12230111::Foo59   bool eq(c10::intrusive_ptr<Foo> other) {
60     return this->x == other->x && this->y == other->y;
61   }
62   std::tuple<std::tuple<std::string, int64_t>, std::tuple<std::string, int64_t>>
__obj_flatten____anonb99f12230111::Foo63   __obj_flatten__() {
64     return std::tuple(std::tuple("x", this->x), std::tuple("y", this->y));
65   }
66 };
67 
68 struct _StaticMethod : torch::CustomClassHolder {
69   // NOLINTNEXTLINE(modernize-use-equals-default)
_StaticMethod__anonb99f12230111::_StaticMethod70   _StaticMethod() {}
staticMethod__anonb99f12230111::_StaticMethod71   static int64_t staticMethod(int64_t input) {
72     return 2 * input;
73   }
74 };
75 
76 struct FooGetterSetter : torch::CustomClassHolder {
FooGetterSetter__anonb99f12230111::FooGetterSetter77   FooGetterSetter() : x(0), y(0) {}
FooGetterSetter__anonb99f12230111::FooGetterSetter78   FooGetterSetter(int64_t x_, int64_t y_) : x(x_), y(y_) {}
79 
getX__anonb99f12230111::FooGetterSetter80   int64_t getX() {
81     // to make sure this is not just attribute lookup
82     return x + 2;
83   }
setX__anonb99f12230111::FooGetterSetter84   void setX(int64_t z) {
85     // to make sure this is not just attribute lookup
86     x = z + 2;
87   }
88 
getY__anonb99f12230111::FooGetterSetter89   int64_t getY() {
90     // to make sure this is not just attribute lookup
91     return y + 4;
92   }
93 
94  private:
95   int64_t x, y;
96 };
97 
98 struct FooGetterSetterLambda : torch::CustomClassHolder {
99   int64_t x;
FooGetterSetterLambda__anonb99f12230111::FooGetterSetterLambda100   FooGetterSetterLambda() : x(0) {}
FooGetterSetterLambda__anonb99f12230111::FooGetterSetterLambda101   FooGetterSetterLambda(int64_t x_) : x(x_) {}
102 };
103 
104 struct FooReadWrite : torch::CustomClassHolder {
105   int64_t x;
106   const int64_t y;
FooReadWrite__anonb99f12230111::FooReadWrite107   FooReadWrite() : x(0), y(0) {}
FooReadWrite__anonb99f12230111::FooReadWrite108   FooReadWrite(int64_t x_, int64_t y_) : x(x_), y(y_) {}
109 };
110 
111 struct LambdaInit : torch::CustomClassHolder {
112   int x, y;
LambdaInit__anonb99f12230111::LambdaInit113   LambdaInit(int x_, int y_) : x(x_), y(y_) {}
diff__anonb99f12230111::LambdaInit114   int64_t diff() {
115     return this->x - this->y;
116   }
117 };
118 
119 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
120 struct NoInit : torch::CustomClassHolder {
121   int64_t x;
122 };
123 
124 struct PickleTester : torch::CustomClassHolder {
PickleTester__anonb99f12230111::PickleTester125   PickleTester(std::vector<int64_t> vals) : vals(std::move(vals)) {}
126   std::vector<int64_t> vals;
127 };
128 
129 // Thread-safe Tensor Queue
130 struct TensorQueue : torch::CustomClassHolder {
TensorQueue__anonb99f12230111::TensorQueue131   explicit TensorQueue(at::Tensor t) : init_tensor_(t) {}
132 
TensorQueue__anonb99f12230111::TensorQueue133   explicit TensorQueue(c10::Dict<std::string, at::Tensor> dict) {
134     init_tensor_ = dict.at(std::string("init_tensor"));
135     const std::string key = "queue";
136     at::Tensor size_tensor;
137     size_tensor = dict.at(std::string(key + "/size")).cpu();
138     const auto* size_tensor_acc = size_tensor.const_data_ptr<int64_t>();
139     int64_t queue_size = size_tensor_acc[0];
140 
141     for (const auto index : c10::irange(queue_size)) {
142       at::Tensor val;
143       queue_[index] = dict.at(key + "/" + std::to_string(index));
144       queue_.push_back(val);
145     }
146   }
147 
serialize__anonb99f12230111::TensorQueue148   c10::Dict<std::string, at::Tensor> serialize() const {
149     c10::Dict<std::string, at::Tensor> dict;
150     dict.insert(std::string("init_tensor"), init_tensor_);
151     const std::string key = "queue";
152     dict.insert(
153         key + "/size", torch::tensor(static_cast<int64_t>(queue_.size())));
154     for (const auto index : c10::irange(queue_.size())) {
155       dict.insert(key + "/" + std::to_string(index), queue_[index]);
156     }
157     return dict;
158   }
159   // Push the element to the rear of queue.
160   // Lock is added for thread safe.
push__anonb99f12230111::TensorQueue161   void push(at::Tensor x) {
162     std::lock_guard<std::mutex> guard(mutex_);
163     queue_.push_back(x);
164   }
165   // Pop the front element of queue and return it.
166   // If empty, return init_tensor_.
167   // Lock is added for thread safe.
pop__anonb99f12230111::TensorQueue168   at::Tensor pop() {
169     std::lock_guard<std::mutex> guard(mutex_);
170     if (!queue_.empty()) {
171       auto val = queue_.front();
172       queue_.pop_front();
173       return val;
174     } else {
175       return init_tensor_;
176     }
177   }
178   // Return front element of queue, read-only.
179   // We might further optimize with read-write lock.
top__anonb99f12230111::TensorQueue180   at::Tensor top() {
181     std::lock_guard<std::mutex> guard(mutex_);
182     if (!queue_.empty()) {
183       auto val = queue_.front();
184       return val;
185     } else {
186       return init_tensor_;
187     }
188   }
size__anonb99f12230111::TensorQueue189   int64_t size() {
190     return queue_.size();
191   }
192 
is_empty__anonb99f12230111::TensorQueue193   bool is_empty() {
194     std::lock_guard<std::mutex> guard(mutex_);
195     return queue_.empty();
196   }
197 
float_size__anonb99f12230111::TensorQueue198   double float_size() {
199     return 1. * queue_.size();
200   }
201 
clone_queue__anonb99f12230111::TensorQueue202   std::vector<at::Tensor> clone_queue() {
203     std::lock_guard<std::mutex> guard(mutex_);
204     std::vector<at::Tensor> ret;
205     for (const auto& t : queue_) {
206       ret.push_back(t.clone());
207     }
208     return ret;
209   }
get_raw_queue__anonb99f12230111::TensorQueue210   std::vector<at::Tensor> get_raw_queue() {
211     std::vector<at::Tensor> raw_queue(queue_.begin(), queue_.end());
212     return raw_queue;
213   }
214 
__obj_flatten____anonb99f12230111::TensorQueue215   std::tuple<std::tuple<std::string, std::vector<at::Tensor>>> __obj_flatten__() {
216     return std::tuple(std::tuple("queue", this->get_raw_queue()));
217   }
218 
219  private:
220   std::deque<at::Tensor> queue_;
221   std::mutex mutex_;
222   at::Tensor init_tensor_;
223 };
224 
225 struct ConstantTensorContainer : torch::CustomClassHolder {
ConstantTensorContainer__anonb99f12230111::ConstantTensorContainer226   explicit ConstantTensorContainer(at::Tensor x) : x_(x) {}
227 
get__anonb99f12230111::ConstantTensorContainer228   at::Tensor get() {
229     return x_;
230   }
231 
tracing_mode__anonb99f12230111::ConstantTensorContainer232   std::string tracing_mode() {
233     return "real";
234   }
235 
236  private:
237   at::Tensor x_;
238 };
239 
take_an_instance(const c10::intrusive_ptr<PickleTester> & instance)240 at::Tensor take_an_instance(const c10::intrusive_ptr<PickleTester>& instance) {
241   return torch::zeros({instance->vals.back(), 4});
242 }
243 
244 struct ElementwiseInterpreter : torch::CustomClassHolder {
245   using InstructionType = std::tuple<
246       std::string /*op*/,
247       std::vector<std::string> /*inputs*/,
248       std::string /*output*/>;
249 
250   // NOLINTNEXTLINE(modernize-use-equals-default)
ElementwiseInterpreter__anonb99f12230111::ElementwiseInterpreter251   ElementwiseInterpreter() {}
252 
253   // Load a list of instructions into the interpreter. As specified above,
254   // instructions specify the operation (currently support "add" and "mul"),
255   // the names of the input values, and the name of the single output value
256   // from this instruction
setInstructions__anonb99f12230111::ElementwiseInterpreter257   void setInstructions(std::vector<InstructionType> instructions) {
258     instructions_ = std::move(instructions);
259   }
260 
261   // Add a constant. The interpreter maintains a set of constants across
262   // calls. They are keyed by name, and constants can be referenced in
263   // Instructions by the name specified
addConstant__anonb99f12230111::ElementwiseInterpreter264   void addConstant(const std::string& name, at::Tensor value) {
265     constants_.insert_or_assign(name, std::move(value));
266   }
267 
268   // Set the string names for the positional inputs to the function this
269   // interpreter represents. When invoked, the interpreter will assign
270   // the positional inputs to the names in the corresponding position in
271   // input_names.
setInputNames__anonb99f12230111::ElementwiseInterpreter272   void setInputNames(std::vector<std::string> input_names) {
273     input_names_ = std::move(input_names);
274   }
275 
276   // Specify the output name for the function this interpreter represents. This
277   // should match the "output" field of one of the instructions in the
278   // instruction list, typically the last instruction.
setOutputName__anonb99f12230111::ElementwiseInterpreter279   void setOutputName(std::string output_name) {
280     output_name_ = std::move(output_name);
281   }
282 
283   // Invoke this interpreter. This takes a list of positional inputs and returns
284   // a single output. Currently, inputs and outputs must all be Tensors.
__call____anonb99f12230111::ElementwiseInterpreter285   at::Tensor __call__(std::vector<at::Tensor> inputs) {
286     // Environment to hold local variables
287     std::unordered_map<std::string, at::Tensor> environment;
288 
289     // Load inputs according to the specified names
290     if (inputs.size() != input_names_.size()) {
291       std::stringstream err;
292       err << "Expected " << input_names_.size() << " inputs, but got "
293           << inputs.size() << "!";
294       throw std::runtime_error(err.str());
295     }
296     for (size_t i = 0; i < inputs.size(); ++i) {
297       environment[input_names_[i]] = inputs[i];
298     }
299 
300     for (InstructionType& instr : instructions_) {
301       // Retrieve all input values for this op
302       std::vector<at::Tensor> inputs;
303       for (const auto& input_name : std::get<1>(instr)) {
304         // Operator output values shadow constants.
305         // Imagine all constants are defined in statements at the beginning
306         // of a function (a la K&R C). Any definition of an output value must
307         // necessarily come after constant definition in textual order. Thus,
308         // We look up values in the environment first then the constant table
309         // second to implement this shadowing behavior
310         if (environment.find(input_name) != environment.end()) {
311           inputs.push_back(environment.at(input_name));
312         } else if (constants_.find(input_name) != constants_.end()) {
313           inputs.push_back(constants_.at(input_name));
314         } else {
315           std::stringstream err;
316           err << "Instruction referenced unknown value " << input_name << "!";
317           throw std::runtime_error(err.str());
318         }
319       }
320 
321       // Run the specified operation
322       at::Tensor result;
323       const auto& op = std::get<0>(instr);
324       if (op == "add") {
325         if (inputs.size() != 2) {
326           throw std::runtime_error("Unexpected number of inputs for add op!");
327         }
328         result = inputs[0] + inputs[1];
329       } else if (op == "mul") {
330         if (inputs.size() != 2) {
331           throw std::runtime_error("Unexpected number of inputs for mul op!");
332         }
333         result = inputs[0] * inputs[1];
334       } else {
335         std::stringstream err;
336         err << "Unknown operator " << op << "!";
337         throw std::runtime_error(err.str());
338       }
339 
340       // Write back result into environment
341       const auto& output_name = std::get<2>(instr);
342       environment[output_name] = std::move(result);
343     }
344 
345     if (!output_name_) {
346       throw std::runtime_error("Output name not specified!");
347     }
348 
349     return environment.at(*output_name_);
350   }
351 
352   // Ser/De infrastructure. See
353   // https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html#defining-serialization-deserialization-methods-for-custom-c-classes
354   // for more info.
355 
356   // This is the type we will use to marshall information on disk during
357   // ser/de. It is a simple tuple composed of primitive types and simple
358   // collection types like vector, optional, and dict.
359   using SerializationType = std::tuple<
360       std::vector<std::string> /*input_names_*/,
361       std::optional<std::string> /*output_name_*/,
362       c10::Dict<std::string, at::Tensor> /*constants_*/,
363       std::vector<InstructionType> /*instructions_*/
364       >;
365 
366   // This function yields the SerializationType instance for `this`.
__getstate____anonb99f12230111::ElementwiseInterpreter367   SerializationType __getstate__() const {
368     return SerializationType{
369         input_names_, output_name_, constants_, instructions_};
370   }
371 
372   // This function will create an instance of `ElementwiseInterpreter` given
373   // an instance of `SerializationType`.
__setstate____anonb99f12230111::ElementwiseInterpreter374   static c10::intrusive_ptr<ElementwiseInterpreter> __setstate__(
375       SerializationType state) {
376     auto instance = c10::make_intrusive<ElementwiseInterpreter>();
377     std::tie(
378         instance->input_names_,
379         instance->output_name_,
380         instance->constants_,
381         instance->instructions_) = std::move(state);
382     return instance;
383   }
384 
385   // Class members
386   std::vector<std::string> input_names_;
387   std::optional<std::string> output_name_;
388   c10::Dict<std::string, at::Tensor> constants_;
389   std::vector<InstructionType> instructions_;
390 };
391 
392 struct ReLUClass : public torch::CustomClassHolder {
run__anonb99f12230111::ReLUClass393   at::Tensor run(const at::Tensor& t) {
394     return t.relu();
395   }
396 };
397 
398 struct FlattenWithTensorOp : public torch::CustomClassHolder {
FlattenWithTensorOp__anonb99f12230111::FlattenWithTensorOp399   explicit FlattenWithTensorOp(at::Tensor t) : t_(t) {}
400 
get__anonb99f12230111::FlattenWithTensorOp401   at::Tensor get() {
402     return t_;
403   }
404 
__obj_flatten____anonb99f12230111::FlattenWithTensorOp405   std::tuple<std::tuple<std::string, at::Tensor>> __obj_flatten__() {
406     return std::tuple(std::tuple("t", this->t_.sin()));
407   }
408 
409  private:
410   at::Tensor t_;
411   ;
412 };
413 
414 struct ContainsTensor : public torch::CustomClassHolder {
ContainsTensor__anonb99f12230111::ContainsTensor415   explicit ContainsTensor(at::Tensor t) : t_(t) {}
416 
get__anonb99f12230111::ContainsTensor417   at::Tensor get() {
418     return t_;
419   }
420 
__obj_flatten____anonb99f12230111::ContainsTensor421   std::tuple<std::tuple<std::string, at::Tensor>> __obj_flatten__() {
422     return std::tuple(std::tuple("t", this->t_));
423   }
424 
425   at::Tensor t_;
426 };
427 
TORCH_LIBRARY(_TorchScriptTesting,m)428 TORCH_LIBRARY(_TorchScriptTesting, m) {
429   m.impl_abstract_pystub("torch.testing._internal.torchbind_impls");
430   m.class_<ScalarTypeClass>("_ScalarTypeClass")
431       .def(torch::init<at::ScalarType>())
432       .def_pickle(
433           [](const c10::intrusive_ptr<ScalarTypeClass>& self) {
434             return std::make_tuple(self->scalar_type_);
435           },
436           [](std::tuple<at::ScalarType> s) {
437             return c10::make_intrusive<ScalarTypeClass>(std::get<0>(s));
438           });
439 
440   m.class_<ReLUClass>("_ReLUClass")
441       .def(torch::init<>())
442       .def("run", &ReLUClass::run);
443 
444   m.class_<_StaticMethod>("_StaticMethod")
445       .def(torch::init<>())
446       .def_static("staticMethod", &_StaticMethod::staticMethod);
447 
448   m.class_<DefaultArgs>("_DefaultArgs")
449       .def(torch::init<int64_t>(), "", {torch::arg("start") = 3})
450       .def("increment", &DefaultArgs::increment, "", {torch::arg("val") = 1})
451       .def("decrement", &DefaultArgs::decrement, "", {torch::arg("val") = 1})
452       .def(
453           "scale_add",
454           &DefaultArgs::scale_add,
455           "",
456           {torch::arg("add"), torch::arg("scale") = 1})
457       .def(
458           "divide",
459           &DefaultArgs::divide,
460           "",
461           {torch::arg("factor") = torch::arg::none()});
462 
463   m.class_<Foo>("_Foo")
464       .def(torch::init<int64_t, int64_t>())
465       // .def(torch::init<>())
466       .def("info", &Foo::info)
467       .def("increment", &Foo::increment)
468       .def("add", &Foo::add)
469       .def("add_tensor", &Foo::add_tensor)
470       .def("__eq__", &Foo::eq)
471       .def("combine", &Foo::combine)
472       .def("__obj_flatten__", &Foo::__obj_flatten__)
473       .def_pickle(
474           [](c10::intrusive_ptr<Foo> self) { // __getstate__
475             return std::vector<int64_t>{self->x, self->y};
476           },
477           [](std::vector<int64_t> state) { // __setstate__
478             return c10::make_intrusive<Foo>(state[0], state[1]);
479           });
480 
481   m.class_<FlattenWithTensorOp>("_FlattenWithTensorOp")
482       .def(torch::init<at::Tensor>())
483       .def("get", &FlattenWithTensorOp::get)
484       .def("__obj_flatten__", &FlattenWithTensorOp::__obj_flatten__);
485 
486   m.class_<ConstantTensorContainer>("_ConstantTensorContainer")
487       .def(torch::init<at::Tensor>())
488       .def("get", &ConstantTensorContainer::get)
489       .def("tracing_mode", &ConstantTensorContainer::tracing_mode);
490 
491   m.def(
492       "takes_foo(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> Tensor");
493   m.def(
494       "takes_foo_python_meta(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> Tensor");
495   m.def(
496       "takes_foo_list_return(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> Tensor[]");
497   m.def(
498       "takes_foo_tuple_return(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> (Tensor, Tensor)");
499 
500   m.class_<FooGetterSetter>("_FooGetterSetter")
501       .def(torch::init<int64_t, int64_t>())
502       .def_property("x", &FooGetterSetter::getX, &FooGetterSetter::setX)
503       .def_property("y", &FooGetterSetter::getY);
504 
505   m.class_<FooGetterSetterLambda>("_FooGetterSetterLambda")
506       .def(torch::init<int64_t>())
507       .def_property(
508           "x",
509           [](const c10::intrusive_ptr<FooGetterSetterLambda>& self) {
510             return self->x;
511           },
512           [](const c10::intrusive_ptr<FooGetterSetterLambda>& self,
513              int64_t val) { self->x = val; });
514 
515   m.class_<FooReadWrite>("_FooReadWrite")
516       .def(torch::init<int64_t, int64_t>())
517       .def_readwrite("x", &FooReadWrite::x)
518       .def_readonly("y", &FooReadWrite::y);
519 
520   m.class_<LambdaInit>("_LambdaInit")
521       .def(torch::init([](int64_t x, int64_t y, bool swap) {
522         if (swap) {
523           return c10::make_intrusive<LambdaInit>(y, x);
524         } else {
525           return c10::make_intrusive<LambdaInit>(x, y);
526         }
527       }))
528       .def("diff", &LambdaInit::diff);
529 
530   m.class_<NoInit>("_NoInit").def(
531       "get_x", [](const c10::intrusive_ptr<NoInit>& self) { return self->x; });
532 
533   m.class_<MyStackClass<std::string>>("_StackString")
534       .def(torch::init<std::vector<std::string>>())
535       .def("push", &MyStackClass<std::string>::push)
536       .def("pop", &MyStackClass<std::string>::pop)
537       .def("clone", &MyStackClass<std::string>::clone)
538       .def("merge", &MyStackClass<std::string>::merge)
539       .def_pickle(
540           [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
541             return self->stack_;
542           },
543           [](std::vector<std::string> state) { // __setstate__
544             return c10::make_intrusive<MyStackClass<std::string>>(
545                 std::vector<std::string>{"i", "was", "deserialized"});
546           })
547       .def("return_a_tuple", &MyStackClass<std::string>::return_a_tuple)
548       .def(
549           "top",
550           [](const c10::intrusive_ptr<MyStackClass<std::string>>& self)
551               -> std::string { return self->stack_.back(); })
552       .def(
553           "__str__",
554           [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
555             std::stringstream ss;
556             ss << "[";
557             for (size_t i = 0; i < self->stack_.size(); ++i) {
558               ss << self->stack_[i];
559               if (i != self->stack_.size() - 1) {
560                 ss << ", ";
561               }
562             }
563             ss << "]";
564             return ss.str();
565           });
566   // clang-format off
567         // The following will fail with a static assert telling you you have to
568         // take an intrusive_ptr<MyStackClass> as the first argument.
569         // .def("foo", [](int64_t a) -> int64_t{ return 3;});
570   // clang-format on
571 
572   m.class_<PickleTester>("_PickleTester")
573       .def(torch::init<std::vector<int64_t>>())
574       .def_pickle(
575           [](c10::intrusive_ptr<PickleTester> self) { // __getstate__
576             return std::vector<int64_t>{1, 3, 3, 7};
577           },
578           [](std::vector<int64_t> state) { // __setstate__
579             return c10::make_intrusive<PickleTester>(std::move(state));
580           })
581       .def(
582           "top",
583           [](const c10::intrusive_ptr<PickleTester>& self) {
584             return self->vals.back();
585           })
586       .def("pop", [](const c10::intrusive_ptr<PickleTester>& self) {
587         auto val = self->vals.back();
588         self->vals.pop_back();
589         return val;
590       });
591 
592   m.def(
593       "take_an_instance(__torch__.torch.classes._TorchScriptTesting._PickleTester x) -> Tensor Y",
594       take_an_instance);
595   // test that schema inference is ok too
596   m.def("take_an_instance_inferred", take_an_instance);
597 
598   m.class_<ElementwiseInterpreter>("_ElementwiseInterpreter")
599       .def(torch::init<>())
600       .def("set_instructions", &ElementwiseInterpreter::setInstructions)
601       .def("add_constant", &ElementwiseInterpreter::addConstant)
602       .def("set_input_names", &ElementwiseInterpreter::setInputNames)
603       .def("set_output_name", &ElementwiseInterpreter::setOutputName)
604       .def("__call__", &ElementwiseInterpreter::__call__)
605       .def_pickle(
606           /* __getstate__ */
607           [](const c10::intrusive_ptr<ElementwiseInterpreter>& self) {
608             return self->__getstate__();
609           },
610           /* __setstate__ */
611           [](ElementwiseInterpreter::SerializationType state) {
612             return ElementwiseInterpreter::__setstate__(std::move(state));
613           });
614 
615   m.class_<ContainsTensor>("_ContainsTensor")
616       .def(torch::init<at::Tensor>())
617       .def("get", &ContainsTensor::get)
618       .def("__obj_flatten__", &ContainsTensor::__obj_flatten__)
619       .def_pickle(
620           // __getstate__
621           [](const c10::intrusive_ptr<ContainsTensor>& self) -> at::Tensor {
622             return self->t_;
623           },
624           // __setstate__
625           [](at::Tensor data) -> c10::intrusive_ptr<ContainsTensor> {
626             return c10::make_intrusive<ContainsTensor>(std::move(data));
627           });
628   m.class_<TensorQueue>("_TensorQueue")
629       .def(torch::init<at::Tensor>())
630       .def("push", &TensorQueue::push)
631       .def("pop", &TensorQueue::pop)
632       .def("top", &TensorQueue::top)
633       .def("is_empty", &TensorQueue::is_empty)
634       .def("float_size", &TensorQueue::float_size)
635       .def("size", &TensorQueue::size)
636       .def("clone_queue", &TensorQueue::clone_queue)
637       .def("get_raw_queue", &TensorQueue::get_raw_queue)
638       .def("__obj_flatten__", &TensorQueue::__obj_flatten__)
639       .def_pickle(
640           // __getstate__
641           [](const c10::intrusive_ptr<TensorQueue>& self)
642               -> c10::Dict<std::string, at::Tensor> {
643             return self->serialize();
644           },
645           // __setstate__
646           [](c10::Dict<std::string, at::Tensor> data)
647               -> c10::intrusive_ptr<TensorQueue> {
648             return c10::make_intrusive<TensorQueue>(std::move(data));
649           });
650 }
651 
takes_foo(c10::intrusive_ptr<Foo> foo,at::Tensor x)652 at::Tensor takes_foo(c10::intrusive_ptr<Foo> foo, at::Tensor x) {
653   return foo->add_tensor(x);
654 }
655 
takes_foo_list_return(c10::intrusive_ptr<Foo> foo,at::Tensor x)656 std::vector<at::Tensor> takes_foo_list_return(
657     c10::intrusive_ptr<Foo> foo,
658     at::Tensor x) {
659   std::vector<at::Tensor> result;
660   result.reserve(3);
661   auto a = foo->add_tensor(x);
662   auto b = foo->add_tensor(a);
663   auto c = foo->add_tensor(b);
664   result.push_back(a);
665   result.push_back(b);
666   result.push_back(c);
667   return result;
668 }
669 
takes_foo_tuple_return(c10::intrusive_ptr<Foo> foo,at::Tensor x)670 std::tuple<at::Tensor, at::Tensor> takes_foo_tuple_return(
671     c10::intrusive_ptr<Foo> foo,
672     at::Tensor x) {
673   auto a = foo->add_tensor(x);
674   auto b = foo->add_tensor(a);
675   return std::make_tuple(a, b);
676 }
677 
queue_push(c10::intrusive_ptr<TensorQueue> tq,at::Tensor x)678 void queue_push(c10::intrusive_ptr<TensorQueue> tq, at::Tensor x) {
679   tq->push(x);
680 }
681 
queue_pop(c10::intrusive_ptr<TensorQueue> tq)682 at::Tensor queue_pop(c10::intrusive_ptr<TensorQueue> tq) {
683   return tq->pop();
684 }
685 
queue_size(c10::intrusive_ptr<TensorQueue> tq)686 int64_t queue_size(c10::intrusive_ptr<TensorQueue> tq) {
687   return tq->size();
688 }
689 
TORCH_LIBRARY_FRAGMENT(_TorchScriptTesting,m)690 TORCH_LIBRARY_FRAGMENT(_TorchScriptTesting, m) {
691   m.impl_abstract_pystub("torch.testing._internal.torchbind_impls");
692   m.def(
693       "takes_foo_cia(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> Tensor");
694   m.def(
695       "queue_pop(__torch__.torch.classes._TorchScriptTesting._TensorQueue foo) -> Tensor");
696   m.def(
697       "queue_push(__torch__.torch.classes._TorchScriptTesting._TensorQueue foo, Tensor x) -> ()");
698   m.def(
699       "queue_size(__torch__.torch.classes._TorchScriptTesting._TensorQueue foo) -> int");
700 }
701 
TORCH_LIBRARY_IMPL(_TorchScriptTesting,CPU,m)702 TORCH_LIBRARY_IMPL(_TorchScriptTesting, CPU, m) {
703   m.impl("takes_foo", takes_foo);
704   m.impl("takes_foo_list_return", takes_foo_list_return);
705   m.impl("takes_foo_tuple_return", takes_foo_tuple_return);
706   m.impl("queue_push", queue_push);
707   m.impl("queue_pop", queue_pop);
708   m.impl("queue_size", queue_size);
709 }
710 
TORCH_LIBRARY_IMPL(_TorchScriptTesting,Meta,m)711 TORCH_LIBRARY_IMPL(_TorchScriptTesting, Meta, m) {
712   m.impl("takes_foo", &takes_foo);
713   m.impl("takes_foo_list_return", takes_foo_list_return);
714   m.impl("takes_foo_tuple_return", takes_foo_tuple_return);
715 }
716 
TORCH_LIBRARY_IMPL(_TorchScriptTesting,CompositeImplicitAutograd,m)717 TORCH_LIBRARY_IMPL(_TorchScriptTesting, CompositeImplicitAutograd, m) {
718   m.impl("takes_foo_cia", takes_foo);
719 }
720 
721 // Need to implement BackendSelect because these two operators don't have tensor
722 // inputs.
TORCH_LIBRARY_IMPL(_TorchScriptTesting,BackendSelect,m)723 TORCH_LIBRARY_IMPL(_TorchScriptTesting, BackendSelect, m) {
724   m.impl("queue_pop", queue_pop);
725   m.impl("queue_size", queue_size);
726 }
727 
728 } // namespace
729