diff --git a/llvm/include/llvm/IR/IntrinsicsRISCV.td b/llvm/include/llvm/IR/IntrinsicsRISCV.td
index eb4381b301b22ef2a4182b6f765e13a881aa72e7..dd140ff56171a1381a1d1fe59384d9799a621956 100644
--- a/llvm/include/llvm/IR/IntrinsicsRISCV.td
+++ b/llvm/include/llvm/IR/IntrinsicsRISCV.td
@@ -74,6 +74,8 @@ let TargetPrefix = "riscv" in {
 
   // Vortex extension
 
+  def int_riscv_vx_uniform : Intrinsic<[llvm_any_ty], [llvm_any_ty], [IntrNoMem]>;
+
   class VortexCSRIntrinsicsImpl<LLVMType itype>
       : Intrinsic<[itype], [], [IntrNoMem, IntrWillReturn]>;
 
diff --git a/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp b/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
index 7491641fa78aafaf6fce0430e26e149575f2dc50..6b1e76192e72859d74015a13fff6c038cf16f861 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
@@ -40,6 +40,7 @@
 #include "llvm/Transforms/IPO.h"
 #include "llvm/Transforms/Scalar.h"
 #include "llvm/Transforms/Vectorize/LoopIdiomVectorize.h"
+#include "VortexBranchDivergence.h"
 #include <optional>
 using namespace llvm;
 
@@ -446,6 +447,8 @@ bool RISCVPassConfig::addRegAssignAndRewriteOptimized() {
 }
 
 void RISCVPassConfig::addIRPasses() {
+  //insertPass(Annotation2MetadataPass::ID(), &VortexBranchDivergence0ID);
+
   addPass(createAtomicExpandLegacyPass());
 
   if (getOptLevel() != CodeGenOptLevel::None) {
@@ -476,13 +479,13 @@ bool RISCVPassConfig::addPreISel() {
 
   if (TM->getTargetFeatureString().contains("vortex")) {
     if (VortexBranchDivergenceMode != 0) {
+      addPass(createLowerSwitchPass());
       addPass(createCFGSimplificationPass());
+      addPass(createFlattenCFGPass());
       addPass(createLoopSimplifyPass());
-      addPass(createFixIrreduciblePass());
       addPass(createUnifyLoopExitsPass());
+      addPass(createFixIrreduciblePass());
       addPass(createSinkingPass());
-      addPass(createLowerSwitchPass());
-      addPass(createFlattenCFGPass());
       addPass(createVortexBranchDivergence0Pass());
       addPass(createStructurizeCFGPass(true, (VortexBranchDivergenceMode == 1)));
       addPass(createVortexBranchDivergence1Pass(VortexBranchDivergenceMode));
@@ -637,6 +640,12 @@ void RISCVTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
                                                  OptimizationLevel Level) {
     LPM.addPass(LoopIdiomVectorizePass(LoopIdiomVectorizeStyle::Predicated));
   });
+  PB.registerPipelineStartEPCallback(
+    [this](ModulePassManager &PM, OptimizationLevel Level) {
+      FunctionPassManager FPM;
+      FPM.addPass(vortex::UniformAnnotationPass());
+      PM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
+    });
 }
 
 yaml::MachineFunctionInfo *
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index bd78f3e3dbb053a127d9447c6ae1c79c5bc2e0a9..912f5eaafade03d8e5ecddfc11d937940748e375 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -1974,5 +1974,9 @@ bool RISCVTTIImpl::hasBranchDivergence(const Function *F) {
 }
 
 bool RISCVTTIImpl::isSourceOfDivergence(const Value *V) {
-  return divergence_tracker_.eval(V);
+  return divergence_tracker_.isSourceOfDivergence(V);
+}
+
+bool RISCVTTIImpl::isAlwaysUniform(const Value *V) {
+  return divergence_tracker_.isAlwaysUniform(V);
 }
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 9fd0df19c1f92491eae8a61210091a95977a2857..fc51df136a509786819cbbb869253eb61aa7e078 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -103,6 +103,7 @@ public:
 
   // Vortex extension
   bool isSourceOfDivergence(const Value *V);
+  bool isAlwaysUniform(const Value *V);
   bool hasBranchDivergence(const Function *F);
 
   bool shouldExpandReduction(const IntrinsicInst *II) const;
diff --git a/llvm/lib/Target/RISCV/VortexBranchDivergence.cpp b/llvm/lib/Target/RISCV/VortexBranchDivergence.cpp
index 7763dae3c9aa7abb7a1592b8dbb0bed28b21a8b9..1bd51712acabc4157594a1c632b0174a10a1dd5d 100644
--- a/llvm/lib/Target/RISCV/VortexBranchDivergence.cpp
+++ b/llvm/lib/Target/RISCV/VortexBranchDivergence.cpp
@@ -415,9 +415,6 @@ bool VortexBranchDivergence0::runOnFunction(Function &F) {
   const auto &TM = TPC.getTM<TargetMachine>();
   const auto &ST = TM.getSubtarget<RISCVSubtarget>(F);
 
-  // Check if the Vortex extension is enabled
-  assert(ST.hasVendorXVortex());
-
   LLVM_DEBUG(dbgs() << "*** Vortex Divergent Branch Handling Pass0 ***\n");
 
   LLVM_DEBUG(dbgs() << "*** before Pass0 changes!\n" << F << "\n");
@@ -426,7 +423,9 @@ bool VortexBranchDivergence0::runOnFunction(Function &F) {
 
   bool changed = false;
 
-  {
+  bool hasStdExtZicond = ST.hasStdExtZicond();
+
+  if (!hasStdExtZicond) {
     // Lower Select instructions into standard if-then-else branches
     SmallVector <SelectInst*, 4> selects;
 
@@ -453,7 +452,7 @@ bool VortexBranchDivergence0::runOnFunction(Function &F) {
     }
   }
 
-  {
+  if (!hasStdExtZicond) {
     // Lower Min/Max intrinsics into standard if-then-else branches
     SmallVector <MinMaxIntrinsic*, 4> MMs;
 
@@ -686,13 +685,12 @@ bool VortexBranchDivergence1::runOnFunction(Function &F) {
   const auto &TM = TPC.getTM<TargetMachine>();
   const auto &ST = TM.getSubtarget<RISCVSubtarget>(F);
 
-  // Check if the Vortex extension is enabled
-  assert(ST.hasVendorXVortex());
-
   this->initialize(F, ST);
 
   auto &Context = F.getContext();
 
+  LLVM_DEBUG(UA_->print(dbgs()));
+
   LLVM_DEBUG(dbgs() << "*** Region info:\n");
   LLVM_DEBUG(RI_->getTopLevelRegion()->dump());
   LLVM_DEBUG(dbgs() << "\n");
@@ -784,6 +782,18 @@ bool VortexBranchDivergence1::runOnFunction(Function &F) {
     LLVM_DEBUG(dbgs() << "*** after changes!\n" << F << "\n");
   }
 
+  // remove uniform intrinsics
+  for (auto iter = inst_begin(F), iterE = inst_end(F); iter != iterE;) {
+    auto& I = *iter++;
+    if (auto II = dyn_cast<IntrinsicInst>(&I)) {
+      if (II->getIntrinsicID() == Intrinsic::riscv_vx_uniform) {
+        auto src = II->getOperand(0);
+        II->replaceAllUsesWith(src);
+        II->eraseFromParent();
+      }
+    }
+  }
+
   return changed;
 }
 
@@ -924,176 +934,6 @@ bool VortexBranchDivergence1::isUniform(Instruction *I) {
 
 ///////////////////////////////////////////////////////////////////////////////
 
-DivergenceTracker::DivergenceTracker(const Function &function)
-  : function_(&function)
-  , initialized_(false)
-{}
-
-void DivergenceTracker::initialize() {
-  LLVM_DEBUG(dbgs() << "*** DivergenceTracker::initialize(): " << function_->getName() << "\n");
-
-  initialized_ = true;
-
-  DenseSet<const Value *> dv_annotations;
-  DenseSet<const Value *> uv_annotations;
-
-  auto module = function_->getParent();
-
-  // Mark all TLS globals as divergent
-  for (auto& GV : module->globals()) {
-    if (GV.isThreadLocal()) {
-      if (dv_nodes_.insert(&GV).second) {
-        LLVM_DEBUG(dbgs() << "*** divergent global variable: " << GV.getName() << "\n");
-      }
-    }
-  }
-
-  for (auto& BB : *function_) {
-    for (auto& I : BB) {
-      //LLVM_DEBUG(dbgs() << "*** instruction: opcode=" << I.getOpcodeName() << ", name=" << I.getName() << "\n");
-      if (I.getMetadata("vortex.uniform") != NULL) {
-        uv_annotations.insert(&I);
-        uv_nodes_.insert(&I);
-        LLVM_DEBUG(dbgs() << "*** uniform annotation: " << I.getName() << "\n");
-      } else
-      if (I.getMetadata("vortex.divergent") != NULL) {
-        dv_annotations.insert(&I);
-        dv_nodes_.insert(&I);
-        LLVM_DEBUG(dbgs() << "*** divergent annotation: " << I.getName() << "\n");
-      } else
-      if (auto II = dyn_cast<llvm::IntrinsicInst>(&I)) {
-        if (II->getIntrinsicID() == llvm::Intrinsic::var_annotation) {
-          auto gv  = dyn_cast<GlobalVariable>(II->getOperand(1));
-          auto cda = dyn_cast<ConstantDataArray>(gv->getInitializer());
-          if (cda->getAsCString() == "vortex.uniform") {
-            Value* var_src = nullptr;
-            auto var = II->getOperand(0);
-            if (auto AI = dyn_cast<AllocaInst>(var)) {
-              var_src = AI;
-              LLVM_DEBUG(dbgs() << "*** uniform annotation: " << AI->getName() << ".src(" << var_src << ")\n");
-            } else
-            if (auto CI = dyn_cast<CastInst>(var)) {
-              var_src = CI->getOperand(0);
-              LLVM_DEBUG(dbgs() << "*** uniform annotation: " << CI->getName() << ".src(" << var_src << ")\n");
-            }
-            uv_annotations.insert(var_src);
-            uv_nodes_.insert(var_src);
-          } else
-          if (cda->getAsCString() == "vortex.divergent") {
-            Value* var_src = nullptr;
-            auto var = II->getOperand(0);
-            if (auto AI = dyn_cast<AllocaInst>(var)) {
-              var_src = AI;
-              LLVM_DEBUG(dbgs() << "*** divergent annotation: " << AI->getName() << ".src(" << var_src << "\n");
-            } else
-            if (auto CI = dyn_cast<CastInst>(var)) {
-              var_src = CI->getOperand(0);
-              LLVM_DEBUG(dbgs() << "*** divergent annotation: " << CI->getName() << ".src(" << var_src << "\n");
-            }
-            dv_annotations.insert(var_src);
-            dv_nodes_.insert(var_src);
-          }
-        }
-      }
-    }
-  }
-
-  // Mark the value of divergent stores as divergent
-  for (auto& BB : *function_) {
-    for (auto& I : BB) {
-      if (auto GE = dyn_cast<GetElementPtrInst>(&I)) {
-        auto addr = GE->getPointerOperand();
-        if (uv_annotations.count(addr) != 0) {
-          LLVM_DEBUG(dbgs() << "*** uniform annotation: " << GE->getName() << "\n");
-          uv_nodes_.insert(GE);
-        } else
-        if (dv_annotations.count(addr) != 0) {
-          LLVM_DEBUG(dbgs() << "*** divergent annotation: " << GE->getName() << "\n");
-          dv_nodes_.insert(GE);
-        }
-      } else
-      if (auto SI = dyn_cast<StoreInst>(&I)) {
-        auto addr = SI->getPointerOperand();
-        if (uv_annotations.count(addr) != 0) {
-          auto value = SI->getValueOperand();
-          if (auto CI = dyn_cast<CastInst>(value)) {
-            LLVM_DEBUG(dbgs() << "*** uniform annotation: " << CI->getName() << ".src\n");
-            auto src = CI->getOperand(0);
-            uv_nodes_.insert(src);
-          } else {
-            LLVM_DEBUG(dbgs() << "*** uniform annotation: " << SI->getName() << ".value\n");
-            uv_nodes_.insert(value);
-          }
-        } else
-        if (dv_annotations.count(addr) != 0) {
-          auto value = SI->getValueOperand();
-          if (auto CI = dyn_cast<CastInst>(value)) {
-            LLVM_DEBUG(dbgs() << "*** divergent annotation: " << CI->getName() << ".src\n");
-            auto src = CI->getOperand(0);
-            dv_nodes_.insert(src);
-          } else {
-            LLVM_DEBUG(dbgs() << "*** divergent annotation: " << SI->getName() << ".value\n");
-            dv_nodes_.insert(value);
-          }
-        }
-      }
-    }
-  }
-}
-
-bool DivergenceTracker::eval(const Value *V) {
-  if (!initialized_) {
-    this->initialize();
-  }
-
-  // Mark annotated uniform variables
-  if (uv_nodes_.count(V) != 0) {
-    LLVM_DEBUG(dbgs() << "*** uniform annotated variable: " << V->getName() << "\n");
-    return false;
-  }
-
-  // Mark annotated divergent variables
-  if (dv_nodes_.count(V) != 0) {
-    LLVM_DEBUG(dbgs() << "*** divergent annotated variable: " << V->getName() << "\n");
-    return true;
-  }
-
-  // We conservatively assume all function arguments to potentially be divergent
-  if (isa<Argument>(V)) {
-    LLVM_DEBUG(dbgs() << "*** divergent function argument: " << V->getName() << "\n");
-    return true;
-  }
-
-  // We conservatively assume function return values are divergent
-  if (isa<CallInst>(V)) {
-    LLVM_DEBUG(dbgs() << "*** divergent return variable: " << V->getName() << "\n");
-    return true;
-  }
-
-  // Atomics are divergent because they are executed sequentially: when an
-  // atomic operation refers to the same address in each thread, then each
-  // thread after the first sees the value written by the previous thread as
-  // original value.
-  if (isa<AtomicRMWInst>(V)
-   || isa<AtomicCmpXchgInst>(V)) {
-    LLVM_DEBUG(dbgs() << "*** divergent atomic variable: " << V->getName() << "\n");
-    return true;
-  }
-
-  // Mark loads from divergent addresses as divergent
-  if (auto LD = dyn_cast<LoadInst>(V)) {
-    auto addr = LD->getPointerOperand();
-    if (dv_nodes_.count(addr) != 0) {
-      LLVM_DEBUG(dbgs() << "*** divergent load variable: " << V->getName() << "\n");
-      return true;
-    }
-  }
-
-  return false;
-}
-
-///////////////////////////////////////////////////////////////////////////////
-
 VortexBranchDivergence2::VortexBranchDivergence2(int PassMode)
   : MachineFunctionPass(ID)
   , PassMode_(PassMode)
@@ -1120,9 +960,6 @@ bool VortexBranchDivergence2::runOnMachineFunction(MachineFunction &MF) {
   auto TII = ST.getInstrInfo();
   auto& MRI = MF.getRegInfo();
 
-  // Check if the Vortex extension is enabled
-  assert(ST.hasVendorXVortex());
-
   bool Changed = false;
 
   switch (PassMode_) {
@@ -1222,4 +1059,160 @@ StringRef VortexBranchDivergence2::getPassName() const {
 
 char VortexBranchDivergence2::ID = 0;
 
+///////////////////////////////////////////////////////////////////////////////
+
+PreservedAnalyses UniformAnnotationPass::run(Function &F, FunctionAnalysisManager &AM) {
+  bool changed = false;
+  for (auto& BB : F) {
+    for (auto& I : BB) {
+      // handle uniform metadata
+      if (I.getMetadata("vortex.uniform") != nullptr) {
+        IRBuilder<> Builder(I.getNextNode());
+        auto ValueType = I.getType();
+        auto IntrinsicFunc = Intrinsic::getDeclaration(F.getParent(), Intrinsic::riscv_vx_uniform, {ValueType, ValueType});
+        auto CallInst = Builder.CreateCall(IntrinsicFunc, {&I}, ".uniform");
+        I.replaceAllUsesWith(CallInst);
+        changed = true;
+        break;
+      }
+      // handle uniform annotations
+      if (auto II = dyn_cast<IntrinsicInst>(&I)) {
+        if (II->getIntrinsicID() == Intrinsic::var_annotation) {
+          auto gv  = dyn_cast<GlobalVariable>(II->getOperand(1));
+          auto cda = dyn_cast<ConstantDataArray>(gv->getInitializer());
+          if (cda->getAsCString() == "vortex.uniform") {
+            auto AnnotatedValue = dyn_cast<AllocaInst>(II->getOperand(0));
+            if (AnnotatedValue == nullptr)
+              continue;
+            std::vector<LoadInst*> loadsToReplace;
+            StoreInst* STore = nullptr;
+            for (auto User : AnnotatedValue->users()) {
+              if (auto LI = dyn_cast<LoadInst>(User))
+                loadsToReplace.push_back(LI);
+              if (auto SI = dyn_cast<StoreInst>(User))
+                STore = SI;
+            }
+            if (STore != nullptr) {
+              IRBuilder<> Builder(STore->getNextNode());
+              auto LoadedValue = Builder.CreateLoad(AnnotatedValue->getAllocatedType(), AnnotatedValue);
+              auto ValueType = LoadedValue->getType();
+              auto IntrinsicFunc = Intrinsic::getDeclaration(F.getParent(), Intrinsic::riscv_vx_uniform, {ValueType, ValueType});
+              auto CallInst = Builder.CreateCall(IntrinsicFunc, {LoadedValue}, ".uniform");
+              for (auto LI : loadsToReplace) {
+                LI->replaceAllUsesWith(CallInst);
+                LI->eraseFromParent();
+              }
+              changed = true;
+              break;
+            }
+          }
+        }
+      }
+    }
+  }
+  return changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
+}
+
+///////////////////////////////////////////////////////////////////////////////
+
+DivergenceTracker::DivergenceTracker(const Function &function)
+  : function_(&function)
+  , initialized_(false)
+{}
+
+void DivergenceTracker::initialize() {
+  LLVM_DEBUG(dbgs() << "*** DivergenceTracker::initialize(): " << function_->getName() << "\n");
+
+  initialized_ = true;
+
+  auto module = function_->getParent();
+
+  // Mark all TLS globals as divergent
+  for (auto& GV : module->globals()) {
+    if (GV.isThreadLocal()) {
+      if (dv_nodes_.insert(&GV).second) {
+        LLVM_DEBUG(dbgs() << "*** divergent global variable: " << GV.getName() << "\n");
+      }
+    }
+  }
+
+  // Mark uniform intrinsic calls as uniform
+  for (auto& BB : *function_) {
+    for (auto& I : BB) {
+      if (auto II = dyn_cast<IntrinsicInst>(&I)) {
+        if (II->getIntrinsicID() == Intrinsic::riscv_vx_uniform) {
+          LLVM_DEBUG(dbgs() << "*** uniform intrinsic variable: " << I.getName() << "\n");
+          uv_nodes_.insert(&I);
+        }
+      }
+    }
+  }
+}
+
+bool DivergenceTracker::isSourceOfDivergence(const Value *V) {
+  if (!initialized_) {
+    this->initialize();
+  }
+
+  // check if node always uniform
+  if (this->isAlwaysUniform(V))
+    return false;
+
+  // Mark annotated divergent variables
+  if (dv_nodes_.count(V) != 0) {
+    LLVM_DEBUG(dbgs() << "*** divergent annotated variable: " << V->getName() << "\n");
+    return true;
+  }
+
+  // Atomics are divergent because they are executed sequentially: when an
+  // atomic operation refers to the same address in each thread, then each
+  // thread after the first sees the value written by the previous thread as
+  // original value.
+  if (isa<AtomicRMWInst>(V)
+   || isa<AtomicCmpXchgInst>(V)) {
+    LLVM_DEBUG(dbgs() << "*** divergent atomic variable: " << V->getName() << "\n");
+    return true;
+  }
+
+  // We conservatively assume all function arguments to potentially be divergent
+  if (isa<Argument>(V)) {
+    LLVM_DEBUG(dbgs() << "*** divergent function argument: " << V->getName() << "\n");
+    return true;
+  }
+
+  // We are certain about intrinsic calls
+  if (isa<IntrinsicInst>(V))
+    return false;
+
+  // We conservatively assume function return values are divergent
+  if (isa<CallInst>(V)) {
+    LLVM_DEBUG(dbgs() << "*** divergent return variable: " << V->getName() << "\n");
+    return true;
+  }
+
+  // We conservatively assume function return values are divergent
+  if (isa<InvokeInst>(V)) {
+    LLVM_DEBUG(dbgs() << "*** divergent return variable: " << V->getName() << "\n");
+    return true;
+  }
+
+  // are are not certain about the rest!
+  return false;
+}
+
+bool DivergenceTracker::isAlwaysUniform(const Value *V) {
+  if (!initialized_) {
+    this->initialize();
+  }
+
+  // Mark annotated uniform variables
+  if (uv_nodes_.count(V) != 0) {
+    LLVM_DEBUG(dbgs() << "*** uniform annotated variable: " << V->getName() << "\n");
+    return true;
+  }
+
+  // are are not certain about the rest!
+  return false;
+}
+
 } // vortex
diff --git a/llvm/lib/Target/RISCV/VortexBranchDivergence.h b/llvm/lib/Target/RISCV/VortexBranchDivergence.h
index f0848d9987669243c788fbdb1118262b0985ea81..b7b7f179faef9cfc2caa2b8d8eed252398f89035 100644
--- a/llvm/lib/Target/RISCV/VortexBranchDivergence.h
+++ b/llvm/lib/Target/RISCV/VortexBranchDivergence.h
@@ -1,22 +1,31 @@
+#pragma once
+
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/IR/Value.h"
+#include "llvm/Passes/PassBuilder.h"
 
 namespace vortex {
 using namespace llvm;
 
+struct UniformAnnotationPass : PassInfoMixin<UniformAnnotationPass> {
+  PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
+};
+
 class DivergenceTracker {
 public:
-    DivergenceTracker(const Function &function);
+  DivergenceTracker(const Function &function);
+
+  bool isSourceOfDivergence(const Value *V);
 
-    bool eval(const Value *V);
+  bool isAlwaysUniform(const Value *V);
 
 private:
-    void initialize();
+  void initialize();
 
-    DenseSet<const Value *> dv_nodes_;
-    DenseSet<const Value *> uv_nodes_;
-    const Function* function_;
-    bool initialized_;
+  DenseSet<const Value *> dv_nodes_;
+  DenseSet<const Value *> uv_nodes_;
+  const Function* function_;
+  bool initialized_;
 };
 
 }
\ No newline at end of file