1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker *
5*523fa7a6SAndroid Build Coastguard Worker * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker */
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Worker #include <cstddef>
10*523fa7a6SAndroid Build Coastguard Worker #include <cstdio>
11*523fa7a6SAndroid Build Coastguard Worker #include <memory>
12*523fa7a6SAndroid Build Coastguard Worker
13*523fa7a6SAndroid Build Coastguard Worker #include <c10/core/ScalarType.h>
14*523fa7a6SAndroid Build Coastguard Worker #include <c10/macros/Macros.h>
15*523fa7a6SAndroid Build Coastguard Worker #include <c10/util/C++17.h>
16*523fa7a6SAndroid Build Coastguard Worker #include <c10/util/Optional.h>
17*523fa7a6SAndroid Build Coastguard Worker #include <pybind11/pybind11.h>
18*523fa7a6SAndroid Build Coastguard Worker #include <pybind11/stl.h>
19*523fa7a6SAndroid Build Coastguard Worker #include <torch/extension.h> // @manual=//caffe2:torch_extension
20*523fa7a6SAndroid Build Coastguard Worker #include <torch/torch.h> // @manual=//caffe2:torch-cpp-cpu
21*523fa7a6SAndroid Build Coastguard Worker
22*523fa7a6SAndroid Build Coastguard Worker namespace exir {
23*523fa7a6SAndroid Build Coastguard Worker namespace {
24*523fa7a6SAndroid Build Coastguard Worker
25*523fa7a6SAndroid Build Coastguard Worker class DataBuffer {
26*523fa7a6SAndroid Build Coastguard Worker private:
27*523fa7a6SAndroid Build Coastguard Worker void* buffer_ = nullptr;
28*523fa7a6SAndroid Build Coastguard Worker
29*523fa7a6SAndroid Build Coastguard Worker public:
DataBuffer(pybind11::bytes data,int64_t len)30*523fa7a6SAndroid Build Coastguard Worker DataBuffer(pybind11::bytes data, int64_t len) {
31*523fa7a6SAndroid Build Coastguard Worker // allocate buffer
32*523fa7a6SAndroid Build Coastguard Worker buffer_ = malloc(len);
33*523fa7a6SAndroid Build Coastguard Worker // convert data to std::string and copy to buffer
34*523fa7a6SAndroid Build Coastguard Worker std::memcpy(buffer_, (std::string{data}).data(), len);
35*523fa7a6SAndroid Build Coastguard Worker }
~DataBuffer()36*523fa7a6SAndroid Build Coastguard Worker ~DataBuffer() {
37*523fa7a6SAndroid Build Coastguard Worker if (buffer_) {
38*523fa7a6SAndroid Build Coastguard Worker free(buffer_);
39*523fa7a6SAndroid Build Coastguard Worker }
40*523fa7a6SAndroid Build Coastguard Worker }
41*523fa7a6SAndroid Build Coastguard Worker DataBuffer(const DataBuffer&) = delete;
42*523fa7a6SAndroid Build Coastguard Worker DataBuffer& operator=(const DataBuffer&) = delete;
43*523fa7a6SAndroid Build Coastguard Worker
get()44*523fa7a6SAndroid Build Coastguard Worker void* get() {
45*523fa7a6SAndroid Build Coastguard Worker return buffer_;
46*523fa7a6SAndroid Build Coastguard Worker }
47*523fa7a6SAndroid Build Coastguard Worker };
48*523fa7a6SAndroid Build Coastguard Worker } // namespace
49*523fa7a6SAndroid Build Coastguard Worker
PYBIND11_MODULE(bindings,m)50*523fa7a6SAndroid Build Coastguard Worker PYBIND11_MODULE(bindings, m) {
51*523fa7a6SAndroid Build Coastguard Worker pybind11::class_<DataBuffer>(m, "DataBuffer")
52*523fa7a6SAndroid Build Coastguard Worker .def(pybind11::init<pybind11::bytes, int64_t>());
53*523fa7a6SAndroid Build Coastguard Worker m.def(
54*523fa7a6SAndroid Build Coastguard Worker "convert_to_tensor",
55*523fa7a6SAndroid Build Coastguard Worker [&](DataBuffer& data_buffer,
56*523fa7a6SAndroid Build Coastguard Worker const int64_t scalar_type,
57*523fa7a6SAndroid Build Coastguard Worker const std::vector<int64_t>& sizes,
58*523fa7a6SAndroid Build Coastguard Worker const std::vector<int64_t>& strides) {
59*523fa7a6SAndroid Build Coastguard Worker at::ScalarType type_option = static_cast<at::ScalarType>(scalar_type);
60*523fa7a6SAndroid Build Coastguard Worker auto opts = torch::TensorOptions().dtype(type_option);
61*523fa7a6SAndroid Build Coastguard Worker
62*523fa7a6SAndroid Build Coastguard Worker // get tensor from memory using metadata
63*523fa7a6SAndroid Build Coastguard Worker torch::Tensor result =
64*523fa7a6SAndroid Build Coastguard Worker torch::from_blob(data_buffer.get(), sizes, strides, opts);
65*523fa7a6SAndroid Build Coastguard Worker return result;
66*523fa7a6SAndroid Build Coastguard Worker });
67*523fa7a6SAndroid Build Coastguard Worker }
68*523fa7a6SAndroid Build Coastguard Worker } // namespace exir
69