1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <algorithm>
16
17 #include "source/enum_string_mapping.h"
18 #include "source/opcode.h"
19 #include "source/val/instruction.h"
20 #include "source/val/validate.h"
21 #include "source/val/validation_state.h"
22
23 namespace spvtools {
24 namespace val {
25 namespace {
26
27 // Returns true if |a| and |b| are instructions defining pointers that point to
28 // types logically match and the decorations that apply to |b| are a subset
29 // of the decorations that apply to |a|.
DoPointeesLogicallyMatch(val::Instruction * a,val::Instruction * b,ValidationState_t & _)30 bool DoPointeesLogicallyMatch(val::Instruction* a, val::Instruction* b,
31 ValidationState_t& _) {
32 if (a->opcode() != spv::Op::OpTypePointer ||
33 b->opcode() != spv::Op::OpTypePointer) {
34 return false;
35 }
36
37 const auto& dec_a = _.id_decorations(a->id());
38 const auto& dec_b = _.id_decorations(b->id());
39 for (const auto& dec : dec_b) {
40 if (std::find(dec_a.begin(), dec_a.end(), dec) == dec_a.end()) {
41 return false;
42 }
43 }
44
45 uint32_t a_type = a->GetOperandAs<uint32_t>(2);
46 uint32_t b_type = b->GetOperandAs<uint32_t>(2);
47
48 if (a_type == b_type) {
49 return true;
50 }
51
52 Instruction* a_type_inst = _.FindDef(a_type);
53 Instruction* b_type_inst = _.FindDef(b_type);
54
55 return _.LogicallyMatch(a_type_inst, b_type_inst, true);
56 }
57
ValidateFunction(ValidationState_t & _,const Instruction * inst)58 spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) {
59 const auto function_type_id = inst->GetOperandAs<uint32_t>(3);
60 const auto function_type = _.FindDef(function_type_id);
61 if (!function_type || spv::Op::OpTypeFunction != function_type->opcode()) {
62 return _.diag(SPV_ERROR_INVALID_ID, inst)
63 << "OpFunction Function Type <id> " << _.getIdName(function_type_id)
64 << " is not a function type.";
65 }
66
67 const auto return_id = function_type->GetOperandAs<uint32_t>(1);
68 if (return_id != inst->type_id()) {
69 return _.diag(SPV_ERROR_INVALID_ID, inst)
70 << "OpFunction Result Type <id> " << _.getIdName(inst->type_id())
71 << " does not match the Function Type's return type <id> "
72 << _.getIdName(return_id) << ".";
73 }
74
75 const std::vector<spv::Op> acceptable = {
76 spv::Op::OpGroupDecorate,
77 spv::Op::OpDecorate,
78 spv::Op::OpEnqueueKernel,
79 spv::Op::OpEntryPoint,
80 spv::Op::OpExecutionMode,
81 spv::Op::OpExecutionModeId,
82 spv::Op::OpFunctionCall,
83 spv::Op::OpGetKernelNDrangeSubGroupCount,
84 spv::Op::OpGetKernelNDrangeMaxSubGroupSize,
85 spv::Op::OpGetKernelWorkGroupSize,
86 spv::Op::OpGetKernelPreferredWorkGroupSizeMultiple,
87 spv::Op::OpGetKernelLocalSizeForSubgroupCount,
88 spv::Op::OpGetKernelMaxNumSubgroups,
89 spv::Op::OpName,
90 spv::Op::OpCooperativeMatrixPerElementOpNV,
91 spv::Op::OpCooperativeMatrixReduceNV,
92 spv::Op::OpCooperativeMatrixLoadTensorNV};
93 for (auto& pair : inst->uses()) {
94 const auto* use = pair.first;
95 if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) ==
96 acceptable.end() &&
97 !use->IsNonSemantic() && !use->IsDebugInfo()) {
98 return _.diag(SPV_ERROR_INVALID_ID, use)
99 << "Invalid use of function result id " << _.getIdName(inst->id())
100 << ".";
101 }
102 }
103
104 return SPV_SUCCESS;
105 }
106
ValidateFunctionParameter(ValidationState_t & _,const Instruction * inst)107 spv_result_t ValidateFunctionParameter(ValidationState_t& _,
108 const Instruction* inst) {
109 // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place.
110 size_t param_index = 0;
111 size_t inst_num = inst->LineNum() - 1;
112 if (inst_num == 0) {
113 return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
114 << "Function parameter cannot be the first instruction.";
115 }
116
117 auto func_inst = &_.ordered_instructions()[inst_num];
118 while (--inst_num) {
119 func_inst = &_.ordered_instructions()[inst_num];
120 if (func_inst->opcode() == spv::Op::OpFunction) {
121 break;
122 } else if (func_inst->opcode() == spv::Op::OpFunctionParameter) {
123 ++param_index;
124 }
125 }
126
127 if (func_inst->opcode() != spv::Op::OpFunction) {
128 return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
129 << "Function parameter must be preceded by a function.";
130 }
131
132 const auto function_type_id = func_inst->GetOperandAs<uint32_t>(3);
133 const auto function_type = _.FindDef(function_type_id);
134 if (!function_type) {
135 return _.diag(SPV_ERROR_INVALID_ID, func_inst)
136 << "Missing function type definition.";
137 }
138 if (param_index >= function_type->words().size() - 3) {
139 return _.diag(SPV_ERROR_INVALID_ID, inst)
140 << "Too many OpFunctionParameters for " << func_inst->id()
141 << ": expected " << function_type->words().size() - 3
142 << " based on the function's type";
143 }
144
145 const auto param_type =
146 _.FindDef(function_type->GetOperandAs<uint32_t>(param_index + 2));
147 if (!param_type || inst->type_id() != param_type->id()) {
148 return _.diag(SPV_ERROR_INVALID_ID, inst)
149 << "OpFunctionParameter Result Type <id> "
150 << _.getIdName(inst->type_id())
151 << " does not match the OpTypeFunction parameter "
152 "type of the same index.";
153 }
154
155 // Validate that PhysicalStorageBuffer have one of Restrict, Aliased,
156 // RestrictPointer, or AliasedPointer.
157 auto param_nonarray_type_id = param_type->id();
158 while (_.GetIdOpcode(param_nonarray_type_id) == spv::Op::OpTypeArray) {
159 param_nonarray_type_id =
160 _.FindDef(param_nonarray_type_id)->GetOperandAs<uint32_t>(1u);
161 }
162 if (_.GetIdOpcode(param_nonarray_type_id) == spv::Op::OpTypePointer ||
163 _.GetIdOpcode(param_nonarray_type_id) ==
164 spv::Op::OpTypeUntypedPointerKHR) {
165 auto param_nonarray_type = _.FindDef(param_nonarray_type_id);
166 if (param_nonarray_type->GetOperandAs<spv::StorageClass>(1u) ==
167 spv::StorageClass::PhysicalStorageBuffer) {
168 // check for Aliased or Restrict
169 const auto& decorations = _.id_decorations(inst->id());
170
171 bool foundAliased = std::any_of(
172 decorations.begin(), decorations.end(), [](const Decoration& d) {
173 return spv::Decoration::Aliased == d.dec_type();
174 });
175
176 bool foundRestrict = std::any_of(
177 decorations.begin(), decorations.end(), [](const Decoration& d) {
178 return spv::Decoration::Restrict == d.dec_type();
179 });
180
181 if (!foundAliased && !foundRestrict) {
182 return _.diag(SPV_ERROR_INVALID_ID, inst)
183 << "OpFunctionParameter " << inst->id()
184 << ": expected Aliased or Restrict for PhysicalStorageBuffer "
185 "pointer.";
186 }
187 if (foundAliased && foundRestrict) {
188 return _.diag(SPV_ERROR_INVALID_ID, inst)
189 << "OpFunctionParameter " << inst->id()
190 << ": can't specify both Aliased and Restrict for "
191 "PhysicalStorageBuffer pointer.";
192 }
193 } else if (param_nonarray_type->opcode() == spv::Op::OpTypePointer) {
194 const auto pointee_type_id =
195 param_nonarray_type->GetOperandAs<uint32_t>(2);
196 const auto pointee_type = _.FindDef(pointee_type_id);
197 if (spv::Op::OpTypePointer == pointee_type->opcode() &&
198 pointee_type->GetOperandAs<spv::StorageClass>(1u) ==
199 spv::StorageClass::PhysicalStorageBuffer) {
200 // check for AliasedPointer/RestrictPointer
201 const auto& decorations = _.id_decorations(inst->id());
202
203 bool foundAliased = std::any_of(
204 decorations.begin(), decorations.end(), [](const Decoration& d) {
205 return spv::Decoration::AliasedPointer == d.dec_type();
206 });
207
208 bool foundRestrict = std::any_of(
209 decorations.begin(), decorations.end(), [](const Decoration& d) {
210 return spv::Decoration::RestrictPointer == d.dec_type();
211 });
212
213 if (!foundAliased && !foundRestrict) {
214 return _.diag(SPV_ERROR_INVALID_ID, inst)
215 << "OpFunctionParameter " << inst->id()
216 << ": expected AliasedPointer or RestrictPointer for "
217 "PhysicalStorageBuffer pointer.";
218 }
219 if (foundAliased && foundRestrict) {
220 return _.diag(SPV_ERROR_INVALID_ID, inst)
221 << "OpFunctionParameter " << inst->id()
222 << ": can't specify both AliasedPointer and "
223 "RestrictPointer for PhysicalStorageBuffer pointer.";
224 }
225 }
226 }
227 }
228
229 return SPV_SUCCESS;
230 }
231
ValidateFunctionCall(ValidationState_t & _,const Instruction * inst)232 spv_result_t ValidateFunctionCall(ValidationState_t& _,
233 const Instruction* inst) {
234 const auto function_id = inst->GetOperandAs<uint32_t>(2);
235 const auto function = _.FindDef(function_id);
236 if (!function || spv::Op::OpFunction != function->opcode()) {
237 return _.diag(SPV_ERROR_INVALID_ID, inst)
238 << "OpFunctionCall Function <id> " << _.getIdName(function_id)
239 << " is not a function.";
240 }
241
242 auto return_type = _.FindDef(function->type_id());
243 if (!return_type || return_type->id() != inst->type_id()) {
244 return _.diag(SPV_ERROR_INVALID_ID, inst)
245 << "OpFunctionCall Result Type <id> " << _.getIdName(inst->type_id())
246 << "s type does not match Function <id> "
247 << _.getIdName(return_type->id()) << "s return type.";
248 }
249
250 const auto function_type_id = function->GetOperandAs<uint32_t>(3);
251 const auto function_type = _.FindDef(function_type_id);
252 if (!function_type || function_type->opcode() != spv::Op::OpTypeFunction) {
253 return _.diag(SPV_ERROR_INVALID_ID, inst)
254 << "Missing function type definition.";
255 }
256
257 const auto function_call_arg_count = inst->words().size() - 4;
258 const auto function_param_count = function_type->words().size() - 3;
259 if (function_param_count != function_call_arg_count) {
260 return _.diag(SPV_ERROR_INVALID_ID, inst)
261 << "OpFunctionCall Function <id>'s parameter count does not match "
262 "the argument count.";
263 }
264
265 for (size_t argument_index = 3, param_index = 2;
266 argument_index < inst->operands().size();
267 argument_index++, param_index++) {
268 const auto argument_id = inst->GetOperandAs<uint32_t>(argument_index);
269 const auto argument = _.FindDef(argument_id);
270 if (!argument) {
271 return _.diag(SPV_ERROR_INVALID_ID, inst)
272 << "Missing argument " << argument_index - 3 << " definition.";
273 }
274
275 const auto argument_type = _.FindDef(argument->type_id());
276 if (!argument_type) {
277 return _.diag(SPV_ERROR_INVALID_ID, inst)
278 << "Missing argument " << argument_index - 3
279 << " type definition.";
280 }
281
282 const auto parameter_type_id =
283 function_type->GetOperandAs<uint32_t>(param_index);
284 const auto parameter_type = _.FindDef(parameter_type_id);
285 if (!parameter_type || argument_type->id() != parameter_type->id()) {
286 if (!parameter_type || !_.options()->before_hlsl_legalization ||
287 !DoPointeesLogicallyMatch(argument_type, parameter_type, _)) {
288 return _.diag(SPV_ERROR_INVALID_ID, inst)
289 << "OpFunctionCall Argument <id> " << _.getIdName(argument_id)
290 << "s type does not match Function <id> "
291 << _.getIdName(parameter_type_id) << "s parameter type.";
292 }
293 }
294
295 if (_.addressing_model() == spv::AddressingModel::Logical) {
296 if ((parameter_type->opcode() == spv::Op::OpTypePointer ||
297 parameter_type->opcode() == spv::Op::OpTypeUntypedPointerKHR) &&
298 !_.options()->relax_logical_pointer) {
299 spv::StorageClass sc =
300 parameter_type->GetOperandAs<spv::StorageClass>(1u);
301 // Validate which storage classes can be pointer operands.
302 switch (sc) {
303 case spv::StorageClass::UniformConstant:
304 case spv::StorageClass::Function:
305 case spv::StorageClass::Private:
306 case spv::StorageClass::Workgroup:
307 case spv::StorageClass::AtomicCounter:
308 // These are always allowed.
309 break;
310 case spv::StorageClass::StorageBuffer:
311 if (!_.features().variable_pointers) {
312 return _.diag(SPV_ERROR_INVALID_ID, inst)
313 << "StorageBuffer pointer operand "
314 << _.getIdName(argument_id)
315 << " requires a variable pointers capability";
316 }
317 break;
318 default:
319 return _.diag(SPV_ERROR_INVALID_ID, inst)
320 << "Invalid storage class for pointer operand "
321 << _.getIdName(argument_id);
322 }
323
324 // Validate memory object declaration requirements.
325 if (argument->opcode() != spv::Op::OpVariable &&
326 argument->opcode() != spv::Op::OpUntypedVariableKHR &&
327 argument->opcode() != spv::Op::OpFunctionParameter) {
328 const bool ssbo_vptr =
329 _.HasCapability(spv::Capability::VariablePointersStorageBuffer) &&
330 sc == spv::StorageClass::StorageBuffer;
331 const bool wg_vptr =
332 _.HasCapability(spv::Capability::VariablePointers) &&
333 sc == spv::StorageClass::Workgroup;
334 const bool uc_ptr = sc == spv::StorageClass::UniformConstant;
335 if (!ssbo_vptr && !wg_vptr && !uc_ptr) {
336 return _.diag(SPV_ERROR_INVALID_ID, inst)
337 << "Pointer operand " << _.getIdName(argument_id)
338 << " must be a memory object declaration";
339 }
340 }
341 }
342 }
343 }
344 return SPV_SUCCESS;
345 }
346
ValidateCooperativeMatrixPerElementOp(ValidationState_t & _,const Instruction * inst)347 spv_result_t ValidateCooperativeMatrixPerElementOp(ValidationState_t& _,
348 const Instruction* inst) {
349 const auto function_id = inst->GetOperandAs<uint32_t>(3);
350 const auto function = _.FindDef(function_id);
351 if (!function || spv::Op::OpFunction != function->opcode()) {
352 return _.diag(SPV_ERROR_INVALID_ID, inst)
353 << "OpCooperativeMatrixPerElementOpNV Function <id> "
354 << _.getIdName(function_id) << " is not a function.";
355 }
356
357 const auto matrix_id = inst->GetOperandAs<uint32_t>(2);
358 const auto matrix = _.FindDef(matrix_id);
359 const auto matrix_type_id = matrix->type_id();
360 if (!_.IsCooperativeMatrixKHRType(matrix_type_id)) {
361 return _.diag(SPV_ERROR_INVALID_ID, inst)
362 << "OpCooperativeMatrixPerElementOpNV Matrix <id> "
363 << _.getIdName(matrix_id) << " is not a cooperative matrix.";
364 }
365
366 const auto result_type_id = inst->GetOperandAs<uint32_t>(0);
367 if (matrix_type_id != result_type_id) {
368 return _.diag(SPV_ERROR_INVALID_ID, inst)
369 << "OpCooperativeMatrixPerElementOpNV Result Type <id> "
370 << _.getIdName(result_type_id) << " must match matrix type <id> "
371 << _.getIdName(matrix_type_id) << ".";
372 }
373
374 const auto matrix_comp_type_id =
375 _.FindDef(matrix_type_id)->GetOperandAs<uint32_t>(1);
376 const auto function_type_id = function->GetOperandAs<uint32_t>(3);
377 const auto function_type = _.FindDef(function_type_id);
378 auto return_type_id = function_type->GetOperandAs<uint32_t>(1);
379 if (return_type_id != matrix_comp_type_id) {
380 return _.diag(SPV_ERROR_INVALID_ID, inst)
381 << "OpCooperativeMatrixPerElementOpNV function return type <id> "
382 << _.getIdName(return_type_id)
383 << " must match matrix component type <id> "
384 << _.getIdName(matrix_comp_type_id) << ".";
385 }
386
387 if (function_type->operands().size() < 5) {
388 return _.diag(SPV_ERROR_INVALID_ID, inst)
389 << "OpCooperativeMatrixPerElementOpNV function type <id> "
390 << _.getIdName(function_type_id)
391 << " must have a least three parameters.";
392 }
393
394 const auto param0_id = function_type->GetOperandAs<uint32_t>(2);
395 const auto param1_id = function_type->GetOperandAs<uint32_t>(3);
396 const auto param2_id = function_type->GetOperandAs<uint32_t>(4);
397 if (!_.IsIntScalarType(param0_id) || _.GetBitWidth(param0_id) != 32) {
398 return _.diag(SPV_ERROR_INVALID_ID, inst)
399 << "OpCooperativeMatrixPerElementOpNV function type first parameter "
400 "type <id> "
401 << _.getIdName(param0_id) << " must be a 32-bit integer.";
402 }
403
404 if (!_.IsIntScalarType(param1_id) || _.GetBitWidth(param1_id) != 32) {
405 return _.diag(SPV_ERROR_INVALID_ID, inst)
406 << "OpCooperativeMatrixPerElementOpNV function type second "
407 "parameter type <id> "
408 << _.getIdName(param1_id) << " must be a 32-bit integer.";
409 }
410
411 if (param2_id != matrix_comp_type_id) {
412 return _.diag(SPV_ERROR_INVALID_ID, inst)
413 << "OpCooperativeMatrixPerElementOpNV function type third parameter "
414 "type <id> "
415 << _.getIdName(param2_id) << " must match matrix component type.";
416 }
417
418 return SPV_SUCCESS;
419 }
420
421 } // namespace
422
FunctionPass(ValidationState_t & _,const Instruction * inst)423 spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) {
424 switch (inst->opcode()) {
425 case spv::Op::OpFunction:
426 if (auto error = ValidateFunction(_, inst)) return error;
427 break;
428 case spv::Op::OpFunctionParameter:
429 if (auto error = ValidateFunctionParameter(_, inst)) return error;
430 break;
431 case spv::Op::OpFunctionCall:
432 if (auto error = ValidateFunctionCall(_, inst)) return error;
433 break;
434 case spv::Op::OpCooperativeMatrixPerElementOpNV:
435 if (auto error = ValidateCooperativeMatrixPerElementOp(_, inst))
436 return error;
437 break;
438 default:
439 break;
440 }
441
442 return SPV_SUCCESS;
443 }
444
445 } // namespace val
446 } // namespace spvtools
447