xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_custom_class_registrations.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/custom_class.h>
2 #include <torch/script.h>
3 
4 namespace torch {
5 namespace jit {
6 
7 struct ScalarTypeClass : public torch::CustomClassHolder {
ScalarTypeClassScalarTypeClass8   ScalarTypeClass(at::ScalarType s) : scalar_type_(s) {}
9   at::ScalarType scalar_type_;
10 };
11 
12 template <class T>
13 struct MyStackClass : torch::CustomClassHolder {
14   std::vector<T> stack_;
MyStackClassMyStackClass15   MyStackClass(std::vector<T> init) : stack_(init.begin(), init.end()) {}
16 
pushMyStackClass17   void push(T x) {
18     stack_.push_back(x);
19   }
popMyStackClass20   T pop() {
21     auto val = stack_.back();
22     stack_.pop_back();
23     return val;
24   }
25 
cloneMyStackClass26   c10::intrusive_ptr<MyStackClass> clone() const {
27     return c10::make_intrusive<MyStackClass>(stack_);
28   }
29 
mergeMyStackClass30   void merge(const c10::intrusive_ptr<MyStackClass>& c) {
31     for (auto& elem : c->stack_) {
32       push(elem);
33     }
34   }
35 
return_a_tupleMyStackClass36   std::tuple<double, int64_t> return_a_tuple() const {
37     return std::make_tuple(1337.0f, 123);
38   }
39 };
40 } // namespace jit
41 } // namespace torch
42