From 17398059d5e257c359060d1369ee86ba71e80272 Mon Sep 17 00:00:00 2001 From: Francois Coppens Date: Wed, 1 Mar 2023 18:12:42 +0100 Subject: [PATCH] Woodbury 2x2 doc kernel passed test. --- org/qmckl_sherman_morrison_woodbury.org | 112 +++++++++++++++++++----- 1 file changed, 91 insertions(+), 21 deletions(-) diff --git a/org/qmckl_sherman_morrison_woodbury.org b/org/qmckl_sherman_morrison_woodbury.org index a10cdc6..37785f8 100644 --- a/org/qmckl_sherman_morrison_woodbury.org +++ b/org/qmckl_sherman_morrison_woodbury.org @@ -1423,7 +1423,7 @@ with Sherman-Morrison and update splitting. Please look at the performance recco *** Introduction The Woodbury 2x2 kernel. It is used to apply two rank-1 updates at once. The formula used in -this algorithm is called the Woodbury Matrix Identity +this algorithm is called the Woodbury Matrix Id \[ (S + U V)^{-1} = S^{-1} - C B^{-1} D \] @@ -1485,12 +1485,15 @@ integer function qmckl_woodbury_2x2_doc_f(& real*8 , intent(inout) :: s_inv(dim * lds) real*8 , intent(inout) :: determinant - real*8 , dimension(lds, 2) :: Updates - real*8 , dimension(dim, lds) :: Inverse - real*8 , dimension(dim, 2) :: C - real*8 , dimension(2, dim) :: D - real*8 :: denominator, idenominator, update - integer*8 :: i, j, l, row + integer*8 , dimension(2, dim) :: V + integer*8 , dimension(2, 2) :: Id + real*8 , dimension(dim, dim) :: Inverse + real*8 , dimension(dim, 2) :: Updates, C + real*8 , dimension(2, 2) :: D, invD + real*8 , dimension(2, dim) :: E, F + + real*8 :: detD, idenominator, update + integer*8 :: i, j, k, l info = QMCKL_FAILURE @@ -1499,19 +1502,86 @@ integer function qmckl_woodbury_2x2_doc_f(& return endif + ! Construct V(2, dim) matrix + V = 0 + V(1, updates_index(1)) = 1 + V(2, updates_index(2)) = 1 + + ! Construct Id(2, 2) matrix + Id = 0 + Id(1, 1) = 1 + Id(2, 2) = 1 + ! Convert 'upds' and 's_inv' into the more easily readable Fortran ! matrices 'Updates' and 'Inverse'. call convert(upds, s_inv, Updates, Inverse, int(2,8), lds, dim) - ! Compute C(dim,2) = Inverse(dim,dim) x Updates(dim,2) + ! Compute C(dim, 2) = Inverse(dim, dim) x Updates(dim, 2) + C = 0 do i = 1, dim - do j = 1, dim - C(i,1) = C(i,1) + Inverse(1,j) * Updates(j,1) - C(i,2) = C(i,1) + Inverse(1,j) * Updates(j,2) + do j = 1, 2 + do k = 1, dim + C(i, j) = C(i, j) + Inverse(i, k) * Updates(k, j) + end do end do end do + ! Construct matrix D(2, 2) := I(2, 2) + V(2, dim) x C(dim, 2) + D = 0 + do i = 1, 2 + do j = 1, 2 + do k = 2, dim + D(i, j) = D(i, j) + V(i, k) * C(k, j) + end do + end do + end do + D = Id + D + ! Compute determinant := det(D) explicitly + detD = D(1,1) * D(2,2) - D(1,2) * D(2,1) + + ! Return early if det(D) is too small + if (abs(detD) < breakdown) return + + ! Update det(S) + determinant = determinant * detD + + ! Compute inv(D) explicitly + invD(1,1) = D(2,2) + invD(1,2) = - D(1,2) + invD(2,1) = - D(2,1) + invD(2,2) = D(1,1) + invD = invD / detD + + ! Compute E(2, dim) := V(2, dim) x Inverse(dim, dim) + E = 0 + do i = 1, 2 + do j = 1, dim + do k = 1, dim + E(i, j) = E(i, j) + V(i, k) * Inverse(k, j) + end do + end do + end do + + ! Compute F(2, dim) := invD(2, 2) x E(2, dim) + F = 0 + do i = 1, 2 + do j = 1, dim + do k = 1, 2 + F(i, j) = F(i, j) + invD(i, k) * E(k, j) + end do + end do + end do + + ! Compute Inverse(dim, dim) := Inverse(dim, dim) - C(dim, 2) x F(2, dim) + do i = 1, dim + do j = 1, dim + do k = 1, 2 + Inverse(i, j) = Inverse(i, j) - C(i, k) * F(k, j) + end do + end do + end do + ! Copy updated inverse and later updates ! back to s_inv and later_upds call copy_back_inv(Inverse, s_inv, lds, dim) @@ -1830,16 +1900,7 @@ qmckl_exit_code qmckl_woodbury_2x2(const qmckl_context context, determinant); } #else - // return qmckl_woodbury_2x2_doc( - // context, - // LDS, - // Dim, - // Updates, - // Updates_index, - // breakdown, - // Slater_inv, - // determinant); - return qmckl_woodbury_2x2_hpc( + return qmckl_woodbury_2x2_doc( context, LDS, Dim, @@ -1848,6 +1909,15 @@ qmckl_exit_code qmckl_woodbury_2x2(const qmckl_context context, breakdown, Slater_inv, determinant); + // return qmckl_woodbury_2x2_hpc( + // context, + // LDS, + // Dim, + // Updates, + // Updates_index, + // breakdown, + // Slater_inv, + // determinant); #endif return QMCKL_FAILURE;