Skip to content

Commit

Permalink
modifying jit lu decomp
Browse files Browse the repository at this point in the history
  • Loading branch information
K20shores committed Sep 20, 2024
1 parent 4193382 commit e8e08ef
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
3 changes: 2 additions & 1 deletion docker/Dockerfile.llvm
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
FROM fedora:37
# We need llvm version 15.0.7

RUN dnf -y update \
&& dnf -y install \
Expand All @@ -9,7 +10,7 @@ RUN dnf -y update \
git \
make \
zlib-devel \
llvm-devel \
llvm-devel \
openmpi-devel \
valgrind \
&& dnf clean all
Expand Down
20 changes: 18 additions & 2 deletions include/micm/jit/solver/jit_lu_decomposition.inl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ namespace micm
auto lki_nkj = lki_nkj_.begin();
auto lkj_uji = lkj_uji_.begin();
auto uii = uii_.begin();

llvm::Value *sentinel_value =
llvm::ConstantInt::get(*(func.context_), llvm::APInt(64, std::numeric_limits<std::size_t>::max()));

for (auto &inLU : niLU_)
{
// Upper triangular matrix
Expand All @@ -83,10 +87,16 @@ namespace micm
llvm::Value *A_ptr_index[1];
A_ptr_index[0] = func.builder_->CreateNSWAdd(loop.index_, iAf);
llvm::Value *A_val = func.GetArrayElement(func.arguments_[0], A_ptr_index, JitType::Double);

// Check for sentinel value
auto check_sentinel = func.builder_->CreateICmpEQ(iAf, sentinel_value);
auto A_val_not_zero = func.builder_->CreateSelect(
check_sentinel, llvm::ConstantFP::get(*(func.context_), llvm::APFloat(0.0)), A_val);

llvm::Value *iUf = llvm::ConstantInt::get(*(func.context_), llvm::APInt(64, uik_nkj->first));
llvm::Value *U_ptr_index[1];
U_ptr_index[0] = func.builder_->CreateNSWAdd(loop.index_, iUf);
func.SetArrayElement(func.arguments_[2], U_ptr_index, JitType::Double, A_val);
func.SetArrayElement(func.arguments_[2], U_ptr_index, JitType::Double, A_val_not_zero);
func.EndLoop(loop);
}
for (std::size_t ikj = 0; ikj < uik_nkj->second; ++ikj)
Expand Down Expand Up @@ -131,10 +141,16 @@ namespace micm
llvm::Value *A_ptr_index[1];
A_ptr_index[0] = func.builder_->CreateNSWAdd(loop.index_, iAf);
llvm::Value *A_val = func.GetArrayElement(func.arguments_[0], A_ptr_index, JitType::Double);

// Check for sentinel value
auto check_sentinel = func.builder_->CreateICmpEQ(iAf, sentinel_value);
auto A_val_not_zero = func.builder_->CreateSelect(
check_sentinel, llvm::ConstantFP::get(*(func.context_), llvm::APFloat(0.0)), A_val);

llvm::Value *iLf = llvm::ConstantInt::get(*(func.context_), llvm::APInt(64, lki_nkj->first));
llvm::Value *L_ptr_index[1];
L_ptr_index[0] = func.builder_->CreateNSWAdd(loop.index_, iLf);
func.SetArrayElement(func.arguments_[1], L_ptr_index, JitType::Double, A_val);
func.SetArrayElement(func.arguments_[1], L_ptr_index, JitType::Double, A_val_not_zero);
func.EndLoop(loop);
}
for (std::size_t ikj = 0; ikj < lki_nkj->second; ++ikj)
Expand Down

0 comments on commit e8e08ef

Please sign in to comment.