1 //===--- SPIRVCallLowering.cpp - Call lowering ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the lowering of LLVM calls to machine code calls for
10 // GlobalISel.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "SPIRVCallLowering.h"
15 #include "MCTargetDesc/SPIRVBaseInfo.h"
16 #include "SPIRV.h"
17 #include "SPIRVBuiltins.h"
18 #include "SPIRVGlobalRegistry.h"
19 #include "SPIRVISelLowering.h"
20 #include "SPIRVRegisterInfo.h"
21 #include "SPIRVSubtarget.h"
22 #include "SPIRVUtils.h"
23 #include "llvm/CodeGen/FunctionLoweringInfo.h"
24 #include "llvm/Support/ModRef.h"
25
26 using namespace llvm;
27
SPIRVCallLowering(const SPIRVTargetLowering & TLI,SPIRVGlobalRegistry * GR)28 SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI,
29 SPIRVGlobalRegistry *GR)
30 : CallLowering(&TLI), GR(GR) {}
31
lowerReturn(MachineIRBuilder & MIRBuilder,const Value * Val,ArrayRef<Register> VRegs,FunctionLoweringInfo & FLI,Register SwiftErrorVReg) const32 bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
33 const Value *Val, ArrayRef<Register> VRegs,
34 FunctionLoweringInfo &FLI,
35 Register SwiftErrorVReg) const {
36 // Currently all return types should use a single register.
37 // TODO: handle the case of multiple registers.
38 if (VRegs.size() > 1)
39 return false;
40 if (Val) {
41 const auto &STI = MIRBuilder.getMF().getSubtarget();
42 return MIRBuilder.buildInstr(SPIRV::OpReturnValue)
43 .addUse(VRegs[0])
44 .constrainAllUses(MIRBuilder.getTII(), *STI.getRegisterInfo(),
45 *STI.getRegBankInfo());
46 }
47 MIRBuilder.buildInstr(SPIRV::OpReturn);
48 return true;
49 }
50
51 // Based on the LLVM function attributes, get a SPIR-V FunctionControl.
getFunctionControl(const Function & F)52 static uint32_t getFunctionControl(const Function &F) {
53 MemoryEffects MemEffects = F.getMemoryEffects();
54
55 uint32_t FuncControl = static_cast<uint32_t>(SPIRV::FunctionControl::None);
56
57 if (F.hasFnAttribute(Attribute::AttrKind::NoInline))
58 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::DontInline);
59 else if (F.hasFnAttribute(Attribute::AttrKind::AlwaysInline))
60 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Inline);
61
62 if (MemEffects.doesNotAccessMemory())
63 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Pure);
64 else if (MemEffects.onlyReadsMemory())
65 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Const);
66
67 return FuncControl;
68 }
69
getConstInt(MDNode * MD,unsigned NumOp)70 static ConstantInt *getConstInt(MDNode *MD, unsigned NumOp) {
71 if (MD->getNumOperands() > NumOp) {
72 auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(NumOp));
73 if (CMeta)
74 return dyn_cast<ConstantInt>(CMeta->getValue());
75 }
76 return nullptr;
77 }
78
79 // This code restores function args/retvalue types for composite cases
80 // because the final types should still be aggregate whereas they're i32
81 // during the translation to cope with aggregate flattening etc.
getOriginalFunctionType(const Function & F)82 static FunctionType *getOriginalFunctionType(const Function &F) {
83 auto *NamedMD = F.getParent()->getNamedMetadata("spv.cloned_funcs");
84 if (NamedMD == nullptr)
85 return F.getFunctionType();
86
87 Type *RetTy = F.getFunctionType()->getReturnType();
88 SmallVector<Type *, 4> ArgTypes;
89 for (auto &Arg : F.args())
90 ArgTypes.push_back(Arg.getType());
91
92 auto ThisFuncMDIt =
93 std::find_if(NamedMD->op_begin(), NamedMD->op_end(), [&F](MDNode *N) {
94 return isa<MDString>(N->getOperand(0)) &&
95 cast<MDString>(N->getOperand(0))->getString() == F.getName();
96 });
97 // TODO: probably one function can have numerous type mutations,
98 // so we should support this.
99 if (ThisFuncMDIt != NamedMD->op_end()) {
100 auto *ThisFuncMD = *ThisFuncMDIt;
101 MDNode *MD = dyn_cast<MDNode>(ThisFuncMD->getOperand(1));
102 assert(MD && "MDNode operand is expected");
103 ConstantInt *Const = getConstInt(MD, 0);
104 if (Const) {
105 auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1));
106 assert(CMeta && "ConstantAsMetadata operand is expected");
107 assert(Const->getSExtValue() >= -1);
108 // Currently -1 indicates return value, greater values mean
109 // argument numbers.
110 if (Const->getSExtValue() == -1)
111 RetTy = CMeta->getType();
112 else
113 ArgTypes[Const->getSExtValue()] = CMeta->getType();
114 }
115 }
116
117 return FunctionType::get(RetTy, ArgTypes, F.isVarArg());
118 }
119
getKernelArgAttribute(const Function & KernelFunction,unsigned ArgIdx,const StringRef AttributeName)120 static MDString *getKernelArgAttribute(const Function &KernelFunction,
121 unsigned ArgIdx,
122 const StringRef AttributeName) {
123 assert(KernelFunction.getCallingConv() == CallingConv::SPIR_KERNEL &&
124 "Kernel attributes are attached/belong only to kernel functions");
125
126 // Lookup the argument attribute in metadata attached to the kernel function.
127 MDNode *Node = KernelFunction.getMetadata(AttributeName);
128 if (Node && ArgIdx < Node->getNumOperands())
129 return cast<MDString>(Node->getOperand(ArgIdx));
130
131 // Sometimes metadata containing kernel attributes is not attached to the
132 // function, but can be found in the named module-level metadata instead.
133 // For example:
134 // !opencl.kernels = !{!0}
135 // !0 = !{void ()* @someKernelFunction, !1, ...}
136 // !1 = !{!"kernel_arg_addr_space", ...}
137 // In this case the actual index of searched argument attribute is ArgIdx + 1,
138 // since the first metadata node operand is occupied by attribute name
139 // ("kernel_arg_addr_space" in the example above).
140 unsigned MDArgIdx = ArgIdx + 1;
141 NamedMDNode *OpenCLKernelsMD =
142 KernelFunction.getParent()->getNamedMetadata("opencl.kernels");
143 if (!OpenCLKernelsMD || OpenCLKernelsMD->getNumOperands() == 0)
144 return nullptr;
145
146 // KernelToMDNodeList contains kernel function declarations followed by
147 // corresponding MDNodes for each attribute. Search only MDNodes "belonging"
148 // to the currently lowered kernel function.
149 MDNode *KernelToMDNodeList = OpenCLKernelsMD->getOperand(0);
150 bool FoundLoweredKernelFunction = false;
151 for (const MDOperand &Operand : KernelToMDNodeList->operands()) {
152 ValueAsMetadata *MaybeValue = dyn_cast<ValueAsMetadata>(Operand);
153 if (MaybeValue && dyn_cast<Function>(MaybeValue->getValue())->getName() ==
154 KernelFunction.getName()) {
155 FoundLoweredKernelFunction = true;
156 continue;
157 }
158 if (MaybeValue && FoundLoweredKernelFunction)
159 return nullptr;
160
161 MDNode *MaybeNode = dyn_cast<MDNode>(Operand);
162 if (FoundLoweredKernelFunction && MaybeNode &&
163 cast<MDString>(MaybeNode->getOperand(0))->getString() ==
164 AttributeName &&
165 MDArgIdx < MaybeNode->getNumOperands())
166 return cast<MDString>(MaybeNode->getOperand(MDArgIdx));
167 }
168 return nullptr;
169 }
170
171 static SPIRV::AccessQualifier::AccessQualifier
getArgAccessQual(const Function & F,unsigned ArgIdx)172 getArgAccessQual(const Function &F, unsigned ArgIdx) {
173 if (F.getCallingConv() != CallingConv::SPIR_KERNEL)
174 return SPIRV::AccessQualifier::ReadWrite;
175
176 MDString *ArgAttribute =
177 getKernelArgAttribute(F, ArgIdx, "kernel_arg_access_qual");
178 if (!ArgAttribute)
179 return SPIRV::AccessQualifier::ReadWrite;
180
181 if (ArgAttribute->getString().compare("read_only") == 0)
182 return SPIRV::AccessQualifier::ReadOnly;
183 if (ArgAttribute->getString().compare("write_only") == 0)
184 return SPIRV::AccessQualifier::WriteOnly;
185 return SPIRV::AccessQualifier::ReadWrite;
186 }
187
188 static std::vector<SPIRV::Decoration::Decoration>
getKernelArgTypeQual(const Function & KernelFunction,unsigned ArgIdx)189 getKernelArgTypeQual(const Function &KernelFunction, unsigned ArgIdx) {
190 MDString *ArgAttribute =
191 getKernelArgAttribute(KernelFunction, ArgIdx, "kernel_arg_type_qual");
192 if (ArgAttribute && ArgAttribute->getString().compare("volatile") == 0)
193 return {SPIRV::Decoration::Volatile};
194 return {};
195 }
196
getArgType(const Function & F,unsigned ArgIdx)197 static Type *getArgType(const Function &F, unsigned ArgIdx) {
198 Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx);
199 if (F.getCallingConv() != CallingConv::SPIR_KERNEL ||
200 isSpecialOpaqueType(OriginalArgType))
201 return OriginalArgType;
202
203 MDString *MDKernelArgType =
204 getKernelArgAttribute(F, ArgIdx, "kernel_arg_type");
205 if (!MDKernelArgType || !MDKernelArgType->getString().endswith("_t"))
206 return OriginalArgType;
207
208 std::string KernelArgTypeStr = "opencl." + MDKernelArgType->getString().str();
209 Type *ExistingOpaqueType =
210 StructType::getTypeByName(F.getContext(), KernelArgTypeStr);
211 return ExistingOpaqueType
212 ? ExistingOpaqueType
213 : StructType::create(F.getContext(), KernelArgTypeStr);
214 }
215
lowerFormalArguments(MachineIRBuilder & MIRBuilder,const Function & F,ArrayRef<ArrayRef<Register>> VRegs,FunctionLoweringInfo & FLI) const216 bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
217 const Function &F,
218 ArrayRef<ArrayRef<Register>> VRegs,
219 FunctionLoweringInfo &FLI) const {
220 assert(GR && "Must initialize the SPIRV type registry before lowering args.");
221 GR->setCurrentFunc(MIRBuilder.getMF());
222
223 // Assign types and names to all args, and store their types for later.
224 FunctionType *FTy = getOriginalFunctionType(F);
225 SmallVector<SPIRVType *, 4> ArgTypeVRegs;
226 if (VRegs.size() > 0) {
227 unsigned i = 0;
228 for (const auto &Arg : F.args()) {
229 // Currently formal args should use single registers.
230 // TODO: handle the case of multiple registers.
231 if (VRegs[i].size() > 1)
232 return false;
233 SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
234 getArgAccessQual(F, i);
235 auto *SpirvTy = GR->assignTypeToVReg(getArgType(F, i), VRegs[i][0],
236 MIRBuilder, ArgAccessQual);
237 ArgTypeVRegs.push_back(SpirvTy);
238
239 if (Arg.hasName())
240 buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder);
241 if (Arg.getType()->isPointerTy()) {
242 auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes());
243 if (DerefBytes != 0)
244 buildOpDecorate(VRegs[i][0], MIRBuilder,
245 SPIRV::Decoration::MaxByteOffset, {DerefBytes});
246 }
247 if (Arg.hasAttribute(Attribute::Alignment)) {
248 auto Alignment = static_cast<unsigned>(
249 Arg.getAttribute(Attribute::Alignment).getValueAsInt());
250 buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Alignment,
251 {Alignment});
252 }
253 if (Arg.hasAttribute(Attribute::ReadOnly)) {
254 auto Attr =
255 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoWrite);
256 buildOpDecorate(VRegs[i][0], MIRBuilder,
257 SPIRV::Decoration::FuncParamAttr, {Attr});
258 }
259 if (Arg.hasAttribute(Attribute::ZExt)) {
260 auto Attr =
261 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::Zext);
262 buildOpDecorate(VRegs[i][0], MIRBuilder,
263 SPIRV::Decoration::FuncParamAttr, {Attr});
264 }
265 if (Arg.hasAttribute(Attribute::NoAlias)) {
266 auto Attr =
267 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoAlias);
268 buildOpDecorate(VRegs[i][0], MIRBuilder,
269 SPIRV::Decoration::FuncParamAttr, {Attr});
270 }
271
272 if (F.getCallingConv() == CallingConv::SPIR_KERNEL) {
273 std::vector<SPIRV::Decoration::Decoration> ArgTypeQualDecs =
274 getKernelArgTypeQual(F, i);
275 for (SPIRV::Decoration::Decoration Decoration : ArgTypeQualDecs)
276 buildOpDecorate(VRegs[i][0], MIRBuilder, Decoration, {});
277 }
278
279 MDNode *Node = F.getMetadata("spirv.ParameterDecorations");
280 if (Node && i < Node->getNumOperands() &&
281 isa<MDNode>(Node->getOperand(i))) {
282 MDNode *MD = cast<MDNode>(Node->getOperand(i));
283 for (const MDOperand &MDOp : MD->operands()) {
284 MDNode *MD2 = dyn_cast<MDNode>(MDOp);
285 assert(MD2 && "Metadata operand is expected");
286 ConstantInt *Const = getConstInt(MD2, 0);
287 assert(Const && "MDOperand should be ConstantInt");
288 auto Dec =
289 static_cast<SPIRV::Decoration::Decoration>(Const->getZExtValue());
290 std::vector<uint32_t> DecVec;
291 for (unsigned j = 1; j < MD2->getNumOperands(); j++) {
292 ConstantInt *Const = getConstInt(MD2, j);
293 assert(Const && "MDOperand should be ConstantInt");
294 DecVec.push_back(static_cast<uint32_t>(Const->getZExtValue()));
295 }
296 buildOpDecorate(VRegs[i][0], MIRBuilder, Dec, DecVec);
297 }
298 }
299 ++i;
300 }
301 }
302
303 // Generate a SPIR-V type for the function.
304 auto MRI = MIRBuilder.getMRI();
305 Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
306 MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass);
307 if (F.isDeclaration())
308 GR->add(&F, &MIRBuilder.getMF(), FuncVReg);
309 SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder);
310 SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
311 FTy, RetTy, ArgTypeVRegs, MIRBuilder);
312
313 // Build the OpTypeFunction declaring it.
314 uint32_t FuncControl = getFunctionControl(F);
315
316 MIRBuilder.buildInstr(SPIRV::OpFunction)
317 .addDef(FuncVReg)
318 .addUse(GR->getSPIRVTypeID(RetTy))
319 .addImm(FuncControl)
320 .addUse(GR->getSPIRVTypeID(FuncTy));
321
322 // Add OpFunctionParameters.
323 int i = 0;
324 for (const auto &Arg : F.args()) {
325 assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs");
326 MRI->setRegClass(VRegs[i][0], &SPIRV::IDRegClass);
327 MIRBuilder.buildInstr(SPIRV::OpFunctionParameter)
328 .addDef(VRegs[i][0])
329 .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i]));
330 if (F.isDeclaration())
331 GR->add(&Arg, &MIRBuilder.getMF(), VRegs[i][0]);
332 i++;
333 }
334 // Name the function.
335 if (F.hasName())
336 buildOpName(FuncVReg, F.getName(), MIRBuilder);
337
338 // Handle entry points and function linkage.
339 if (F.getCallingConv() == CallingConv::SPIR_KERNEL) {
340 auto MIB = MIRBuilder.buildInstr(SPIRV::OpEntryPoint)
341 .addImm(static_cast<uint32_t>(SPIRV::ExecutionModel::Kernel))
342 .addUse(FuncVReg);
343 addStringImm(F.getName(), MIB);
344 } else if (F.getLinkage() == GlobalValue::LinkageTypes::ExternalLinkage ||
345 F.getLinkage() == GlobalValue::LinkOnceODRLinkage) {
346 auto LnkTy = F.isDeclaration() ? SPIRV::LinkageType::Import
347 : SPIRV::LinkageType::Export;
348 buildOpDecorate(FuncVReg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
349 {static_cast<uint32_t>(LnkTy)}, F.getGlobalIdentifier());
350 }
351
352 return true;
353 }
354
lowerCall(MachineIRBuilder & MIRBuilder,CallLoweringInfo & Info) const355 bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
356 CallLoweringInfo &Info) const {
357 // Currently call returns should have single vregs.
358 // TODO: handle the case of multiple registers.
359 if (Info.OrigRet.Regs.size() > 1)
360 return false;
361 MachineFunction &MF = MIRBuilder.getMF();
362 GR->setCurrentFunc(MF);
363 FunctionType *FTy = nullptr;
364 const Function *CF = nullptr;
365
366 // Emit a regular OpFunctionCall. If it's an externally declared function,
367 // be sure to emit its type and function declaration here. It will be hoisted
368 // globally later.
369 if (Info.Callee.isGlobal()) {
370 CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal());
371 // TODO: support constexpr casts and indirect calls.
372 if (CF == nullptr)
373 return false;
374 FTy = getOriginalFunctionType(*CF);
375 }
376
377 Register ResVReg =
378 Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
379 std::string FuncName = Info.Callee.getGlobal()->getName().str();
380 std::string DemangledName = getOclOrSpirvBuiltinDemangledName(FuncName);
381 const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
382 // TODO: check that it's OCL builtin, then apply OpenCL_std.
383 if (!DemangledName.empty() && CF && CF->isDeclaration() &&
384 ST->canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
385 const Type *OrigRetTy = Info.OrigRet.Ty;
386 if (FTy)
387 OrigRetTy = FTy->getReturnType();
388 SmallVector<Register, 8> ArgVRegs;
389 for (auto Arg : Info.OrigArgs) {
390 assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
391 ArgVRegs.push_back(Arg.Regs[0]);
392 SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(Arg.Ty, MIRBuilder);
393 GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MIRBuilder.getMF());
394 }
395 if (auto Res = SPIRV::lowerBuiltin(
396 DemangledName, SPIRV::InstructionSet::OpenCL_std, MIRBuilder,
397 ResVReg, OrigRetTy, ArgVRegs, GR))
398 return *Res;
399 }
400 if (CF && CF->isDeclaration() &&
401 !GR->find(CF, &MIRBuilder.getMF()).isValid()) {
402 // Emit the type info and forward function declaration to the first MBB
403 // to ensure VReg definition dependencies are valid across all MBBs.
404 MachineIRBuilder FirstBlockBuilder;
405 FirstBlockBuilder.setMF(MF);
406 FirstBlockBuilder.setMBB(*MF.getBlockNumbered(0));
407
408 SmallVector<ArrayRef<Register>, 8> VRegArgs;
409 SmallVector<SmallVector<Register, 1>, 8> ToInsert;
410 for (const Argument &Arg : CF->args()) {
411 if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero())
412 continue; // Don't handle zero sized types.
413 ToInsert.push_back(
414 {MIRBuilder.getMRI()->createGenericVirtualRegister(LLT::scalar(32))});
415 VRegArgs.push_back(ToInsert.back());
416 }
417 // TODO: Reuse FunctionLoweringInfo
418 FunctionLoweringInfo FuncInfo;
419 lowerFormalArguments(FirstBlockBuilder, *CF, VRegArgs, FuncInfo);
420 }
421
422 // Make sure there's a valid return reg, even for functions returning void.
423 if (!ResVReg.isValid())
424 ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
425 SPIRVType *RetType =
426 GR->assignTypeToVReg(FTy->getReturnType(), ResVReg, MIRBuilder);
427
428 // Emit the OpFunctionCall and its args.
429 auto MIB = MIRBuilder.buildInstr(SPIRV::OpFunctionCall)
430 .addDef(ResVReg)
431 .addUse(GR->getSPIRVTypeID(RetType))
432 .add(Info.Callee);
433
434 for (const auto &Arg : Info.OrigArgs) {
435 // Currently call args should have single vregs.
436 if (Arg.Regs.size() > 1)
437 return false;
438 MIB.addUse(Arg.Regs[0]);
439 }
440 return MIB.constrainAllUses(MIRBuilder.getTII(), *ST->getRegisterInfo(),
441 *ST->getRegBankInfo());
442 }
443