From 7fe952c13c5571019afcfa58850b8fdb8568657a Mon Sep 17 00:00:00 2001
From: Blaise Tine <tinebp@yahoo.com>
Date: Mon, 10 Jun 2024 03:24:33 -0700
Subject: [PATCH] minor fixes

---
 llvm/lib/Target/RISCV/RISCVTargetMachine.cpp  | 36 ++++----
 .../Target/RISCV/VortexBranchDivergence.cpp   | 82 +++++++++++--------
 2 files changed, 66 insertions(+), 52 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp b/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
index 6b1e76192e72..a45c64500d5b 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
@@ -128,13 +128,13 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeRISCVTarget() {
   initializeRISCVPreLegalizerCombinerPass(*PR);
   initializeRISCVPostLegalizerCombinerPass(*PR);
   initializeKCFIPass(*PR);
-  if (VortexBranchDivergenceMode != 0) {
-    gVortexBranchDivergenceMode = 1;
+  gVortexBranchDivergenceMode = VortexBranchDivergenceMode;
+  if (gVortexBranchDivergenceMode != 0) {
     initializeVortexBranchDivergence0Pass(*PR);
     initializeVortexBranchDivergence1Pass(*PR);
     initializeVortexBranchDivergence2Pass(*PR);
   }
-  if (VortexKernelSchedulerMode != 0) {
+  if (gVortexBranchDivergenceMode != 0) {
     initializeVortexIntrinsicFuncLoweringPass(*PR);
   }
   initializeRISCVDeadRegisterDefinitionsPass(*PR);
@@ -193,8 +193,8 @@ RISCVTargetMachine::RISCVTargetMachine(const Target &T, const Triple &TT,
   setMachineOutliner(true);
   setSupportsDefaultOutlining(true);
 
-  auto isVortex = FS.contains("vortex");
-  if (isVortex) {
+  if (FS.contains("vortex")
+   && gVortexBranchDivergenceMode != 0) {
    setRequiresStructuredCFG(true);
   }
 
@@ -269,7 +269,7 @@ RISCVTargetMachine::getSubtargetImpl(const Function &F) const {
       auto TargetABI = RISCVABI::getTargetABI(ABIName);
       if (TargetABI != RISCVABI::ABI_Unknown &&
           ModuleTargetABI->getString() != ABIName) {
-        report_fatal_error("-target-abi option != target-abi module flag");
+        report_fatal_error("-target-abi option != target-abi module flag: " + ModuleTargetABI->getString() + " vs " + ABIName);
       }
       ABIName = ModuleTargetABI->getString();
     }
@@ -478,7 +478,7 @@ bool RISCVPassConfig::addPreISel() {
   }
 
   if (TM->getTargetFeatureString().contains("vortex")) {
-    if (VortexBranchDivergenceMode != 0) {
+    if (gVortexBranchDivergenceMode != 0) {
       addPass(createLowerSwitchPass());
       addPass(createCFGSimplificationPass());
       addPass(createFlattenCFGPass());
@@ -487,8 +487,8 @@ bool RISCVPassConfig::addPreISel() {
       addPass(createFixIrreduciblePass());
       addPass(createSinkingPass());
       addPass(createVortexBranchDivergence0Pass());
-      addPass(createStructurizeCFGPass(true, (VortexBranchDivergenceMode == 1)));
-      addPass(createVortexBranchDivergence1Pass(VortexBranchDivergenceMode));
+      addPass(createStructurizeCFGPass(true, (gVortexBranchDivergenceMode == 1)));
+      addPass(createVortexBranchDivergence1Pass(gVortexBranchDivergenceMode));
     }
     if (VortexKernelSchedulerMode != 0) {
       addPass(createVortexIntrinsicFuncLoweringPass());
@@ -582,7 +582,7 @@ void RISCVPassConfig::addPreEmitPass2() {
   }));
 
   if (TM->getTargetFeatureString().contains("vortex")
-   && VortexBranchDivergenceMode != 0) {
+   && gVortexBranchDivergenceMode != 0) {
     addPass(createVortexBranchDivergence2Pass(1));
   }
 }
@@ -609,7 +609,7 @@ void RISCVPassConfig::addPreRegAlloc() {
   addPass(createRISCVInsertWriteVXRMPass());
 
   if (TM->getTargetFeatureString().contains("vortex")
-   && VortexBranchDivergenceMode != 0) {
+   && gVortexBranchDivergenceMode != 0) {
     addPass(createVortexBranchDivergence2Pass(0));
   }
 
@@ -640,12 +640,14 @@ void RISCVTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
                                                  OptimizationLevel Level) {
     LPM.addPass(LoopIdiomVectorizePass(LoopIdiomVectorizeStyle::Predicated));
   });
-  PB.registerPipelineStartEPCallback(
-    [this](ModulePassManager &PM, OptimizationLevel Level) {
-      FunctionPassManager FPM;
-      FPM.addPass(vortex::UniformAnnotationPass());
-      PM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
-    });
+  if (gVortexBranchDivergenceMode != 0) {
+    PB.registerPipelineStartEPCallback(
+      [this](ModulePassManager &PM, OptimizationLevel Level) {
+        FunctionPassManager FPM;
+        FPM.addPass(vortex::UniformAnnotationPass());
+        PM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
+      });
+  }
 }
 
 yaml::MachineFunctionInfo *
diff --git a/llvm/lib/Target/RISCV/VortexBranchDivergence.cpp b/llvm/lib/Target/RISCV/VortexBranchDivergence.cpp
index 1bd51712acab..998dc96147c7 100644
--- a/llvm/lib/Target/RISCV/VortexBranchDivergence.cpp
+++ b/llvm/lib/Target/RISCV/VortexBranchDivergence.cpp
@@ -804,6 +804,7 @@ void VortexBranchDivergence1::processLoops(LLVMContext* context, Function* funct
   for (auto it = loops_.rbegin(), ite = loops_.rend(); it != ite; ++it) {
     auto loop = *it;
     auto header = loop->getHeader();
+    assert(header);
 
     auto preheader = loop->getLoopPreheader();
     assert(preheader);
@@ -1063,53 +1064,68 @@ char VortexBranchDivergence2::ID = 0;
 
 PreservedAnalyses UniformAnnotationPass::run(Function &F, FunctionAnalysisManager &AM) {
   bool changed = false;
+
+  std::vector<std::pair<Instruction*, bool>> uniformInsts;
+
   for (auto& BB : F) {
     for (auto& I : BB) {
-      // handle uniform metadata
+      // find uniform metadata
       if (I.getMetadata("vortex.uniform") != nullptr) {
-        IRBuilder<> Builder(I.getNextNode());
-        auto ValueType = I.getType();
-        auto IntrinsicFunc = Intrinsic::getDeclaration(F.getParent(), Intrinsic::riscv_vx_uniform, {ValueType, ValueType});
-        auto CallInst = Builder.CreateCall(IntrinsicFunc, {&I}, ".uniform");
-        I.replaceAllUsesWith(CallInst);
-        changed = true;
-        break;
+        uniformInsts.push_back({&I, false});
+        continue;
       }
-      // handle uniform annotations
+      // find uniform annotations
       if (auto II = dyn_cast<IntrinsicInst>(&I)) {
         if (II->getIntrinsicID() == Intrinsic::var_annotation) {
           auto gv  = dyn_cast<GlobalVariable>(II->getOperand(1));
           auto cda = dyn_cast<ConstantDataArray>(gv->getInitializer());
           if (cda->getAsCString() == "vortex.uniform") {
             auto AnnotatedValue = dyn_cast<AllocaInst>(II->getOperand(0));
-            if (AnnotatedValue == nullptr)
-              continue;
-            std::vector<LoadInst*> loadsToReplace;
-            StoreInst* STore = nullptr;
-            for (auto User : AnnotatedValue->users()) {
-              if (auto LI = dyn_cast<LoadInst>(User))
-                loadsToReplace.push_back(LI);
-              if (auto SI = dyn_cast<StoreInst>(User))
-                STore = SI;
-            }
-            if (STore != nullptr) {
-              IRBuilder<> Builder(STore->getNextNode());
-              auto LoadedValue = Builder.CreateLoad(AnnotatedValue->getAllocatedType(), AnnotatedValue);
-              auto ValueType = LoadedValue->getType();
-              auto IntrinsicFunc = Intrinsic::getDeclaration(F.getParent(), Intrinsic::riscv_vx_uniform, {ValueType, ValueType});
-              auto CallInst = Builder.CreateCall(IntrinsicFunc, {LoadedValue}, ".uniform");
-              for (auto LI : loadsToReplace) {
-                LI->replaceAllUsesWith(CallInst);
-                LI->eraseFromParent();
-              }
-              changed = true;
-              break;
+            if (AnnotatedValue) {
+              uniformInsts.push_back({AnnotatedValue, true});
             }
+            continue;
           }
         }
       }
     }
   }
+
+  for (auto Instr : uniformInsts) {
+    if (Instr.second) {
+      auto AnnotatedValue = reinterpret_cast<AllocaInst*>(Instr.first);
+      std::vector<LoadInst*> loadsToReplace;
+      StoreInst* Store = nullptr;
+      for (auto User : AnnotatedValue->users()) {
+        if (auto LI = dyn_cast<LoadInst>(User))
+          loadsToReplace.push_back(LI);
+        if (auto SI = dyn_cast<StoreInst>(User))
+          Store = SI;
+      }
+      if (Store != nullptr) {
+        IRBuilder<> Builder(Store->getNextNode());
+        auto LoadedValue = Builder.CreateLoad(AnnotatedValue->getAllocatedType(), AnnotatedValue, AnnotatedValue->getName() + ".loaded");
+        auto ValueType = LoadedValue->getType();
+        auto IntrinsicFunc = Intrinsic::getDeclaration(F.getParent(), Intrinsic::riscv_vx_uniform, {ValueType, ValueType});
+        auto CallInst = Builder.CreateCall(IntrinsicFunc, {LoadedValue}, LoadedValue->getName() + ".uniform");
+        for (auto LI : loadsToReplace) {
+          LI->replaceAllUsesWith(CallInst);
+          LI->eraseFromParent();
+        }
+        changed = true;
+      }
+    } else {
+      auto I = Instr.first;
+      IRBuilder<> Builder(I->getNextNode());
+      auto ValueType = I->getType();
+      auto IntrinsicFunc = Intrinsic::getDeclaration(F.getParent(), Intrinsic::riscv_vx_uniform, {ValueType, ValueType});
+      auto CallInst = Builder.CreateCall(IntrinsicFunc, {llvm::UndefValue::get(ValueType)}, I->getName() + ".uniform");
+      I->replaceAllUsesWith(CallInst);
+      CallInst->setArgOperand(0, I);
+      changed = true;
+    }
+  }
+
   return changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
 }
 
@@ -1180,10 +1196,6 @@ bool DivergenceTracker::isSourceOfDivergence(const Value *V) {
     return true;
   }
 
-  // We are certain about intrinsic calls
-  if (isa<IntrinsicInst>(V))
-    return false;
-
   // We conservatively assume function return values are divergent
   if (isa<CallInst>(V)) {
     LLVM_DEBUG(dbgs() << "*** divergent return variable: " << V->getName() << "\n");
-- 
GitLab