xref: /aosp_15_r20/external/mesa3d/src/compiler/nir/nir_opt_uniform_subgroup.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright 2023 Intel Corporation
3  * SPDX-License-Identifier: MIT
4  */
5 
6 /**
7  * \file
8  * Optimize subgroup operations with uniform sources.
9  */
10 
11 #include "nir/nir.h"
12 #include "nir/nir_builder.h"
13 
14 static bool
opt_uniform_subgroup_filter(const nir_instr * instr,const void * _state)15 opt_uniform_subgroup_filter(const nir_instr *instr, const void *_state)
16 {
17    if (instr->type != nir_instr_type_intrinsic)
18       return false;
19 
20    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
21 
22    switch (intrin->intrinsic) {
23    case nir_intrinsic_shuffle:
24    case nir_intrinsic_read_invocation:
25    case nir_intrinsic_read_first_invocation:
26    case nir_intrinsic_quad_broadcast:
27    case nir_intrinsic_quad_swap_horizontal:
28    case nir_intrinsic_quad_swap_vertical:
29    case nir_intrinsic_quad_swap_diagonal:
30    case nir_intrinsic_quad_swizzle_amd:
31    case nir_intrinsic_masked_swizzle_amd:
32    case nir_intrinsic_vote_all:
33    case nir_intrinsic_vote_any:
34       return !nir_src_is_divergent(intrin->src[0]);
35 
36    case nir_intrinsic_reduce:
37    case nir_intrinsic_exclusive_scan:
38    case nir_intrinsic_inclusive_scan: {
39       if (nir_src_is_divergent(intrin->src[0]))
40          return false;
41 
42       const nir_op reduction_op = (nir_op) nir_intrinsic_reduction_op(intrin);
43 
44       switch (reduction_op) {
45       case nir_op_iadd:
46       case nir_op_fadd:
47       case nir_op_ixor:
48          return true;
49 
50       case nir_op_imin:
51       case nir_op_umin:
52       case nir_op_fmin:
53       case nir_op_imax:
54       case nir_op_umax:
55       case nir_op_fmax:
56       case nir_op_iand:
57       case nir_op_ior:
58          return intrin->intrinsic != nir_intrinsic_exclusive_scan;
59 
60       default:
61          return false;
62       }
63    }
64 
65    default:
66       return false;
67    }
68 }
69 
70 static nir_def *
count_active_invocations(nir_builder * b,nir_def * value,bool inclusive,bool has_mbcnt_amd)71 count_active_invocations(nir_builder *b, nir_def *value, bool inclusive,
72                          bool has_mbcnt_amd)
73 {
74    /* For the non-inclusive case, the two paths are functionally the same.
75     * For the inclusive case, the are similar but very subtly different.
76     *
77     * The bit_count path will mask "value" with the subgroup LE mask instead
78     * of the subgroup LT mask. This is the definition of the inclusive count.
79     *
80     * AMD's mbcnt instruction always uses the subgroup LT mask. To perform the
81     * inclusive count using mbcnt, two assumptions are made. First, trivially,
82     * the current invocation is active. Second, the bit for the current
83     * invocation in "value" is set.  Since "value" is assumed to be the result
84     * of ballot(true), the second condition will also be met.
85     *
86     * When those conditions are met, the inclusive count is the exclusive
87     * count plus one.
88     */
89    if (has_mbcnt_amd) {
90       return nir_mbcnt_amd(b, value, nir_imm_int(b, (int) inclusive));
91    } else {
92       nir_def *mask = inclusive
93          ? nir_load_subgroup_le_mask(b, 1, 32)
94          : nir_load_subgroup_lt_mask(b, 1, 32);
95 
96       return nir_bit_count(b, nir_iand(b, value, mask));
97    }
98 }
99 
100 static nir_def *
opt_uniform_subgroup_instr(nir_builder * b,nir_instr * instr,void * _state)101 opt_uniform_subgroup_instr(nir_builder *b, nir_instr *instr, void *_state)
102 {
103    const nir_lower_subgroups_options *options = (nir_lower_subgroups_options *) _state;
104    nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
105 
106    if (intrin->intrinsic == nir_intrinsic_reduce ||
107        intrin->intrinsic == nir_intrinsic_inclusive_scan ||
108        intrin->intrinsic == nir_intrinsic_exclusive_scan) {
109       const nir_op reduction_op = (nir_op) nir_intrinsic_reduction_op(intrin);
110 
111       if (reduction_op == nir_op_iadd ||
112           reduction_op == nir_op_fadd ||
113           reduction_op == nir_op_ixor) {
114          nir_def *count;
115 
116          nir_def *ballot = nir_ballot(b, options->ballot_components,
117                                       options->ballot_bit_size, nir_imm_true(b));
118 
119          if (intrin->intrinsic == nir_intrinsic_reduce) {
120             count = nir_bit_count(b, ballot);
121          } else {
122             count = count_active_invocations(b, ballot,
123                                              intrin->intrinsic == nir_intrinsic_inclusive_scan,
124                                              false);
125          }
126 
127          const unsigned bit_size = intrin->src[0].ssa->bit_size;
128 
129          if (reduction_op == nir_op_iadd) {
130             return nir_imul(b,
131                             nir_u2uN(b, count, bit_size),
132                             intrin->src[0].ssa);
133          } else if (reduction_op == nir_op_fadd) {
134             return nir_fmul(b,
135                             nir_u2fN(b, count, bit_size),
136                             intrin->src[0].ssa);
137          } else {
138             return nir_imul(b,
139                             nir_u2uN(b,
140                                      nir_iand(b, count, nir_imm_int(b, 1)),
141                                      bit_size),
142                             intrin->src[0].ssa);
143          }
144       }
145    }
146 
147    return intrin->src[0].ssa;
148 }
149 
150 bool
nir_opt_uniform_subgroup(nir_shader * shader,const nir_lower_subgroups_options * options)151 nir_opt_uniform_subgroup(nir_shader *shader,
152                          const nir_lower_subgroups_options *options)
153 {
154    bool progress = nir_shader_lower_instructions(shader,
155                                                  opt_uniform_subgroup_filter,
156                                                  opt_uniform_subgroup_instr,
157                                                  (void *) options);
158 
159    return progress;
160 }
161 
162