diff --git a/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp b/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp index 6b1e76192e72859d74015a13fff6c038cf16f861..a45c64500d5b0fe063c0d9bec049499f35694336 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp +++ b/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp @@ -128,13 +128,13 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeRISCVTarget() { initializeRISCVPreLegalizerCombinerPass(*PR); initializeRISCVPostLegalizerCombinerPass(*PR); initializeKCFIPass(*PR); - if (VortexBranchDivergenceMode != 0) { - gVortexBranchDivergenceMode = 1; + gVortexBranchDivergenceMode = VortexBranchDivergenceMode; + if (gVortexBranchDivergenceMode != 0) { initializeVortexBranchDivergence0Pass(*PR); initializeVortexBranchDivergence1Pass(*PR); initializeVortexBranchDivergence2Pass(*PR); } - if (VortexKernelSchedulerMode != 0) { + if (gVortexBranchDivergenceMode != 0) { initializeVortexIntrinsicFuncLoweringPass(*PR); } initializeRISCVDeadRegisterDefinitionsPass(*PR); @@ -193,8 +193,8 @@ RISCVTargetMachine::RISCVTargetMachine(const Target &T, const Triple &TT, setMachineOutliner(true); setSupportsDefaultOutlining(true); - auto isVortex = FS.contains("vortex"); - if (isVortex) { + if (FS.contains("vortex") + && gVortexBranchDivergenceMode != 0) { setRequiresStructuredCFG(true); } @@ -269,7 +269,7 @@ RISCVTargetMachine::getSubtargetImpl(const Function &F) const { auto TargetABI = RISCVABI::getTargetABI(ABIName); if (TargetABI != RISCVABI::ABI_Unknown && ModuleTargetABI->getString() != ABIName) { - report_fatal_error("-target-abi option != target-abi module flag"); + report_fatal_error("-target-abi option != target-abi module flag: " + ModuleTargetABI->getString() + " vs " + ABIName); } ABIName = ModuleTargetABI->getString(); } @@ -478,7 +478,7 @@ bool RISCVPassConfig::addPreISel() { } if (TM->getTargetFeatureString().contains("vortex")) { - if (VortexBranchDivergenceMode != 0) { + if (gVortexBranchDivergenceMode != 0) { addPass(createLowerSwitchPass()); addPass(createCFGSimplificationPass()); addPass(createFlattenCFGPass()); @@ -487,8 +487,8 @@ bool RISCVPassConfig::addPreISel() { addPass(createFixIrreduciblePass()); addPass(createSinkingPass()); addPass(createVortexBranchDivergence0Pass()); - addPass(createStructurizeCFGPass(true, (VortexBranchDivergenceMode == 1))); - addPass(createVortexBranchDivergence1Pass(VortexBranchDivergenceMode)); + addPass(createStructurizeCFGPass(true, (gVortexBranchDivergenceMode == 1))); + addPass(createVortexBranchDivergence1Pass(gVortexBranchDivergenceMode)); } if (VortexKernelSchedulerMode != 0) { addPass(createVortexIntrinsicFuncLoweringPass()); @@ -582,7 +582,7 @@ void RISCVPassConfig::addPreEmitPass2() { })); if (TM->getTargetFeatureString().contains("vortex") - && VortexBranchDivergenceMode != 0) { + && gVortexBranchDivergenceMode != 0) { addPass(createVortexBranchDivergence2Pass(1)); } } @@ -609,7 +609,7 @@ void RISCVPassConfig::addPreRegAlloc() { addPass(createRISCVInsertWriteVXRMPass()); if (TM->getTargetFeatureString().contains("vortex") - && VortexBranchDivergenceMode != 0) { + && gVortexBranchDivergenceMode != 0) { addPass(createVortexBranchDivergence2Pass(0)); } @@ -640,12 +640,14 @@ void RISCVTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) { OptimizationLevel Level) { LPM.addPass(LoopIdiomVectorizePass(LoopIdiomVectorizeStyle::Predicated)); }); - PB.registerPipelineStartEPCallback( - [this](ModulePassManager &PM, OptimizationLevel Level) { - FunctionPassManager FPM; - FPM.addPass(vortex::UniformAnnotationPass()); - PM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM))); - }); + if (gVortexBranchDivergenceMode != 0) { + PB.registerPipelineStartEPCallback( + [this](ModulePassManager &PM, OptimizationLevel Level) { + FunctionPassManager FPM; + FPM.addPass(vortex::UniformAnnotationPass()); + PM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM))); + }); + } } yaml::MachineFunctionInfo * diff --git a/llvm/lib/Target/RISCV/VortexBranchDivergence.cpp b/llvm/lib/Target/RISCV/VortexBranchDivergence.cpp index 1bd51712acabc4157594a1c632b0174a10a1dd5d..998dc96147c7a190df20959eab20aae9c4951f60 100644 --- a/llvm/lib/Target/RISCV/VortexBranchDivergence.cpp +++ b/llvm/lib/Target/RISCV/VortexBranchDivergence.cpp @@ -804,6 +804,7 @@ void VortexBranchDivergence1::processLoops(LLVMContext* context, Function* funct for (auto it = loops_.rbegin(), ite = loops_.rend(); it != ite; ++it) { auto loop = *it; auto header = loop->getHeader(); + assert(header); auto preheader = loop->getLoopPreheader(); assert(preheader); @@ -1063,53 +1064,68 @@ char VortexBranchDivergence2::ID = 0; PreservedAnalyses UniformAnnotationPass::run(Function &F, FunctionAnalysisManager &AM) { bool changed = false; + + std::vector<std::pair<Instruction*, bool>> uniformInsts; + for (auto& BB : F) { for (auto& I : BB) { - // handle uniform metadata + // find uniform metadata if (I.getMetadata("vortex.uniform") != nullptr) { - IRBuilder<> Builder(I.getNextNode()); - auto ValueType = I.getType(); - auto IntrinsicFunc = Intrinsic::getDeclaration(F.getParent(), Intrinsic::riscv_vx_uniform, {ValueType, ValueType}); - auto CallInst = Builder.CreateCall(IntrinsicFunc, {&I}, ".uniform"); - I.replaceAllUsesWith(CallInst); - changed = true; - break; + uniformInsts.push_back({&I, false}); + continue; } - // handle uniform annotations + // find uniform annotations if (auto II = dyn_cast<IntrinsicInst>(&I)) { if (II->getIntrinsicID() == Intrinsic::var_annotation) { auto gv = dyn_cast<GlobalVariable>(II->getOperand(1)); auto cda = dyn_cast<ConstantDataArray>(gv->getInitializer()); if (cda->getAsCString() == "vortex.uniform") { auto AnnotatedValue = dyn_cast<AllocaInst>(II->getOperand(0)); - if (AnnotatedValue == nullptr) - continue; - std::vector<LoadInst*> loadsToReplace; - StoreInst* STore = nullptr; - for (auto User : AnnotatedValue->users()) { - if (auto LI = dyn_cast<LoadInst>(User)) - loadsToReplace.push_back(LI); - if (auto SI = dyn_cast<StoreInst>(User)) - STore = SI; - } - if (STore != nullptr) { - IRBuilder<> Builder(STore->getNextNode()); - auto LoadedValue = Builder.CreateLoad(AnnotatedValue->getAllocatedType(), AnnotatedValue); - auto ValueType = LoadedValue->getType(); - auto IntrinsicFunc = Intrinsic::getDeclaration(F.getParent(), Intrinsic::riscv_vx_uniform, {ValueType, ValueType}); - auto CallInst = Builder.CreateCall(IntrinsicFunc, {LoadedValue}, ".uniform"); - for (auto LI : loadsToReplace) { - LI->replaceAllUsesWith(CallInst); - LI->eraseFromParent(); - } - changed = true; - break; + if (AnnotatedValue) { + uniformInsts.push_back({AnnotatedValue, true}); } + continue; } } } } } + + for (auto Instr : uniformInsts) { + if (Instr.second) { + auto AnnotatedValue = reinterpret_cast<AllocaInst*>(Instr.first); + std::vector<LoadInst*> loadsToReplace; + StoreInst* Store = nullptr; + for (auto User : AnnotatedValue->users()) { + if (auto LI = dyn_cast<LoadInst>(User)) + loadsToReplace.push_back(LI); + if (auto SI = dyn_cast<StoreInst>(User)) + Store = SI; + } + if (Store != nullptr) { + IRBuilder<> Builder(Store->getNextNode()); + auto LoadedValue = Builder.CreateLoad(AnnotatedValue->getAllocatedType(), AnnotatedValue, AnnotatedValue->getName() + ".loaded"); + auto ValueType = LoadedValue->getType(); + auto IntrinsicFunc = Intrinsic::getDeclaration(F.getParent(), Intrinsic::riscv_vx_uniform, {ValueType, ValueType}); + auto CallInst = Builder.CreateCall(IntrinsicFunc, {LoadedValue}, LoadedValue->getName() + ".uniform"); + for (auto LI : loadsToReplace) { + LI->replaceAllUsesWith(CallInst); + LI->eraseFromParent(); + } + changed = true; + } + } else { + auto I = Instr.first; + IRBuilder<> Builder(I->getNextNode()); + auto ValueType = I->getType(); + auto IntrinsicFunc = Intrinsic::getDeclaration(F.getParent(), Intrinsic::riscv_vx_uniform, {ValueType, ValueType}); + auto CallInst = Builder.CreateCall(IntrinsicFunc, {llvm::UndefValue::get(ValueType)}, I->getName() + ".uniform"); + I->replaceAllUsesWith(CallInst); + CallInst->setArgOperand(0, I); + changed = true; + } + } + return changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); } @@ -1180,10 +1196,6 @@ bool DivergenceTracker::isSourceOfDivergence(const Value *V) { return true; } - // We are certain about intrinsic calls - if (isa<IntrinsicInst>(V)) - return false; - // We conservatively assume function return values are divergent if (isa<CallInst>(V)) { LLVM_DEBUG(dbgs() << "*** divergent return variable: " << V->getName() << "\n");