From f4becac4c01e085e7c28b1b4b954b9ceaf709670 Mon Sep 17 00:00:00 2001 From: Pablo Oliveira Date: Thu, 6 May 2021 10:51:42 +0200 Subject: [PATCH 1/2] First implementation of SM4 SM4: Sherman Morrison, mix between SM3 + SM2 Leave zero denominators for later (SM3), and when none are left then split (SM2) --- include/SM_Standard.hpp | 5 +++ src/SM_Standard.cpp | 71 ++++++++++++++++++++++++++++++++++++++++- src/SM_mod.f90 | 7 ++++ tests/test_h5.cpp | 2 ++ tests/vfc_test_h5.cpp | 2 ++ 5 files changed, 86 insertions(+), 1 deletion(-) diff --git a/include/SM_Standard.hpp b/include/SM_Standard.hpp index 5541517..4bfabd0 100644 --- a/include/SM_Standard.hpp +++ b/include/SM_Standard.hpp @@ -10,3 +10,8 @@ void SM2(double *Slater_inv, unsigned int Dim, unsigned int N_updates, // Sherman Morrison, leaving zero denominators for later void SM3(double *Slater_inv, unsigned int Dim, unsigned int N_updates, double *Updates, unsigned int *Updates_index); + +// Sherman Morrison (SM3+SM2), leaving zero denominators for later (SM3), and +// when none are left falling back on Splitting (SM2) +void SM4(double *Slater_inv, unsigned int Dim, unsigned int N_updates, + double *Updates, unsigned int *Updates_index); diff --git a/src/SM_Standard.cpp b/src/SM_Standard.cpp index b188a2a..5314387 100644 --- a/src/SM_Standard.cpp +++ b/src/SM_Standard.cpp @@ -172,6 +172,70 @@ void SM3(double *Slater_inv, unsigned int Dim, unsigned int N_updates, } } +// Sherman Morrison, mix between SM3 + SM2 +// Leave zero denominators for later (SM3), and when none are left then split (SM2) +void SM4(double *Slater_inv, unsigned int Dim, unsigned int N_updates, + double *Updates, unsigned int *Updates_index) { + std::cerr << "Called SM4 with updates " << N_updates << std::endl; + double C[Dim]; + double D[Dim]; + + double later_updates[Dim * N_updates]; + unsigned int later_index[N_updates]; + unsigned int later = 0; + + unsigned int l = 0; + // For each update + while (l < N_updates) { + // C = A^{-1} x U_l + for (unsigned int i = 0; i < Dim; i++) { + C[i] = 0; + for (unsigned int j = 0; j < Dim; j++) { + C[i] += Slater_inv[i * Dim + j] * Updates[l * Dim + j]; + } + } + + // Denominator + double den = 1 + C[Updates_index[l] - 1]; + if (fabs(den) < threshold()) { + std::cerr << "Breakdown condition triggered at " << Updates_index[l] + << std::endl; + + for (unsigned int j = 0; j < Dim; j++) { + later_updates[later * Dim + j] = Updates[l * Dim + j]; + } + later_index[later] = Updates_index[l]; + later++; + l += 1; + continue; + } + double iden = 1 / den; + + // D = v^T x A^{-1} + for (unsigned int j = 0; j < Dim; j++) { + D[j] = Slater_inv[(Updates_index[l] - 1) * Dim + j]; + } + + // A^{-1} = A^{-1} - C x D / den + for (unsigned int i = 0; i < Dim; i++) { + for (unsigned int j = 0; j < Dim; j++) { + double update = C[i] * D[j] * iden; + Slater_inv[i * Dim + j] -= update; + } + } + l += 1; + } + + // If all the updates have failed, fall back on splitting (SM2) + if (later == N_updates) { + SM2(Slater_inv, Dim, later, later_updates, later_index); + } + // If some have failed, make a recursive call + else if (later > 0) { + SM4(Slater_inv, Dim, later, later_updates, later_index); + } +} + extern "C" { void SM1_f(double **linSlater_inv, unsigned int *Dim, unsigned int *N_updates, double **linUpdates, @@ -184,10 +248,15 @@ extern "C" { unsigned int **Updates_index) { SM2(*linSlater_inv, *Dim, *N_updates, *linUpdates, *Updates_index); } - + void SM3_f(double **linSlater_inv, unsigned int *Dim, unsigned int *N_updates, double **linUpdates, unsigned int **Updates_index) { SM3(*linSlater_inv, *Dim, *N_updates, *linUpdates, *Updates_index); } + void SM4_f(double **linSlater_inv, unsigned int *Dim, + unsigned int *N_updates, double **linUpdates, + unsigned int **Updates_index) { + SM4(*linSlater_inv, *Dim, *N_updates, *linUpdates, *Updates_index); + } } diff --git a/src/SM_mod.f90 b/src/SM_mod.f90 index f947458..bcf072c 100644 --- a/src/SM_mod.f90 +++ b/src/SM_mod.f90 @@ -36,5 +36,12 @@ module Sherman_Morrison real(c_double), dimension(:,:), allocatable, intent(in) :: Updates real(c_double), dimension(:,:), allocatable, intent(in out) :: Slater_inv end subroutine SM3 + subroutine SM4(Slater_inv, dim, n_updates, Updates, Updates_index) bind(C, name="SM4_f") + use, intrinsic :: iso_c_binding, only : c_int, c_double + integer(c_int), intent(in) :: dim, n_updates + integer(c_int), dimension(:), allocatable, intent(in) :: Updates_index + real(c_double), dimension(:,:), allocatable, intent(in) :: Updates + real(c_double), dimension(:,:), allocatable, intent(in out) :: Slater_inv + end subroutine SM4 end interface end module Sherman_Morrison diff --git a/tests/test_h5.cpp b/tests/test_h5.cpp index 8f5c5d2..005eca8 100644 --- a/tests/test_h5.cpp +++ b/tests/test_h5.cpp @@ -104,6 +104,8 @@ int test_cycle(H5File file, int cycle, std::string version, double tolerance) { SM2(slater_inverse, dim, nupdates, u, col_update_index); } else if (version == "sm3") { SM3(slater_inverse, dim, nupdates, u, col_update_index); + } else if (version == "sm4") { + SM4(slater_inverse, dim, nupdates, u, col_update_index); } else { std::cerr << "Unknown version " << version << std::endl; exit(1); diff --git a/tests/vfc_test_h5.cpp b/tests/vfc_test_h5.cpp index 52b2bb6..5fc8583 100644 --- a/tests/vfc_test_h5.cpp +++ b/tests/vfc_test_h5.cpp @@ -137,6 +137,8 @@ int test_cycle(H5File file, int cycle, std::string version, vfc_probes * probes) SM2(slater_inverse, dim, nupdates, u, col_update_index); } else if (version == "sm3") { SM3(slater_inverse, dim, nupdates, u, col_update_index); + } else if (version == "sm3") { + SM4(slater_inverse, dim, nupdates, u, col_update_index); } else { std::cerr << "Unknown version " << version << std::endl; exit(1); From ada8cd6888f4ac37b0214dc38f7e4b7a7b90ca31 Mon Sep 17 00:00:00 2001 From: Pablo Oliveira Date: Thu, 6 May 2021 11:01:07 +0200 Subject: [PATCH 2/2] Fix typo --- tests/vfc_test_h5.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/vfc_test_h5.cpp b/tests/vfc_test_h5.cpp index 5fc8583..7b63b5d 100644 --- a/tests/vfc_test_h5.cpp +++ b/tests/vfc_test_h5.cpp @@ -137,7 +137,7 @@ int test_cycle(H5File file, int cycle, std::string version, vfc_probes * probes) SM2(slater_inverse, dim, nupdates, u, col_update_index); } else if (version == "sm3") { SM3(slater_inverse, dim, nupdates, u, col_update_index); - } else if (version == "sm3") { + } else if (version == "sm4") { SM4(slater_inverse, dim, nupdates, u, col_update_index); } else { std::cerr << "Unknown version " << version << std::endl;