xref: /aosp_15_r20/external/angle/third_party/spirv-tools/src/source/val/validate_derivatives.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // Validates correctness of derivative SPIR-V instructions.
16 
17 #include <string>
18 
19 #include "source/opcode.h"
20 #include "source/val/instruction.h"
21 #include "source/val/validate.h"
22 #include "source/val/validation_state.h"
23 
24 namespace spvtools {
25 namespace val {
26 
27 // Validates correctness of derivative instructions.
DerivativesPass(ValidationState_t & _,const Instruction * inst)28 spv_result_t DerivativesPass(ValidationState_t& _, const Instruction* inst) {
29   const spv::Op opcode = inst->opcode();
30   const uint32_t result_type = inst->type_id();
31 
32   switch (opcode) {
33     case spv::Op::OpDPdx:
34     case spv::Op::OpDPdy:
35     case spv::Op::OpFwidth:
36     case spv::Op::OpDPdxFine:
37     case spv::Op::OpDPdyFine:
38     case spv::Op::OpFwidthFine:
39     case spv::Op::OpDPdxCoarse:
40     case spv::Op::OpDPdyCoarse:
41     case spv::Op::OpFwidthCoarse: {
42       if (!_.IsFloatScalarOrVectorType(result_type)) {
43         return _.diag(SPV_ERROR_INVALID_DATA, inst)
44                << "Expected Result Type to be float scalar or vector type: "
45                << spvOpcodeString(opcode);
46       }
47       if (!_.ContainsSizedIntOrFloatType(result_type, spv::Op::OpTypeFloat,
48                                          32)) {
49         return _.diag(SPV_ERROR_INVALID_DATA, inst)
50                << "Result type component width must be 32 bits";
51       }
52 
53       const uint32_t p_type = _.GetOperandTypeId(inst, 2);
54       if (p_type != result_type) {
55         return _.diag(SPV_ERROR_INVALID_DATA, inst)
56                << "Expected P type and Result Type to be the same: "
57                << spvOpcodeString(opcode);
58       }
59       _.function(inst->function()->id())
60           ->RegisterExecutionModelLimitation([opcode](spv::ExecutionModel model,
61                                                       std::string* message) {
62             if (model != spv::ExecutionModel::Fragment &&
63                 model != spv::ExecutionModel::GLCompute &&
64                 model != spv::ExecutionModel::MeshEXT &&
65                 model != spv::ExecutionModel::TaskEXT) {
66               if (message) {
67                 *message =
68                     std::string(
69                         "Derivative instructions require Fragment, GLCompute, "
70                         "MeshEXT or TaskEXT execution model: ") +
71                     spvOpcodeString(opcode);
72               }
73               return false;
74             }
75             return true;
76           });
77       _.function(inst->function()->id())
78           ->RegisterLimitation([opcode](const ValidationState_t& state,
79                                         const Function* entry_point,
80                                         std::string* message) {
81             const auto* models = state.GetExecutionModels(entry_point->id());
82             const auto* modes = state.GetExecutionModes(entry_point->id());
83             if (models &&
84                 (models->find(spv::ExecutionModel::GLCompute) !=
85                      models->end() ||
86                  models->find(spv::ExecutionModel::MeshEXT) != models->end() ||
87                  models->find(spv::ExecutionModel::TaskEXT) != models->end()) &&
88                 (!modes ||
89                  (modes->find(spv::ExecutionMode::DerivativeGroupLinearKHR) ==
90                       modes->end() &&
91                   modes->find(spv::ExecutionMode::DerivativeGroupQuadsKHR) ==
92                       modes->end()))) {
93               if (message) {
94                 *message =
95                     std::string(
96                         "Derivative instructions require "
97                         "DerivativeGroupQuadsKHR "
98                         "or DerivativeGroupLinearKHR execution mode for "
99                         "GLCompute, MeshEXT or TaskEXT execution model: ") +
100                     spvOpcodeString(opcode);
101               }
102               return false;
103             }
104             return true;
105           });
106       break;
107     }
108 
109     default:
110       break;
111   }
112 
113   return SPV_SUCCESS;
114 }
115 
116 }  // namespace val
117 }  // namespace spvtools
118