1 //
2 // Copyright (C) 2014 LunarG, Inc.
3 // Copyright (C) 2015-2018 Google, Inc.
4 //
5 // All rights reserved.
6 //
7 // Redistribution and use in source and binary forms, with or without
8 // modification, are permitted provided that the following conditions
9 // are met:
10 //
11 //    Redistributions of source code must retain the above copyright
12 //    notice, this list of conditions and the following disclaimer.
13 //
14 //    Redistributions in binary form must reproduce the above
15 //    copyright notice, this list of conditions and the following
16 //    disclaimer in the documentation and/or other materials provided
17 //    with the distribution.
18 //
19 //    Neither the name of 3Dlabs Inc. Ltd. nor the names of its
20 //    contributors may be used to endorse or promote products derived
21 //    from this software without specific prior written permission.
22 //
23 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
26 // FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
27 // COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
28 // INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
29 // BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
30 // LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31 // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
32 // LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
33 // ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
34 // POSSIBILITY OF SUCH DAMAGE.
35 
36 // SPIRV-IR
37 //
38 // Simple in-memory representation (IR) of SPIRV.  Just for holding
39 // Each function's CFG of blocks.  Has this hierarchy:
40 //  - Module, which is a list of
41 //    - Function, which is a list of
42 //      - Block, which is a list of
43 //        - Instruction
44 //
45 
46 #pragma once
47 #ifndef spvIR_H
48 #define spvIR_H
49 
50 #include "spirv.hpp"
51 
52 #include <algorithm>
53 #include <cassert>
54 #include <functional>
55 #include <iostream>
56 #include <memory>
57 #include <vector>
58 #include <set>
59 #include <optional>
60 
61 namespace spv {
62 
63 class Block;
64 class Function;
65 class Module;
66 
67 const Id NoResult = 0;
68 const Id NoType = 0;
69 
70 const Decoration NoPrecision = DecorationMax;
71 
72 #ifdef __GNUC__
73 #   define POTENTIALLY_UNUSED __attribute__((unused))
74 #else
75 #   define POTENTIALLY_UNUSED
76 #endif
77 
78 POTENTIALLY_UNUSED
79 const MemorySemanticsMask MemorySemanticsAllMemory =
80                 (MemorySemanticsMask)(MemorySemanticsUniformMemoryMask |
81                                       MemorySemanticsWorkgroupMemoryMask |
82                                       MemorySemanticsAtomicCounterMemoryMask |
83                                       MemorySemanticsImageMemoryMask);
84 
85 struct IdImmediate {
86     bool isId;      // true if word is an Id, false if word is an immediate
87     unsigned word;
IdImmediateIdImmediate88     IdImmediate(bool i, unsigned w) : isId(i), word(w) {}
89 };
90 
91 //
92 // SPIR-V IR instruction.
93 //
94 
95 class Instruction {
96 public:
Instruction(Id resultId,Id typeId,Op opCode)97     Instruction(Id resultId, Id typeId, Op opCode) : resultId(resultId), typeId(typeId), opCode(opCode), block(nullptr) { }
Instruction(Op opCode)98     explicit Instruction(Op opCode) : resultId(NoResult), typeId(NoType), opCode(opCode), block(nullptr) { }
~Instruction()99     virtual ~Instruction() {}
addIdOperand(Id id)100     void addIdOperand(Id id) {
101         // ids can't be 0
102         assert(id);
103         operands.push_back(id);
104         idOperand.push_back(true);
105     }
addImmediateOperand(unsigned int immediate)106     void addImmediateOperand(unsigned int immediate) {
107         operands.push_back(immediate);
108         idOperand.push_back(false);
109     }
setImmediateOperand(unsigned idx,unsigned int immediate)110     void setImmediateOperand(unsigned idx, unsigned int immediate) {
111         assert(!idOperand[idx]);
112         operands[idx] = immediate;
113     }
114 
addStringOperand(const char * str)115     void addStringOperand(const char* str)
116     {
117         unsigned int word = 0;
118         unsigned int shiftAmount = 0;
119         char c;
120 
121         do {
122             c = *(str++);
123             word |= ((unsigned int)c) << shiftAmount;
124             shiftAmount += 8;
125             if (shiftAmount == 32) {
126                 addImmediateOperand(word);
127                 word = 0;
128                 shiftAmount = 0;
129             }
130         } while (c != 0);
131 
132         // deal with partial last word
133         if (shiftAmount > 0) {
134             addImmediateOperand(word);
135         }
136     }
isIdOperand(int op)137     bool isIdOperand(int op) const { return idOperand[op]; }
setBlock(Block * b)138     void setBlock(Block* b) { block = b; }
getBlock()139     Block* getBlock() const { return block; }
getOpCode()140     Op getOpCode() const { return opCode; }
getNumOperands()141     int getNumOperands() const
142     {
143         assert(operands.size() == idOperand.size());
144         return (int)operands.size();
145     }
getResultId()146     Id getResultId() const { return resultId; }
getTypeId()147     Id getTypeId() const { return typeId; }
getIdOperand(int op)148     Id getIdOperand(int op) const {
149         assert(idOperand[op]);
150         return operands[op];
151     }
getImmediateOperand(int op)152     unsigned int getImmediateOperand(int op) const {
153         assert(!idOperand[op]);
154         return operands[op];
155     }
156 
157     // Write out the binary form.
dump(std::vector<unsigned int> & out)158     void dump(std::vector<unsigned int>& out) const
159     {
160         // Compute the wordCount
161         unsigned int wordCount = 1;
162         if (typeId)
163             ++wordCount;
164         if (resultId)
165             ++wordCount;
166         wordCount += (unsigned int)operands.size();
167 
168         // Write out the beginning of the instruction
169         out.push_back(((wordCount) << WordCountShift) | opCode);
170         if (typeId)
171             out.push_back(typeId);
172         if (resultId)
173             out.push_back(resultId);
174 
175         // Write out the operands
176         for (int op = 0; op < (int)operands.size(); ++op)
177             out.push_back(operands[op]);
178     }
179 
180 protected:
181     Instruction(const Instruction&);
182     Id resultId;
183     Id typeId;
184     Op opCode;
185     std::vector<Id> operands;     // operands, both <id> and immediates (both are unsigned int)
186     std::vector<bool> idOperand;  // true for operands that are <id>, false for immediates
187     Block* block;
188 };
189 
190 //
191 // SPIR-V IR block.
192 //
193 
194 struct DebugSourceLocation {
195     int line;
196     int column;
197     spv::Id fileId;
198 };
199 
200 class Block {
201 public:
202     Block(Id id, Function& parent);
~Block()203     virtual ~Block()
204     {
205     }
206 
getId()207     Id getId() { return instructions.front()->getResultId(); }
208 
getParent()209     Function& getParent() const { return parent; }
210     // Returns true if the source location is actually updated.
211     // Note we still need the builder to insert the line marker instruction. This is just a tracker.
updateDebugSourceLocation(int line,int column,spv::Id fileId)212     bool updateDebugSourceLocation(int line, int column, spv::Id fileId) {
213         if (currentSourceLoc && currentSourceLoc->line == line && currentSourceLoc->column == column &&
214             currentSourceLoc->fileId == fileId) {
215             return false;
216         }
217 
218         currentSourceLoc = DebugSourceLocation{line, column, fileId};
219         return true;
220     }
221     // Returns true if the scope is actually updated.
222     // Note we still need the builder to insert the debug scope instruction. This is just a tracker.
updateDebugScope(spv::Id scopeId)223     bool updateDebugScope(spv::Id scopeId) {
224         assert(scopeId);
225         if (currentDebugScope && *currentDebugScope == scopeId) {
226             return false;
227         }
228 
229         currentDebugScope = scopeId;
230         return true;
231     }
232     void addInstruction(std::unique_ptr<Instruction> inst);
addPredecessor(Block * pred)233     void addPredecessor(Block* pred) { predecessors.push_back(pred); pred->successors.push_back(this);}
addLocalVariable(std::unique_ptr<Instruction> inst)234     void addLocalVariable(std::unique_ptr<Instruction> inst) { localVariables.push_back(std::move(inst)); }
getPredecessors()235     const std::vector<Block*>& getPredecessors() const { return predecessors; }
getSuccessors()236     const std::vector<Block*>& getSuccessors() const { return successors; }
getInstructions()237     const std::vector<std::unique_ptr<Instruction> >& getInstructions() const {
238         return instructions;
239     }
getLocalVariables()240     const std::vector<std::unique_ptr<Instruction> >& getLocalVariables() const { return localVariables; }
setUnreachable()241     void setUnreachable() { unreachable = true; }
isUnreachable()242     bool isUnreachable() const { return unreachable; }
243     // Returns the block's merge instruction, if one exists (otherwise null).
getMergeInstruction()244     const Instruction* getMergeInstruction() const {
245         if (instructions.size() < 2) return nullptr;
246         const Instruction* nextToLast = (instructions.cend() - 2)->get();
247         switch (nextToLast->getOpCode()) {
248             case OpSelectionMerge:
249             case OpLoopMerge:
250                 return nextToLast;
251             default:
252                 return nullptr;
253         }
254         return nullptr;
255     }
256 
257     // Change this block into a canonical dead merge block.  Delete instructions
258     // as necessary.  A canonical dead merge block has only an OpLabel and an
259     // OpUnreachable.
rewriteAsCanonicalUnreachableMerge()260     void rewriteAsCanonicalUnreachableMerge() {
261         assert(localVariables.empty());
262         // Delete all instructions except for the label.
263         assert(instructions.size() > 0);
264         instructions.resize(1);
265         successors.clear();
266         addInstruction(std::unique_ptr<Instruction>(new Instruction(OpUnreachable)));
267     }
268     // Change this block into a canonical dead continue target branching to the
269     // given header ID.  Delete instructions as necessary.  A canonical dead continue
270     // target has only an OpLabel and an unconditional branch back to the corresponding
271     // header.
rewriteAsCanonicalUnreachableContinue(Block * header)272     void rewriteAsCanonicalUnreachableContinue(Block* header) {
273         assert(localVariables.empty());
274         // Delete all instructions except for the label.
275         assert(instructions.size() > 0);
276         instructions.resize(1);
277         successors.clear();
278         // Add OpBranch back to the header.
279         assert(header != nullptr);
280         Instruction* branch = new Instruction(OpBranch);
281         branch->addIdOperand(header->getId());
282         addInstruction(std::unique_ptr<Instruction>(branch));
283         successors.push_back(header);
284     }
285 
isTerminated()286     bool isTerminated() const
287     {
288         switch (instructions.back()->getOpCode()) {
289         case OpBranch:
290         case OpBranchConditional:
291         case OpSwitch:
292         case OpKill:
293         case OpTerminateInvocation:
294         case OpReturn:
295         case OpReturnValue:
296         case OpUnreachable:
297             return true;
298         default:
299             return false;
300         }
301     }
302 
dump(std::vector<unsigned int> & out)303     void dump(std::vector<unsigned int>& out) const
304     {
305         instructions[0]->dump(out);
306         for (int i = 0; i < (int)localVariables.size(); ++i)
307             localVariables[i]->dump(out);
308         for (int i = 1; i < (int)instructions.size(); ++i)
309             instructions[i]->dump(out);
310     }
311 
312 protected:
313     Block(const Block&);
314     Block& operator=(Block&);
315 
316     // To enforce keeping parent and ownership in sync:
317     friend Function;
318 
319     std::vector<std::unique_ptr<Instruction> > instructions;
320     std::vector<Block*> predecessors, successors;
321     std::vector<std::unique_ptr<Instruction> > localVariables;
322     Function& parent;
323 
324     // Track source location of the last source location marker instruction.
325     std::optional<DebugSourceLocation> currentSourceLoc;
326 
327     // Track scope of the last debug scope instruction.
328     std::optional<spv::Id> currentDebugScope;
329 
330     // track whether this block is known to be uncreachable (not necessarily
331     // true for all unreachable blocks, but should be set at least
332     // for the extraneous ones introduced by the builder).
333     bool unreachable;
334 };
335 
336 // The different reasons for reaching a block in the inReadableOrder traversal.
337 enum ReachReason {
338     // Reachable from the entry block via transfers of control, i.e. branches.
339     ReachViaControlFlow = 0,
340     // A continue target that is not reachable via control flow.
341     ReachDeadContinue,
342     // A merge block that is not reachable via control flow.
343     ReachDeadMerge
344 };
345 
346 // Traverses the control-flow graph rooted at root in an order suited for
347 // readable code generation.  Invokes callback at every node in the traversal
348 // order.  The callback arguments are:
349 // - the block,
350 // - the reason we reached the block,
351 // - if the reason was that block is an unreachable continue or unreachable merge block
352 //   then the last parameter is the corresponding header block.
353 void inReadableOrder(Block* root, std::function<void(Block*, ReachReason, Block* header)> callback);
354 
355 //
356 // SPIR-V IR Function.
357 //
358 
359 class Function {
360 public:
361     Function(Id id, Id resultType, Id functionType, Id firstParam, LinkageType linkage, const std::string& name, Module& parent);
~Function()362     virtual ~Function()
363     {
364         for (int i = 0; i < (int)parameterInstructions.size(); ++i)
365             delete parameterInstructions[i];
366 
367         for (int i = 0; i < (int)blocks.size(); ++i)
368             delete blocks[i];
369     }
getId()370     Id getId() const { return functionInstruction.getResultId(); }
getParamId(int p)371     Id getParamId(int p) const { return parameterInstructions[p]->getResultId(); }
getParamType(int p)372     Id getParamType(int p) const { return parameterInstructions[p]->getTypeId(); }
373 
addBlock(Block * block)374     void addBlock(Block* block) { blocks.push_back(block); }
removeBlock(Block * block)375     void removeBlock(Block* block)
376     {
377         auto found = find(blocks.begin(), blocks.end(), block);
378         assert(found != blocks.end());
379         blocks.erase(found);
380         delete block;
381     }
382 
getParent()383     Module& getParent() const { return parent; }
getEntryBlock()384     Block* getEntryBlock() const { return blocks.front(); }
getLastBlock()385     Block* getLastBlock() const { return blocks.back(); }
getBlocks()386     const std::vector<Block*>& getBlocks() const { return blocks; }
387     void addLocalVariable(std::unique_ptr<Instruction> inst);
getReturnType()388     Id getReturnType() const { return functionInstruction.getTypeId(); }
getFuncId()389     Id getFuncId() const { return functionInstruction.getResultId(); }
getFuncTypeId()390     Id getFuncTypeId() const { return functionInstruction.getIdOperand(1); }
setReturnPrecision(Decoration precision)391     void setReturnPrecision(Decoration precision)
392     {
393         if (precision == DecorationRelaxedPrecision)
394             reducedPrecisionReturn = true;
395     }
getReturnPrecision()396     Decoration getReturnPrecision() const
397         { return reducedPrecisionReturn ? DecorationRelaxedPrecision : NoPrecision; }
398 
setDebugLineInfo(Id fileName,int line,int column)399     void setDebugLineInfo(Id fileName, int line, int column) {
400         lineInstruction = std::unique_ptr<Instruction>{new Instruction(OpLine)};
401         lineInstruction->addIdOperand(fileName);
402         lineInstruction->addImmediateOperand(line);
403         lineInstruction->addImmediateOperand(column);
404     }
hasDebugLineInfo()405     bool hasDebugLineInfo() const { return lineInstruction != nullptr; }
406 
setImplicitThis()407     void setImplicitThis() { implicitThis = true; }
hasImplicitThis()408     bool hasImplicitThis() const { return implicitThis; }
409 
addParamPrecision(unsigned param,Decoration precision)410     void addParamPrecision(unsigned param, Decoration precision)
411     {
412         if (precision == DecorationRelaxedPrecision)
413             reducedPrecisionParams.insert(param);
414     }
getParamPrecision(unsigned param)415     Decoration getParamPrecision(unsigned param) const
416     {
417         return reducedPrecisionParams.find(param) != reducedPrecisionParams.end() ?
418             DecorationRelaxedPrecision : NoPrecision;
419     }
420 
dump(std::vector<unsigned int> & out)421     void dump(std::vector<unsigned int>& out) const
422     {
423         // OpLine
424         if (lineInstruction != nullptr) {
425             lineInstruction->dump(out);
426         }
427 
428         // OpFunction
429         functionInstruction.dump(out);
430 
431         // OpFunctionParameter
432         for (int p = 0; p < (int)parameterInstructions.size(); ++p)
433             parameterInstructions[p]->dump(out);
434 
435         // Blocks
436         inReadableOrder(blocks[0], [&out](const Block* b, ReachReason, Block*) { b->dump(out); });
437         Instruction end(0, 0, OpFunctionEnd);
438         end.dump(out);
439     }
440 
getLinkType()441     LinkageType getLinkType() const { return linkType; }
getExportName()442     const char* getExportName() const { return exportName.c_str(); }
443 
444 protected:
445     Function(const Function&);
446     Function& operator=(Function&);
447 
448     Module& parent;
449     std::unique_ptr<Instruction> lineInstruction;
450     Instruction functionInstruction;
451     std::vector<Instruction*> parameterInstructions;
452     std::vector<Block*> blocks;
453     bool implicitThis;  // true if this is a member function expecting to be passed a 'this' as the first argument
454     bool reducedPrecisionReturn;
455     std::set<int> reducedPrecisionParams;  // list of parameter indexes that need a relaxed precision arg
456     LinkageType linkType;
457     std::string exportName;
458 };
459 
460 //
461 // SPIR-V IR Module.
462 //
463 
464 class Module {
465 public:
Module()466     Module() {}
~Module()467     virtual ~Module()
468     {
469         // TODO delete things
470     }
471 
addFunction(Function * fun)472     void addFunction(Function *fun) { functions.push_back(fun); }
473 
mapInstruction(Instruction * instruction)474     void mapInstruction(Instruction *instruction)
475     {
476         spv::Id resultId = instruction->getResultId();
477         // map the instruction's result id
478         if (resultId >= idToInstruction.size())
479             idToInstruction.resize(resultId + 16);
480         idToInstruction[resultId] = instruction;
481     }
482 
getInstruction(Id id)483     Instruction* getInstruction(Id id) const { return idToInstruction[id]; }
getFunctions()484     const std::vector<Function*>& getFunctions() const { return functions; }
getTypeId(Id resultId)485     spv::Id getTypeId(Id resultId) const {
486         return idToInstruction[resultId] == nullptr ? NoType : idToInstruction[resultId]->getTypeId();
487     }
getStorageClass(Id typeId)488     StorageClass getStorageClass(Id typeId) const
489     {
490         assert(idToInstruction[typeId]->getOpCode() == spv::OpTypePointer);
491         return (StorageClass)idToInstruction[typeId]->getImmediateOperand(0);
492     }
493 
dump(std::vector<unsigned int> & out)494     void dump(std::vector<unsigned int>& out) const
495     {
496         for (int f = 0; f < (int)functions.size(); ++f)
497             functions[f]->dump(out);
498     }
499 
500 protected:
501     Module(const Module&);
502     std::vector<Function*> functions;
503 
504     // map from result id to instruction having that result id
505     std::vector<Instruction*> idToInstruction;
506 
507     // map from a result id to its type id
508 };
509 
510 //
511 // Implementation (it's here due to circular type definitions).
512 //
513 
514 // Add both
515 // - the OpFunction instruction
516 // - all the OpFunctionParameter instructions
Function(Id id,Id resultType,Id functionType,Id firstParamId,LinkageType linkage,const std::string & name,Module & parent)517 __inline Function::Function(Id id, Id resultType, Id functionType, Id firstParamId, LinkageType linkage, const std::string& name, Module& parent)
518     : parent(parent), lineInstruction(nullptr),
519       functionInstruction(id, resultType, OpFunction), implicitThis(false),
520       reducedPrecisionReturn(false),
521       linkType(linkage)
522 {
523     // OpFunction
524     functionInstruction.addImmediateOperand(FunctionControlMaskNone);
525     functionInstruction.addIdOperand(functionType);
526     parent.mapInstruction(&functionInstruction);
527     parent.addFunction(this);
528 
529     // OpFunctionParameter
530     Instruction* typeInst = parent.getInstruction(functionType);
531     int numParams = typeInst->getNumOperands() - 1;
532     for (int p = 0; p < numParams; ++p) {
533         Instruction* param = new Instruction(firstParamId + p, typeInst->getIdOperand(p + 1), OpFunctionParameter);
534         parent.mapInstruction(param);
535         parameterInstructions.push_back(param);
536     }
537 
538     // If importing/exporting, save the function name (without the mangled parameters) for the linkage decoration
539     if (linkType != LinkageTypeMax) {
540         exportName = name.substr(0, name.find_first_of('('));
541     }
542 }
543 
addLocalVariable(std::unique_ptr<Instruction> inst)544 __inline void Function::addLocalVariable(std::unique_ptr<Instruction> inst)
545 {
546     Instruction* raw_instruction = inst.get();
547     blocks[0]->addLocalVariable(std::move(inst));
548     parent.mapInstruction(raw_instruction);
549 }
550 
Block(Id id,Function & parent)551 __inline Block::Block(Id id, Function& parent) : parent(parent), unreachable(false)
552 {
553     instructions.push_back(std::unique_ptr<Instruction>(new Instruction(id, NoType, OpLabel)));
554     instructions.back()->setBlock(this);
555     parent.getParent().mapInstruction(instructions.back().get());
556 }
557 
addInstruction(std::unique_ptr<Instruction> inst)558 __inline void Block::addInstruction(std::unique_ptr<Instruction> inst)
559 {
560     Instruction* raw_instruction = inst.get();
561     instructions.push_back(std::move(inst));
562     raw_instruction->setBlock(this);
563     if (raw_instruction->getResultId())
564         parent.getParent().mapInstruction(raw_instruction);
565 }
566 
567 }  // end spv namespace
568 
569 #endif // spvIR_H
570