1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 #include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h> 12 13 #include <cstring> 14 15 namespace vkcompute { 16 17 // 18 // Staging Buffer <-> Tensor 19 // 20 21 void add_staging_to_tensor_node( 22 ComputeGraph& graph, 23 const ValueRef in_staging, 24 const ValueRef out_tensor); 25 26 void add_tensor_to_staging_node( 27 ComputeGraph& graph, 28 const ValueRef in_tensor, 29 const ValueRef out_staging); 30 31 // 32 // Standard Prepack 33 // 34 35 /* 36 * Given that `v` is a `TensorRef`, create a new `Tensor` value with the 37 * specified `storage_type` and `memory_layout`, and add a a prepacking node to 38 * transfer the `TensorRef` data to the new `Tensor` object via a staging to 39 * tensor shader. The created `Tensor` value is then returned. 40 * 41 * If `passthrough` is `true`, then `v` may be a `Tensor` as well. If `v` is a 42 * `Tensor`, then it is returned as-is. If `passthrough` is `false` (default), 43 * then an exception will be thrown. 44 */ 45 46 ValueRef prepack_standard( 47 ComputeGraph& graph, 48 const ValueRef tensor_data, 49 const utils::StorageType storage_type, 50 const utils::GPUMemoryLayout layout, 51 const bool passthrough = false); 52 53 /* 54 * Equivalent to `prepack_standard()` function, except the `storage_type` and 55 * `memory_layout` are set to match `to_copy`, which must be a `Tensor`. 56 */ 57 ValueRef prepack_standard_like( 58 ComputeGraph& graph, 59 const ValueRef tensor_data, 60 const ValueRef to_copy, 61 const bool passthrough = false); 62 63 // 64 // Direct buffer copy prepack 65 // 66 67 /* 68 * Given that `v` is a `TensorRef`, create a new `Tensor` value with buffer 69 * storage and `kWidthPacked` memory layout, and add a prepacking node to 70 * transfer the `TensorRef` data to the new `Tensor` object via a direct buffer 71 * to buffer copy shader. 72 */ 73 ValueRef prepack_direct_copy_buffer( 74 ComputeGraph& graph, 75 const ValueRef tensor_data); 76 77 } // namespace vkcompute 78