1 // Copyright (c) 2016 The vulkano developers
2 // Licensed under the Apache License, Version 2.0
3 // <LICENSE-APACHE or
4 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT
5 // license <LICENSE-MIT or https://opensource.org/licenses/MIT>,
6 // at your option. All files in the project carrying such
7 // notice may not be copied, modified, or distributed except
8 // according to those terms.
9 
10 //! A pipeline that performs general-purpose operations.
11 //!
12 //! A compute pipeline takes buffers and/or images as both inputs and outputs. It operates
13 //! "standalone", with no additional infrastructure such as render passes or vertex input. Compute
14 //! pipelines can be used by themselves for performing work on the Vulkan device, but they can also
15 //! assist graphics operations by precalculating or postprocessing the operations from another kind
16 //! of pipeline. While it theoretically possible to perform graphics operations entirely in a
17 //! compute pipeline, a graphics pipeline is better suited to that task.
18 //!
19 //! A compute pipeline is relatively simple to create, requiring only a pipeline layout and a single
20 //! shader, the *compute shader*. The compute shader is the actual program that performs the work.
21 //! Once created, you can execute a compute pipeline by *binding* it in a command buffer, binding
22 //! any descriptor sets and/or push constants that the pipeline needs, and then issuing a `dispatch`
23 //! command on the command buffer.
24 
25 use super::layout::PipelineLayoutCreateInfo;
26 use crate::{
27     descriptor_set::layout::{
28         DescriptorSetLayout, DescriptorSetLayoutCreateInfo, DescriptorSetLayoutCreationError,
29     },
30     device::{Device, DeviceOwned},
31     macros::impl_id_counter,
32     pipeline::{
33         cache::PipelineCache,
34         layout::{PipelineLayout, PipelineLayoutCreationError, PipelineLayoutSupersetError},
35         Pipeline, PipelineBindPoint,
36     },
37     shader::{DescriptorBindingRequirements, EntryPoint, SpecializationConstants},
38     DeviceSize, OomError, VulkanError, VulkanObject,
39 };
40 use ahash::HashMap;
41 use std::{
42     error::Error,
43     fmt::{Debug, Display, Error as FmtError, Formatter},
44     mem,
45     mem::MaybeUninit,
46     num::NonZeroU64,
47     ptr,
48     sync::Arc,
49 };
50 
51 /// A pipeline object that describes to the Vulkan implementation how it should perform compute
52 /// operations.
53 ///
54 /// The template parameter contains the descriptor set to use with this pipeline.
55 ///
56 /// Pass an optional `Arc` to a `PipelineCache` to enable pipeline caching. The vulkan
57 /// implementation will handle the `PipelineCache` and check if it is available.
58 /// Check the documentation of the `PipelineCache` for more information.
59 pub struct ComputePipeline {
60     handle: ash::vk::Pipeline,
61     device: Arc<Device>,
62     id: NonZeroU64,
63     layout: Arc<PipelineLayout>,
64     descriptor_binding_requirements: HashMap<(u32, u32), DescriptorBindingRequirements>,
65     num_used_descriptor_sets: u32,
66 }
67 
68 impl ComputePipeline {
69     /// Builds a new `ComputePipeline`.
70     ///
71     /// `func` is a closure that is given a mutable reference to the inferred descriptor set
72     /// definitions. This can be used to make changes to the layout before it's created, for example
73     /// to add dynamic buffers or immutable samplers.
new<Css, F>( device: Arc<Device>, shader: EntryPoint<'_>, specialization_constants: &Css, cache: Option<Arc<PipelineCache>>, func: F, ) -> Result<Arc<ComputePipeline>, ComputePipelineCreationError> where Css: SpecializationConstants, F: FnOnce(&mut [DescriptorSetLayoutCreateInfo]),74     pub fn new<Css, F>(
75         device: Arc<Device>,
76         shader: EntryPoint<'_>,
77         specialization_constants: &Css,
78         cache: Option<Arc<PipelineCache>>,
79         func: F,
80     ) -> Result<Arc<ComputePipeline>, ComputePipelineCreationError>
81     where
82         Css: SpecializationConstants,
83         F: FnOnce(&mut [DescriptorSetLayoutCreateInfo]),
84     {
85         let mut set_layout_create_infos = DescriptorSetLayoutCreateInfo::from_requirements(
86             shader.descriptor_binding_requirements(),
87         );
88         func(&mut set_layout_create_infos);
89         let set_layouts = set_layout_create_infos
90             .iter()
91             .map(|desc| DescriptorSetLayout::new(device.clone(), desc.clone()))
92             .collect::<Result<Vec<_>, _>>()?;
93 
94         let layout = PipelineLayout::new(
95             device.clone(),
96             PipelineLayoutCreateInfo {
97                 set_layouts,
98                 push_constant_ranges: shader
99                     .push_constant_requirements()
100                     .cloned()
101                     .into_iter()
102                     .collect(),
103                 ..Default::default()
104             },
105         )?;
106 
107         unsafe {
108             ComputePipeline::with_unchecked_pipeline_layout(
109                 device,
110                 shader,
111                 specialization_constants,
112                 layout,
113                 cache,
114             )
115         }
116     }
117 
118     /// Builds a new `ComputePipeline` with a specific pipeline layout.
119     ///
120     /// An error will be returned if the pipeline layout isn't a superset of what the shader
121     /// uses.
with_pipeline_layout<Css>( device: Arc<Device>, shader: EntryPoint<'_>, specialization_constants: &Css, layout: Arc<PipelineLayout>, cache: Option<Arc<PipelineCache>>, ) -> Result<Arc<ComputePipeline>, ComputePipelineCreationError> where Css: SpecializationConstants,122     pub fn with_pipeline_layout<Css>(
123         device: Arc<Device>,
124         shader: EntryPoint<'_>,
125         specialization_constants: &Css,
126         layout: Arc<PipelineLayout>,
127         cache: Option<Arc<PipelineCache>>,
128     ) -> Result<Arc<ComputePipeline>, ComputePipelineCreationError>
129     where
130         Css: SpecializationConstants,
131     {
132         let spec_descriptors = Css::descriptors();
133 
134         for (constant_id, reqs) in shader.specialization_constant_requirements() {
135             let map_entry = spec_descriptors
136                 .iter()
137                 .find(|desc| desc.constant_id == constant_id)
138                 .ok_or(ComputePipelineCreationError::IncompatibleSpecializationConstants)?;
139 
140             if map_entry.size as DeviceSize != reqs.size {
141                 return Err(ComputePipelineCreationError::IncompatibleSpecializationConstants);
142             }
143         }
144 
145         layout.ensure_compatible_with_shader(
146             shader.descriptor_binding_requirements(),
147             shader.push_constant_requirements(),
148         )?;
149 
150         unsafe {
151             ComputePipeline::with_unchecked_pipeline_layout(
152                 device,
153                 shader,
154                 specialization_constants,
155                 layout,
156                 cache,
157             )
158         }
159     }
160 
161     /// Same as `with_pipeline_layout`, but doesn't check whether the pipeline layout is a
162     /// superset of what the shader expects.
with_unchecked_pipeline_layout<Css>( device: Arc<Device>, shader: EntryPoint<'_>, specialization_constants: &Css, layout: Arc<PipelineLayout>, cache: Option<Arc<PipelineCache>>, ) -> Result<Arc<ComputePipeline>, ComputePipelineCreationError> where Css: SpecializationConstants,163     pub unsafe fn with_unchecked_pipeline_layout<Css>(
164         device: Arc<Device>,
165         shader: EntryPoint<'_>,
166         specialization_constants: &Css,
167         layout: Arc<PipelineLayout>,
168         cache: Option<Arc<PipelineCache>>,
169     ) -> Result<Arc<ComputePipeline>, ComputePipelineCreationError>
170     where
171         Css: SpecializationConstants,
172     {
173         let fns = device.fns();
174 
175         let handle = {
176             let spec_descriptors = Css::descriptors();
177             let specialization = ash::vk::SpecializationInfo {
178                 map_entry_count: spec_descriptors.len() as u32,
179                 p_map_entries: spec_descriptors.as_ptr() as *const _,
180                 data_size: mem::size_of_val(specialization_constants),
181                 p_data: specialization_constants as *const Css as *const _,
182             };
183 
184             let stage = ash::vk::PipelineShaderStageCreateInfo {
185                 flags: ash::vk::PipelineShaderStageCreateFlags::empty(),
186                 stage: ash::vk::ShaderStageFlags::COMPUTE,
187                 module: shader.module().handle(),
188                 p_name: shader.name().as_ptr(),
189                 p_specialization_info: if specialization.data_size == 0 {
190                     ptr::null()
191                 } else {
192                     &specialization
193                 },
194                 ..Default::default()
195             };
196 
197             let infos = ash::vk::ComputePipelineCreateInfo {
198                 flags: ash::vk::PipelineCreateFlags::empty(),
199                 stage,
200                 layout: layout.handle(),
201                 base_pipeline_handle: ash::vk::Pipeline::null(),
202                 base_pipeline_index: 0,
203                 ..Default::default()
204             };
205 
206             let cache_handle = match cache {
207                 Some(ref cache) => cache.handle(),
208                 None => ash::vk::PipelineCache::null(),
209             };
210 
211             let mut output = MaybeUninit::uninit();
212             (fns.v1_0.create_compute_pipelines)(
213                 device.handle(),
214                 cache_handle,
215                 1,
216                 &infos,
217                 ptr::null(),
218                 output.as_mut_ptr(),
219             )
220             .result()
221             .map_err(VulkanError::from)?;
222             output.assume_init()
223         };
224 
225         let descriptor_binding_requirements: HashMap<_, _> = shader
226             .descriptor_binding_requirements()
227             .map(|(loc, reqs)| (loc, reqs.clone()))
228             .collect();
229         let num_used_descriptor_sets = descriptor_binding_requirements
230             .keys()
231             .map(|loc| loc.0)
232             .max()
233             .map(|x| x + 1)
234             .unwrap_or(0);
235 
236         Ok(Arc::new(ComputePipeline {
237             handle,
238             device: device.clone(),
239             id: Self::next_id(),
240             layout,
241             descriptor_binding_requirements,
242             num_used_descriptor_sets,
243         }))
244     }
245 
246     /// Returns the `Device` this compute pipeline was created with.
247     #[inline]
device(&self) -> &Arc<Device>248     pub fn device(&self) -> &Arc<Device> {
249         &self.device
250     }
251 }
252 
253 impl Pipeline for ComputePipeline {
254     #[inline]
bind_point(&self) -> PipelineBindPoint255     fn bind_point(&self) -> PipelineBindPoint {
256         PipelineBindPoint::Compute
257     }
258 
259     #[inline]
layout(&self) -> &Arc<PipelineLayout>260     fn layout(&self) -> &Arc<PipelineLayout> {
261         &self.layout
262     }
263 
264     #[inline]
num_used_descriptor_sets(&self) -> u32265     fn num_used_descriptor_sets(&self) -> u32 {
266         self.num_used_descriptor_sets
267     }
268 
269     #[inline]
descriptor_binding_requirements( &self, ) -> &HashMap<(u32, u32), DescriptorBindingRequirements>270     fn descriptor_binding_requirements(
271         &self,
272     ) -> &HashMap<(u32, u32), DescriptorBindingRequirements> {
273         &self.descriptor_binding_requirements
274     }
275 }
276 
277 impl Debug for ComputePipeline {
fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError>278     fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError> {
279         write!(f, "<Vulkan compute pipeline {:?}>", self.handle)
280     }
281 }
282 
283 impl_id_counter!(ComputePipeline);
284 
285 unsafe impl VulkanObject for ComputePipeline {
286     type Handle = ash::vk::Pipeline;
287 
288     #[inline]
handle(&self) -> Self::Handle289     fn handle(&self) -> Self::Handle {
290         self.handle
291     }
292 }
293 
294 unsafe impl DeviceOwned for ComputePipeline {
295     #[inline]
device(&self) -> &Arc<Device>296     fn device(&self) -> &Arc<Device> {
297         self.device()
298     }
299 }
300 
301 impl Drop for ComputePipeline {
302     #[inline]
drop(&mut self)303     fn drop(&mut self) {
304         unsafe {
305             let fns = self.device.fns();
306             (fns.v1_0.destroy_pipeline)(self.device.handle(), self.handle, ptr::null());
307         }
308     }
309 }
310 
311 /// Error that can happen when creating a compute pipeline.
312 #[derive(Clone, Debug, PartialEq, Eq)]
313 pub enum ComputePipelineCreationError {
314     /// Not enough memory.
315     OomError(OomError),
316     /// Error while creating a descriptor set layout object.
317     DescriptorSetLayoutCreationError(DescriptorSetLayoutCreationError),
318     /// Error while creating the pipeline layout object.
319     PipelineLayoutCreationError(PipelineLayoutCreationError),
320     /// The pipeline layout is not compatible with what the shader expects.
321     IncompatiblePipelineLayout(PipelineLayoutSupersetError),
322     /// The provided specialization constants are not compatible with what the shader expects.
323     IncompatibleSpecializationConstants,
324 }
325 
326 impl Error for ComputePipelineCreationError {
source(&self) -> Option<&(dyn Error + 'static)>327     fn source(&self) -> Option<&(dyn Error + 'static)> {
328         match self {
329             Self::OomError(err) => Some(err),
330             Self::DescriptorSetLayoutCreationError(err) => Some(err),
331             Self::PipelineLayoutCreationError(err) => Some(err),
332             Self::IncompatiblePipelineLayout(err) => Some(err),
333             Self::IncompatibleSpecializationConstants => None,
334         }
335     }
336 }
337 
338 impl Display for ComputePipelineCreationError {
fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError>339     fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError> {
340         write!(
341             f,
342             "{}",
343             match self {
344                 ComputePipelineCreationError::OomError(_) => "not enough memory available",
345                 ComputePipelineCreationError::DescriptorSetLayoutCreationError(_) => {
346                     "error while creating a descriptor set layout object"
347                 }
348                 ComputePipelineCreationError::PipelineLayoutCreationError(_) => {
349                     "error while creating the pipeline layout object"
350                 }
351                 ComputePipelineCreationError::IncompatiblePipelineLayout(_) => {
352                     "the pipeline layout is not compatible with what the shader expects"
353                 }
354                 ComputePipelineCreationError::IncompatibleSpecializationConstants => {
355                     "the provided specialization constants are not compatible with what the shader \
356                     expects"
357                 }
358             }
359         )
360     }
361 }
362 
363 impl From<OomError> for ComputePipelineCreationError {
from(err: OomError) -> ComputePipelineCreationError364     fn from(err: OomError) -> ComputePipelineCreationError {
365         Self::OomError(err)
366     }
367 }
368 
369 impl From<DescriptorSetLayoutCreationError> for ComputePipelineCreationError {
from(err: DescriptorSetLayoutCreationError) -> Self370     fn from(err: DescriptorSetLayoutCreationError) -> Self {
371         Self::DescriptorSetLayoutCreationError(err)
372     }
373 }
374 
375 impl From<PipelineLayoutCreationError> for ComputePipelineCreationError {
from(err: PipelineLayoutCreationError) -> Self376     fn from(err: PipelineLayoutCreationError) -> Self {
377         Self::PipelineLayoutCreationError(err)
378     }
379 }
380 
381 impl From<PipelineLayoutSupersetError> for ComputePipelineCreationError {
from(err: PipelineLayoutSupersetError) -> Self382     fn from(err: PipelineLayoutSupersetError) -> Self {
383         Self::IncompatiblePipelineLayout(err)
384     }
385 }
386 
387 impl From<VulkanError> for ComputePipelineCreationError {
from(err: VulkanError) -> ComputePipelineCreationError388     fn from(err: VulkanError) -> ComputePipelineCreationError {
389         match err {
390             err @ VulkanError::OutOfHostMemory => Self::OomError(OomError::from(err)),
391             err @ VulkanError::OutOfDeviceMemory => Self::OomError(OomError::from(err)),
392             _ => panic!("unexpected error: {:?}", err),
393         }
394     }
395 }
396 
397 #[cfg(test)]
398 mod tests {
399     use crate::{
400         buffer::{Buffer, BufferCreateInfo, BufferUsage},
401         command_buffer::{
402             allocator::StandardCommandBufferAllocator, AutoCommandBufferBuilder, CommandBufferUsage,
403         },
404         descriptor_set::{
405             allocator::StandardDescriptorSetAllocator, PersistentDescriptorSet, WriteDescriptorSet,
406         },
407         memory::allocator::{AllocationCreateInfo, MemoryUsage, StandardMemoryAllocator},
408         pipeline::{ComputePipeline, Pipeline, PipelineBindPoint},
409         shader::{ShaderModule, SpecializationConstants, SpecializationMapEntry},
410         sync::{now, GpuFuture},
411     };
412 
413     // TODO: test for basic creation
414     // TODO: test for pipeline layout error
415 
416     #[test]
specialization_constants()417     fn specialization_constants() {
418         // This test checks whether specialization constants work.
419         // It executes a single compute shader (one invocation) that writes the value of a spec.
420         // constant to a buffer. The buffer content is then checked for the right value.
421 
422         let (device, queue) = gfx_dev_and_queue!();
423 
424         let module = unsafe {
425             /*
426             #version 450
427 
428             layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
429 
430             layout(constant_id = 83) const int VALUE = 0xdeadbeef;
431 
432             layout(set = 0, binding = 0) buffer Output {
433                 int write;
434             } write;
435 
436             void main() {
437                 write.write = VALUE;
438             }
439             */
440             const MODULE: [u8; 480] = [
441                 3, 2, 35, 7, 0, 0, 1, 0, 1, 0, 8, 0, 14, 0, 0, 0, 0, 0, 0, 0, 17, 0, 2, 0, 1, 0, 0,
442                 0, 11, 0, 6, 0, 1, 0, 0, 0, 71, 76, 83, 76, 46, 115, 116, 100, 46, 52, 53, 48, 0,
443                 0, 0, 0, 14, 0, 3, 0, 0, 0, 0, 0, 1, 0, 0, 0, 15, 0, 5, 0, 5, 0, 0, 0, 4, 0, 0, 0,
444                 109, 97, 105, 110, 0, 0, 0, 0, 16, 0, 6, 0, 4, 0, 0, 0, 17, 0, 0, 0, 1, 0, 0, 0, 1,
445                 0, 0, 0, 1, 0, 0, 0, 3, 0, 3, 0, 2, 0, 0, 0, 194, 1, 0, 0, 5, 0, 4, 0, 4, 0, 0, 0,
446                 109, 97, 105, 110, 0, 0, 0, 0, 5, 0, 4, 0, 7, 0, 0, 0, 79, 117, 116, 112, 117, 116,
447                 0, 0, 6, 0, 5, 0, 7, 0, 0, 0, 0, 0, 0, 0, 119, 114, 105, 116, 101, 0, 0, 0, 5, 0,
448                 4, 0, 9, 0, 0, 0, 119, 114, 105, 116, 101, 0, 0, 0, 5, 0, 4, 0, 11, 0, 0, 0, 86,
449                 65, 76, 85, 69, 0, 0, 0, 72, 0, 5, 0, 7, 0, 0, 0, 0, 0, 0, 0, 35, 0, 0, 0, 0, 0, 0,
450                 0, 71, 0, 3, 0, 7, 0, 0, 0, 3, 0, 0, 0, 71, 0, 4, 0, 9, 0, 0, 0, 34, 0, 0, 0, 0, 0,
451                 0, 0, 71, 0, 4, 0, 9, 0, 0, 0, 33, 0, 0, 0, 0, 0, 0, 0, 71, 0, 4, 0, 11, 0, 0, 0,
452                 1, 0, 0, 0, 83, 0, 0, 0, 19, 0, 2, 0, 2, 0, 0, 0, 33, 0, 3, 0, 3, 0, 0, 0, 2, 0, 0,
453                 0, 21, 0, 4, 0, 6, 0, 0, 0, 32, 0, 0, 0, 1, 0, 0, 0, 30, 0, 3, 0, 7, 0, 0, 0, 6, 0,
454                 0, 0, 32, 0, 4, 0, 8, 0, 0, 0, 2, 0, 0, 0, 7, 0, 0, 0, 59, 0, 4, 0, 8, 0, 0, 0, 9,
455                 0, 0, 0, 2, 0, 0, 0, 43, 0, 4, 0, 6, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 50, 0, 4, 0,
456                 6, 0, 0, 0, 11, 0, 0, 0, 239, 190, 173, 222, 32, 0, 4, 0, 12, 0, 0, 0, 2, 0, 0, 0,
457                 6, 0, 0, 0, 54, 0, 5, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 248, 0, 2,
458                 0, 5, 0, 0, 0, 65, 0, 5, 0, 12, 0, 0, 0, 13, 0, 0, 0, 9, 0, 0, 0, 10, 0, 0, 0, 62,
459                 0, 3, 0, 13, 0, 0, 0, 11, 0, 0, 0, 253, 0, 1, 0, 56, 0, 1, 0,
460             ];
461             ShaderModule::from_bytes(device.clone(), &MODULE).unwrap()
462         };
463 
464         #[derive(Debug, Copy, Clone)]
465         #[allow(non_snake_case)]
466         #[repr(C)]
467         struct SpecConsts {
468             VALUE: i32,
469         }
470         unsafe impl SpecializationConstants for SpecConsts {
471             fn descriptors() -> &'static [SpecializationMapEntry] {
472                 static DESCRIPTORS: [SpecializationMapEntry; 1] = [SpecializationMapEntry {
473                     constant_id: 83,
474                     offset: 0,
475                     size: 4,
476                 }];
477                 &DESCRIPTORS
478             }
479         }
480 
481         let pipeline = ComputePipeline::new(
482             device.clone(),
483             module.entry_point("main").unwrap(),
484             &SpecConsts { VALUE: 0x12345678 },
485             None,
486             |_| {},
487         )
488         .unwrap();
489 
490         let memory_allocator = StandardMemoryAllocator::new_default(device.clone());
491         let data_buffer = Buffer::from_data(
492             &memory_allocator,
493             BufferCreateInfo {
494                 usage: BufferUsage::STORAGE_BUFFER,
495                 ..Default::default()
496             },
497             AllocationCreateInfo {
498                 usage: MemoryUsage::Upload,
499                 ..Default::default()
500             },
501             0,
502         )
503         .unwrap();
504 
505         let ds_allocator = StandardDescriptorSetAllocator::new(device.clone());
506         let set = PersistentDescriptorSet::new(
507             &ds_allocator,
508             pipeline.layout().set_layouts().get(0).unwrap().clone(),
509             [WriteDescriptorSet::buffer(0, data_buffer.clone())],
510         )
511         .unwrap();
512 
513         let cb_allocator = StandardCommandBufferAllocator::new(device.clone(), Default::default());
514         let mut cbb = AutoCommandBufferBuilder::primary(
515             &cb_allocator,
516             queue.queue_family_index(),
517             CommandBufferUsage::OneTimeSubmit,
518         )
519         .unwrap();
520         cbb.bind_pipeline_compute(pipeline.clone())
521             .bind_descriptor_sets(
522                 PipelineBindPoint::Compute,
523                 pipeline.layout().clone(),
524                 0,
525                 set,
526             )
527             .dispatch([1, 1, 1])
528             .unwrap();
529         let cb = cbb.build().unwrap();
530 
531         let future = now(device)
532             .then_execute(queue, cb)
533             .unwrap()
534             .then_signal_fence_and_flush()
535             .unwrap();
536         future.wait(None).unwrap();
537 
538         let data_buffer_content = data_buffer.read().unwrap();
539         assert_eq!(*data_buffer_content, 0x12345678);
540     }
541 }
542