#include #include #include #include namespace at { Tensor custom_add_impl(Tensor t1, Tensor t2) { return t1 + t2; } Tensor fn_with_all_inputs_impl( const Tensor& tensor, const c10::List& tensors, const c10::List>& optional_tensors, const bool b8, const c10::List& b8s, const int64_t i64, const c10::List& i64s, const int64_t& symint, const IntArrayRef symints, const double f64, const c10::List& f64s, const at::Scalar& scalar, at::ArrayRef scalars, const std::string& string, const std::vector& strings, // const c10::ScalarType& dtype, // const MemoryFormat& memory_format, // const Layout& layout, const Device& device, // optional const std::optional& o_tensor, const std::optional>& o_tensors, const std::optional& o_b8, const std::optional>& o_b8s, const std::optional& o_i64, const std::optional>& o_i64s, const std::optional& o_symint, const std::optional& o_symints, const std::optional& o_f64, const std::optional>& o_f64s, const std::optional& o_scalar, const std::optional>& o_scalars, const std::optional& o_string, const std::optional>& o_strings, // const std::optional& o_dtype, // const std::optional& o_memory_format, // const std::optional& o_layout, const std::optional& o_device) { std::cout << "tensor shape: " << tensor.sizes() << std::endl; std::cout << "tensors shape: "; for (auto t : tensors) { std::cout << t.get().toTensor().sizes() << ", "; } std::cout << std::endl; std::cout << "optional tensors shape: "; for (auto t : optional_tensors) { if (t.get().toOptional().has_value()) { std::cout << t.get().toTensor().sizes() << ", "; } else { std::cout << "None, "; } } std::cout << std::endl; std::cout << "b8 " << c10::IValue(b8) << std::endl; std::cout << "b8s " << c10::IValue(b8s) << std::endl; std::cout << "i64 " << c10::IValue(i64) << std::endl; std::cout << "i64s " << c10::IValue(i64s) << std::endl; std::cout << "symint " << c10::IValue(symint) << std::endl; std::cout << "symints " << c10::IValue(symints) << std::endl; std::cout << "f64 " << c10::IValue(f64) << std::endl; std::cout << "f64s " << c10::IValue(f64s) << std::endl; std::cout << "scalar " << c10::IValue(scalar) << std::endl; std::cout << "scalars " << c10::IValue(scalars) << std::endl; std::cout << "string " << c10::IValue(string) << std::endl; std::cout << "strings " << c10::IValue(strings) << std::endl; // std::cout << "dtype " << c10::IValue(dtype) << std::endl; // std::cout << "memory_format " << c10::IValue(memory_format) << std::endl; // std::cout << "layout " << c10::IValue(layout) << std::endl; std::cout << "device " << c10::IValue(device) << std::endl; std::cout << "o_tensor " << (o_tensor.has_value() ? c10::IValue(o_tensor.value().sizes()) : "None") << std::endl; std::cout << "o_tensors shape: "; if (o_tensors.has_value()) { for (auto t : o_tensors.value()) { std::cout << t.get().toTensor().sizes() << ", "; } } else { std::cout << "None"; } std::cout << std::endl; std::cout << "o_b8 " << (o_b8.has_value() ? c10::IValue(o_b8.value()) : "None") << std::endl; std::cout << "o_b8s " << (o_b8s.has_value() ? c10::IValue(o_b8s.value()) : "None") << std::endl; std::cout << "o_i64 " << (o_i64.has_value() ? c10::IValue(o_i64.value()) : "None") << std::endl; std::cout << "o_i64s " << (o_i64s.has_value() ? c10::IValue(o_i64s.value()) : "None") << std::endl; std::cout << "o_symint " << (o_symint.has_value() ? c10::IValue(o_symint.value()) : "None") << std::endl; std::cout << "o_symints " << (o_symints.has_value() ? c10::IValue(o_symints.value()) : "None") << std::endl; std::cout << "o_f64 " << (o_f64.has_value() ? c10::IValue(o_f64.value()) : "None") << std::endl; std::cout << "o_f64s " << (o_f64s.has_value() ? c10::IValue(o_f64s.value()) : "None") << std::endl; std::cout << "o_scalar " << (o_scalar.has_value() ? c10::IValue(o_scalar.value()) : "None") << std::endl; std::cout << "o_scalars " << (o_scalars.has_value() ? c10::IValue(o_scalars.value()) : "None") << std::endl; std::cout << "o_string " << (o_string.has_value() ? c10::IValue(o_string.value()) : "None") << std::endl; std::cout << "o_strings " << (o_strings.has_value() ? c10::IValue(o_strings.value()) : "None") << std::endl; // std::cout << "o_dtype " // << (o_dtype.has_value() ? c10::IValue(o_dtype.value()) : "None") // << std::endl; // std::cout << "o_memory_format " // << (o_memory_format.has_value() // ? c10::IValue(o_memory_format.value()) // : "None") // << std::endl; // std::cout << "o_layout " // << (o_layout.has_value() ? c10::IValue(o_layout.value()) : "None") // << std::endl; std::cout << "o_device " << (o_device.has_value() ? c10::IValue(o_device.value()) : "None") << std::endl; int64_t int_hash = 0; int_hash ^= i64; for (auto i : i64s) { int_hash ^= i; } if (o_i64.has_value()) { int_hash ^= o_i64.value(); } if (o_i64s.has_value()) { for (auto i : o_i64s.value()) { int_hash ^= i; } } int_hash ^= symint; for (auto i : symints) { int_hash ^= i; } if (o_symint.has_value()) { int_hash ^= o_symint.value(); } if (o_symints.has_value()) { for (auto i : o_symints.value()) { int_hash ^= i; } } return tensor + int_hash; } Tensor fn_with_default_input_impl(const Tensor& tensor, const int64_t i64) { return tensor + i64; } std::tuple fn_with_tuple_output_impl( const Tensor& tensor, const int64_t i64) { return {tensor + i64, tensor - i64}; } std::vector fn_with_list_output_impl( TensorList tensors, const int64_t i64) { std::vector outputs; for (auto& t : tensors) { outputs.emplace_back(t + i64); } return outputs; } std::tuple> fn_with_mix_outputs_impl( const Tensor& tensor, TensorList tensors) { std::vector outputs; for (auto& t : tensors) { outputs.emplace_back(t + 2); } return {tensor + 1, outputs}; } std::tuple fn_with_input_mutation_impl( Tensor& t0, const Tensor& t1, Tensor& t2) { t0.add_(1); t2.sub_(1); return {t1 + 1, t1 + 2}; } // NOLINTBEGIN(clang-diagnostic-unused-parameter) Tensor fn_with_all_inputs_meta( const Tensor& tensor, const c10::List& tensors, const c10::List>& optional_tensors, const bool b8, const c10::List& b8s, const int64_t i64, const c10::List& i64s, const c10::SymInt& symint, c10::SymIntArrayRef symints, const double f64, const c10::List& f64s, const at::Scalar& scalar, at::ArrayRef scalars, const std::string& string, const std::vector& strings, // const c10::ScalarType& dtype, // const MemoryFormat& memory_format, // const Layout& layout, const Device& device, // optional const std::optional& o_tensor, const std::optional>& o_tensors, const std::optional& o_b8, const std::optional>& o_b8s, const std::optional& o_i64, const std::optional>& o_i64s, const std::optional& o_symint, at::OptionalSymIntArrayRef o_symints, const std::optional& o_f64, const std::optional>& o_f64s, const std::optional& o_scalar, const std::optional>& o_scalars, const std::optional& o_string, const std::optional>& o_strings, // const std::optional& o_dtype, // const std::optional& o_memory_format, // const std::optional& o_layout, const std::optional& o_device) { return tensor; } Tensor fn_with_default_input_meta(const Tensor& tensor, const int64_t i64) { return tensor.clone(); } std::tuple fn_with_tuple_output_meta( const Tensor& tensor, const int64_t i64) { return {tensor.clone(), tensor.clone()}; } std::vector fn_with_list_output_meta( TensorList tensors, const int64_t i64) { std::vector outputs; for (auto& t : tensors) { outputs.push_back(t.clone()); } return outputs; } std::tuple> fn_with_mix_outputs_meta( const Tensor& tensor, TensorList tensors) { std::vector outputs; for (auto& t : tensors) { outputs.push_back(t.clone()); } return {tensor.clone(), outputs}; } std::tuple fn_with_input_mutation_meta( Tensor& t0, const Tensor& t1, Tensor& t2) { return {t1.clone(), t1.clone()}; } } // namespace at TORCH_LIBRARY(aoti_custom_ops, m) { m.def("custom_add(Tensor t1, Tensor t2) -> Tensor"); m.def( "fn_with_all_inputs(Tensor tensor, " "Tensor[] tensors, " "Tensor?[] optional_tensors, " "bool b8, bool[] b8s, " "int i64, int[] i64s, " "SymInt symint, SymInt[] symints, " "float f64, float[] f64s, " "Scalar scalar, Scalar[] scalars, " "str string, str[] strings, " // "ScalarType dtype, " // "MemoryFormat memory_format, " // "Layout layout, " "Device device, " "*, " "Tensor? o_tensor, Tensor[]? o_tensors, " "bool? o_b8, bool[]? o_b8s, " "int? o_i64, int[]? o_i64s, " "SymInt? o_symint, SymInt[]? o_symints, " "float? o_f64, float[]? o_f64s, " "Scalar? o_scalar, Scalar[]? o_scalars, " "str? o_string, str[]? o_strings, " // "ScalarType? o_dtype, " // "MemoryFormat? o_memory_format, " // "Layout? o_layout, " "Device? o_device) -> Tensor"); m.def("fn_with_default_input(Tensor t, int i=3) -> Tensor"); m.def("fn_with_tuple_output(Tensor t, int i) -> (Tensor, Tensor)"); m.def("fn_with_list_output(Tensor[] tensors, int i) -> Tensor[]"); m.def( "fn_with_mix_outputs(Tensor t, Tensor[] tensors) -> (Tensor, Tensor[])"); m.def( "fn_with_input_mutation(Tensor(a!) t0, Tensor t1, Tensor(b!) t2) -> (Tensor, Tensor)"); } TORCH_LIBRARY_IMPL(aoti_custom_ops, CompositeExplicitAutograd, m) { m.impl("custom_add", at::custom_add_impl); m.impl("fn_with_all_inputs", at::fn_with_all_inputs_impl); m.impl("fn_with_default_input", at::fn_with_default_input_impl); m.impl("fn_with_tuple_output", at::fn_with_tuple_output_impl); m.impl("fn_with_list_output", at::fn_with_list_output_impl); m.impl("fn_with_mix_outputs", at::fn_with_mix_outputs_impl); m.impl("fn_with_input_mutation", at::fn_with_input_mutation_impl); } TORCH_LIBRARY_IMPL(aoti_custom_ops, Meta, m) { m.impl("fn_with_all_inputs", at::fn_with_all_inputs_meta); m.impl("fn_with_default_input", at::fn_with_default_input_meta); m.impl("fn_with_tuple_output", at::fn_with_tuple_output_meta); m.impl("fn_with_list_output", at::fn_with_list_output_meta); m.impl("fn_with_mix_outputs", at::fn_with_mix_outputs_meta); m.impl("fn_with_input_mutation", at::fn_with_input_mutation_meta); }