1 /*
2 * Copyright 2023 Intel Corporation
3 * SPDX-License-Identifier: MIT
4 */
5
6 /**
7 * \file
8 * Optimize subgroup operations with uniform sources.
9 */
10
11 #include "nir/nir.h"
12 #include "nir/nir_builder.h"
13
14 static bool
opt_uniform_subgroup_filter(const nir_instr * instr,const void * _state)15 opt_uniform_subgroup_filter(const nir_instr *instr, const void *_state)
16 {
17 if (instr->type != nir_instr_type_intrinsic)
18 return false;
19
20 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
21
22 switch (intrin->intrinsic) {
23 case nir_intrinsic_shuffle:
24 case nir_intrinsic_read_invocation:
25 case nir_intrinsic_read_first_invocation:
26 case nir_intrinsic_quad_broadcast:
27 case nir_intrinsic_quad_swap_horizontal:
28 case nir_intrinsic_quad_swap_vertical:
29 case nir_intrinsic_quad_swap_diagonal:
30 case nir_intrinsic_quad_swizzle_amd:
31 case nir_intrinsic_masked_swizzle_amd:
32 case nir_intrinsic_vote_all:
33 case nir_intrinsic_vote_any:
34 return !nir_src_is_divergent(intrin->src[0]);
35
36 case nir_intrinsic_reduce:
37 case nir_intrinsic_exclusive_scan:
38 case nir_intrinsic_inclusive_scan: {
39 if (nir_src_is_divergent(intrin->src[0]))
40 return false;
41
42 const nir_op reduction_op = (nir_op) nir_intrinsic_reduction_op(intrin);
43
44 switch (reduction_op) {
45 case nir_op_iadd:
46 case nir_op_fadd:
47 case nir_op_ixor:
48 return true;
49
50 case nir_op_imin:
51 case nir_op_umin:
52 case nir_op_fmin:
53 case nir_op_imax:
54 case nir_op_umax:
55 case nir_op_fmax:
56 case nir_op_iand:
57 case nir_op_ior:
58 return intrin->intrinsic != nir_intrinsic_exclusive_scan;
59
60 default:
61 return false;
62 }
63 }
64
65 default:
66 return false;
67 }
68 }
69
70 static nir_def *
count_active_invocations(nir_builder * b,nir_def * value,bool inclusive,bool has_mbcnt_amd)71 count_active_invocations(nir_builder *b, nir_def *value, bool inclusive,
72 bool has_mbcnt_amd)
73 {
74 /* For the non-inclusive case, the two paths are functionally the same.
75 * For the inclusive case, the are similar but very subtly different.
76 *
77 * The bit_count path will mask "value" with the subgroup LE mask instead
78 * of the subgroup LT mask. This is the definition of the inclusive count.
79 *
80 * AMD's mbcnt instruction always uses the subgroup LT mask. To perform the
81 * inclusive count using mbcnt, two assumptions are made. First, trivially,
82 * the current invocation is active. Second, the bit for the current
83 * invocation in "value" is set. Since "value" is assumed to be the result
84 * of ballot(true), the second condition will also be met.
85 *
86 * When those conditions are met, the inclusive count is the exclusive
87 * count plus one.
88 */
89 if (has_mbcnt_amd) {
90 return nir_mbcnt_amd(b, value, nir_imm_int(b, (int) inclusive));
91 } else {
92 nir_def *mask = inclusive
93 ? nir_load_subgroup_le_mask(b, 1, 32)
94 : nir_load_subgroup_lt_mask(b, 1, 32);
95
96 return nir_bit_count(b, nir_iand(b, value, mask));
97 }
98 }
99
100 static nir_def *
opt_uniform_subgroup_instr(nir_builder * b,nir_instr * instr,void * _state)101 opt_uniform_subgroup_instr(nir_builder *b, nir_instr *instr, void *_state)
102 {
103 const nir_lower_subgroups_options *options = (nir_lower_subgroups_options *) _state;
104 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
105
106 if (intrin->intrinsic == nir_intrinsic_reduce ||
107 intrin->intrinsic == nir_intrinsic_inclusive_scan ||
108 intrin->intrinsic == nir_intrinsic_exclusive_scan) {
109 const nir_op reduction_op = (nir_op) nir_intrinsic_reduction_op(intrin);
110
111 if (reduction_op == nir_op_iadd ||
112 reduction_op == nir_op_fadd ||
113 reduction_op == nir_op_ixor) {
114 nir_def *count;
115
116 nir_def *ballot = nir_ballot(b, options->ballot_components,
117 options->ballot_bit_size, nir_imm_true(b));
118
119 if (intrin->intrinsic == nir_intrinsic_reduce) {
120 count = nir_bit_count(b, ballot);
121 } else {
122 count = count_active_invocations(b, ballot,
123 intrin->intrinsic == nir_intrinsic_inclusive_scan,
124 false);
125 }
126
127 const unsigned bit_size = intrin->src[0].ssa->bit_size;
128
129 if (reduction_op == nir_op_iadd) {
130 return nir_imul(b,
131 nir_u2uN(b, count, bit_size),
132 intrin->src[0].ssa);
133 } else if (reduction_op == nir_op_fadd) {
134 return nir_fmul(b,
135 nir_u2fN(b, count, bit_size),
136 intrin->src[0].ssa);
137 } else {
138 return nir_imul(b,
139 nir_u2uN(b,
140 nir_iand(b, count, nir_imm_int(b, 1)),
141 bit_size),
142 intrin->src[0].ssa);
143 }
144 }
145 }
146
147 return intrin->src[0].ssa;
148 }
149
150 bool
nir_opt_uniform_subgroup(nir_shader * shader,const nir_lower_subgroups_options * options)151 nir_opt_uniform_subgroup(nir_shader *shader,
152 const nir_lower_subgroups_options *options)
153 {
154 bool progress = nir_shader_lower_instructions(shader,
155 opt_uniform_subgroup_filter,
156 opt_uniform_subgroup_instr,
157 (void *) options);
158
159 return progress;
160 }
161
162