1 // Copyright 2018 The Amber Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/vulkan/compute_pipeline.h"
16 
17 #include "src/vulkan/command_pool.h"
18 #include "src/vulkan/device.h"
19 
20 namespace amber {
21 namespace vulkan {
22 
ComputePipeline(Device * device,uint32_t fence_timeout_ms,bool pipeline_runtime_layer_enabled,const std::vector<VkPipelineShaderStageCreateInfo> & shader_stage_info)23 ComputePipeline::ComputePipeline(
24     Device* device,
25     uint32_t fence_timeout_ms,
26     bool pipeline_runtime_layer_enabled,
27     const std::vector<VkPipelineShaderStageCreateInfo>& shader_stage_info)
28     : Pipeline(PipelineType::kCompute,
29                device,
30                fence_timeout_ms,
31                pipeline_runtime_layer_enabled,
32                shader_stage_info) {}
33 
34 ComputePipeline::~ComputePipeline() = default;
35 
Initialize(CommandPool * pool)36 Result ComputePipeline::Initialize(CommandPool* pool) {
37   return Pipeline::Initialize(pool);
38 }
39 
CreateVkComputePipeline(const VkPipelineLayout & pipeline_layout,VkPipeline * pipeline)40 Result ComputePipeline::CreateVkComputePipeline(
41     const VkPipelineLayout& pipeline_layout,
42     VkPipeline* pipeline) {
43   auto shader_stage_info = GetVkShaderStageInfo();
44   if (shader_stage_info.size() != 1) {
45     return Result(
46         "Vulkan::CreateVkComputePipeline number of shaders given to compute "
47         "pipeline is not 1");
48   }
49 
50   if (shader_stage_info[0].stage != VK_SHADER_STAGE_COMPUTE_BIT)
51     return Result("Vulkan: Non compute shader for compute pipeline");
52 
53   shader_stage_info[0].pName = GetEntryPointName(VK_SHADER_STAGE_COMPUTE_BIT);
54 
55   VkComputePipelineCreateInfo pipeline_info = VkComputePipelineCreateInfo();
56   pipeline_info.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
57   pipeline_info.stage = shader_stage_info[0];
58   pipeline_info.layout = pipeline_layout;
59 
60   if (device_->GetPtrs()->vkCreateComputePipelines(
61           device_->GetVkDevice(), VK_NULL_HANDLE, 1, &pipeline_info, nullptr,
62           pipeline) != VK_SUCCESS) {
63     return Result("Vulkan::Calling vkCreateComputePipelines Fail");
64   }
65 
66   return {};
67 }
68 
Compute(uint32_t x,uint32_t y,uint32_t z)69 Result ComputePipeline::Compute(uint32_t x, uint32_t y, uint32_t z) {
70   Result r = SendDescriptorDataToDeviceIfNeeded();
71   if (!r.IsSuccess())
72     return r;
73 
74   VkPipelineLayout pipeline_layout = VK_NULL_HANDLE;
75   r = CreateVkPipelineLayout(&pipeline_layout);
76   if (!r.IsSuccess())
77     return r;
78 
79   VkPipeline pipeline = VK_NULL_HANDLE;
80   r = CreateVkComputePipeline(pipeline_layout, &pipeline);
81   if (!r.IsSuccess())
82     return r;
83 
84   // Note that a command updating a descriptor set and a command using
85   // it must be submitted separately, because using a descriptor set
86   // while updating it is not safe.
87   UpdateDescriptorSetsIfNeeded();
88 
89   {
90     CommandBufferGuard guard(GetCommandBuffer());
91     if (!guard.IsRecording())
92       return guard.GetResult();
93 
94     BindVkDescriptorSets(pipeline_layout);
95 
96     r = RecordPushConstant(pipeline_layout);
97     if (!r.IsSuccess())
98       return r;
99 
100     device_->GetPtrs()->vkCmdBindPipeline(command_->GetVkCommandBuffer(),
101                                           VK_PIPELINE_BIND_POINT_COMPUTE,
102                                           pipeline);
103     device_->GetPtrs()->vkCmdDispatch(command_->GetVkCommandBuffer(), x, y, z);
104 
105     r = guard.Submit(GetFenceTimeout(), GetPipelineRuntimeLayerEnabled());
106     if (!r.IsSuccess())
107       return r;
108   }
109 
110   r = ReadbackDescriptorsToHostDataQueue();
111   if (!r.IsSuccess())
112     return r;
113 
114   device_->GetPtrs()->vkDestroyPipeline(device_->GetVkDevice(), pipeline,
115                                         nullptr);
116   device_->GetPtrs()->vkDestroyPipelineLayout(device_->GetVkDevice(),
117                                               pipeline_layout, nullptr);
118 
119   return {};
120 }
121 
122 }  // namespace vulkan
123 }  // namespace amber
124