xref: /aosp_15_r20/external/angle/third_party/spirv-tools/src/source/val/validate_function.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <algorithm>
16 
17 #include "source/enum_string_mapping.h"
18 #include "source/opcode.h"
19 #include "source/val/instruction.h"
20 #include "source/val/validate.h"
21 #include "source/val/validation_state.h"
22 
23 namespace spvtools {
24 namespace val {
25 namespace {
26 
27 // Returns true if |a| and |b| are instructions defining pointers that point to
28 // types logically match and the decorations that apply to |b| are a subset
29 // of the decorations that apply to |a|.
DoPointeesLogicallyMatch(val::Instruction * a,val::Instruction * b,ValidationState_t & _)30 bool DoPointeesLogicallyMatch(val::Instruction* a, val::Instruction* b,
31                               ValidationState_t& _) {
32   if (a->opcode() != spv::Op::OpTypePointer ||
33       b->opcode() != spv::Op::OpTypePointer) {
34     return false;
35   }
36 
37   const auto& dec_a = _.id_decorations(a->id());
38   const auto& dec_b = _.id_decorations(b->id());
39   for (const auto& dec : dec_b) {
40     if (std::find(dec_a.begin(), dec_a.end(), dec) == dec_a.end()) {
41       return false;
42     }
43   }
44 
45   uint32_t a_type = a->GetOperandAs<uint32_t>(2);
46   uint32_t b_type = b->GetOperandAs<uint32_t>(2);
47 
48   if (a_type == b_type) {
49     return true;
50   }
51 
52   Instruction* a_type_inst = _.FindDef(a_type);
53   Instruction* b_type_inst = _.FindDef(b_type);
54 
55   return _.LogicallyMatch(a_type_inst, b_type_inst, true);
56 }
57 
ValidateFunction(ValidationState_t & _,const Instruction * inst)58 spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) {
59   const auto function_type_id = inst->GetOperandAs<uint32_t>(3);
60   const auto function_type = _.FindDef(function_type_id);
61   if (!function_type || spv::Op::OpTypeFunction != function_type->opcode()) {
62     return _.diag(SPV_ERROR_INVALID_ID, inst)
63            << "OpFunction Function Type <id> " << _.getIdName(function_type_id)
64            << " is not a function type.";
65   }
66 
67   const auto return_id = function_type->GetOperandAs<uint32_t>(1);
68   if (return_id != inst->type_id()) {
69     return _.diag(SPV_ERROR_INVALID_ID, inst)
70            << "OpFunction Result Type <id> " << _.getIdName(inst->type_id())
71            << " does not match the Function Type's return type <id> "
72            << _.getIdName(return_id) << ".";
73   }
74 
75   const std::vector<spv::Op> acceptable = {
76       spv::Op::OpGroupDecorate,
77       spv::Op::OpDecorate,
78       spv::Op::OpEnqueueKernel,
79       spv::Op::OpEntryPoint,
80       spv::Op::OpExecutionMode,
81       spv::Op::OpExecutionModeId,
82       spv::Op::OpFunctionCall,
83       spv::Op::OpGetKernelNDrangeSubGroupCount,
84       spv::Op::OpGetKernelNDrangeMaxSubGroupSize,
85       spv::Op::OpGetKernelWorkGroupSize,
86       spv::Op::OpGetKernelPreferredWorkGroupSizeMultiple,
87       spv::Op::OpGetKernelLocalSizeForSubgroupCount,
88       spv::Op::OpGetKernelMaxNumSubgroups,
89       spv::Op::OpName,
90       spv::Op::OpCooperativeMatrixPerElementOpNV,
91       spv::Op::OpCooperativeMatrixReduceNV,
92       spv::Op::OpCooperativeMatrixLoadTensorNV};
93   for (auto& pair : inst->uses()) {
94     const auto* use = pair.first;
95     if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) ==
96             acceptable.end() &&
97         !use->IsNonSemantic() && !use->IsDebugInfo()) {
98       return _.diag(SPV_ERROR_INVALID_ID, use)
99              << "Invalid use of function result id " << _.getIdName(inst->id())
100              << ".";
101     }
102   }
103 
104   return SPV_SUCCESS;
105 }
106 
ValidateFunctionParameter(ValidationState_t & _,const Instruction * inst)107 spv_result_t ValidateFunctionParameter(ValidationState_t& _,
108                                        const Instruction* inst) {
109   // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place.
110   size_t param_index = 0;
111   size_t inst_num = inst->LineNum() - 1;
112   if (inst_num == 0) {
113     return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
114            << "Function parameter cannot be the first instruction.";
115   }
116 
117   auto func_inst = &_.ordered_instructions()[inst_num];
118   while (--inst_num) {
119     func_inst = &_.ordered_instructions()[inst_num];
120     if (func_inst->opcode() == spv::Op::OpFunction) {
121       break;
122     } else if (func_inst->opcode() == spv::Op::OpFunctionParameter) {
123       ++param_index;
124     }
125   }
126 
127   if (func_inst->opcode() != spv::Op::OpFunction) {
128     return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
129            << "Function parameter must be preceded by a function.";
130   }
131 
132   const auto function_type_id = func_inst->GetOperandAs<uint32_t>(3);
133   const auto function_type = _.FindDef(function_type_id);
134   if (!function_type) {
135     return _.diag(SPV_ERROR_INVALID_ID, func_inst)
136            << "Missing function type definition.";
137   }
138   if (param_index >= function_type->words().size() - 3) {
139     return _.diag(SPV_ERROR_INVALID_ID, inst)
140            << "Too many OpFunctionParameters for " << func_inst->id()
141            << ": expected " << function_type->words().size() - 3
142            << " based on the function's type";
143   }
144 
145   const auto param_type =
146       _.FindDef(function_type->GetOperandAs<uint32_t>(param_index + 2));
147   if (!param_type || inst->type_id() != param_type->id()) {
148     return _.diag(SPV_ERROR_INVALID_ID, inst)
149            << "OpFunctionParameter Result Type <id> "
150            << _.getIdName(inst->type_id())
151            << " does not match the OpTypeFunction parameter "
152               "type of the same index.";
153   }
154 
155   // Validate that PhysicalStorageBuffer have one of Restrict, Aliased,
156   // RestrictPointer, or AliasedPointer.
157   auto param_nonarray_type_id = param_type->id();
158   while (_.GetIdOpcode(param_nonarray_type_id) == spv::Op::OpTypeArray) {
159     param_nonarray_type_id =
160         _.FindDef(param_nonarray_type_id)->GetOperandAs<uint32_t>(1u);
161   }
162   if (_.GetIdOpcode(param_nonarray_type_id) == spv::Op::OpTypePointer ||
163       _.GetIdOpcode(param_nonarray_type_id) ==
164           spv::Op::OpTypeUntypedPointerKHR) {
165     auto param_nonarray_type = _.FindDef(param_nonarray_type_id);
166     if (param_nonarray_type->GetOperandAs<spv::StorageClass>(1u) ==
167         spv::StorageClass::PhysicalStorageBuffer) {
168       // check for Aliased or Restrict
169       const auto& decorations = _.id_decorations(inst->id());
170 
171       bool foundAliased = std::any_of(
172           decorations.begin(), decorations.end(), [](const Decoration& d) {
173             return spv::Decoration::Aliased == d.dec_type();
174           });
175 
176       bool foundRestrict = std::any_of(
177           decorations.begin(), decorations.end(), [](const Decoration& d) {
178             return spv::Decoration::Restrict == d.dec_type();
179           });
180 
181       if (!foundAliased && !foundRestrict) {
182         return _.diag(SPV_ERROR_INVALID_ID, inst)
183                << "OpFunctionParameter " << inst->id()
184                << ": expected Aliased or Restrict for PhysicalStorageBuffer "
185                   "pointer.";
186       }
187       if (foundAliased && foundRestrict) {
188         return _.diag(SPV_ERROR_INVALID_ID, inst)
189                << "OpFunctionParameter " << inst->id()
190                << ": can't specify both Aliased and Restrict for "
191                   "PhysicalStorageBuffer pointer.";
192       }
193     } else if (param_nonarray_type->opcode() == spv::Op::OpTypePointer) {
194       const auto pointee_type_id =
195           param_nonarray_type->GetOperandAs<uint32_t>(2);
196       const auto pointee_type = _.FindDef(pointee_type_id);
197       if (spv::Op::OpTypePointer == pointee_type->opcode() &&
198           pointee_type->GetOperandAs<spv::StorageClass>(1u) ==
199               spv::StorageClass::PhysicalStorageBuffer) {
200         // check for AliasedPointer/RestrictPointer
201         const auto& decorations = _.id_decorations(inst->id());
202 
203         bool foundAliased = std::any_of(
204             decorations.begin(), decorations.end(), [](const Decoration& d) {
205               return spv::Decoration::AliasedPointer == d.dec_type();
206             });
207 
208         bool foundRestrict = std::any_of(
209             decorations.begin(), decorations.end(), [](const Decoration& d) {
210               return spv::Decoration::RestrictPointer == d.dec_type();
211             });
212 
213         if (!foundAliased && !foundRestrict) {
214           return _.diag(SPV_ERROR_INVALID_ID, inst)
215                  << "OpFunctionParameter " << inst->id()
216                  << ": expected AliasedPointer or RestrictPointer for "
217                     "PhysicalStorageBuffer pointer.";
218         }
219         if (foundAliased && foundRestrict) {
220           return _.diag(SPV_ERROR_INVALID_ID, inst)
221                  << "OpFunctionParameter " << inst->id()
222                  << ": can't specify both AliasedPointer and "
223                     "RestrictPointer for PhysicalStorageBuffer pointer.";
224         }
225       }
226     }
227   }
228 
229   return SPV_SUCCESS;
230 }
231 
ValidateFunctionCall(ValidationState_t & _,const Instruction * inst)232 spv_result_t ValidateFunctionCall(ValidationState_t& _,
233                                   const Instruction* inst) {
234   const auto function_id = inst->GetOperandAs<uint32_t>(2);
235   const auto function = _.FindDef(function_id);
236   if (!function || spv::Op::OpFunction != function->opcode()) {
237     return _.diag(SPV_ERROR_INVALID_ID, inst)
238            << "OpFunctionCall Function <id> " << _.getIdName(function_id)
239            << " is not a function.";
240   }
241 
242   auto return_type = _.FindDef(function->type_id());
243   if (!return_type || return_type->id() != inst->type_id()) {
244     return _.diag(SPV_ERROR_INVALID_ID, inst)
245            << "OpFunctionCall Result Type <id> " << _.getIdName(inst->type_id())
246            << "s type does not match Function <id> "
247            << _.getIdName(return_type->id()) << "s return type.";
248   }
249 
250   const auto function_type_id = function->GetOperandAs<uint32_t>(3);
251   const auto function_type = _.FindDef(function_type_id);
252   if (!function_type || function_type->opcode() != spv::Op::OpTypeFunction) {
253     return _.diag(SPV_ERROR_INVALID_ID, inst)
254            << "Missing function type definition.";
255   }
256 
257   const auto function_call_arg_count = inst->words().size() - 4;
258   const auto function_param_count = function_type->words().size() - 3;
259   if (function_param_count != function_call_arg_count) {
260     return _.diag(SPV_ERROR_INVALID_ID, inst)
261            << "OpFunctionCall Function <id>'s parameter count does not match "
262               "the argument count.";
263   }
264 
265   for (size_t argument_index = 3, param_index = 2;
266        argument_index < inst->operands().size();
267        argument_index++, param_index++) {
268     const auto argument_id = inst->GetOperandAs<uint32_t>(argument_index);
269     const auto argument = _.FindDef(argument_id);
270     if (!argument) {
271       return _.diag(SPV_ERROR_INVALID_ID, inst)
272              << "Missing argument " << argument_index - 3 << " definition.";
273     }
274 
275     const auto argument_type = _.FindDef(argument->type_id());
276     if (!argument_type) {
277       return _.diag(SPV_ERROR_INVALID_ID, inst)
278              << "Missing argument " << argument_index - 3
279              << " type definition.";
280     }
281 
282     const auto parameter_type_id =
283         function_type->GetOperandAs<uint32_t>(param_index);
284     const auto parameter_type = _.FindDef(parameter_type_id);
285     if (!parameter_type || argument_type->id() != parameter_type->id()) {
286       if (!parameter_type || !_.options()->before_hlsl_legalization ||
287           !DoPointeesLogicallyMatch(argument_type, parameter_type, _)) {
288         return _.diag(SPV_ERROR_INVALID_ID, inst)
289                << "OpFunctionCall Argument <id> " << _.getIdName(argument_id)
290                << "s type does not match Function <id> "
291                << _.getIdName(parameter_type_id) << "s parameter type.";
292       }
293     }
294 
295     if (_.addressing_model() == spv::AddressingModel::Logical) {
296       if ((parameter_type->opcode() == spv::Op::OpTypePointer ||
297            parameter_type->opcode() == spv::Op::OpTypeUntypedPointerKHR) &&
298           !_.options()->relax_logical_pointer) {
299         spv::StorageClass sc =
300             parameter_type->GetOperandAs<spv::StorageClass>(1u);
301         // Validate which storage classes can be pointer operands.
302         switch (sc) {
303           case spv::StorageClass::UniformConstant:
304           case spv::StorageClass::Function:
305           case spv::StorageClass::Private:
306           case spv::StorageClass::Workgroup:
307           case spv::StorageClass::AtomicCounter:
308             // These are always allowed.
309             break;
310           case spv::StorageClass::StorageBuffer:
311             if (!_.features().variable_pointers) {
312               return _.diag(SPV_ERROR_INVALID_ID, inst)
313                      << "StorageBuffer pointer operand "
314                      << _.getIdName(argument_id)
315                      << " requires a variable pointers capability";
316             }
317             break;
318           default:
319             return _.diag(SPV_ERROR_INVALID_ID, inst)
320                    << "Invalid storage class for pointer operand "
321                    << _.getIdName(argument_id);
322         }
323 
324         // Validate memory object declaration requirements.
325         if (argument->opcode() != spv::Op::OpVariable &&
326             argument->opcode() != spv::Op::OpUntypedVariableKHR &&
327             argument->opcode() != spv::Op::OpFunctionParameter) {
328           const bool ssbo_vptr =
329               _.HasCapability(spv::Capability::VariablePointersStorageBuffer) &&
330               sc == spv::StorageClass::StorageBuffer;
331           const bool wg_vptr =
332               _.HasCapability(spv::Capability::VariablePointers) &&
333               sc == spv::StorageClass::Workgroup;
334           const bool uc_ptr = sc == spv::StorageClass::UniformConstant;
335           if (!ssbo_vptr && !wg_vptr && !uc_ptr) {
336             return _.diag(SPV_ERROR_INVALID_ID, inst)
337                    << "Pointer operand " << _.getIdName(argument_id)
338                    << " must be a memory object declaration";
339           }
340         }
341       }
342     }
343   }
344   return SPV_SUCCESS;
345 }
346 
ValidateCooperativeMatrixPerElementOp(ValidationState_t & _,const Instruction * inst)347 spv_result_t ValidateCooperativeMatrixPerElementOp(ValidationState_t& _,
348                                                    const Instruction* inst) {
349   const auto function_id = inst->GetOperandAs<uint32_t>(3);
350   const auto function = _.FindDef(function_id);
351   if (!function || spv::Op::OpFunction != function->opcode()) {
352     return _.diag(SPV_ERROR_INVALID_ID, inst)
353            << "OpCooperativeMatrixPerElementOpNV Function <id> "
354            << _.getIdName(function_id) << " is not a function.";
355   }
356 
357   const auto matrix_id = inst->GetOperandAs<uint32_t>(2);
358   const auto matrix = _.FindDef(matrix_id);
359   const auto matrix_type_id = matrix->type_id();
360   if (!_.IsCooperativeMatrixKHRType(matrix_type_id)) {
361     return _.diag(SPV_ERROR_INVALID_ID, inst)
362            << "OpCooperativeMatrixPerElementOpNV Matrix <id> "
363            << _.getIdName(matrix_id) << " is not a cooperative matrix.";
364   }
365 
366   const auto result_type_id = inst->GetOperandAs<uint32_t>(0);
367   if (matrix_type_id != result_type_id) {
368     return _.diag(SPV_ERROR_INVALID_ID, inst)
369            << "OpCooperativeMatrixPerElementOpNV Result Type <id> "
370            << _.getIdName(result_type_id) << " must match matrix type <id> "
371            << _.getIdName(matrix_type_id) << ".";
372   }
373 
374   const auto matrix_comp_type_id =
375       _.FindDef(matrix_type_id)->GetOperandAs<uint32_t>(1);
376   const auto function_type_id = function->GetOperandAs<uint32_t>(3);
377   const auto function_type = _.FindDef(function_type_id);
378   auto return_type_id = function_type->GetOperandAs<uint32_t>(1);
379   if (return_type_id != matrix_comp_type_id) {
380     return _.diag(SPV_ERROR_INVALID_ID, inst)
381            << "OpCooperativeMatrixPerElementOpNV function return type <id> "
382            << _.getIdName(return_type_id)
383            << " must match matrix component type <id> "
384            << _.getIdName(matrix_comp_type_id) << ".";
385   }
386 
387   if (function_type->operands().size() < 5) {
388     return _.diag(SPV_ERROR_INVALID_ID, inst)
389            << "OpCooperativeMatrixPerElementOpNV function type <id> "
390            << _.getIdName(function_type_id)
391            << " must have a least three parameters.";
392   }
393 
394   const auto param0_id = function_type->GetOperandAs<uint32_t>(2);
395   const auto param1_id = function_type->GetOperandAs<uint32_t>(3);
396   const auto param2_id = function_type->GetOperandAs<uint32_t>(4);
397   if (!_.IsIntScalarType(param0_id) || _.GetBitWidth(param0_id) != 32) {
398     return _.diag(SPV_ERROR_INVALID_ID, inst)
399            << "OpCooperativeMatrixPerElementOpNV function type first parameter "
400               "type <id> "
401            << _.getIdName(param0_id) << " must be a 32-bit integer.";
402   }
403 
404   if (!_.IsIntScalarType(param1_id) || _.GetBitWidth(param1_id) != 32) {
405     return _.diag(SPV_ERROR_INVALID_ID, inst)
406            << "OpCooperativeMatrixPerElementOpNV function type second "
407               "parameter type <id> "
408            << _.getIdName(param1_id) << " must be a 32-bit integer.";
409   }
410 
411   if (param2_id != matrix_comp_type_id) {
412     return _.diag(SPV_ERROR_INVALID_ID, inst)
413            << "OpCooperativeMatrixPerElementOpNV function type third parameter "
414               "type <id> "
415            << _.getIdName(param2_id) << " must match matrix component type.";
416   }
417 
418   return SPV_SUCCESS;
419 }
420 
421 }  // namespace
422 
FunctionPass(ValidationState_t & _,const Instruction * inst)423 spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) {
424   switch (inst->opcode()) {
425     case spv::Op::OpFunction:
426       if (auto error = ValidateFunction(_, inst)) return error;
427       break;
428     case spv::Op::OpFunctionParameter:
429       if (auto error = ValidateFunctionParameter(_, inst)) return error;
430       break;
431     case spv::Op::OpFunctionCall:
432       if (auto error = ValidateFunctionCall(_, inst)) return error;
433       break;
434     case spv::Op::OpCooperativeMatrixPerElementOpNV:
435       if (auto error = ValidateCooperativeMatrixPerElementOp(_, inst))
436         return error;
437       break;
438     default:
439       break;
440   }
441 
442   return SPV_SUCCESS;
443 }
444 
445 }  // namespace val
446 }  // namespace spvtools
447