diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td index 46d0a66d59c3753bb8ea41fa494f9d99aabd9422..113f77661d97296c2c188d8d7f51edf24037b794 100644 --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -2050,6 +2050,14 @@ def Convergent : InheritableAttr { let SimpleHandler = 1; } +// Vortex attributes for marking variables as uniform +def VortexUniform : InheritableAttr { + let Spellings = [GNU<"uniform">, Declspec<"__uniform__">]; + let Subjects = SubjectList<[Var]>; + let LangOpts = [CPlusPlus]; + let Documentation = [Undocumented]; +} + def NoInline : DeclOrStmtAttr { let Spellings = [CustomKeyword<"__noinline__">, GCC<"noinline">, CXX11<"clang", "noinline">, C23<"clang", "noinline">, diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp index 1064507f34616a5ba9e33588f7d0b184fd8bcde7..5ec7fe42e09b0bb7e20d8607e5bbb7b34918997c 100644 --- a/clang/lib/AST/ASTContext.cpp +++ b/clang/lib/AST/ASTContext.cpp @@ -12316,6 +12316,8 @@ bool ASTContext::DeclMustBeEmitted(const Decl *D) { if (VD->getDescribedVarTemplate() || isa<VarTemplatePartialSpecializationDecl>(VD)) return false; + if (VD->hasAttr<VortexUniformAttr>()) + return false; } else if (const auto *FD = dyn_cast<FunctionDecl>(D)) { // We never need to emit an uninstantiated function template. if (FD->getTemplatedKind() == FunctionDecl::TK_FunctionTemplate) diff --git a/clang/lib/CodeGen/CGDecl.cpp b/clang/lib/CodeGen/CGDecl.cpp index c3251bb5ab5657978203191fa1989b0464cc383a..9ce8c13bd95848366c58e5d1d4f5e61526af4122 100644 --- a/clang/lib/CodeGen/CGDecl.cpp +++ b/clang/lib/CodeGen/CGDecl.cpp @@ -1684,6 +1684,14 @@ CodeGenFunction::EmitAutoVarAlloca(const VarDecl &D) { if (D.hasAttr<AnnotateAttr>() && HaveInsertPoint()) EmitVarAnnotations(&D, address.emitRawPointer(*this)); + if (D.hasAttr<VortexUniformAttr>() && HaveInsertPoint()) { + if (auto I = dyn_cast<llvm::Instruction>(address.getPointer())) { + auto &Context = I->getContext(); + auto MD = llvm::MDNode::get(Context, llvm::MDString::get(Context, "Uniform Variable")); + I->setMetadata("vortex.uniform", MD); + } + } + // Make sure we call @llvm.lifetime.end. if (emission.useLifetimeMarkers()) EHStack.pushCleanup<CallLifetimeEnd>(NormalEHLifetimeMarker, @@ -2756,6 +2764,14 @@ void CodeGenFunction::EmitParmDecl(const VarDecl &D, ParamValue Arg, if (D.hasAttr<AnnotateAttr>()) EmitVarAnnotations(&D, DeclPtr.emitRawPointer(*this)); + if (D.hasAttr<VortexUniformAttr>()) { + if (auto I = dyn_cast<llvm::Instruction>(DeclPtr.getPointer())) { + auto &Context = I->getContext(); + auto MD = llvm::MDNode::get(Context, llvm::MDString::get(Context, "Uniform Variable")); + I->setMetadata("vortex.uniform", MD); + } + } + // We can only check return value nullability if all arguments to the // function satisfy their nullability preconditions. This makes it necessary // to emit null checks for args in the function body itself. diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp index e2eada24f9fccbfa1e67bc52d617b4bf85c38248..d41fee8e34985e956966dde84994fb5c11a64c71 100644 --- a/clang/lib/Sema/SemaDeclAttr.cpp +++ b/clang/lib/Sema/SemaDeclAttr.cpp @@ -4563,6 +4563,10 @@ static void handleConstantAttr(Sema &S, Decl *D, const ParsedAttr &AL) { D->addAttr(::new (S.Context) CUDAConstantAttr(S.Context, AL)); } +static void handleVortexUniformAttr(Sema &S, Decl *D, const ParsedAttr &AL) { + handleSimpleAttribute<VortexUniformAttr>(S, D, AL); +} + static void handleSharedAttr(Sema &S, Decl *D, const ParsedAttr &AL) { const auto *VD = cast<VarDecl>(D); // extern __shared__ is only allowed on arrays with no length (e.g. @@ -6469,6 +6473,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL, case ParsedAttr::AT_CUDAConstant: handleConstantAttr(S, D, AL); break; + case ParsedAttr::AT_VortexUniform: + handleVortexUniformAttr(S, D, AL); + break; case ParsedAttr::AT_PassObjectSize: handlePassObjectSizeAttr(S, D, AL); break; diff --git a/llvm/include/llvm/IR/IntrinsicsRISCV.td b/llvm/include/llvm/IR/IntrinsicsRISCV.td index 2da154c300344c635d6746e251f7efe7833eadfe..eb4381b301b22ef2a4182b6f765e13a881aa72e7 100644 --- a/llvm/include/llvm/IR/IntrinsicsRISCV.td +++ b/llvm/include/llvm/IR/IntrinsicsRISCV.td @@ -72,6 +72,57 @@ let TargetPrefix = "riscv" in { // ptr addr, ixlen cmpval, ixlen newval, ixlen mask, ixlenimm ordering) defm int_riscv_masked_cmpxchg : MaskedAtomicRMWFiveArgIntrinsics; + // Vortex extension + + class VortexCSRIntrinsicsImpl<LLVMType itype> + : Intrinsic<[itype], [], [IntrNoMem, IntrWillReturn]>; + + multiclass VortexCSRIntrinsics { + def _i32 : VortexCSRIntrinsicsImpl<llvm_i32_ty>; + def _i64 : VortexCSRIntrinsicsImpl<llvm_i64_ty>; + } + + class VortexD0S1IntrinsicsImpl<LLVMType itype> + : Intrinsic<[], [itype], [IntrNoMem, IntrHasSideEffects, IntrConvergent, IntrWillReturn]>; + + multiclass VortexD0S1Intrinsics { + def _i32 : VortexD0S1IntrinsicsImpl<llvm_i32_ty>; + def _i64 : VortexD0S1IntrinsicsImpl<llvm_i64_ty>; + } + + class VortexD0S2IntrinsicsImpl<LLVMType itype> + : Intrinsic<[], [itype, itype], [IntrNoMem, IntrHasSideEffects, IntrConvergent, IntrWillReturn]>; + + multiclass VortexD0S2Intrinsics { + def _i32 : VortexD0S2IntrinsicsImpl<llvm_i32_ty>; + def _i64 : VortexD0S2IntrinsicsImpl<llvm_i64_ty>; + } + + class VortexD1S1IntrinsicsImpl<LLVMType itype> + : Intrinsic<[itype], [itype], [IntrNoMem, IntrHasSideEffects, IntrConvergent, IntrWillReturn]>; + + multiclass VortexD1S1Intrinsics { + def _i32 : VortexD1S1IntrinsicsImpl<llvm_i32_ty>; + def _i64 : VortexD1S1IntrinsicsImpl<llvm_i64_ty>; + } + + defm int_riscv_vx_tmask : VortexCSRIntrinsics; + defm int_riscv_vx_tid : VortexCSRIntrinsics; + defm int_riscv_vx_wid : VortexCSRIntrinsics; + defm int_riscv_vx_cid : VortexCSRIntrinsics; + defm int_riscv_vx_nt : VortexCSRIntrinsics; + defm int_riscv_vx_nw : VortexCSRIntrinsics; + defm int_riscv_vx_nc : VortexCSRIntrinsics; + + defm int_riscv_vx_tmc : VortexD0S1Intrinsics; + defm int_riscv_vx_pred : VortexD0S2Intrinsics; + defm int_riscv_vx_pred_n : VortexD0S2Intrinsics; + defm int_riscv_vx_split : VortexD1S1Intrinsics; + defm int_riscv_vx_split_n : VortexD1S1Intrinsics; + defm int_riscv_vx_join : VortexD0S1Intrinsics; + defm int_riscv_vx_mov : VortexD1S1Intrinsics; + defm int_riscv_vx_bar : VortexD0S2Intrinsics; + } // TargetPrefix = "riscv" //===----------------------------------------------------------------------===// @@ -649,7 +700,7 @@ let TargetPrefix = "riscv" in { class RISCVClassifyMasked : DefaultAttrsIntrinsic<[LLVMVectorOfBitcastsToInt<0>], [LLVMVectorOfBitcastsToInt<0>, llvm_anyvector_ty, - LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, llvm_anyint_ty, LLVMMatchType<1>], [IntrNoMem, ImmArg<ArgIndex<4>>]>, RISCVVIntrinsic { let VLOperand = 3; diff --git a/llvm/include/llvm/Transforms/Scalar.h b/llvm/include/llvm/Transforms/Scalar.h index f74a49785e11b7732091df8c06f77f061308b01c..40bbee2ca461932455e08482ffb2bf94397d587f 100644 --- a/llvm/include/llvm/Transforms/Scalar.h +++ b/llvm/include/llvm/Transforms/Scalar.h @@ -91,9 +91,12 @@ FunctionPass *createFlattenCFGPass(); // CFG Structurization - Remove irreducible control flow // /// -/// When \p SkipUniformRegions is true the structizer will not structurize +/// When \p SkipUniformRegions is true the structurizer will not structurize /// regions that only contain uniform branches. -Pass *createStructurizeCFGPass(bool SkipUniformRegions = false); +/// When \p SkipRegionalBranches is true the structurizer will only structurize +// regions that contain divergent branchs that do not form a region. +Pass *createStructurizeCFGPass(bool SkipUniformRegions = false, + bool SkipRegionalBranches = false); //===----------------------------------------------------------------------===// // diff --git a/llvm/lib/Target/RISCV/CMakeLists.txt b/llvm/lib/Target/RISCV/CMakeLists.txt index f28a7092e3cec173bfe6694fe185f6bd0ff473ea..20f51c6beb0f25a22bdae690689d0fce9b304e7a 100644 --- a/llvm/lib/Target/RISCV/CMakeLists.txt +++ b/llvm/lib/Target/RISCV/CMakeLists.txt @@ -55,6 +55,8 @@ add_llvm_target(RISCVCodeGen RISCVTargetObjectFile.cpp RISCVTargetTransformInfo.cpp RISCVVectorPeephole.cpp + VortexBranchDivergence.cpp + VortexIntrinsicFunc.cpp GISel/RISCVCallLowering.cpp GISel/RISCVInstructionSelector.cpp GISel/RISCVLegalizerInfo.cpp diff --git a/llvm/lib/Target/RISCV/RISCV.h b/llvm/lib/Target/RISCV/RISCV.h index 0d2473c7c5de1cd525e5c40e72bbb78971c54602..20f62f26919e5ba218f60425cfd59a9f44a2ae40 100644 --- a/llvm/lib/Target/RISCV/RISCV.h +++ b/llvm/lib/Target/RISCV/RISCV.h @@ -18,6 +18,7 @@ #include "llvm/Target/TargetMachine.h" namespace llvm { +class ModulePass; class FunctionPass; class InstructionSelector; class PassRegistry; @@ -91,6 +92,18 @@ void initializeRISCVPostLegalizerCombinerPass(PassRegistry &); FunctionPass *createRISCVO0PreLegalizerCombiner(); void initializeRISCVO0PreLegalizerCombinerPass(PassRegistry &); +FunctionPass *createVortexBranchDivergence0Pass(); +void initializeVortexBranchDivergence0Pass(PassRegistry&); + +FunctionPass *createVortexBranchDivergence1Pass(int divergenceMode = 0); +void initializeVortexBranchDivergence1Pass(PassRegistry&); + +FunctionPass *createVortexBranchDivergence2Pass(int PassMode); +void initializeVortexBranchDivergence2Pass(PassRegistry&); + +ModulePass *createVortexIntrinsicFuncLoweringPass(); +void initializeVortexIntrinsicFuncLoweringPass(PassRegistry&); + FunctionPass *createRISCVPreLegalizerCombiner(); void initializeRISCVPreLegalizerCombinerPass(PassRegistry &); } // namespace llvm diff --git a/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp b/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp index d93709ac03420e9131b81cefad2d9ad09eff8416..34fc6dc9daa52391744ea14a20d97ff31a0bf8ff 100644 --- a/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp +++ b/llvm/lib/Target/RISCV/RISCVExpandPseudoInsts.cpp @@ -186,6 +186,9 @@ bool RISCVExpandPseudo::expandCCOp(MachineBasicBlock &MBB, auto CC = static_cast<RISCVCC::CondCode>(MI.getOperand(3).getImm()); CC = RISCVCC::getOppositeBranchCondition(CC); + llvm::errs() << "error: unimplemented divergent codegen found!\n"; + std::abort(); + // Insert branch instruction. BuildMI(MBB, MBBI, DL, TII->getBrCond(CC)) .addReg(MI.getOperand(1).getReg()) diff --git a/llvm/lib/Target/RISCV/RISCVFeatures.td b/llvm/lib/Target/RISCV/RISCVFeatures.td index 3c868dbbf8b3a1512f2fd5df624ad921f7d533f2..ed448293f6f328cc809c8bbb0f7d4f6602d367c0 100644 --- a/llvm/lib/Target/RISCV/RISCVFeatures.td +++ b/llvm/lib/Target/RISCV/RISCVFeatures.td @@ -37,7 +37,7 @@ class RISCVExtension<string name, int major, int minor, string desc, bit Experimental = false; } -// The groupID/bitmask of RISCVExtension is used to retrieve a specific bit value +// The groupID/bitmask of RISCVExtension is used to retrieve a specific bit value // from __riscv_feature_bits based on the groupID and bitmask. // groupID - groupID of extension // bitPos - bit position of extension bitmask @@ -1202,6 +1202,15 @@ def HasVendorXSfcease AssemblerPredicate<(all_of FeatureVendorXSfcease), "'XSfcease' (SiFive sf.cease Instruction)">; +// Vortex extensions + +def FeatureVendorVortex + : SubtargetFeature<"vortex", "HasVendorXVortex", "true", + "'Vortex' (Vortex ISA Extension)">; +def HasVendorXVortex : Predicate<"Subtarget->HasVendorXVortex()">, + AssemblerPredicate<(all_of FeatureVendorVortex), + "'Vortex' (Vortex ISA Extension)">; + // Core-V Extensions def FeatureVendorXCVelw diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 823fb428472ef34a55afbe5707b1dc2af3318f4e..5cbd2a1cbd3f9a3fa926e633280755eb056c31a4 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -50,8 +50,14 @@ using namespace llvm; #define DEBUG_TYPE "riscv-lower" +#ifndef NDEBUG +#define LLVM_DEBUG(x) do {x;} while (false) +#endif + STATISTIC(NumTailCalls, "Number of tail calls"); +extern int gVortexBranchDivergenceMode; + static cl::opt<unsigned> ExtensionMaxWebSize( DEBUG_TYPE "-ext-max-web-size", cl::Hidden, cl::desc("Give the maximum size (in number of nodes) of the web of " @@ -18145,6 +18151,9 @@ static MachineBasicBlock *emitReadCounterWidePseudo(MachineInstr &MI, .addImm(HiCounter) .addReg(RISCV::X0); + llvm::errs() << "error: unimplemented divergent codegen found!\n"; + std::abort(); + BuildMI(LoopMBB, DL, TII->get(RISCV::BNE)) .addReg(HiReg) .addReg(ReadAgainReg) @@ -18283,6 +18292,96 @@ static MachineBasicBlock *emitQuietFCMP(MachineInstr &MI, MachineBasicBlock *BB, return BB; } +static Register BrCondToNECodeGen(RISCVCC::CondCode CC, + Register LHS, + Register RHS, + MachineBasicBlock &MBB, + const MachineBasicBlock::iterator& loc, + const DebugLoc& DL, + const TargetInstrInfo& TII) { + auto Result = MBB.getParent()->getRegInfo().createVirtualRegister(&RISCV::GPRRegClass); + switch (CC) { + default: + llvm_unreachable("Unknown condition code!"); + case RISCVCC::COND_EQ: { + auto Result1 = MBB.getParent()->getRegInfo().createVirtualRegister(&RISCV::GPRRegClass); + BuildMI(MBB, loc, DL, TII.get(RISCV::XOR), Result1) + .addReg(LHS) + .addReg(RHS); + BuildMI(MBB, loc, DL, TII.get(RISCV::SLTIU), Result) + .addReg(Result1) + .addImm(1); + } break; + case RISCVCC::COND_NE: { + auto Result1 = MBB.getParent()->getRegInfo().createVirtualRegister(&RISCV::GPRRegClass); + BuildMI(MBB, loc, DL, TII.get(RISCV::XOR), Result1) + .addReg(LHS) + .addReg(RHS); + BuildMI(MBB, loc, DL, TII.get(RISCV::SLTU), Result) + .addReg(RISCV::X0) + .addReg(Result1); + } break; + case RISCVCC::COND_LT: + BuildMI(MBB, loc, DL, TII.get(RISCV::SLT), Result) + .addReg(LHS) + .addReg(RHS); + break; + case RISCVCC::COND_GE: + BuildMI(MBB, loc, DL, TII.get(RISCV::SLT), Result) + .addReg(RHS) + .addReg(LHS); + break; + case RISCVCC::COND_LTU: + BuildMI(MBB, loc, DL, TII.get(RISCV::SLTU), Result) + .addReg(LHS) + .addReg(RHS); + break; + case RISCVCC::COND_GEU: + BuildMI(MBB, loc, DL, TII.get(RISCV::SLTU), Result) + .addReg(RHS) + .addReg(LHS); + break; + } + return Result; +} + +static void InsertVXSplit(Register* CondReg, + Register* DVReg, + RISCVCC::CondCode CC, + Register LHS, + Register RHS, + MachineBasicBlock &MBB, + const MachineBasicBlock::iterator& loc, + const DebugLoc& DL, + const TargetInstrInfo& TII) { + auto _CondReg2 = MBB.getParent()->getRegInfo().createVirtualRegister(&RISCV::GPRRegClass); + auto _DVReg = MBB.getParent()->getRegInfo().createVirtualRegister(&RISCV::GPRRegClass); + + // convert the condition to NE + auto _CondReg1 = BrCondToNECodeGen(CC, LHS, RHS, MBB, loc, DL, TII); + + // insert VX_MOV + BuildMI(MBB, loc, DL, TII.get(RISCV::VX_MOV), _CondReg2) + .addReg(_CondReg1); + + // insert VX_SPLIT + BuildMI(MBB, loc, DL, TII.get(RISCV::VX_SPLIT), _DVReg) + .addReg(_CondReg2); + + *CondReg = _CondReg2; + *DVReg = _DVReg; +} + +static void InsertVXJoin(Register DVReg, + MachineBasicBlock &MBB, + const MachineBasicBlock::iterator& loc, + const DebugLoc& DL, + const TargetInstrInfo& TII) { + // insert VX_JOINT + BuildMI(MBB, loc, DL, TII.get(RISCV::VX_JOIN)) + .addReg(DVReg); +} + static MachineBasicBlock * EmitLoweredCascadedSelect(MachineInstr &First, MachineInstr &Second, MachineBasicBlock *ThisMBB, @@ -18340,46 +18439,89 @@ EmitLoweredCascadedSelect(MachineInstr &First, MachineInstr &Second, ThisMBB->end()); SinkMBB->transferSuccessorsAndUpdatePHIs(ThisMBB); - // Fallthrough block for ThisMBB. - ThisMBB->addSuccessor(FirstMBB); - // Fallthrough block for FirstMBB. - FirstMBB->addSuccessor(SecondMBB); - ThisMBB->addSuccessor(SinkMBB); - FirstMBB->addSuccessor(SinkMBB); - // This is fallthrough. - SecondMBB->addSuccessor(SinkMBB); - auto FirstCC = static_cast<RISCVCC::CondCode>(First.getOperand(3).getImm()); Register FLHS = First.getOperand(1).getReg(); Register FRHS = First.getOperand(2).getReg(); - // Insert appropriate branch. - BuildMI(FirstMBB, DL, TII.getBrCond(FirstCC)) - .addReg(FLHS) - .addReg(FRHS) - .addMBB(SinkMBB); - - Register SLHS = Second.getOperand(1).getReg(); - Register SRHS = Second.getOperand(2).getReg(); Register Op1Reg4 = First.getOperand(4).getReg(); Register Op1Reg5 = First.getOperand(5).getReg(); auto SecondCC = static_cast<RISCVCC::CondCode>(Second.getOperand(3).getImm()); - // Insert appropriate branch. - BuildMI(ThisMBB, DL, TII.getBrCond(SecondCC)) - .addReg(SLHS) - .addReg(SRHS) - .addMBB(SinkMBB); - + Register SLHS = Second.getOperand(1).getReg(); + Register SRHS = Second.getOperand(2).getReg(); Register DestReg = Second.getOperand(0).getReg(); Register Op2Reg4 = Second.getOperand(4).getReg(); - BuildMI(*SinkMBB, SinkMBB->begin(), DL, TII.get(RISCV::PHI), DestReg) - .addReg(Op2Reg4) - .addMBB(ThisMBB) - .addReg(Op1Reg4) - .addMBB(FirstMBB) - .addReg(Op1Reg5) - .addMBB(SecondMBB); + if (Subtarget.hasVendorXVortex() && gVortexBranchDivergenceMode != 0) { + MachineBasicBlock *SinkMBB1 = F->CreateMachineBasicBlock(LLVM_BB); + F->insert(It, SinkMBB1); + + ThisMBB->addSuccessor(FirstMBB); + FirstMBB->addSuccessor(SecondMBB); + ThisMBB->addSuccessor(SinkMBB); + FirstMBB->addSuccessor(SinkMBB1); + SecondMBB->addSuccessor(SinkMBB1); + SinkMBB1->addSuccessor(SinkMBB); + + Register CCReg1, DVReg1; + InsertVXSplit(&CCReg1, &DVReg1, FirstCC, FLHS, FRHS, *FirstMBB, FirstMBB->end(), DL, TII); + + BuildMI(FirstMBB, DL, TII.getBrCond(RISCVCC::COND_NE)) + .addReg(CCReg1) + .addReg(RISCV::X0) + .addMBB(SinkMBB1); + + Register CCReg2, DVReg2; + InsertVXSplit(&CCReg2, &DVReg2, SecondCC, SLHS, SRHS, *ThisMBB, ThisMBB->end(), DL, TII); + + BuildMI(ThisMBB, DL, TII.getBrCond(SecondCC)) + .addReg(CCReg2) + .addReg(RISCV::X0) + .addMBB(SinkMBB); + + auto DestReg1 = F->getRegInfo().createVirtualRegister(&RISCV::GPRRegClass); + BuildMI(*SinkMBB1, SinkMBB1->begin(), DL, TII.get(RISCV::PHI), DestReg1) + .addReg(Op1Reg4) + .addMBB(FirstMBB) + .addReg(Op1Reg5) + .addMBB(SecondMBB); + + BuildMI(*SinkMBB, SinkMBB->begin(), DL, TII.get(RISCV::PHI), DestReg) + .addReg(Op2Reg4) + .addMBB(ThisMBB) + .addReg(DestReg1) + .addMBB(SinkMBB1); + + InsertVXJoin(DVReg1, *SinkMBB1, SinkMBB1->begin(), DL, TII); + InsertVXJoin(DVReg2, *SinkMBB, SinkMBB->begin(), DL, TII); + + LLVM_DEBUG(dbgs() << "*** Vortex: EmitLoweredCascadedSelect\n" << + *ThisMBB << "\n" << *FirstMBB << "\n" << *SecondMBB << "\n" << *SinkMBB1 << "\n" << *SinkMBB << "\n"); + LLVM_DEBUG(F->dump()); + } else { + ThisMBB->addSuccessor(FirstMBB); + FirstMBB->addSuccessor(SecondMBB); + ThisMBB->addSuccessor(SinkMBB); + FirstMBB->addSuccessor(SinkMBB); + SecondMBB->addSuccessor(SinkMBB); + + BuildMI(FirstMBB, DL, TII.getBrCond(FirstCC)) + .addReg(FLHS) + .addReg(FRHS) + .addMBB(SinkMBB); + + BuildMI(ThisMBB, DL, TII.getBrCond(SecondCC)) + .addReg(SLHS) + .addReg(SRHS) + .addMBB(SinkMBB); + + BuildMI(*SinkMBB, SinkMBB->begin(), DL, TII.get(RISCV::PHI), DestReg) + .addReg(Op2Reg4) + .addMBB(ThisMBB) + .addReg(Op1Reg4) + .addMBB(FirstMBB) + .addReg(Op1Reg5) + .addMBB(SecondMBB); + } // Now remove the Select_FPRX_s. First.eraseFromParent(); Second.eraseFromParent(); @@ -18498,17 +18640,32 @@ static MachineBasicBlock *emitSelectPseudo(MachineInstr &MI, HeadMBB->addSuccessor(IfFalseMBB); HeadMBB->addSuccessor(TailMBB); - // Insert appropriate branch. - if (MI.getOperand(2).isImm()) - BuildMI(HeadMBB, DL, TII.getBrCond(CC, MI.getOperand(2).isImm())) - .addReg(LHS) - .addImm(MI.getOperand(2).getImm()) - .addMBB(TailMBB); - else - BuildMI(HeadMBB, DL, TII.getBrCond(CC)) + if (Subtarget.hasVendorXVortex() && gVortexBranchDivergenceMode != 0) { + Register CCReg, DVReg; + InsertVXSplit(&CCReg, &DVReg, CC, LHS, RHS, *HeadMBB, HeadMBB->end(), DL, TII); + + BuildMI(HeadMBB, DL, TII.getBrCond(RISCVCC::COND_NE)) + .addReg(CCReg) + .addReg(RISCV::X0) + .addMBB(TailMBB); + + InsertVXJoin(DVReg, *TailMBB, TailMBB->begin(), DL, TII); + + LLVM_DEBUG(dbgs() << "*** Vortex: emitSelectPseudo\n" << *HeadMBB << "\n" << *TailMBB << "\n"); + LLVM_DEBUG(F->dump()); + } else { + // Insert appropriate branch. + if (MI.getOperand(2).isImm()) + BuildMI(HeadMBB, DL, TII.getBrCond(CC, MI.getOperand(2).isImm())) + .addReg(LHS) + .addImm(MI.getOperand(2).getImm()) + .addMBB(TailMBB); + else + BuildMI(HeadMBB, DL, TII.getBrCond(CC)) .addReg(LHS) .addReg(RHS) .addMBB(TailMBB); + } // IfFalseMBB just falls through to TailMBB. IfFalseMBB->addSuccessor(TailMBB); @@ -18708,6 +18865,9 @@ static MachineBasicBlock *emitFROUND(MachineInstr &MI, MachineBasicBlock *MBB, if (MI.getFlag(MachineInstr::MIFlag::NoFPExcept)) MIB->setFlag(MachineInstr::MIFlag::NoFPExcept); + llvm::errs() << "error: unimplemented divergent codegen found!\n"; + std::abort(); + // Insert branch. BuildMI(MBB, DL, TII.get(RISCV::BEQ)) .addReg(CmpReg) diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.td b/llvm/lib/Target/RISCV/RISCVInstrInfo.td index 04054d2c3feeed79c392b2e1246acb4d0df3db06..c3d113481342502b9ed7d120af227f2db2d0b7ea 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.td @@ -2079,6 +2079,7 @@ include "RISCVInstrInfoXVentana.td" include "RISCVInstrInfoXTHead.td" include "RISCVInstrInfoXSf.td" include "RISCVInstrInfoSFB.td" +include "RISCVInstrInfoVX.td" include "RISCVInstrInfoXCV.td" include "RISCVInstrInfoXwch.td" diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVX.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVX.td new file mode 100644 index 0000000000000000000000000000000000000000..36433496dc28e68684b7097408d37f987dc22be8 --- /dev/null +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVX.td @@ -0,0 +1,113 @@ +// Vortex instructions definitions + +def RISCV_CUSTOM0 : RISCVOpcode<"CUSTOM0", 0b0001011>; // 0x0B +def RISCV_CUSTOM1 : RISCVOpcode<"CUSTOM1", 0b0101011>; // 0x2B +def RISCV_CUSTOM2 : RISCVOpcode<"CUSTOM2", 0b1011011>; // 0x5B +def RISCV_CUSTOM3 : RISCVOpcode<"CUSTOM3", 0b1111011>; // 0x7B + +let hasSideEffects = 1, mayStore = 0 , mayLoad = 0 in { + +def VX_TMC : RVInstR<0, 0, RISCV_CUSTOM0, (outs), (ins GPR:$rs1), "vx_tmc", "$rs1">, Sched<[]> { + let rd = 0; + let rs2 = 0; +} + +def VX_WSPAWN : RVInstR<0, 1, RISCV_CUSTOM0, (outs), (ins GPR:$rs1, GPR:$rs2), "vx_wspawn", "$rs1, $rs2">, Sched<[]> { + let rd = 0; +} + +def VX_SPLIT : RVInstR<0, 2, RISCV_CUSTOM0, (outs GPR:$rd), (ins GPR:$rs1), "vx_split", "$rd, $rs1">, Sched<[]> { + let rs2 = 0; +} + +def VX_SPLIT_N : RVInstR<0, 2, RISCV_CUSTOM0, (outs GPR:$rd), (ins GPR:$rs1), "vx_split_n", "$rd, $rs1">, Sched<[]> { + let rs2 = 1; +} + +def VX_JOIN : RVInstR<0, 3, RISCV_CUSTOM0, (outs), (ins GPR:$rs1), "vx_join", "$rs1">, Sched<[]> { + let rd = 0; + let rs2 = 0; +} + +def VX_BAR : RVInstR<0, 4, RISCV_CUSTOM0, (outs), (ins GPR:$rs1, GPR:$rs2), "vx_bar", "$rs1, $rs2">, Sched<[]> { + let rd = 0; + let isBarrier = 1; +} + +def VX_PRED : RVInstR<0, 5, RISCV_CUSTOM0, (outs), (ins GPR:$rs1, GPR:$rs2), "vx_pred", "$rs1, $rs2">, Sched<[]> { + let rd = 0; +} + +def VX_PRED_N : RVInstR<0, 5, RISCV_CUSTOM0, (outs), (ins GPR:$rs1, GPR:$rs2), "vx_pred_n", "$rs1, $rs2">, Sched<[]> { + let rd = 1; +} + +def VX_RAST : RVInstR<1, 0, RISCV_CUSTOM0, (outs GPR:$rd), (ins), "vx_rast", "">, Sched<[]> { + let rs1 = 0; + let rs2 = 0; +} + +def VX_TEX : RVInstR4<0, 0, RISCV_CUSTOM1, (outs GPR:$rd), (GPR:$rs1, GPR:$rs2, GPR:$rs3), "vx_tex", "$rd, $rs1, $rs2, $rs3">, Sched<[]> {} + +def VX_ROP : RVInstR4<0, 1, RISCV_CUSTOM1, (outs), (ins GPR:$rs1, GPR:$rs2, GPR:$rs3), "vx_rop", "$rs1, $rs2, $rs3">, Sched<[]> { + let rd = 0; +} + +} + +def VX_MOV : Pseudo<(outs GPR:$dst), (ins GPR:$src), [], "vx_mov", "$dst, $src"> {} + +def CSR_NT : SysReg<"nt", 0xFC0>; +def CSR_NW : SysReg<"nw", 0xFC1>; +def CSR_NC : SysReg<"nw", 0xFC2>; +def CSR_NG : SysReg<"nw", 0xFC3>; +def CSR_TID : SysReg<"tid", 0xCC0>; +def CSR_WID : SysReg<"wid", 0xCC1>; +def CSR_CID : SysReg<"cid", 0xCC2>; +def CSR_GID : SysReg<"gid", 0xCC3>; +def CSR_TMASK : SysReg<"tmask", 0xCC4>; + +def : Pat<(int_riscv_vx_tid_i32), (CSRRS CSR_TID.Encoding, (XLenVT X0))>; +def : Pat<(int_riscv_vx_tid_i64), (CSRRS CSR_TID.Encoding, (XLenVT X0))>; + +def : Pat<(int_riscv_vx_wid_i32), (CSRRS CSR_WID.Encoding, (XLenVT X0))>; +def : Pat<(int_riscv_vx_wid_i64), (CSRRS CSR_WID.Encoding, (XLenVT X0))>; + +def : Pat<(int_riscv_vx_cid_i32), (CSRRS CSR_CID.Encoding, (XLenVT X0))>; +def : Pat<(int_riscv_vx_cid_i64), (CSRRS CSR_CID.Encoding, (XLenVT X0))>; + +def : Pat<(int_riscv_vx_nt_i32), (CSRRS CSR_NT.Encoding, (XLenVT X0))>; +def : Pat<(int_riscv_vx_nt_i64), (CSRRS CSR_NT.Encoding, (XLenVT X0))>; + +def : Pat<(int_riscv_vx_nw_i32), (CSRRS CSR_NW.Encoding, (XLenVT X0))>; +def : Pat<(int_riscv_vx_nw_i64), (CSRRS CSR_NW.Encoding, (XLenVT X0))>; + +def : Pat<(int_riscv_vx_nc_i32), (CSRRS CSR_NC.Encoding, (XLenVT X0))>; +def : Pat<(int_riscv_vx_nc_i64), (CSRRS CSR_NC.Encoding, (XLenVT X0))>; + +def : Pat<(int_riscv_vx_tmask_i32), (CSRRS CSR_TMASK.Encoding, (XLenVT X0))>; +def : Pat<(int_riscv_vx_tmask_i64), (CSRRS CSR_TMASK.Encoding, (XLenVT X0))>; + +def : Pat<(int_riscv_vx_tmc_i32 GPR:$rs1), (VX_TMC GPR:$rs1)>; +def : Pat<(int_riscv_vx_tmc_i64 GPR:$rs1), (VX_TMC GPR:$rs1)>; + +def : Pat<(int_riscv_vx_pred_i32 GPR:$rs1, GPR:$rs2), (VX_PRED GPR:$rs1, GPR:$rs2)>; +def : Pat<(int_riscv_vx_pred_i64 GPR:$rs1, GPR:$rs2), (VX_PRED GPR:$rs1, GPR:$rs2)>; + +def : Pat<(int_riscv_vx_pred_n_i32 GPR:$rs1, GPR:$rs2), (VX_PRED_N GPR:$rs1, GPR:$rs2)>; +def : Pat<(int_riscv_vx_pred_n_i64 GPR:$rs1, GPR:$rs2), (VX_PRED_N GPR:$rs1, GPR:$rs2)>; + +def : Pat<(int_riscv_vx_split_i32 GPR:$rs1), (VX_SPLIT GPR:$rs1)>; +def : Pat<(int_riscv_vx_split_i64 GPR:$rs1), (VX_SPLIT GPR:$rs1)>; + +def : Pat<(int_riscv_vx_split_n_i32 GPR:$rs1), (VX_SPLIT_N GPR:$rs1)>; +def : Pat<(int_riscv_vx_split_n_i64 GPR:$rs1), (VX_SPLIT_N GPR:$rs1)>; + +def : Pat<(int_riscv_vx_join_i32 GPR:$rs1), (VX_JOIN GPR:$rs1)>; +def : Pat<(int_riscv_vx_join_i64 GPR:$rs1), (VX_JOIN GPR:$rs1)>; + +def : Pat<(int_riscv_vx_mov_i32 GPR:$src), (VX_MOV GPR:$src)>; +def : Pat<(int_riscv_vx_mov_i64 GPR:$src), (VX_MOV GPR:$src)>; + +def : Pat<(int_riscv_vx_bar_i32 GPR:$rs1, GPR:$rs2), (VX_BAR GPR:$rs1, GPR:$rs2)>; +def : Pat<(int_riscv_vx_bar_i64 GPR:$rs1, GPR:$rs2), (VX_BAR GPR:$rs1, GPR:$rs2)>; diff --git a/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp b/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp index 21fbf47875e682ed2c19907e0a51254bb61e2e4f..7491641fa78aafaf6fce0430e26e149575f2dc50 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp +++ b/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp @@ -36,6 +36,7 @@ #include "llvm/Passes/PassBuilder.h" #include "llvm/Support/FormattedStream.h" #include "llvm/Target/TargetOptions.h" +#include "llvm/Transforms/Utils.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Vectorize/LoopIdiomVectorize.h" @@ -47,6 +48,20 @@ static cl::opt<bool> EnableRedundantCopyElimination( cl::desc("Enable the redundant copy elimination pass"), cl::init(true), cl::Hidden); +// 0: Disable Vortex Branch Divergence +// 1: Enable Vortex Branch Divergence with only non-regional-only divergent branches structurized first +// 2: Enable Vortex Branch Divergence with all divergent branches structurized first +static cl::opt<int> VortexBranchDivergenceMode( + "vortex-branch-divergence", + cl::desc("Set Vortex Branch Divergence Mode"), + cl::init(1)); +int gVortexBranchDivergenceMode = 0; + +static cl::opt<int> VortexKernelSchedulerMode( + "vortex-kernel-scheduler", + cl::desc("Set Vortex Kernel Scheduler Mode"), + cl::init(0)); + // FIXME: Unify control over GlobalMerge. static cl::opt<cl::boolOrDefault> EnableGlobalMerge("riscv-enable-global-merge", cl::Hidden, @@ -112,6 +127,15 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeRISCVTarget() { initializeRISCVPreLegalizerCombinerPass(*PR); initializeRISCVPostLegalizerCombinerPass(*PR); initializeKCFIPass(*PR); + if (VortexBranchDivergenceMode != 0) { + gVortexBranchDivergenceMode = 1; + initializeVortexBranchDivergence0Pass(*PR); + initializeVortexBranchDivergence1Pass(*PR); + initializeVortexBranchDivergence2Pass(*PR); + } + if (VortexKernelSchedulerMode != 0) { + initializeVortexIntrinsicFuncLoweringPass(*PR); + } initializeRISCVDeadRegisterDefinitionsPass(*PR); initializeRISCVMakeCompressibleOptPass(*PR); initializeRISCVGatherScatterLoweringPass(*PR); @@ -168,6 +192,11 @@ RISCVTargetMachine::RISCVTargetMachine(const Target &T, const Triple &TT, setMachineOutliner(true); setSupportsDefaultOutlining(true); + auto isVortex = FS.contains("vortex"); + if (isVortex) { + setRequiresStructuredCFG(true); + } + if (TT.isOSFuchsia() && !TT.isArch64Bit()) report_fatal_error("Fuchsia is only supported for 64-bit"); } @@ -445,6 +474,23 @@ bool RISCVPassConfig::addPreISel() { /* MergeExternalByDefault */ true)); } + if (TM->getTargetFeatureString().contains("vortex")) { + if (VortexBranchDivergenceMode != 0) { + addPass(createCFGSimplificationPass()); + addPass(createLoopSimplifyPass()); + addPass(createFixIrreduciblePass()); + addPass(createUnifyLoopExitsPass()); + addPass(createSinkingPass()); + addPass(createLowerSwitchPass()); + addPass(createFlattenCFGPass()); + addPass(createVortexBranchDivergence0Pass()); + addPass(createStructurizeCFGPass(true, (VortexBranchDivergenceMode == 1))); + addPass(createVortexBranchDivergence1Pass(VortexBranchDivergenceMode)); + } + if (VortexKernelSchedulerMode != 0) { + addPass(createVortexIntrinsicFuncLoweringPass()); + } + } return false; } @@ -531,6 +577,11 @@ void RISCVPassConfig::addPreEmitPass2() { addPass(createUnpackMachineBundles([&](const MachineFunction &MF) { return MF.getFunction().getParent()->getModuleFlag("kcfi"); })); + + if (TM->getTargetFeatureString().contains("vortex") + && VortexBranchDivergenceMode != 0) { + addPass(createVortexBranchDivergence2Pass(1)); + } } void RISCVPassConfig::addMachineSSAOptimization() { @@ -554,6 +605,11 @@ void RISCVPassConfig::addPreRegAlloc() { addPass(createRISCVInsertReadWriteCSRPass()); addPass(createRISCVInsertWriteVXRMPass()); + if (TM->getTargetFeatureString().contains("vortex") + && VortexBranchDivergenceMode != 0) { + addPass(createVortexBranchDivergence2Pass(0)); + } + // Run RISCVInsertVSETVLI after PHI elimination. On O1 and above do it after // register coalescing so needVSETVLIPHI doesn't need to look through COPYs. if (!EnableVSETVLIAfterRVVRegAlloc) { diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp index 5a92d6bab31a97114b353f831603a8bd8ec27e61..bd78f3e3dbb053a127d9447c6ae1c79c5bc2e0a9 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp @@ -1968,3 +1968,11 @@ bool RISCVTTIImpl::areInlineCompatible(const Function *Caller, // target-features. return (CallerBits & CalleeBits) == CalleeBits; } + +bool RISCVTTIImpl::hasBranchDivergence(const Function *F) { + return hasBranchDivergence_; +} + +bool RISCVTTIImpl::isSourceOfDivergence(const Value *V) { + return divergence_tracker_.eval(V); +} diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h index 9c37a4f6ec2d04f86df38c7c87e7ac0307402243..9fd0df19c1f92491eae8a61210091a95977a2857 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h @@ -23,6 +23,7 @@ #include "llvm/CodeGen/BasicTTIImpl.h" #include "llvm/IR/Function.h" #include <optional> +#include "VortexBranchDivergence.h" namespace llvm { @@ -34,6 +35,8 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> { const RISCVSubtarget *ST; const RISCVTargetLowering *TLI; + vortex::DivergenceTracker divergence_tracker_; + bool hasBranchDivergence_; const RISCVSubtarget *getST() const { return ST; } const RISCVTargetLowering *getTLI() const { return TLI; } @@ -58,7 +61,9 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> { public: explicit RISCVTTIImpl(const RISCVTargetMachine *TM, const Function &F) : BaseT(TM, F.getDataLayout()), ST(TM->getSubtargetImpl(F)), - TLI(ST->getTargetLowering()) {} + TLI(ST->getTargetLowering()) + , divergence_tracker_(F) + , hasBranchDivergence_(ST->hasVendorXVortex()) {} bool areInlineCompatible(const Function *Caller, const Function *Callee) const; @@ -96,6 +101,10 @@ public: TargetTransformInfo::PopcntSupportKind getPopcntSupport(unsigned TyWidth); + // Vortex extension + bool isSourceOfDivergence(const Value *V); + bool hasBranchDivergence(const Function *F); + bool shouldExpandReduction(const IntrinsicInst *II) const; bool supportsScalableVectors() const { return ST->hasVInstructions(); } bool enableOrderedReductions() const { return true; } diff --git a/llvm/lib/Target/RISCV/VortexBranchDivergence.cpp b/llvm/lib/Target/RISCV/VortexBranchDivergence.cpp new file mode 100644 index 0000000000000000000000000000000000000000..95dddd477998370975c57a40ca849a69041809f1 --- /dev/null +++ b/llvm/lib/Target/RISCV/VortexBranchDivergence.cpp @@ -0,0 +1,1225 @@ +#include "VortexBranchDivergence.h" + +#include "llvm/Support/Debug.h" +#include "RISCV.h" +#include "RISCVSubtarget.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" +#include "llvm/CodeGen/TargetPassConfig.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" + +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IntrinsicsRISCV.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/ModuleSlotTracker.h" +#include "llvm/IR/IRBuilder.h" + + +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils.h" +#include "llvm/Transforms/Utils/SSAUpdater.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" + + +#include "llvm/Analysis/UniformityAnalysis.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/RegionInfo.h" +#include "llvm/Analysis/RegionIterator.h" +#include "llvm/Analysis/RegionPass.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/PostDominators.h" + +#include "llvm/CodeGen/MachineBasicBlock.h" +#include "llvm/CodeGen/MachineConstantPool.h" +#include "llvm/CodeGen/MachineDominators.h" +#include "llvm/CodeGen/MachineFrameInfo.h" +#include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/CodeGen/MachineInstr.h" +#include "llvm/CodeGen/MachineInstrBuilder.h" +#include "llvm/CodeGen/MachineOperand.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" + +#include <iostream> + +using namespace vortex; +using namespace llvm; +using namespace llvm::PatternMatch; + +#define DEBUG_TYPE "vortex-branch-divergence" + +#ifndef NDEBUG +#define LLVM_DEBUG(x) do {x;} while (false) +#endif + +namespace vortex { + +class NamePrinter { +private: + std::unique_ptr<ModuleSlotTracker> MST_; + +public: + void init(Function* function) { + auto module = function->getParent(); + MST_ = std::make_unique<ModuleSlotTracker>(module); + MST_->incorporateFunction(*function); + } + + std::string ValueName(llvm::Value* V) { + std::string str("V."); + if (V->hasName()) { + str += std::string(V->getName().data(), V->getName().size()); + } else { + auto slot = MST_->getLocalSlot(V); + str += std::to_string(slot); + } + return str; + } + + std::string BBName(llvm::BasicBlock* BB) { + std::string str("BB."); + if (BB->hasName()) { + str += std::string(BB->getName().data(), BB->getName().size()); + } else { + auto slot = MST_->getLocalSlot(&BB->front()); + if (slot > 0) { + str += std::to_string(slot - 1); + } else { + str = ""; + } + } + return str; + } +}; + +static void FindSuccessor(DenseSet<BasicBlock *>& visited, + BasicBlock* current, + BasicBlock* target, + std::vector<BasicBlock*>& out) { + visited.insert(current); + auto branch = dyn_cast<BranchInst>(current->getTerminator()); + if (!branch) + return; + for (auto succ : branch->successors()) { + if (succ == target) { + out.push_back(current); + } else { + if (visited.count(succ) == 0) { + FindSuccessor(visited, succ, target, out); + } + } + } +} + +static void FindSuccessor(BasicBlock* start, BasicBlock* target, std::vector<BasicBlock*>& out) { + DenseSet<BasicBlock *> visited; + FindSuccessor(visited, start, target, out); +} + +class ReplaceSuccessor { +private: + DenseMap<std::pair<PHINode*, BasicBlock*>, PHINode*> phi_table_; + NamePrinter namePrinter_; + +public: + + void init(Function* function) { + namePrinter_.init(function); + phi_table_.clear(); + } + + bool replaceSuccessor(BasicBlock* BB, BasicBlock* oldSucc, BasicBlock* newSucc) { + auto branch = dyn_cast<BranchInst>(BB->getTerminator()); + if (branch) { + for (unsigned i = 0, n = branch->getNumSuccessors(); i < n; ++i) { + auto succ = branch->getSuccessor(i); + if (succ == oldSucc) { + LLVM_DEBUG(dbgs() << "****** replace " << namePrinter_.BBName(BB) << ".succ[" << i << "]: " << namePrinter_.BBName(oldSucc) << " with " << namePrinter_.BBName(newSucc) << "\n"); + branch->setSuccessor(i, newSucc); + this->replacePhiDefs(oldSucc, BB, newSucc); + return true; + } + } + } + return false; + } + + void replacePhiDefs(BasicBlock* block, BasicBlock* oldPred, BasicBlock* newPred) { + // process all phi nodes in old successor + for (auto II = block->begin(), IE = block->end(); II != IE; ++II) { + PHINode *phi = dyn_cast<PHINode>(II); + if (!phi) + continue; + + for (unsigned op = 0, nOps = phi->getNumOperands(); op != nOps; ++op) { + if (phi->getIncomingBlock(op) != oldPred) + continue; + + PHINode* phi_stub; + auto key = std::make_pair(phi, newPred); + auto entry = phi_table_.find(key); + if (entry != phi_table_.end()) { + phi_stub = entry->second; + } else { + // create corresponding Phi node in new block + phi_stub = PHINode::Create(phi->getType(), 1, phi->getName(), &newPred->front()); + phi_table_[key] = phi_stub; + + // add new phi to succesor's phi node + phi->addIncoming(phi_stub, newPred); + } + + // move phi's operand into new phi node + Value *del_value = phi->removeIncomingValue(op); + phi_stub->addIncoming(del_value, oldPred); + } + } + } +}; + +static void InsertBasicBlock(const std::vector<BasicBlock*> BBs, BasicBlock* succBB, BasicBlock* newBB) { + DenseMap<std::pair<PHINode*, BasicBlock*>, PHINode*> phi_table; + for (auto BB : BBs) { + auto TI = BB->getTerminator(); + TI->replaceSuccessorWith(succBB, newBB); + for (auto& I : *succBB) { + auto phi = dyn_cast<PHINode>(&I); + if (!phi) + continue; + for (unsigned op = 0, n = phi->getNumOperands(); op != n; ++op) { + if (phi->getIncomingBlock(op) != BB) + continue; + PHINode* phi_stub; + auto key = std::make_pair(phi, newBB); + auto entry = phi_table.find(key); + if (entry != phi_table.end()) { + phi_stub = entry->second; + } else { + // create corresponding Phi node in new block + phi_stub = PHINode::Create(phi->getType(), 1, phi->getName(), &newBB->front()); + phi_table[key] = phi_stub; + // add new phi to succesor's phi node + phi->addIncoming(phi_stub, newBB); + } + // move phi's operand into new phi node + auto value = phi->removeIncomingValue(op); + phi_stub->addIncoming(value, BB); + } + } + } +} + +static BasicBlock* SplitBasicBlockBefore(BasicBlock* BB, BasicBlock::iterator I, const Twine &BBName) { + assert(BB->getTerminator() && + "Can't use splitBasicBlockBefore on degenerate BB!"); + assert(I != BB->end() && + "Trying to get me to create degenerate basic block!"); + + assert((!isa<PHINode>(*I) || BB->getSinglePredecessor()) && + "cannot split on multi incoming phis"); + + auto New = BasicBlock::Create(BB->getContext(), BBName, BB->getParent(), BB); + // Save DebugLoc of split point before invalidating iterator. + auto Loc = I->getDebugLoc(); + // Move all of the specified instructions from the original basic block into + // the new basic block. + New->splice(New->end(), BB, I); + + // Loop through all of the predecessors of the 'this' block (which will be the + // predecessors of the New block), replace the specified successor 'this' + // block to point at the New block and update any PHI nodes in 'this' block. + // If there were PHI nodes in 'this' block, the PHI nodes are updated + // to reflect that the incoming branches will be from the New block and not + // from predecessors of the 'this' block. + SmallVector<BasicBlock *, 32> preds(predecessors(BB)); + for (auto Pred : preds) { + auto TI = Pred->getTerminator(); + TI->replaceSuccessorWith(BB, New); + BB->replacePhiUsesWith(Pred, New); + } + // Add a branch instruction from "New" to "this" Block. + auto BI = BranchInst::Create(BB, New); + BI->setDebugLoc(Loc); + + return New; +} + +/////////////////////////////////////////////////////////////////////////////// + +struct VortexBranchDivergence0 : public FunctionPass { +private: + UniformityInfo *UA_; + +public: + + static char ID; + + VortexBranchDivergence0(); + + StringRef getPassName() const override; + + void getAnalysisUsage(AnalysisUsage &AU) const override; + + bool runOnFunction(Function &F) override; +}; + +/////////////////////////////////////////////////////////////////////////////// + +class VortexBranchDivergence1 : public FunctionPass { +private: + + using StackEntry = std::pair<BasicBlock *, Value *>; + using StackVector = SmallVector<StackEntry, 16>; + + int divergenceMode_; + + ReplaceSuccessor replaceSuccessor_; + NamePrinter namePrinter_; + + std::vector<BasicBlock*> div_blocks_; + DenseSet<BasicBlock*> div_blocks_set_; + + std::vector<Loop*> loops_; + DenseSet<Loop*> loops_set_; + + UniformityInfo *UA_; + DominatorTree *DT_; + PostDominatorTree *PDT_; + LoopInfo *LI_; + RegionInfo *RI_; + + Type* SizeTTy_; + + Function *tmask_func_; + Function *pred_func_; + Function *pred_n_func_; + Function *tmc_func_; + Function *split_func_; + Function *split_n_func_; + Function *join_func_; + Function *mov_func_; + + void initialize(Function &F, const RISCVSubtarget &ST); + + void processBranches(LLVMContext* context, Function* function); + + void processLoops(LLVMContext* context, Function* function); + + bool isUniform(Instruction *T); + +public: + + static char ID; + + VortexBranchDivergence1(int divergenceMode = 0); + + StringRef getPassName() const override; + + void getAnalysisUsage(AnalysisUsage &AU) const override; + + bool runOnFunction(Function &F) override; +}; + +/////////////////////////////////////////////////////////////////////////////// + +struct VortexBranchDivergence2 : public MachineFunctionPass { +private: + int PassMode_; + +public: + static char ID; + VortexBranchDivergence2(int PassMode); + + bool runOnMachineFunction(MachineFunction &MF) override; + + StringRef getPassName() const override; +}; + +} + +/////////////////////////////////////////////////////////////////////////////// + +namespace llvm { + +void initializeVortexBranchDivergence0Pass(PassRegistry &); +void initializeVortexBranchDivergence1Pass(PassRegistry &); + +FunctionPass *createVortexBranchDivergence0Pass() { + return new VortexBranchDivergence0(); +} + +FunctionPass *createVortexBranchDivergence1Pass(int divergenceMode) { + return new VortexBranchDivergence1(divergenceMode); +} + +FunctionPass *createVortexBranchDivergence2Pass(int PassMode) { + return new VortexBranchDivergence2(PassMode); +} + +} + +INITIALIZE_PASS_BEGIN(VortexBranchDivergence0, "vortex-branch-divergence-0", + "Vortex Branch Divergence Pre-Pass", false, false) +INITIALIZE_PASS_DEPENDENCY(UniformityInfoWrapperPass) +INITIALIZE_PASS_END(VortexBranchDivergence0, "vortex-branch-divergence-0", + "Vortex Branch Divergence Pre-Pass", false, false) + +INITIALIZE_PASS_BEGIN(VortexBranchDivergence1, "vortex-branch-divergence-1", + "Vortex Branch Divergence", false, false) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopSimplify) +INITIALIZE_PASS_DEPENDENCY(RegionInfoPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(UniformityInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) +INITIALIZE_PASS_END(VortexBranchDivergence1, "vortex-branch-divergence-1", + "Vortex Branch Divergence", false, false) + +INITIALIZE_PASS(VortexBranchDivergence2, "VortexBranchDivergence-2", + "Vortex Branch Divergence Post-Pass", false, false) + +namespace vortex { + +char VortexBranchDivergence0::ID = 0; + +StringRef VortexBranchDivergence0::getPassName() const { + return "Vortex Unify Function Exit Nodes"; +} + +VortexBranchDivergence0::VortexBranchDivergence0() : FunctionPass(ID) { + initializeVortexBranchDivergence0Pass(*PassRegistry::getPassRegistry()); +} + +void VortexBranchDivergence0::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addPreservedID(BreakCriticalEdgesID); + AU.addPreservedID(LowerSwitchID); + AU.addRequired<UniformityInfoWrapperPass>(); + AU.addRequired<TargetPassConfig>(); + FunctionPass::getAnalysisUsage(AU); +} + +bool VortexBranchDivergence0::runOnFunction(Function &F) { + auto &Context = F.getContext(); + const auto &TPC = getAnalysis<TargetPassConfig>(); + const auto &TM = TPC.getTM<TargetMachine>(); + const auto &ST = TM.getSubtarget<RISCVSubtarget>(F); + + // Check if the Vortex extension is enabled + assert(ST.hasVendorXVortex()); + + LLVM_DEBUG(dbgs() << "*** Vortex Divergent Branch Handling Pass0 ***\n"); + + LLVM_DEBUG(dbgs() << "*** before Pass0 changes!\n" << F << "\n"); + + UA_ = &getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo(); + + bool changed = false; + + { + // Lower Select instructions into standard if-then-else branches + SmallVector <SelectInst*, 4> selects; + + for (auto I = inst_begin(F), E = inst_end(F); I != E; ++I) { + if (auto SI = dyn_cast<SelectInst>(&*I)) { + if (UA_->isUniform(SI)) + continue; + selects.emplace_back(SI); + } + } + + for (auto SI : selects) { + auto BB = SI->getParent(); + LLVM_DEBUG(dbgs() << "*** unswitching divergent select instruction: " << *SI << "\n"); + SplitBlockAndInsertIfThen(SI->getCondition(), SI, false); + auto CondBr = cast<BranchInst>(BB->getTerminator()); + auto ThenBB = CondBr->getSuccessor(0); + auto Phi = PHINode::Create(SI->getType(), 2, "unswitched.select", SI); + Phi->addIncoming(SI->getTrueValue(), ThenBB); + Phi->addIncoming(SI->getFalseValue(), BB); + SI->replaceAllUsesWith(Phi); + SI->eraseFromParent(); + changed = true; + } + } + + { + // Lower Min/Max intrinsics into standard if-then-else branches + SmallVector <MinMaxIntrinsic*, 4> MMs; + + for (auto I = inst_begin(F), E = inst_end(F); I != E; ++I) { + if (auto MMI = dyn_cast<MinMaxIntrinsic>(&*I)) { + if (UA_->isUniform(MMI)) + continue; + auto ID = MMI->getIntrinsicID(); + if (ID == Intrinsic::smin + || ID == Intrinsic::smax + || ID == Intrinsic::umin + || ID == Intrinsic::umax) { + MMs.emplace_back(MMI); + } + } + } + + for (auto MMI : MMs) { + LLVM_DEBUG(dbgs() << "*** unswitching divergent min/max instruction: " << *MMI << "\n"); + auto BB = MMI->getParent(); + auto ID = MMI->getIntrinsicID(); + auto LHS = MMI->getArgOperand(0); + auto RHS = MMI->getArgOperand(1); + IRBuilder<> Builder(MMI); + auto Cond = (ID == Intrinsic::smin || ID == Intrinsic::smax) + ? Builder.CreateICmpSLT(LHS, RHS) + : Builder.CreateICmpULT(LHS, RHS); + SplitBlockAndInsertIfThen(Cond, MMI, false); + auto CondBr = cast<BranchInst>(BB->getTerminator()); + auto ThenBB = CondBr->getSuccessor(0); + auto Phi = PHINode::Create(MMI->getType(), 2, "unswitched.minmax", MMI); + if (ID == Intrinsic::smin || ID == Intrinsic::umin) { + Phi->addIncoming(LHS, ThenBB); + Phi->addIncoming(RHS, BB); + } else { + Phi->addIncoming(RHS, ThenBB); + Phi->addIncoming(LHS, BB); + } + MMI->replaceAllUsesWith(Phi); + MMI->eraseFromParent(); + changed = true; + } + } + + { + std::vector<BasicBlock*> ReturningBlocks; + std::vector<BasicBlock*> UnreachableBlocks; + + for (auto& BB : F) { + if (isa<ReturnInst>(BB.getTerminator())) + ReturningBlocks.push_back(&BB); + else if (isa<UnreachableInst>(BB.getTerminator())) + UnreachableBlocks.push_back(&BB); + } + + // + // Handle return blocks + // + BasicBlock* ReturnBlock = nullptr; + if (ReturningBlocks.empty()) { + ReturnBlock = nullptr; + } else if (ReturningBlocks.size() == 1) { + ReturnBlock = ReturningBlocks.front(); + } else { + // Otherwise, fold all returns into a single exit block. + // We need to insert a new basic block into the function, add PHI + // nodes (if the function returns values), and convert all of the return + // instructions into unconditional branches. + BasicBlock *NewRetBlock = BasicBlock::Create(Context, "UnifiedReturnBlock", &F); + + PHINode *PN = nullptr; + if (F.getReturnType()->isVoidTy()) { + ReturnInst::Create(Context, nullptr, NewRetBlock); + } else { + // If the function doesn't return void... add a PHI node to the block... + PN = PHINode::Create(F.getReturnType(), ReturningBlocks.size(), "UnifiedRetVal"); + PN->insertInto(NewRetBlock, NewRetBlock->end()); + ReturnInst::Create(Context, PN, NewRetBlock); + } + + // Loop over all of the blocks, replacing the return instruction with an + // unconditional branch. + for (auto BB : ReturningBlocks) { + // Add an incoming element to the PHI node for every return instruction that + // is merging into this new block... + if (PN) + PN->addIncoming(BB->getTerminator()->getOperand(0), BB); + + BB->back().eraseFromParent();; // Remove the return insn + BranchInst::Create(NewRetBlock, BB); + } + ReturnBlock = NewRetBlock; + changed = true; + } + + // + // Handle unreacheable blocks + // + BasicBlock* UnreachableBlock = nullptr; + if (UnreachableBlocks.empty()) { + UnreachableBlock = nullptr; + } else if (UnreachableBlocks.size() == 1) { + UnreachableBlock = UnreachableBlocks.front(); + } else { + UnreachableBlock = BasicBlock::Create(Context, "UnifiedUnreachableBlock", &F); + new UnreachableInst(Context, UnreachableBlock); + for (BasicBlock *BB : UnreachableBlocks) { + BB->back().eraseFromParent(); // Remove the unreachable inst. + auto Br = BranchInst::Create(UnreachableBlock, BB); + Br->setMetadata("Unreachable", MDNode::get(Context, std::nullopt)); + } + changed = true; + } + + // Ensure single exit block + if (UnreachableBlock && ReturnBlock) { + + auto NewRetBlock = BasicBlock::Create(Context, "UnifiedReturnAndUnreachableBlock", &F); + auto RetType = F.getReturnType(); + PHINode* PN = nullptr; + + if (!RetType->isVoidTy()) { + // Need to insert PhI node to merge return values from incoming blocks + PN = PHINode::Create(RetType, ReturningBlocks.size(), "UnifiedReturnAndUnreachableVal"); + PN->insertInto(NewRetBlock, NewRetBlock->end()); + + auto DummyRetValue = llvm::Constant::getNullValue(RetType); + PN->addIncoming(DummyRetValue, UnreachableBlock); + + PN->addIncoming(ReturnBlock->getTerminator()->getOperand(0), ReturnBlock); + } + + ReturnInst::Create(Context, PN, NewRetBlock); + + UnreachableBlock->back().eraseFromParent(); + auto Br = BranchInst::Create(NewRetBlock, UnreachableBlock); + Br->setMetadata("Unreachable", llvm::MDNode::get(Context, std::nullopt)); + + ReturnBlock->back().eraseFromParent(); + BranchInst::Create(NewRetBlock, ReturnBlock); + + ReturnBlock = NewRetBlock; + changed = true; + } + } + + if (changed) { + LLVM_DEBUG(dbgs() << "*** after Pass0 changes!\n" << F << "\n"); + } + + return changed; +} + +/////////////////////////////////////////////////////////////////////////////// + +char VortexBranchDivergence1::ID = 0; + +VortexBranchDivergence1::VortexBranchDivergence1(int divergenceMode) + : FunctionPass(ID) + , divergenceMode_(divergenceMode) { + initializeVortexBranchDivergence1Pass(*PassRegistry::getPassRegistry()); +} + +StringRef VortexBranchDivergence1::getPassName() const { + return "Vortex Branch Divergence"; +} + +void VortexBranchDivergence1::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<LoopInfoWrapperPass>(); + AU.addRequired<RegionInfoPass>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<PostDominatorTreeWrapperPass>(); + AU.addRequired<UniformityInfoWrapperPass>(); + AU.addRequired<TargetPassConfig>(); + FunctionPass::getAnalysisUsage(AU); +} + +void VortexBranchDivergence1::initialize(Function &F, const RISCVSubtarget &ST) { + auto& M = *F.getParent(); + auto& Context = M.getContext(); + + auto sizeTSize = M.getDataLayout().getPointerSizeInBits(); + switch (sizeTSize) { + case 128: SizeTTy_ = llvm::Type::getInt128Ty(Context); break; + case 64: SizeTTy_ = llvm::Type::getInt64Ty(Context); break; + case 32: SizeTTy_ = llvm::Type::getInt32Ty(Context); break; + case 16: SizeTTy_ = llvm::Type::getInt16Ty(Context); break; + case 8: SizeTTy_ = llvm::Type::getInt8Ty(Context); break; + default: + LLVM_DEBUG(dbgs() << "Error: invalid pointer size: " << sizeTSize << "\n"); + } + + if (sizeTSize == 64) { + tmask_func_ = Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_tmask_i64); + pred_func_ = Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_pred_i64); + pred_n_func_= Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_pred_n_i64); + tmc_func_ = Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_tmc_i64); + split_func_ = Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_split_i64); + split_n_func_ = Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_split_n_i64); + join_func_ = Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_join_i64); + mov_func_ = Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_mov_i64); + } else { + assert(sizeTSize == 32); + tmask_func_ = Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_tmask_i32); + pred_func_ = Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_pred_i32); + pred_n_func_= Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_pred_n_i32); + tmc_func_ = Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_tmc_i32); + split_func_ = Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_split_i32); + split_n_func_ = Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_split_n_i32); + join_func_ = Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_join_i32); + mov_func_ = Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_mov_i32); + } + + RI_ = &getAnalysis<RegionInfoPass>().getRegionInfo(); + LI_ = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + UA_ = &getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo(); + DT_ = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + PDT_= &getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); + + namePrinter_.init(&F); + replaceSuccessor_.init(&F); + div_blocks_.clear(); + div_blocks_set_.clear(); + loops_.clear(); + loops_set_.clear(); +} + +bool VortexBranchDivergence1::runOnFunction(Function &F) { + const auto &TPC = getAnalysis<TargetPassConfig>(); + const auto &TM = TPC.getTM<TargetMachine>(); + const auto &ST = TM.getSubtarget<RISCVSubtarget>(F); + + // Check if the Vortex extension is enabled + assert(ST.hasVendorXVortex()); + + this->initialize(F, ST); + + auto &Context = F.getContext(); + + LLVM_DEBUG(dbgs() << "*** Region info:\n"); + LLVM_DEBUG(RI_->getTopLevelRegion()->dump()); + LLVM_DEBUG(dbgs() << "\n"); + + bool changed = false; + + for (auto I = df_begin(&F.getEntryBlock()), + E = df_end(&F.getEntryBlock()); I != E; ++I) { + auto BB = *I; + + auto Br = dyn_cast<BranchInst>(BB->getTerminator()); + if (!Br) + continue; + + // only process conditional branches + if (Br->isUnconditional()) { + LLVM_DEBUG(dbgs() << "*** skip non-conditional branch: " << namePrinter_.BBName(BB) << "\n"); + continue; + } + + // only process divergent branches + if (this->isUniform(Br)) { + LLVM_DEBUG(dbgs() << "*** skip uniform branch: " << namePrinter_.BBName(BB) << "\n"); + continue; + } + + auto loop = LI_->getLoopFor(BB); + if (loop) { + auto ipdom = PDT_->findNearestCommonDominator(Br->getSuccessor(0), Br->getSuccessor(1)); + if (ipdom && loop->contains(ipdom)) { + if (div_blocks_set_.insert(BB).second) { + // add new branch to the list + LLVM_DEBUG(dbgs() << "*** divergent branch: " << namePrinter_.BBName(BB) << "\n"); + div_blocks_.push_back(BB); + } + } else { + if (loops_set_.insert(loop).second) { + // add new loop to the list + LLVM_DEBUG(dbgs() << "*** divergent loop: " << namePrinter_.BBName(loop->getHeader()) << "\n"); + loops_.push_back(loop); + } + } + } else { + auto ipdom = PDT_->findNearestCommonDominator(Br->getSuccessor(0), Br->getSuccessor(1)); + if (ipdom == nullptr) { + llvm::errs() << "Warning: divergent branch with no IPDOM: " << namePrinter_.BBName(BB) << " --- skipping.\n"; + continue; + } + bool has_unreacheable = false; + for (auto succ : Br->successors()) { + if (succ->back().getMetadata("Unreachable") != nullptr) { + has_unreacheable = true; + break; + } + } + if (has_unreacheable) { + llvm::errs() << "Warning: divergent branch with unreachable IPDOM: " << namePrinter_.BBName(BB) << " --- skipping.\n"; + continue; + } + if (div_blocks_set_.insert(BB).second) { + // add new branch to the list + LLVM_DEBUG(dbgs() << "*** divergent branch: " << namePrinter_.BBName(BB) << ", IPDOM=" << namePrinter_.BBName(ipdom) << "\n"); + div_blocks_.push_back(BB); + } + } + } + + // apply transformation + if (!loops_.empty() || !div_blocks_.empty()) { + LLVM_DEBUG(dbgs() << "*** before changes!\n" << F << "\n"); + + // process the loop + // This should be done first such that loop analysis is not tempered + if (!loops_.empty()) { + this->processLoops(&Context, &F); + loops_.clear(); + // update PDT + PDT_->recalculate(F); + } + + // process branches + if (!div_blocks_.empty()) { + this->processBranches(&Context, &F); + div_blocks_.clear(); + } + + changed = true; + + LLVM_DEBUG(dbgs() << "*** after changes!\n" << F << "\n"); + } + + return changed; +} + +void VortexBranchDivergence1::processLoops(LLVMContext* context, Function* function) { + DenseSet<const BasicBlock *> stub_blocks; + + // traverse the list in reverse order + for (auto it = loops_.rbegin(), ite = loops_.rend(); it != ite; ++it) { + auto loop = *it; + auto header = loop->getHeader(); + + auto preheader = loop->getLoopPreheader(); + assert(preheader); + + auto preheader_term = preheader->getTerminator(); + assert(preheader_term); + + auto preheader_br = dyn_cast<BranchInst>(preheader_term); + assert(preheader_br); + + LLVM_DEBUG(dbgs() << "*** process loop: " << namePrinter_.BBName(header) << "\n"); + + // save current thread mask in preheader + auto tmask = CallInst::Create(tmask_func_, "tmask", preheader_br); + LLVM_DEBUG(dbgs() << "*** backup thread mask '" << namePrinter_.ValueName(tmask) << "' before loop preheader branch: " << namePrinter_.BBName(preheader) << "\n"); + + // restore thread mask at loop exit blocks + { + SmallVector <BasicBlock *, 8> exiting_blocks; + loop->getExitingBlocks(exiting_blocks); // blocks inside the loop going out + + for (auto exiting_block : exiting_blocks) { + int exit_edges = 0; + auto branch = dyn_cast<BranchInst>(exiting_block->getTerminator()); + for (auto succ : branch->successors()) { + // stub blocks insertion will generate invalid exiting blocks. + // we just need exclude those new blocks. + if (loop->contains(succ) + || stub_blocks.count(succ) != 0) + continue; + + if (branch->isUnconditional()) + continue; + + // ensure only one exit edge + assert(exit_edges == 0); + ++exit_edges; + + // insert a predicate instruction to mask out threads that are exiting the loop + + IRBuilder<> ir_builder(branch); + auto succ0 = branch->getSuccessor(0); + auto cond_orig = branch->getCondition(); + auto cond_orig_i1 = ir_builder.CreateICmpNE(cond_orig, ConstantInt::get(cond_orig->getType(), 0), namePrinter_.ValueName(cond_orig) + ".to.i1"); + auto cond_orig_i32 = ir_builder.CreateIntCast(cond_orig_i1, SizeTTy_, false, namePrinter_.ValueName(cond_orig_i1) + ".to.i32"); + + // insert a custom mov instruction to prevent branch condition from being optimized away during codegen + auto cond = CallInst::Create(mov_func_, cond_orig_i32, namePrinter_.ValueName(cond_orig_i32) + ".mov", branch); + + LLVM_DEBUG(dbgs() << "*** insert thread predicate '" << namePrinter_.ValueName(cond) << "' before exiting block: " << namePrinter_.BBName(exiting_block) << "\n"); + if (!loop->contains(succ0)) { + CallInst::Create(pred_n_func_, {cond, tmask}, "", branch); + } else { + CallInst::Create(pred_func_, {cond, tmask}, "", branch); + } + LLVM_DEBUG(dbgs() << "*** after predicate change!\n" << function << "\n"); + + // change branch condition + auto cond_i1 = ir_builder.CreateICmpNE(cond, ConstantInt::get(SizeTTy_, 0), namePrinter_.ValueName(cond) + ".to.i1"); + branch->setCondition(cond_i1); + } + } + } + } +} + +void VortexBranchDivergence1::processBranches(LLVMContext* context, Function* function) { + std::unordered_map<BasicBlock*, BasicBlock*> ipdoms; + + // pre-gather ipdoms for divergent branches + for (auto BI = div_blocks_.rbegin(), BIE = div_blocks_.rend(); BI != BIE; ++BI) { + auto block = *BI; + auto branch = dyn_cast<BranchInst>(block->getTerminator()); + assert(branch); + auto ipdom = PDT_->findNearestCommonDominator(branch->getSuccessor(0), branch->getSuccessor(1)); + if (ipdom == nullptr) { + llvm::errs() << "error: divergent branch with no IPDOM: " << namePrinter_.BBName(block) << "\n"; + std::abort(); + } + ipdoms[block] = ipdom; + } + + // traverse the list in reverse order + for (auto BI = div_blocks_.rbegin(), BIE = div_blocks_.rend(); BI != BIE; ++BI) { + auto block = *BI; + auto ipdom = ipdoms[block]; + auto branch = dyn_cast<BranchInst>(block->getTerminator()); + assert(branch); +#ifndef NDEBUG + auto region = RI_->getRegionFor(block); + LLVM_DEBUG(dbgs() << "*** process branch " << namePrinter_.BBName(block) << ", region=" << region->getNameStr() << "\n"); +#endif + // insert a mov instruction before split + IRBuilder<> ir_builder(branch); + auto cond_orig = branch->getCondition(); + auto cond_orig_i1 = ir_builder.CreateICmpNE(cond_orig, ConstantInt::get(cond_orig->getType(), 0), namePrinter_.ValueName(cond_orig) + ".to.i1"); + auto cond_orig_i32 = ir_builder.CreateIntCast(cond_orig_i1, SizeTTy_, false, namePrinter_.ValueName(cond_orig_i1) + ".to.i32"); + auto cond = CallInst::Create(mov_func_, cond_orig_i32, namePrinter_.ValueName(cond_orig_i32) + ".mov", branch); + + // insert split instruction before divergent branch + LLVM_DEBUG(dbgs() << "*** insert split '" << namePrinter_.ValueName(cond) << "' before " << namePrinter_.BBName(block) << "'s branch.\n"); + auto stack_ptr = CallInst::Create(split_func_, cond, "", branch); + + // change branch condition + auto cond_i1 = ir_builder.CreateICmpNE(cond, ConstantInt::get(SizeTTy_, 0), namePrinter_.ValueName(cond) + ".to.i1"); + branch->setCondition(cond_i1); + + // insert a join stub block before ipdom + auto stub = BasicBlock::Create(*context, "join_stub", function, ipdom); + LLVM_DEBUG(dbgs() << "*** insert join stub '" << stub->getName() << "' before " << namePrinter_.BBName(ipdom) << "\n"); + auto stub_br = BranchInst::Create(ipdom, stub); + CallInst::Create(join_func_, stack_ptr, "", stub_br); + std::vector<BasicBlock*> preds; + FindSuccessor(block, ipdom, preds); + for (auto pred : preds) { + bool found = replaceSuccessor_.replaceSuccessor(pred, ipdom, stub); + if (!found) { + std::abort(); + } + } + } +} + +bool VortexBranchDivergence1::isUniform(Instruction *I) { + return UA_->isUniform(I) + || (I->getMetadata("structurizecfg.uniform") != nullptr); +} + +/////////////////////////////////////////////////////////////////////////////// + +DivergenceTracker::DivergenceTracker(const Function &function) + : function_(&function) + , initialized_(false) +{} + +void DivergenceTracker::initialize() { + LLVM_DEBUG(dbgs() << "*** DivergenceTracker::initialize(): " << function_->getName() << "\n"); + + initialized_ = true; + + DenseSet<const Value *> dv_annotations; + DenseSet<const Value *> uv_annotations; + + auto module = function_->getParent(); + + // Mark all TLS globals as divergent + for (auto& GV : module->globals()) { + if (GV.isThreadLocal()) { + if (dv_nodes_.insert(&GV).second) { + LLVM_DEBUG(dbgs() << "*** divergent global variable: " << GV.getName() << "\n"); + } + } + } + + for (auto& BB : *function_) { + for (auto& I : BB) { + //LLVM_DEBUG(dbgs() << "*** instruction: opcode=" << I.getOpcodeName() << ", name=" << I.getName() << "\n"); + if (I.getMetadata("vortex.uniform") != NULL) { + uv_annotations.insert(&I); + uv_nodes_.insert(&I); + LLVM_DEBUG(dbgs() << "*** uniform annotation: " << I.getName() << "\n"); + } else + if (I.getMetadata("vortex.divergent") != NULL) { + dv_annotations.insert(&I); + dv_nodes_.insert(&I); + LLVM_DEBUG(dbgs() << "*** divergent annotation: " << I.getName() << "\n"); + } else + if (auto II = dyn_cast<llvm::IntrinsicInst>(&I)) { + if (II->getIntrinsicID() == llvm::Intrinsic::var_annotation) { + auto gv = dyn_cast<GlobalVariable>(II->getOperand(1)); + auto cda = dyn_cast<ConstantDataArray>(gv->getInitializer()); + if (cda->getAsCString() == "vortex.uniform") { + Value* var_src = nullptr; + auto var = II->getOperand(0); + if (auto AI = dyn_cast<AllocaInst>(var)) { + var_src = AI; + LLVM_DEBUG(dbgs() << "*** uniform annotation: " << AI->getName() << ".src(" << var_src << ")\n"); + } else + if (auto CI = dyn_cast<CastInst>(var)) { + var_src = CI->getOperand(0); + LLVM_DEBUG(dbgs() << "*** uniform annotation: " << CI->getName() << ".src(" << var_src << ")\n"); + } + uv_annotations.insert(var_src); + uv_nodes_.insert(var_src); + } else + if (cda->getAsCString() == "vortex.divergent") { + Value* var_src = nullptr; + auto var = II->getOperand(0); + if (auto AI = dyn_cast<AllocaInst>(var)) { + var_src = AI; + LLVM_DEBUG(dbgs() << "*** uniform annotation: " << AI->getName() << ".src(" << var_src << "\n"); + } else + if (auto CI = dyn_cast<CastInst>(var)) { + var_src = CI->getOperand(0); + LLVM_DEBUG(dbgs() << "*** uniform annotation: " << CI->getName() << ".src(" << var_src << "\n"); + } + dv_annotations.insert(var_src); + dv_nodes_.insert(var_src); + } + } + } + } + } + + // Mark the value of divergent stores as divergent + for (auto& BB : *function_) { + for (auto& I : BB) { + if (auto GE = dyn_cast<GetElementPtrInst>(&I)) { + auto addr = GE->getPointerOperand(); + if (uv_annotations.count(addr) != 0) { + LLVM_DEBUG(dbgs() << "*** uniform annotation: " << GE->getName() << "\n"); + uv_nodes_.insert(GE); + } else + if (dv_annotations.count(addr) != 0) { + LLVM_DEBUG(dbgs() << "*** divergent annotation: " << GE->getName() << "\n"); + dv_nodes_.insert(GE); + } + } else + if (auto SI = dyn_cast<StoreInst>(&I)) { + auto addr = SI->getPointerOperand(); + if (uv_annotations.count(addr) != 0) { + auto value = SI->getValueOperand(); + if (auto CI = dyn_cast<CastInst>(value)) { + LLVM_DEBUG(dbgs() << "*** uniform annotation: " << CI->getName() << ".src\n"); + auto src = CI->getOperand(0); + uv_nodes_.insert(src); + } else { + LLVM_DEBUG(dbgs() << "*** uniform annotation: " << SI->getName() << ".value\n"); + uv_nodes_.insert(value); + } + } else + if (dv_annotations.count(addr) != 0) { + auto value = SI->getValueOperand(); + if (auto CI = dyn_cast<CastInst>(value)) { + LLVM_DEBUG(dbgs() << "*** divergent annotation: " << CI->getName() << ".src\n"); + auto src = CI->getOperand(0); + dv_nodes_.insert(src); + } else { + LLVM_DEBUG(dbgs() << "*** divergent annotation: " << SI->getName() << ".value\n"); + dv_nodes_.insert(value); + } + } + } + } + } +} + +bool DivergenceTracker::eval(const Value *V) { + if (!initialized_) { + this->initialize(); + } + + // Mark annotated uniform variables + if (uv_nodes_.count(V) != 0) { + LLVM_DEBUG(dbgs() << "*** uniform annotated variable: " << V->getName() << "\n"); + return false; + } + + // Mark annotated divergent variables + if (dv_nodes_.count(V) != 0) { + LLVM_DEBUG(dbgs() << "*** divergent annotated variable: " << V->getName() << "\n"); + return true; + } + + // We conservatively assume all function arguments to potentially be divergent + if (isa<Argument>(V)) { + LLVM_DEBUG(dbgs() << "*** divergent function argument: " << V->getName() << "\n"); + return true; + } + + // We conservatively assume function return values are divergent + if (isa<CallInst>(V)) { + LLVM_DEBUG(dbgs() << "*** divergent return variable: " << V->getName() << "\n"); + return true; + } + + // Atomics are divergent because they are executed sequentially: when an + // atomic operation refers to the same address in each thread, then each + // thread after the first sees the value written by the previous thread as + // original value. + if (isa<AtomicRMWInst>(V) + || isa<AtomicCmpXchgInst>(V)) { + LLVM_DEBUG(dbgs() << "*** divergent atomic variable: " << V->getName() << "\n"); + return true; + } + + // Mark loads from divergent addresses as divergent + if (auto LD = dyn_cast<LoadInst>(V)) { + auto addr = LD->getPointerOperand(); + if (dv_nodes_.count(addr) != 0) { + LLVM_DEBUG(dbgs() << "*** divergent load variable: " << V->getName() << "\n"); + return true; + } + } + + return false; +} + +/////////////////////////////////////////////////////////////////////////////// + +VortexBranchDivergence2::VortexBranchDivergence2(int PassMode) + : MachineFunctionPass(ID) + , PassMode_(PassMode) +{} + +static bool FindNextJoin(MachineBasicBlock::iterator* out, + const MachineBasicBlock::iterator& start, + const MachineBasicBlock& curMBB) { + for (auto it = start; it != curMBB.end(); ++it) { + if (it->getOpcode() == RISCV::VX_JOIN) { + *out = it; + return true; + } + } + if (curMBB.succ_size() == 1) { + auto succMBB = *curMBB.succ_begin(); + return FindNextJoin(out, succMBB->begin(), *succMBB); + } + return false; +} + +bool VortexBranchDivergence2::runOnMachineFunction(MachineFunction &MF) { + const auto &ST = MF.getSubtarget<RISCVSubtarget>(); + auto TII = ST.getInstrInfo(); + auto& MRI = MF.getRegInfo(); + + // Check if the Vortex extension is enabled + assert(ST.hasVendorXVortex()); + + bool Changed = false; + + switch (PassMode_) { + case 0: + for (auto& MBB : MF) { + for (auto _MII = MBB.instr_begin(), MIIEnd = MBB.instr_end(); _MII != MIIEnd;) { + auto MII = _MII++; + auto& MI = *MII; + if (MI.getOpcode() == RISCV::VX_MOV) { + auto DestReg = MI.getOperand(0).getReg(); + auto SrcReg = MI.getOperand(1).getReg(); + MRI.replaceRegWith(DestReg, SrcReg); + MI.eraseFromParent(); + Changed = true; + } + } + } + break; + + case 1: + for (auto& MBB : MF) { + for (auto _MII = MBB.instr_begin(), MIIEnd = MBB.instr_end(); _MII != MIIEnd;) { + auto MII = _MII++; + auto& MI = *MII; + if (!(MI.getOpcode() == RISCV::VX_SPLIT + || MI.getOpcode() == RISCV::VX_SPLIT_N)) + continue; + + // find the corresponding branch instruction + auto MII_br = MII; + for (;MII_br != MIIEnd; ++MII_br) { + if (MII_br->isBranch()) + break; + } + + if (MII_br == MIIEnd + || MII_br->getOpcode() == RISCV::PseudoBR) { + // if a join instruction is found in same or proceeding fallthrough blocks, + // that means the protected branch was removed during optimization passes + // we can safely remove the left-out split and join instructions + MachineBasicBlock::iterator MII_join; + if (FindNextJoin(&MII_join, std::next(MII), MBB)) { + if (_MII == MII_join) { + ++_MII; + } + MII_join->eraseFromParent(); + MI.eraseFromParent(); + LLVM_DEBUG(dbgs() << "*** Vortex: cleanup removed branches!\n"); + Changed = true; + continue; + } + + llvm::errs() << "error: missing divergent branch!\n" << MBB << "\n"; + std::abort(); + } + + // ensure Branch BEQ/BNE xi, x0 + if (!(MII_br->getOpcode() == RISCV::BEQ + || MII_br->getOpcode() == RISCV::BNE) + || !MII_br->getOperand(0).isReg() + || !MII_br->getOperand(1).isReg() + || MII_br->getOperand(1).getReg() != RISCV::X0) { + llvm::errs() << "error: unsupported divergent branch!\n" << MBB << "\n"; + std::abort(); + } + + // ensure branch opcode match + if (MII_br->getOpcode() == RISCV::BEQ) { + switch (MI.getOpcode()) { + case RISCV::VX_SPLIT: + MI.setDesc(TII->get(RISCV::VX_SPLIT_N)); + break; + case RISCV::VX_SPLIT_N: + MI.setDesc(TII->get(RISCV::VX_SPLIT)); + break; + } + LLVM_DEBUG(dbgs() << "*** Vortex: fixed predicate opcode!\n"); + Changed = true; + continue; + } + } + } + break; + } + + if (Changed) { + LLVM_DEBUG(dbgs() << "*** after changes!\n" << MF.getName() << "\n"); + LLVM_DEBUG(MF.dump();); + } + + return false; +} + +StringRef VortexBranchDivergence2::getPassName() const { + return "VortexBranchDivergence2"; +} + +char VortexBranchDivergence2::ID = 0; + +} // vortex diff --git a/llvm/lib/Target/RISCV/VortexBranchDivergence.h b/llvm/lib/Target/RISCV/VortexBranchDivergence.h new file mode 100644 index 0000000000000000000000000000000000000000..f0848d9987669243c788fbdb1118262b0985ea81 --- /dev/null +++ b/llvm/lib/Target/RISCV/VortexBranchDivergence.h @@ -0,0 +1,22 @@ +#include "llvm/ADT/DenseSet.h" +#include "llvm/IR/Value.h" + +namespace vortex { +using namespace llvm; + +class DivergenceTracker { +public: + DivergenceTracker(const Function &function); + + bool eval(const Value *V); + +private: + void initialize(); + + DenseSet<const Value *> dv_nodes_; + DenseSet<const Value *> uv_nodes_; + const Function* function_; + bool initialized_; +}; + +} \ No newline at end of file diff --git a/llvm/lib/Target/RISCV/VortexIntrinsicFunc.cpp b/llvm/lib/Target/RISCV/VortexIntrinsicFunc.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7ede9f1f23d909dc56a8dda8743f80f29e27fa41 --- /dev/null +++ b/llvm/lib/Target/RISCV/VortexIntrinsicFunc.cpp @@ -0,0 +1,246 @@ +#include "RISCV.h" +#include "RISCVSubtarget.h" + +//#include "llvm-c/Core.h" + +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/IntrinsicsRISCV.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/Value.h" + +#include "llvm/Pass.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "RISCV.h" +#include "RISCVSubtarget.h" +#include "llvm/InitializePasses.h" +#include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/CodeGen/ValueTypes.h" +#include "llvm/ADT/APInt.h" + +#include <iostream> +#include <set> +#include <string> +#include <vector> + +using namespace llvm; +// using namespace vortex; + +#define DEBUG_TYPE "vortex-intrinsic-lowering" + +void printIR(llvm::Module *module_) { + std::string module_str; + llvm::raw_string_ostream ostream{module_str}; + module_->print(ostream, nullptr, false); + std::cout << module_str << std::endl; +} + +class VortexIntrinsicFuncLowering final : public ModulePass { + + bool runOnModule(Module &M) override; + bool Modified; + + public: + static char ID; + VortexIntrinsicFuncLowering(); +}; + +namespace llvm { + + void initializeVortexIntrinsicFuncLoweringPass(PassRegistry &); + ModulePass *createVortexIntrinsicFuncLoweringPass() { + return new VortexIntrinsicFuncLowering(); + } +} // End namespace llvm + +INITIALIZE_PASS(VortexIntrinsicFuncLowering, DEBUG_TYPE, + "Fix function bitcasts for AMDGPU", false, false) + +char VortexIntrinsicFuncLowering::ID = 0; + +VortexIntrinsicFuncLowering::VortexIntrinsicFuncLowering() : ModulePass(ID) { + initializeVortexIntrinsicFuncLoweringPass(*PassRegistry::getPassRegistry()); +} + +int CheckFTarget(std::vector<StringRef> FTargets, StringRef fname) { + for (size_t i = 0; i < FTargets.size(); i++) { + if (FTargets[i].equals(fname)) { + return (i); + } + } + return -1; +} + +bool VortexIntrinsicFuncLowering::runOnModule(Module &M) { + Modified = false; + std::cerr << "VORTEX Intrinsic Func pass " << std::endl; + + std::set<llvm::Function *> DeclToRemove; + std::set<llvm::Instruction *> CallToRemove; + std::set<llvm::Instruction *> vxBarCallToRemove; + std::vector<StringRef> FTargets = { + "vx_barrier", + "vx_num_threads", "vx_num_warps", "vx_num_cores", + "vx_thread_id", "vx_warp_id", "vx_core_id", + "vx_thread_mask", "vx_tmc"}; + + //Type* SizeTTy_; + auto& Context = M.getContext(); + + auto sizeTSize = M.getDataLayout().getPointerSizeInBits(); + /*switch (sizeTSize) { + case 128: SizeTTy_ = llvm::Type::getInt128Ty(Context); break; + case 64: SizeTTy_ = llvm::Type::getInt64Ty(Context); break; + case 32: SizeTTy_ = llvm::Type::getInt32Ty(Context); break; + case 16: SizeTTy_ = llvm::Type::getInt16Ty(Context); break; + case 8: SizeTTy_ = llvm::Type::getInt8Ty(Context); break; + default: + SizeTTy_ = llvm::Type::getInt32Ty(Context); break; + }*/ + + Function *bar_func_; + Function *tid_func_; + Function *wid_func_; + Function *cid_func_; + Function *nt_func_; + Function *nw_func_; + Function *nc_func_; + Function* tmask_func_; + Function *tmc_func_; + + if (sizeTSize == 64) { + bar_func_ = Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_bar_i64); + tid_func_ = Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_tid_i64); + wid_func_ = Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_wid_i64); + cid_func_ = Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_cid_i64); + nt_func_ = Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_nt_i64); + nw_func_ = Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_nw_i64); + nc_func_ = Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_nc_i64); + tmask_func_ = Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_tmask_i64); + tmc_func_ = Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_tmc_i64); + } else { + assert(sizeTSize == 32); + bar_func_ = Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_bar_i32); + tid_func_ = Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_tid_i32); + wid_func_ = Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_wid_i32); + cid_func_ = Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_cid_i32); + nt_func_ = Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_nt_i32); + nw_func_ = Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_nw_i32); + nc_func_ = Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_nc_i32); + tmask_func_ = Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_tmask_i32); + tmc_func_ = Intrinsic::getDeclaration(&M, Intrinsic::riscv_vx_tmc_i32); + } + + // Find tharget vx intrinsic + for (llvm::Module::iterator I = M.begin(), E = M.end(); I != E; ++I) { + llvm::Function *F = &*I; + if (F->isDeclaration()) { + int check = CheckFTarget(FTargets, F->getName()); + if (check != -1) + DeclToRemove.insert(F); + continue; + } + for (Function::iterator I = F->begin(), E = F->end(); I != E; ++I) { + for (BasicBlock::iterator BI = I->begin(), BE = I->end(); BI != BE; + ++BI) { + Instruction *Instr = dyn_cast<Instruction>(BI); + if (!llvm::isa<CallInst>(Instr)) + continue; + CallInst *CallInstr = dyn_cast<CallInst>(Instr); + Function *Callee = CallInstr->getCalledFunction(); + if (Callee == nullptr) + continue; + + int check = CheckFTarget(FTargets, Callee->getName()); + + if (check == 0) { + vxBarCallToRemove.insert(Instr); + + } else if (check == 1) { + auto ntinst = CallInst::Create(nt_func_, "nt", Instr); + Instr->replaceAllUsesWith(ntinst); + CallToRemove.insert(Instr); + + } else if (check == 2) { + auto nwinst = CallInst::Create(nw_func_, "nw", Instr); + Instr->replaceAllUsesWith(nwinst); + CallToRemove.insert(Instr); + + } else if (check == 3) { + auto ncinst = CallInst::Create(nc_func_, "nc", Instr); + Instr->replaceAllUsesWith(ncinst); + CallToRemove.insert(Instr); + + } else if (check == 4) { + auto tidinst = CallInst::Create(tid_func_, "tid", Instr); + Instr->replaceAllUsesWith(tidinst); + CallToRemove.insert(Instr); + + } else if (check == 5) { + auto widinst = CallInst::Create(wid_func_, "wid", Instr); + Instr->replaceAllUsesWith(widinst); + CallToRemove.insert(Instr); + + } else if (check == 6) { + auto cidinst = CallInst::Create(cid_func_, "cid", Instr); + Instr->replaceAllUsesWith(cidinst); + CallToRemove.insert(Instr); + + } else if (check == 7) { + auto tmaskinst = CallInst::Create(tmask_func_, "tmask", Instr); + Instr->replaceAllUsesWith(tmaskinst); + CallToRemove.insert(Instr); + + } else if (check == 8) { + CallInst *Callinst = dyn_cast<CallInst>(Instr); + auto tmask = Callinst->getArgOperand(0); + auto tmcinst = CallInst::Create(tmc_func_, {tmask}, "", Instr); + Instr->replaceAllUsesWith(tmcinst); + CallToRemove.insert(Instr); + } + } // end of BB loop + } // end of F loop + } // end of M loop + + // Insert vx_barrier(barCnt, warp_size) + if (!vxBarCallToRemove.empty()) { + int barCnt = 1; + for (auto B : vxBarCallToRemove) { + + CallInst *Callinst = dyn_cast<CallInst>(B); + LLVMContext &context = M.getContext(); + auto barID = + llvm::ConstantInt::get(context, llvm::APInt(32, (barCnt++), false)); + auto barCnt = Callinst->getArgOperand(1); + // auto barCnt = llvm::ConstantInt::get(context, llvm::APInt(32, 4, + // false)); + auto barinst = CallInst::Create(bar_func_, {barID, barCnt}, "", B); + B->replaceAllUsesWith(barinst); + } + Modified = true; + } + + for (auto B : vxBarCallToRemove) { + B->eraseFromParent(); + } + + if (!CallToRemove.empty()) + Modified = true; + + for (auto B : CallToRemove) { + B->eraseFromParent(); + } + + for (auto F : DeclToRemove) { + F->eraseFromParent(); + } + printIR(&M); + return Modified; +} diff --git a/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp b/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp index 9c711ec183821ff8d9319f2035f4a434f2225e48..f7ef7f83070e7bac912c1b5a045c4a4be2435fca 100644 --- a/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp +++ b/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp @@ -54,6 +54,10 @@ using namespace llvm::PatternMatch; #define DEBUG_TYPE "structurizecfg" +#ifndef NDEBUG +#define LLVM_DEBUG(x) do {x;} while (false) +#endif + // The name for newly created blocks. const char FlowBlockName[] = "Flow"; @@ -321,16 +325,20 @@ public: void init(Region *R); bool run(Region *R, DominatorTree *DT); bool makeUniformRegion(Region *R, UniformityInfo &UA); + bool skipRegionalBranches(Region *R, UniformityInfo &UA); }; class StructurizeCFGLegacyPass : public RegionPass { bool SkipUniformRegions; + bool SkipRegionalBranches; public: static char ID; - explicit StructurizeCFGLegacyPass(bool SkipUniformRegions_ = false) - : RegionPass(ID), SkipUniformRegions(SkipUniformRegions_) { + explicit StructurizeCFGLegacyPass(bool SkipUniformRegions_ = false, + bool SkipRegionalBranches_ = false) + : RegionPass(ID), SkipUniformRegions(SkipUniformRegions_) + , SkipRegionalBranches(SkipRegionalBranches_) { if (ForceSkipUniformRegions.getNumOccurrences()) SkipUniformRegions = ForceSkipUniformRegions.getValue(); initializeStructurizeCFGLegacyPassPass(*PassRegistry::getPassRegistry()); @@ -345,6 +353,11 @@ public: if (SCFG.makeUniformRegion(R, UA)) return false; } + if (SkipRegionalBranches) { + auto &UA = getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo(); + if (SCFG.skipRegionalBranches(R, UA)) + return false; + } DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); return SCFG.run(R, DT); } @@ -352,7 +365,7 @@ public: StringRef getPassName() const override { return "Structurize control flow"; } void getAnalysisUsage(AnalysisUsage &AU) const override { - if (SkipUniformRegions) + if (SkipUniformRegions || SkipRegionalBranches) AU.addRequired<UniformityInfoWrapperPass>(); AU.addRequired<DominatorTreeWrapperPass>(); @@ -1164,6 +1177,37 @@ bool StructurizeCFG::makeUniformRegion(Region *R, UniformityInfo &UA) { return false; } +bool StructurizeCFG::skipRegionalBranches(Region *R, UniformityInfo &UA) { + // Only structurize regions that contains divergent and non-regional sub-regions + for (auto E : R->elements()) { + if (E->isSubRegion()) + continue; + + auto BB = E->getEntry(); + auto Br = dyn_cast<BranchInst>(BB->getTerminator()); + + // Skip non-conditional branches + if (Br == nullptr || !Br->isConditional()) + continue; + + // Skip uniform conditional branches + if (UA.isUniform(Br)) + continue; + + // Skip divergent branches that are regional + if (BB == R->getEntry()) { + LLVM_DEBUG(dbgs() << "*** structurize: skip divergent complete region " << *R << "\n"); + continue; + } + + // This basicblock is divergent and non-regional + LLVM_DEBUG(dbgs() << "*** structurize: divergent non-regional block: " << *R << "\n"); + return false; + } + + return true; +} + /// Run the transformation for each region found bool StructurizeCFG::run(Region *R, DominatorTree *DT) { if (R->isTopLevelRegion()) @@ -1202,8 +1246,9 @@ bool StructurizeCFG::run(Region *R, DominatorTree *DT) { return true; } -Pass *llvm::createStructurizeCFGPass(bool SkipUniformRegions) { - return new StructurizeCFGLegacyPass(SkipUniformRegions); +Pass *llvm::createStructurizeCFGPass(bool SkipUniformRegions, + bool SkipRegionalBranches) { + return new StructurizeCFGLegacyPass(SkipUniformRegions, SkipRegionalBranches); } static void addRegionIntoQueue(Region &R, std::vector<Region *> &Regions) {