1 #include <test/cpp/jit/test_custom_class_registrations.h>
2
3 #include <torch/custom_class.h>
4 #include <torch/script.h>
5
6 #include <iostream>
7 #include <string>
8 #include <vector>
9
10 using namespace torch::jit;
11
12 namespace {
13
14 struct DefaultArgs : torch::CustomClassHolder {
15 int x;
DefaultArgs__anonb99f12230111::DefaultArgs16 DefaultArgs(int64_t start = 3) : x(start) {}
increment__anonb99f12230111::DefaultArgs17 int64_t increment(int64_t val = 1) {
18 x += val;
19 return x;
20 }
decrement__anonb99f12230111::DefaultArgs21 int64_t decrement(int64_t val = 1) {
22 x += val;
23 return x;
24 }
scale_add__anonb99f12230111::DefaultArgs25 int64_t scale_add(int64_t add, int64_t scale = 1) {
26 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
27 x = scale * x + add;
28 return x;
29 }
divide__anonb99f12230111::DefaultArgs30 int64_t divide(std::optional<int64_t> factor) {
31 if (factor) {
32 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
33 x = x / *factor;
34 }
35 return x;
36 }
37 };
38
39 struct Foo : torch::CustomClassHolder {
40 int x, y;
Foo__anonb99f12230111::Foo41 Foo() : x(0), y(0) {}
Foo__anonb99f12230111::Foo42 Foo(int x_, int y_) : x(x_), y(y_) {}
info__anonb99f12230111::Foo43 int64_t info() {
44 return this->x * this->y;
45 }
add__anonb99f12230111::Foo46 int64_t add(int64_t z) {
47 return (x + y) * z;
48 }
add_tensor__anonb99f12230111::Foo49 at::Tensor add_tensor(at::Tensor z) {
50 return (x + y) * z;
51 }
increment__anonb99f12230111::Foo52 void increment(int64_t z) {
53 this->x += z;
54 this->y += z;
55 }
combine__anonb99f12230111::Foo56 int64_t combine(c10::intrusive_ptr<Foo> b) {
57 return this->info() + b->info();
58 }
eq__anonb99f12230111::Foo59 bool eq(c10::intrusive_ptr<Foo> other) {
60 return this->x == other->x && this->y == other->y;
61 }
62 std::tuple<std::tuple<std::string, int64_t>, std::tuple<std::string, int64_t>>
__obj_flatten____anonb99f12230111::Foo63 __obj_flatten__() {
64 return std::tuple(std::tuple("x", this->x), std::tuple("y", this->y));
65 }
66 };
67
68 struct _StaticMethod : torch::CustomClassHolder {
69 // NOLINTNEXTLINE(modernize-use-equals-default)
_StaticMethod__anonb99f12230111::_StaticMethod70 _StaticMethod() {}
staticMethod__anonb99f12230111::_StaticMethod71 static int64_t staticMethod(int64_t input) {
72 return 2 * input;
73 }
74 };
75
76 struct FooGetterSetter : torch::CustomClassHolder {
FooGetterSetter__anonb99f12230111::FooGetterSetter77 FooGetterSetter() : x(0), y(0) {}
FooGetterSetter__anonb99f12230111::FooGetterSetter78 FooGetterSetter(int64_t x_, int64_t y_) : x(x_), y(y_) {}
79
getX__anonb99f12230111::FooGetterSetter80 int64_t getX() {
81 // to make sure this is not just attribute lookup
82 return x + 2;
83 }
setX__anonb99f12230111::FooGetterSetter84 void setX(int64_t z) {
85 // to make sure this is not just attribute lookup
86 x = z + 2;
87 }
88
getY__anonb99f12230111::FooGetterSetter89 int64_t getY() {
90 // to make sure this is not just attribute lookup
91 return y + 4;
92 }
93
94 private:
95 int64_t x, y;
96 };
97
98 struct FooGetterSetterLambda : torch::CustomClassHolder {
99 int64_t x;
FooGetterSetterLambda__anonb99f12230111::FooGetterSetterLambda100 FooGetterSetterLambda() : x(0) {}
FooGetterSetterLambda__anonb99f12230111::FooGetterSetterLambda101 FooGetterSetterLambda(int64_t x_) : x(x_) {}
102 };
103
104 struct FooReadWrite : torch::CustomClassHolder {
105 int64_t x;
106 const int64_t y;
FooReadWrite__anonb99f12230111::FooReadWrite107 FooReadWrite() : x(0), y(0) {}
FooReadWrite__anonb99f12230111::FooReadWrite108 FooReadWrite(int64_t x_, int64_t y_) : x(x_), y(y_) {}
109 };
110
111 struct LambdaInit : torch::CustomClassHolder {
112 int x, y;
LambdaInit__anonb99f12230111::LambdaInit113 LambdaInit(int x_, int y_) : x(x_), y(y_) {}
diff__anonb99f12230111::LambdaInit114 int64_t diff() {
115 return this->x - this->y;
116 }
117 };
118
119 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
120 struct NoInit : torch::CustomClassHolder {
121 int64_t x;
122 };
123
124 struct PickleTester : torch::CustomClassHolder {
PickleTester__anonb99f12230111::PickleTester125 PickleTester(std::vector<int64_t> vals) : vals(std::move(vals)) {}
126 std::vector<int64_t> vals;
127 };
128
129 // Thread-safe Tensor Queue
130 struct TensorQueue : torch::CustomClassHolder {
TensorQueue__anonb99f12230111::TensorQueue131 explicit TensorQueue(at::Tensor t) : init_tensor_(t) {}
132
TensorQueue__anonb99f12230111::TensorQueue133 explicit TensorQueue(c10::Dict<std::string, at::Tensor> dict) {
134 init_tensor_ = dict.at(std::string("init_tensor"));
135 const std::string key = "queue";
136 at::Tensor size_tensor;
137 size_tensor = dict.at(std::string(key + "/size")).cpu();
138 const auto* size_tensor_acc = size_tensor.const_data_ptr<int64_t>();
139 int64_t queue_size = size_tensor_acc[0];
140
141 for (const auto index : c10::irange(queue_size)) {
142 at::Tensor val;
143 queue_[index] = dict.at(key + "/" + std::to_string(index));
144 queue_.push_back(val);
145 }
146 }
147
serialize__anonb99f12230111::TensorQueue148 c10::Dict<std::string, at::Tensor> serialize() const {
149 c10::Dict<std::string, at::Tensor> dict;
150 dict.insert(std::string("init_tensor"), init_tensor_);
151 const std::string key = "queue";
152 dict.insert(
153 key + "/size", torch::tensor(static_cast<int64_t>(queue_.size())));
154 for (const auto index : c10::irange(queue_.size())) {
155 dict.insert(key + "/" + std::to_string(index), queue_[index]);
156 }
157 return dict;
158 }
159 // Push the element to the rear of queue.
160 // Lock is added for thread safe.
push__anonb99f12230111::TensorQueue161 void push(at::Tensor x) {
162 std::lock_guard<std::mutex> guard(mutex_);
163 queue_.push_back(x);
164 }
165 // Pop the front element of queue and return it.
166 // If empty, return init_tensor_.
167 // Lock is added for thread safe.
pop__anonb99f12230111::TensorQueue168 at::Tensor pop() {
169 std::lock_guard<std::mutex> guard(mutex_);
170 if (!queue_.empty()) {
171 auto val = queue_.front();
172 queue_.pop_front();
173 return val;
174 } else {
175 return init_tensor_;
176 }
177 }
178 // Return front element of queue, read-only.
179 // We might further optimize with read-write lock.
top__anonb99f12230111::TensorQueue180 at::Tensor top() {
181 std::lock_guard<std::mutex> guard(mutex_);
182 if (!queue_.empty()) {
183 auto val = queue_.front();
184 return val;
185 } else {
186 return init_tensor_;
187 }
188 }
size__anonb99f12230111::TensorQueue189 int64_t size() {
190 return queue_.size();
191 }
192
is_empty__anonb99f12230111::TensorQueue193 bool is_empty() {
194 std::lock_guard<std::mutex> guard(mutex_);
195 return queue_.empty();
196 }
197
float_size__anonb99f12230111::TensorQueue198 double float_size() {
199 return 1. * queue_.size();
200 }
201
clone_queue__anonb99f12230111::TensorQueue202 std::vector<at::Tensor> clone_queue() {
203 std::lock_guard<std::mutex> guard(mutex_);
204 std::vector<at::Tensor> ret;
205 for (const auto& t : queue_) {
206 ret.push_back(t.clone());
207 }
208 return ret;
209 }
get_raw_queue__anonb99f12230111::TensorQueue210 std::vector<at::Tensor> get_raw_queue() {
211 std::vector<at::Tensor> raw_queue(queue_.begin(), queue_.end());
212 return raw_queue;
213 }
214
__obj_flatten____anonb99f12230111::TensorQueue215 std::tuple<std::tuple<std::string, std::vector<at::Tensor>>> __obj_flatten__() {
216 return std::tuple(std::tuple("queue", this->get_raw_queue()));
217 }
218
219 private:
220 std::deque<at::Tensor> queue_;
221 std::mutex mutex_;
222 at::Tensor init_tensor_;
223 };
224
225 struct ConstantTensorContainer : torch::CustomClassHolder {
ConstantTensorContainer__anonb99f12230111::ConstantTensorContainer226 explicit ConstantTensorContainer(at::Tensor x) : x_(x) {}
227
get__anonb99f12230111::ConstantTensorContainer228 at::Tensor get() {
229 return x_;
230 }
231
tracing_mode__anonb99f12230111::ConstantTensorContainer232 std::string tracing_mode() {
233 return "real";
234 }
235
236 private:
237 at::Tensor x_;
238 };
239
take_an_instance(const c10::intrusive_ptr<PickleTester> & instance)240 at::Tensor take_an_instance(const c10::intrusive_ptr<PickleTester>& instance) {
241 return torch::zeros({instance->vals.back(), 4});
242 }
243
244 struct ElementwiseInterpreter : torch::CustomClassHolder {
245 using InstructionType = std::tuple<
246 std::string /*op*/,
247 std::vector<std::string> /*inputs*/,
248 std::string /*output*/>;
249
250 // NOLINTNEXTLINE(modernize-use-equals-default)
ElementwiseInterpreter__anonb99f12230111::ElementwiseInterpreter251 ElementwiseInterpreter() {}
252
253 // Load a list of instructions into the interpreter. As specified above,
254 // instructions specify the operation (currently support "add" and "mul"),
255 // the names of the input values, and the name of the single output value
256 // from this instruction
setInstructions__anonb99f12230111::ElementwiseInterpreter257 void setInstructions(std::vector<InstructionType> instructions) {
258 instructions_ = std::move(instructions);
259 }
260
261 // Add a constant. The interpreter maintains a set of constants across
262 // calls. They are keyed by name, and constants can be referenced in
263 // Instructions by the name specified
addConstant__anonb99f12230111::ElementwiseInterpreter264 void addConstant(const std::string& name, at::Tensor value) {
265 constants_.insert_or_assign(name, std::move(value));
266 }
267
268 // Set the string names for the positional inputs to the function this
269 // interpreter represents. When invoked, the interpreter will assign
270 // the positional inputs to the names in the corresponding position in
271 // input_names.
setInputNames__anonb99f12230111::ElementwiseInterpreter272 void setInputNames(std::vector<std::string> input_names) {
273 input_names_ = std::move(input_names);
274 }
275
276 // Specify the output name for the function this interpreter represents. This
277 // should match the "output" field of one of the instructions in the
278 // instruction list, typically the last instruction.
setOutputName__anonb99f12230111::ElementwiseInterpreter279 void setOutputName(std::string output_name) {
280 output_name_ = std::move(output_name);
281 }
282
283 // Invoke this interpreter. This takes a list of positional inputs and returns
284 // a single output. Currently, inputs and outputs must all be Tensors.
__call____anonb99f12230111::ElementwiseInterpreter285 at::Tensor __call__(std::vector<at::Tensor> inputs) {
286 // Environment to hold local variables
287 std::unordered_map<std::string, at::Tensor> environment;
288
289 // Load inputs according to the specified names
290 if (inputs.size() != input_names_.size()) {
291 std::stringstream err;
292 err << "Expected " << input_names_.size() << " inputs, but got "
293 << inputs.size() << "!";
294 throw std::runtime_error(err.str());
295 }
296 for (size_t i = 0; i < inputs.size(); ++i) {
297 environment[input_names_[i]] = inputs[i];
298 }
299
300 for (InstructionType& instr : instructions_) {
301 // Retrieve all input values for this op
302 std::vector<at::Tensor> inputs;
303 for (const auto& input_name : std::get<1>(instr)) {
304 // Operator output values shadow constants.
305 // Imagine all constants are defined in statements at the beginning
306 // of a function (a la K&R C). Any definition of an output value must
307 // necessarily come after constant definition in textual order. Thus,
308 // We look up values in the environment first then the constant table
309 // second to implement this shadowing behavior
310 if (environment.find(input_name) != environment.end()) {
311 inputs.push_back(environment.at(input_name));
312 } else if (constants_.find(input_name) != constants_.end()) {
313 inputs.push_back(constants_.at(input_name));
314 } else {
315 std::stringstream err;
316 err << "Instruction referenced unknown value " << input_name << "!";
317 throw std::runtime_error(err.str());
318 }
319 }
320
321 // Run the specified operation
322 at::Tensor result;
323 const auto& op = std::get<0>(instr);
324 if (op == "add") {
325 if (inputs.size() != 2) {
326 throw std::runtime_error("Unexpected number of inputs for add op!");
327 }
328 result = inputs[0] + inputs[1];
329 } else if (op == "mul") {
330 if (inputs.size() != 2) {
331 throw std::runtime_error("Unexpected number of inputs for mul op!");
332 }
333 result = inputs[0] * inputs[1];
334 } else {
335 std::stringstream err;
336 err << "Unknown operator " << op << "!";
337 throw std::runtime_error(err.str());
338 }
339
340 // Write back result into environment
341 const auto& output_name = std::get<2>(instr);
342 environment[output_name] = std::move(result);
343 }
344
345 if (!output_name_) {
346 throw std::runtime_error("Output name not specified!");
347 }
348
349 return environment.at(*output_name_);
350 }
351
352 // Ser/De infrastructure. See
353 // https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html#defining-serialization-deserialization-methods-for-custom-c-classes
354 // for more info.
355
356 // This is the type we will use to marshall information on disk during
357 // ser/de. It is a simple tuple composed of primitive types and simple
358 // collection types like vector, optional, and dict.
359 using SerializationType = std::tuple<
360 std::vector<std::string> /*input_names_*/,
361 std::optional<std::string> /*output_name_*/,
362 c10::Dict<std::string, at::Tensor> /*constants_*/,
363 std::vector<InstructionType> /*instructions_*/
364 >;
365
366 // This function yields the SerializationType instance for `this`.
__getstate____anonb99f12230111::ElementwiseInterpreter367 SerializationType __getstate__() const {
368 return SerializationType{
369 input_names_, output_name_, constants_, instructions_};
370 }
371
372 // This function will create an instance of `ElementwiseInterpreter` given
373 // an instance of `SerializationType`.
__setstate____anonb99f12230111::ElementwiseInterpreter374 static c10::intrusive_ptr<ElementwiseInterpreter> __setstate__(
375 SerializationType state) {
376 auto instance = c10::make_intrusive<ElementwiseInterpreter>();
377 std::tie(
378 instance->input_names_,
379 instance->output_name_,
380 instance->constants_,
381 instance->instructions_) = std::move(state);
382 return instance;
383 }
384
385 // Class members
386 std::vector<std::string> input_names_;
387 std::optional<std::string> output_name_;
388 c10::Dict<std::string, at::Tensor> constants_;
389 std::vector<InstructionType> instructions_;
390 };
391
392 struct ReLUClass : public torch::CustomClassHolder {
run__anonb99f12230111::ReLUClass393 at::Tensor run(const at::Tensor& t) {
394 return t.relu();
395 }
396 };
397
398 struct FlattenWithTensorOp : public torch::CustomClassHolder {
FlattenWithTensorOp__anonb99f12230111::FlattenWithTensorOp399 explicit FlattenWithTensorOp(at::Tensor t) : t_(t) {}
400
get__anonb99f12230111::FlattenWithTensorOp401 at::Tensor get() {
402 return t_;
403 }
404
__obj_flatten____anonb99f12230111::FlattenWithTensorOp405 std::tuple<std::tuple<std::string, at::Tensor>> __obj_flatten__() {
406 return std::tuple(std::tuple("t", this->t_.sin()));
407 }
408
409 private:
410 at::Tensor t_;
411 ;
412 };
413
414 struct ContainsTensor : public torch::CustomClassHolder {
ContainsTensor__anonb99f12230111::ContainsTensor415 explicit ContainsTensor(at::Tensor t) : t_(t) {}
416
get__anonb99f12230111::ContainsTensor417 at::Tensor get() {
418 return t_;
419 }
420
__obj_flatten____anonb99f12230111::ContainsTensor421 std::tuple<std::tuple<std::string, at::Tensor>> __obj_flatten__() {
422 return std::tuple(std::tuple("t", this->t_));
423 }
424
425 at::Tensor t_;
426 };
427
TORCH_LIBRARY(_TorchScriptTesting,m)428 TORCH_LIBRARY(_TorchScriptTesting, m) {
429 m.impl_abstract_pystub("torch.testing._internal.torchbind_impls");
430 m.class_<ScalarTypeClass>("_ScalarTypeClass")
431 .def(torch::init<at::ScalarType>())
432 .def_pickle(
433 [](const c10::intrusive_ptr<ScalarTypeClass>& self) {
434 return std::make_tuple(self->scalar_type_);
435 },
436 [](std::tuple<at::ScalarType> s) {
437 return c10::make_intrusive<ScalarTypeClass>(std::get<0>(s));
438 });
439
440 m.class_<ReLUClass>("_ReLUClass")
441 .def(torch::init<>())
442 .def("run", &ReLUClass::run);
443
444 m.class_<_StaticMethod>("_StaticMethod")
445 .def(torch::init<>())
446 .def_static("staticMethod", &_StaticMethod::staticMethod);
447
448 m.class_<DefaultArgs>("_DefaultArgs")
449 .def(torch::init<int64_t>(), "", {torch::arg("start") = 3})
450 .def("increment", &DefaultArgs::increment, "", {torch::arg("val") = 1})
451 .def("decrement", &DefaultArgs::decrement, "", {torch::arg("val") = 1})
452 .def(
453 "scale_add",
454 &DefaultArgs::scale_add,
455 "",
456 {torch::arg("add"), torch::arg("scale") = 1})
457 .def(
458 "divide",
459 &DefaultArgs::divide,
460 "",
461 {torch::arg("factor") = torch::arg::none()});
462
463 m.class_<Foo>("_Foo")
464 .def(torch::init<int64_t, int64_t>())
465 // .def(torch::init<>())
466 .def("info", &Foo::info)
467 .def("increment", &Foo::increment)
468 .def("add", &Foo::add)
469 .def("add_tensor", &Foo::add_tensor)
470 .def("__eq__", &Foo::eq)
471 .def("combine", &Foo::combine)
472 .def("__obj_flatten__", &Foo::__obj_flatten__)
473 .def_pickle(
474 [](c10::intrusive_ptr<Foo> self) { // __getstate__
475 return std::vector<int64_t>{self->x, self->y};
476 },
477 [](std::vector<int64_t> state) { // __setstate__
478 return c10::make_intrusive<Foo>(state[0], state[1]);
479 });
480
481 m.class_<FlattenWithTensorOp>("_FlattenWithTensorOp")
482 .def(torch::init<at::Tensor>())
483 .def("get", &FlattenWithTensorOp::get)
484 .def("__obj_flatten__", &FlattenWithTensorOp::__obj_flatten__);
485
486 m.class_<ConstantTensorContainer>("_ConstantTensorContainer")
487 .def(torch::init<at::Tensor>())
488 .def("get", &ConstantTensorContainer::get)
489 .def("tracing_mode", &ConstantTensorContainer::tracing_mode);
490
491 m.def(
492 "takes_foo(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> Tensor");
493 m.def(
494 "takes_foo_python_meta(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> Tensor");
495 m.def(
496 "takes_foo_list_return(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> Tensor[]");
497 m.def(
498 "takes_foo_tuple_return(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> (Tensor, Tensor)");
499
500 m.class_<FooGetterSetter>("_FooGetterSetter")
501 .def(torch::init<int64_t, int64_t>())
502 .def_property("x", &FooGetterSetter::getX, &FooGetterSetter::setX)
503 .def_property("y", &FooGetterSetter::getY);
504
505 m.class_<FooGetterSetterLambda>("_FooGetterSetterLambda")
506 .def(torch::init<int64_t>())
507 .def_property(
508 "x",
509 [](const c10::intrusive_ptr<FooGetterSetterLambda>& self) {
510 return self->x;
511 },
512 [](const c10::intrusive_ptr<FooGetterSetterLambda>& self,
513 int64_t val) { self->x = val; });
514
515 m.class_<FooReadWrite>("_FooReadWrite")
516 .def(torch::init<int64_t, int64_t>())
517 .def_readwrite("x", &FooReadWrite::x)
518 .def_readonly("y", &FooReadWrite::y);
519
520 m.class_<LambdaInit>("_LambdaInit")
521 .def(torch::init([](int64_t x, int64_t y, bool swap) {
522 if (swap) {
523 return c10::make_intrusive<LambdaInit>(y, x);
524 } else {
525 return c10::make_intrusive<LambdaInit>(x, y);
526 }
527 }))
528 .def("diff", &LambdaInit::diff);
529
530 m.class_<NoInit>("_NoInit").def(
531 "get_x", [](const c10::intrusive_ptr<NoInit>& self) { return self->x; });
532
533 m.class_<MyStackClass<std::string>>("_StackString")
534 .def(torch::init<std::vector<std::string>>())
535 .def("push", &MyStackClass<std::string>::push)
536 .def("pop", &MyStackClass<std::string>::pop)
537 .def("clone", &MyStackClass<std::string>::clone)
538 .def("merge", &MyStackClass<std::string>::merge)
539 .def_pickle(
540 [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
541 return self->stack_;
542 },
543 [](std::vector<std::string> state) { // __setstate__
544 return c10::make_intrusive<MyStackClass<std::string>>(
545 std::vector<std::string>{"i", "was", "deserialized"});
546 })
547 .def("return_a_tuple", &MyStackClass<std::string>::return_a_tuple)
548 .def(
549 "top",
550 [](const c10::intrusive_ptr<MyStackClass<std::string>>& self)
551 -> std::string { return self->stack_.back(); })
552 .def(
553 "__str__",
554 [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
555 std::stringstream ss;
556 ss << "[";
557 for (size_t i = 0; i < self->stack_.size(); ++i) {
558 ss << self->stack_[i];
559 if (i != self->stack_.size() - 1) {
560 ss << ", ";
561 }
562 }
563 ss << "]";
564 return ss.str();
565 });
566 // clang-format off
567 // The following will fail with a static assert telling you you have to
568 // take an intrusive_ptr<MyStackClass> as the first argument.
569 // .def("foo", [](int64_t a) -> int64_t{ return 3;});
570 // clang-format on
571
572 m.class_<PickleTester>("_PickleTester")
573 .def(torch::init<std::vector<int64_t>>())
574 .def_pickle(
575 [](c10::intrusive_ptr<PickleTester> self) { // __getstate__
576 return std::vector<int64_t>{1, 3, 3, 7};
577 },
578 [](std::vector<int64_t> state) { // __setstate__
579 return c10::make_intrusive<PickleTester>(std::move(state));
580 })
581 .def(
582 "top",
583 [](const c10::intrusive_ptr<PickleTester>& self) {
584 return self->vals.back();
585 })
586 .def("pop", [](const c10::intrusive_ptr<PickleTester>& self) {
587 auto val = self->vals.back();
588 self->vals.pop_back();
589 return val;
590 });
591
592 m.def(
593 "take_an_instance(__torch__.torch.classes._TorchScriptTesting._PickleTester x) -> Tensor Y",
594 take_an_instance);
595 // test that schema inference is ok too
596 m.def("take_an_instance_inferred", take_an_instance);
597
598 m.class_<ElementwiseInterpreter>("_ElementwiseInterpreter")
599 .def(torch::init<>())
600 .def("set_instructions", &ElementwiseInterpreter::setInstructions)
601 .def("add_constant", &ElementwiseInterpreter::addConstant)
602 .def("set_input_names", &ElementwiseInterpreter::setInputNames)
603 .def("set_output_name", &ElementwiseInterpreter::setOutputName)
604 .def("__call__", &ElementwiseInterpreter::__call__)
605 .def_pickle(
606 /* __getstate__ */
607 [](const c10::intrusive_ptr<ElementwiseInterpreter>& self) {
608 return self->__getstate__();
609 },
610 /* __setstate__ */
611 [](ElementwiseInterpreter::SerializationType state) {
612 return ElementwiseInterpreter::__setstate__(std::move(state));
613 });
614
615 m.class_<ContainsTensor>("_ContainsTensor")
616 .def(torch::init<at::Tensor>())
617 .def("get", &ContainsTensor::get)
618 .def("__obj_flatten__", &ContainsTensor::__obj_flatten__)
619 .def_pickle(
620 // __getstate__
621 [](const c10::intrusive_ptr<ContainsTensor>& self) -> at::Tensor {
622 return self->t_;
623 },
624 // __setstate__
625 [](at::Tensor data) -> c10::intrusive_ptr<ContainsTensor> {
626 return c10::make_intrusive<ContainsTensor>(std::move(data));
627 });
628 m.class_<TensorQueue>("_TensorQueue")
629 .def(torch::init<at::Tensor>())
630 .def("push", &TensorQueue::push)
631 .def("pop", &TensorQueue::pop)
632 .def("top", &TensorQueue::top)
633 .def("is_empty", &TensorQueue::is_empty)
634 .def("float_size", &TensorQueue::float_size)
635 .def("size", &TensorQueue::size)
636 .def("clone_queue", &TensorQueue::clone_queue)
637 .def("get_raw_queue", &TensorQueue::get_raw_queue)
638 .def("__obj_flatten__", &TensorQueue::__obj_flatten__)
639 .def_pickle(
640 // __getstate__
641 [](const c10::intrusive_ptr<TensorQueue>& self)
642 -> c10::Dict<std::string, at::Tensor> {
643 return self->serialize();
644 },
645 // __setstate__
646 [](c10::Dict<std::string, at::Tensor> data)
647 -> c10::intrusive_ptr<TensorQueue> {
648 return c10::make_intrusive<TensorQueue>(std::move(data));
649 });
650 }
651
takes_foo(c10::intrusive_ptr<Foo> foo,at::Tensor x)652 at::Tensor takes_foo(c10::intrusive_ptr<Foo> foo, at::Tensor x) {
653 return foo->add_tensor(x);
654 }
655
takes_foo_list_return(c10::intrusive_ptr<Foo> foo,at::Tensor x)656 std::vector<at::Tensor> takes_foo_list_return(
657 c10::intrusive_ptr<Foo> foo,
658 at::Tensor x) {
659 std::vector<at::Tensor> result;
660 result.reserve(3);
661 auto a = foo->add_tensor(x);
662 auto b = foo->add_tensor(a);
663 auto c = foo->add_tensor(b);
664 result.push_back(a);
665 result.push_back(b);
666 result.push_back(c);
667 return result;
668 }
669
takes_foo_tuple_return(c10::intrusive_ptr<Foo> foo,at::Tensor x)670 std::tuple<at::Tensor, at::Tensor> takes_foo_tuple_return(
671 c10::intrusive_ptr<Foo> foo,
672 at::Tensor x) {
673 auto a = foo->add_tensor(x);
674 auto b = foo->add_tensor(a);
675 return std::make_tuple(a, b);
676 }
677
queue_push(c10::intrusive_ptr<TensorQueue> tq,at::Tensor x)678 void queue_push(c10::intrusive_ptr<TensorQueue> tq, at::Tensor x) {
679 tq->push(x);
680 }
681
queue_pop(c10::intrusive_ptr<TensorQueue> tq)682 at::Tensor queue_pop(c10::intrusive_ptr<TensorQueue> tq) {
683 return tq->pop();
684 }
685
queue_size(c10::intrusive_ptr<TensorQueue> tq)686 int64_t queue_size(c10::intrusive_ptr<TensorQueue> tq) {
687 return tq->size();
688 }
689
TORCH_LIBRARY_FRAGMENT(_TorchScriptTesting,m)690 TORCH_LIBRARY_FRAGMENT(_TorchScriptTesting, m) {
691 m.impl_abstract_pystub("torch.testing._internal.torchbind_impls");
692 m.def(
693 "takes_foo_cia(__torch__.torch.classes._TorchScriptTesting._Foo foo, Tensor x) -> Tensor");
694 m.def(
695 "queue_pop(__torch__.torch.classes._TorchScriptTesting._TensorQueue foo) -> Tensor");
696 m.def(
697 "queue_push(__torch__.torch.classes._TorchScriptTesting._TensorQueue foo, Tensor x) -> ()");
698 m.def(
699 "queue_size(__torch__.torch.classes._TorchScriptTesting._TensorQueue foo) -> int");
700 }
701
TORCH_LIBRARY_IMPL(_TorchScriptTesting,CPU,m)702 TORCH_LIBRARY_IMPL(_TorchScriptTesting, CPU, m) {
703 m.impl("takes_foo", takes_foo);
704 m.impl("takes_foo_list_return", takes_foo_list_return);
705 m.impl("takes_foo_tuple_return", takes_foo_tuple_return);
706 m.impl("queue_push", queue_push);
707 m.impl("queue_pop", queue_pop);
708 m.impl("queue_size", queue_size);
709 }
710
TORCH_LIBRARY_IMPL(_TorchScriptTesting,Meta,m)711 TORCH_LIBRARY_IMPL(_TorchScriptTesting, Meta, m) {
712 m.impl("takes_foo", &takes_foo);
713 m.impl("takes_foo_list_return", takes_foo_list_return);
714 m.impl("takes_foo_tuple_return", takes_foo_tuple_return);
715 }
716
TORCH_LIBRARY_IMPL(_TorchScriptTesting,CompositeImplicitAutograd,m)717 TORCH_LIBRARY_IMPL(_TorchScriptTesting, CompositeImplicitAutograd, m) {
718 m.impl("takes_foo_cia", takes_foo);
719 }
720
721 // Need to implement BackendSelect because these two operators don't have tensor
722 // inputs.
TORCH_LIBRARY_IMPL(_TorchScriptTesting,BackendSelect,m)723 TORCH_LIBRARY_IMPL(_TorchScriptTesting, BackendSelect, m) {
724 m.impl("queue_pop", queue_pop);
725 m.impl("queue_size", queue_size);
726 }
727
728 } // namespace
729