xref: /aosp_15_r20/external/executorch/exir/verification/bindings.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker  * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker  *
5*523fa7a6SAndroid Build Coastguard Worker  * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker  * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker  */
8*523fa7a6SAndroid Build Coastguard Worker 
9*523fa7a6SAndroid Build Coastguard Worker #include <cstddef>
10*523fa7a6SAndroid Build Coastguard Worker #include <cstdio>
11*523fa7a6SAndroid Build Coastguard Worker #include <memory>
12*523fa7a6SAndroid Build Coastguard Worker 
13*523fa7a6SAndroid Build Coastguard Worker #include <c10/core/ScalarType.h>
14*523fa7a6SAndroid Build Coastguard Worker #include <c10/macros/Macros.h>
15*523fa7a6SAndroid Build Coastguard Worker #include <c10/util/C++17.h>
16*523fa7a6SAndroid Build Coastguard Worker #include <c10/util/Optional.h>
17*523fa7a6SAndroid Build Coastguard Worker #include <pybind11/pybind11.h>
18*523fa7a6SAndroid Build Coastguard Worker #include <pybind11/stl.h>
19*523fa7a6SAndroid Build Coastguard Worker #include <torch/extension.h> // @manual=//caffe2:torch_extension
20*523fa7a6SAndroid Build Coastguard Worker #include <torch/torch.h> // @manual=//caffe2:torch-cpp-cpu
21*523fa7a6SAndroid Build Coastguard Worker 
22*523fa7a6SAndroid Build Coastguard Worker namespace exir {
23*523fa7a6SAndroid Build Coastguard Worker namespace {
24*523fa7a6SAndroid Build Coastguard Worker 
25*523fa7a6SAndroid Build Coastguard Worker class DataBuffer {
26*523fa7a6SAndroid Build Coastguard Worker  private:
27*523fa7a6SAndroid Build Coastguard Worker   void* buffer_ = nullptr;
28*523fa7a6SAndroid Build Coastguard Worker 
29*523fa7a6SAndroid Build Coastguard Worker  public:
DataBuffer(pybind11::bytes data,int64_t len)30*523fa7a6SAndroid Build Coastguard Worker   DataBuffer(pybind11::bytes data, int64_t len) {
31*523fa7a6SAndroid Build Coastguard Worker     // allocate buffer
32*523fa7a6SAndroid Build Coastguard Worker     buffer_ = malloc(len);
33*523fa7a6SAndroid Build Coastguard Worker     // convert data to std::string and copy to buffer
34*523fa7a6SAndroid Build Coastguard Worker     std::memcpy(buffer_, (std::string{data}).data(), len);
35*523fa7a6SAndroid Build Coastguard Worker   }
~DataBuffer()36*523fa7a6SAndroid Build Coastguard Worker   ~DataBuffer() {
37*523fa7a6SAndroid Build Coastguard Worker     if (buffer_) {
38*523fa7a6SAndroid Build Coastguard Worker       free(buffer_);
39*523fa7a6SAndroid Build Coastguard Worker     }
40*523fa7a6SAndroid Build Coastguard Worker   }
41*523fa7a6SAndroid Build Coastguard Worker   DataBuffer(const DataBuffer&) = delete;
42*523fa7a6SAndroid Build Coastguard Worker   DataBuffer& operator=(const DataBuffer&) = delete;
43*523fa7a6SAndroid Build Coastguard Worker 
get()44*523fa7a6SAndroid Build Coastguard Worker   void* get() {
45*523fa7a6SAndroid Build Coastguard Worker     return buffer_;
46*523fa7a6SAndroid Build Coastguard Worker   }
47*523fa7a6SAndroid Build Coastguard Worker };
48*523fa7a6SAndroid Build Coastguard Worker } // namespace
49*523fa7a6SAndroid Build Coastguard Worker 
PYBIND11_MODULE(bindings,m)50*523fa7a6SAndroid Build Coastguard Worker PYBIND11_MODULE(bindings, m) {
51*523fa7a6SAndroid Build Coastguard Worker   pybind11::class_<DataBuffer>(m, "DataBuffer")
52*523fa7a6SAndroid Build Coastguard Worker       .def(pybind11::init<pybind11::bytes, int64_t>());
53*523fa7a6SAndroid Build Coastguard Worker   m.def(
54*523fa7a6SAndroid Build Coastguard Worker       "convert_to_tensor",
55*523fa7a6SAndroid Build Coastguard Worker       [&](DataBuffer& data_buffer,
56*523fa7a6SAndroid Build Coastguard Worker           const int64_t scalar_type,
57*523fa7a6SAndroid Build Coastguard Worker           const std::vector<int64_t>& sizes,
58*523fa7a6SAndroid Build Coastguard Worker           const std::vector<int64_t>& strides) {
59*523fa7a6SAndroid Build Coastguard Worker         at::ScalarType type_option = static_cast<at::ScalarType>(scalar_type);
60*523fa7a6SAndroid Build Coastguard Worker         auto opts = torch::TensorOptions().dtype(type_option);
61*523fa7a6SAndroid Build Coastguard Worker 
62*523fa7a6SAndroid Build Coastguard Worker         // get tensor from memory using metadata
63*523fa7a6SAndroid Build Coastguard Worker         torch::Tensor result =
64*523fa7a6SAndroid Build Coastguard Worker             torch::from_blob(data_buffer.get(), sizes, strides, opts);
65*523fa7a6SAndroid Build Coastguard Worker         return result;
66*523fa7a6SAndroid Build Coastguard Worker       });
67*523fa7a6SAndroid Build Coastguard Worker }
68*523fa7a6SAndroid Build Coastguard Worker } // namespace exir
69