10
1
mirror of https://github.com/pfloos/quack synced 2024-12-22 04:13:52 +01:00

dRPA (with no TDA) on GPU: V0

This commit is contained in:
AbdAmmar 2024-11-29 20:32:19 +01:00
parent fd4dc5b77e
commit 542cce2da9
10 changed files with 337 additions and 56 deletions

View File

@ -0,0 +1,16 @@
#ifndef MY_LINALG
#define MY_LINALG
extern void A_D_At(int n, double *A, double *D, double *R);
extern void A_Dinv_At(int n, double *A, double *D, double *R);
extern void A_D_inplace(int n, double *A, double *D);
extern void A_Dinv_inplace(int n, double *A, double *D);
extern void elementwise_dsqrt(int nS, double *A, double *A_Sq);
extern void elementwise_dsqrt_inplace(int nS, double *A);
extern void diag_dn_dsyevd(int n, int *info, double *W, double *A);
#endif

View File

@ -5,6 +5,7 @@
extern void ph_dRPA_A_sing(int nO, int nV, int nBas, int nS, double *eps, double *ERI, double *A);
extern void ph_dRPA_B_sing(int nO, int nV, int nBas, int nS, double *ERI, double *B);
extern void diag_dn_dsyevd(int n, int *info, double *W, double *A);
extern void ph_dRPA_ApB_sing(int nO, int nV, int nBas, int nS, double *eps, double *ERI, double *ApB);
extern void ph_dRPA_AmB_sing(int nO, int nV, int nBas, int nS, double *eps, double *ERI, double *AmB);
#endif

View File

@ -25,7 +25,7 @@ __global__ void A_D_At_kernel(int n, double *A, double *D, double *R) {
while(k < n) {
kn = k * n;
R[ij] += D[k] * U[i + kn] * U[j + kn];
R[ij] += D[k] * A[i + kn] * A[j + kn];
k ++;
} // k

View File

@ -0,0 +1,57 @@
#include <stdio.h>
__global__ void A_D_inplace_kernel(int n, double *A, double *D) {
int i, j;
int in, ji;
double tmp;
i = blockIdx.x * blockDim.x + threadIdx.x;
j = blockIdx.y * blockDim.y + threadIdx.y;
while(i < n) {
in = i * n;
tmp = D[i];
while(j < n) {
ji = in + j;
A[ji] = A[ji] * tmp;
j += blockDim.y * gridDim.y;
} // j
i += blockDim.x * gridDim.x;
} // i
}
extern "C" void A_D_inplace(int n, double *A, double *D) {
int sBlocks = 32;
int nBlocks = (n + sBlocks - 1) / sBlocks;
dim3 dimGrid(nBlocks, nBlocks, 1);
dim3 dimBlock(sBlocks, sBlocks, 1);
printf("lunching A_D_inplace_kernel with %dx%d blocks and %dx%d threads/block\n",
nBlocks, nBlocks, sBlocks, sBlocks);
A_D_inplace_kernel<<<dimGrid, dimBlock>>>(n, A, D);
}

View File

@ -25,7 +25,7 @@ __global__ void A_Dinv_At_kernel(int n, double *A, double *D, double *R) {
while(k < n) {
kn = k * n;
R[ij] += D[k] * U[i + kn] * U[j + kn] / (D[k] + 1e-12);
R[ij] += D[k] * A[i + kn] * A[j + kn] / (D[k] + 1e-12);
k ++;
} // k

View File

@ -0,0 +1,57 @@
#include <stdio.h>
__global__ void A_Dinv_inplace_kernel(int n, double *A, double *D) {
int i, j;
int in, ji;
double tmp;
i = blockIdx.x * blockDim.x + threadIdx.x;
j = blockIdx.y * blockDim.y + threadIdx.y;
while(i < n) {
in = i * n;
tmp = 1.0 / (1e-12 + D[i]);
while(j < n) {
ji = in + j;
A[ji] = A[ji] * tmp;
j += blockDim.y * gridDim.y;
} // j
i += blockDim.x * gridDim.x;
} // i
}
extern "C" void A_Dinv_inplace(int n, double *A, double *D) {
int sBlocks = 32;
int nBlocks = (n + sBlocks - 1) / sBlocks;
dim3 dimGrid(nBlocks, nBlocks, 1);
dim3 dimBlock(sBlocks, sBlocks, 1);
printf("lunching A_Dinv_inplace_kernel with %dx%d blocks and %dx%d threads/block\n",
nBlocks, nBlocks, sBlocks, sBlocks);
A_Dinv_inplace_kernel<<<dimGrid, dimBlock>>>(n, A, D);
}

View File

@ -0,0 +1,51 @@
#include <stdio.h>
#include <math.h>
__global__ void elementwise_dsqrt_kernel(int nS, double *A, double *A_Sq) {
int i;
i = blockIdx.x * blockDim.x + threadIdx.x;
while(i < nS) {
if(A[i] > 0.0) {
A_Sq[i] = sqrt(A[i]);
} else {
A_Sq[i] = sqrt(-A[i]);
}
i += blockDim.x * gridDim.x;
} // i
}
extern "C" void elementwise_dsqrt(int nS, double *A, double *A_Sq) {
int sBlocks = 32;
int nBlocks = (nS + sBlocks - 1) / sBlocks;
dim3 dimGrid(nBlocks, 1, 1);
dim3 dimBlock(sBlocks, 1, 1);
printf("lunching elementwise_dsqrt_kernel with %d blocks and %d threads/block\n",
nBlocks, sBlocks);
elementwise_dsqrt_kernel<<<dimGrid, dimBlock>>>(nS, A, A_Sq);
}

View File

@ -2,13 +2,12 @@
#include <math.h>
__global__ void elementwise_dsqrt_inplace_kernel(int nS, double *A, int *nb_neg_sqrt) {
__global__ void elementwise_dsqrt_inplace_kernel(int nS, double *A) {
int i;
i = blockIdx.x * blockDim.x + threadIdx.x;
nb_neg_sqrt = 0;
while(i < nS) {
@ -31,7 +30,7 @@ __global__ void elementwise_dsqrt_inplace_kernel(int nS, double *A, int *nb_neg_
extern "C" void elementwise_dsqrt_inplace(int nS, double *A, int *nb_neg_sqrt) {
extern "C" void elementwise_dsqrt_inplace(int nS, double *A) {
int sBlocks = 32;
int nBlocks = (nS + sBlocks - 1) / sBlocks;
@ -43,7 +42,7 @@ extern "C" void elementwise_dsqrt_inplace(int nS, double *A, int *nb_neg_sqrt) {
nBlocks, sBlocks);
elementwise_dsqrt_inplace_kernel<<<dimGrid, dimBlock>>>(nS, A, nb_neg_sqrt);
elementwise_dsqrt_inplace_kernel<<<dimGrid, dimBlock>>>(nS, A);
}

View File

@ -8,9 +8,15 @@
#include "utils.h"
#include "ph_rpa.h"
#include "my_linalg.h"
void ph_drpa_sing(int nO, int nBas, int nS, double *h_eps, double *h_ERI,
double *h_Omega, double *h_XpY, double *h_XmY) {
double *h_Omega, double *h_XpY, double *h_XmY) {
double *d_eps = NULL;
double *d_ERI = NULL;
@ -23,18 +29,23 @@ void ph_drpa_sing(int nO, int nBas, int nS, double *h_eps, double *h_ERI,
long long nS_long = (long long) nS;
long long nS2 = nS_long * nS_long;
cublasHandle_t handle;
const double alpha=1.0, beta=0.0;
float elapsedTime;
cudaEvent_t start, stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);
check_Cuda_Errors(cudaMalloc((void**)&d_eps, nBas * sizeof(double)),
"cudaMalloc", __FILE__, __LINE__);
check_Cuda_Errors(cudaMalloc((void**)&d_ERI, nBas4 * sizeof(double)),
"cudaMalloc", __FILE__, __LINE__);
printf("CPU->GPU transfer..\n");
cudaEventRecord(start, 0);
check_Cuda_Errors(cudaMemcpy(d_eps, h_eps, nBas * sizeof(double), cudaMemcpyHostToDevice),
"cudaMemcpy", __FILE__, __LINE__);
@ -67,15 +78,12 @@ void ph_drpa_sing(int nO, int nBas, int nS, double *h_eps, double *h_ERI,
// diagonalize A-B
int *d_info = NULL;
int *d_info1 = NULL;
check_Cuda_Errors(cudaMalloc((void**)&d_info1, sizeof(int)), "cudaMalloc", __FILE__, __LINE__);
double *d_Omega = NULL;
check_Cuda_Errors(cudaMalloc((void**)&d_info, sizeof(int)),
"cudaMalloc", __FILE__, __LINE__);
check_Cuda_Errors(cudaMalloc((void**)&d_Omega, nS * sizeof(double)),
"cudaMalloc", __FILE__, __LINE__);
check_Cuda_Errors(cudaMalloc((void**)&d_Omega, nS * sizeof(double)), "cudaMalloc", __FILE__, __LINE__);
cudaEventRecord(start, 0);
diag_dn_dsyevd(nS, d_info, d_Omega, d_AmB);
diag_dn_dsyevd(nS, d_info1, d_Omega, d_AmB);
check_Cuda_Errors(cudaGetLastError(), "cudaGetLastError", __FILE__, __LINE__);
cudaEventRecord(stop, 0);
cudaEventSynchronize(stop);
@ -84,31 +92,24 @@ void ph_drpa_sing(int nO, int nBas, int nS, double *h_eps, double *h_ERI,
// d_Omega <-- d_Omega^{0.5}
// TODO: nb of <= 0 elements
cudaEventRecord(start, 0);
elementwise_dsqrt_inplace(nS, d_Omega);
// TODO
//int *d_nb_neg_sqrt = NULL;
//check_Cuda_Errors(cudaMalloc((void**)&d_nb_neg_sqrt, sizeof(int)),
// "cudaMalloc", __FILE__, __LINE__);
//int nb_neg_sqrt = 0;
//check_Cuda_Errors(cudaMemcpy(&nb_neg_sqrt, d_nb_neg_sqrt, sizeof(int), cudaMemcpyDeviceToHost),
// "cudaMemcpy", __FILE__, __LINE__);
//if (nb_neg_sqrt > 0) {
// printf("You may have instabilities in linear response: A-B is not positive definite!!\n");
// printf("nb of <= 0 elements = %d\n", nb_neg_sqrt);
//}
cudaEventRecord(stop, 0);
cudaEventSynchronize(stop);
cudaEventElapsedTime(&elapsedTime, start, stop);
printf("Time elapsed on elementwise_dsqrt_inplace %f msec\n", elapsedTime);
// TODO
// d_AmB (d_Omega)^{+0.5} (d_AmB)^T
// d_AmB (d_Omega)^{-0.5} (d_AmB)^T
// d_AmBSq = d_AmB (d_Omega)^{+0.5} (d_AmB)^T
// d_AmBSqInv = d_AmB (d_Omega)^{-0.5} (d_AmB)^T
cudaEventRecord(start, 0);
double *d_AmBSq = NULL;
check_Cuda_Errors(cudaMalloc((void**)&d_AmBSq, nS * sizeof(double)),
"cudaMalloc", __FILE__, __LINE__);
double *d_AmBSqInv = NULL;
check_Cuda_Errors(cudaMalloc((void**)&d_AmBSqInv, nS * sizeof(double)),
"cudaMalloc", __FILE__, __LINE__);
cudaEventRecord(start, 0);
A_D_At(nS, d_AmB, d_Omega, d_AmBSq);
A_Dinv_At(nS, d_AmB, d_Omega, d_AmBSqInv);
check_Cuda_Errors(cudaGetLastError(), "cudaGetLastError", __FILE__, __LINE__);
@ -118,35 +119,128 @@ void ph_drpa_sing(int nO, int nBas, int nS, double *h_eps, double *h_ERI,
printf("Time elapsed on d_AmBSq & d_AmBSqInv = %f msec\n", elapsedTime);
// TODO
//call dgemm('N','N',nS,nS,nS,1d0,ApB,size(ApB,1),AmBSq,size(AmBSq,1),0d0,tmp,size(tmp,1))
//call dgemm('N','N',nS,nS,nS,1d0,AmBSq,size(AmBSq,1),tmp,size(tmp,1),0d0,Z,size(Z,1))
//call diagonalize_matrix(nS,Z,Om)
//if(minval(Om) < 0d0) &
// call print_warning('You may have instabilities in linear response: negative excitations!!')
//Om = sqrt(Om)
//call dgemm('T','N',nS,nS,nS,1d0,Z,size(Z,1),AmBSq,size(AmBSq,1),0d0,XpY,size(XpY,1))
//call DA(nS,1d0/dsqrt(Om),XpY)
//call dgemm('T','N',nS,nS,nS,1d0,Z,size(Z,1),AmBIv,size(AmBIv,1),0d0,XmY,size(XmY,1))
//call DA(nS,1d0*dsqrt(Om),XmY)
// Dgemm
cudaEventRecord(start, 0);
check_Cublas_Errors(cublasCreate(&handle), "cublasCreate", __FILE__, __LINE__);
// X + Y
check_Cublas_Errors(cublasDgemm(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
nS, nS, nS,
&alpha,
d_ApB, nS,
d_AmBSq, nS,
&beta,
d_AmB, nS),
"cublasDgemm", __FILE__, __LINE__);
check_Cuda_Errors(cudaDeviceSynchronize(), "cudaDeviceSynchronize", __FILE__, __LINE__);
// X - Y
check_Cublas_Errors(cublasDgemm(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
nS, nS, nS,
&alpha,
d_AmBSq, nS,
d_AmB, nS,
&beta,
d_ApB, nS),
"cublasDgemm", __FILE__, __LINE__);
check_Cublas_Errors(cublasDestroy(handle), "cublasDestroy", __FILE__, __LINE__);
cudaEventRecord(stop, 0);
cudaEventSynchronize(stop);
cudaEventElapsedTime(&elapsedTime, start, stop);
printf("Time elapsed on cublasDgemm = %f msec\n", elapsedTime);
// diagonalize
int *d_info2 = NULL;
check_Cuda_Errors(cudaMalloc((void**)&d_info2, sizeof(int)), "cudaMalloc", __FILE__, __LINE__);
cudaEventRecord(start, 0);
diag_dn_dsyevd(nS, d_info2, d_Omega, d_ApB);
check_Cuda_Errors(cudaGetLastError(), "cudaGetLastError", __FILE__, __LINE__);
cudaEventRecord(stop, 0);
cudaEventSynchronize(stop);
cudaEventElapsedTime(&elapsedTime, start, stop);
printf("Time elapsed on diag ApB = %f msec\n", elapsedTime);
// d_Omega <-- d_Omega^{0.5}
// TODO: nb of <= 0 elements
cudaEventRecord(start, 0);
elementwise_dsqrt_inplace(nS, d_Omega);
cudaEventRecord(stop, 0);
cudaEventSynchronize(stop);
cudaEventElapsedTime(&elapsedTime, start, stop);
printf("Time elapsed on elementwise_dsqrt_inplace %f msec\n", elapsedTime);
// Dgemm
cudaEventRecord(start, 0);
check_Cublas_Errors(cublasCreate(&handle), "cublasCreate", __FILE__, __LINE__);
// X + Y
check_Cublas_Errors(cublasDgemm(handle,
CUBLAS_OP_T, CUBLAS_OP_N,
nS, nS, nS,
&alpha,
d_ApB, nS,
d_AmBSq, nS,
&beta,
d_AmB, nS),
"cublasDgemm", __FILE__, __LINE__);
check_Cuda_Errors(cudaDeviceSynchronize(), "cudaDeviceSynchronize", __FILE__, __LINE__);
// X - Y
check_Cublas_Errors(cublasDgemm(handle,
CUBLAS_OP_T, CUBLAS_OP_N,
nS, nS, nS,
&alpha,
d_ApB, nS,
d_AmBSqInv, nS,
&beta,
d_AmBSq, nS),
"cublasDgemm", __FILE__, __LINE__);
check_Cublas_Errors(cublasDestroy(handle), "cublasDestroy", __FILE__, __LINE__);
cudaEventRecord(stop, 0);
cudaEventSynchronize(stop);
cudaEventElapsedTime(&elapsedTime, start, stop);
printf("Time elapsed on cublasDgemm = %f msec\n", elapsedTime);
cudaEventRecord(start, 0);
elementwise_dsqrt(nS, d_Omega, d_AmBSq); // avoid addition memory allocation
A_Dinv_inplace(nS, d_AmB, d_AmBSq); // X + Y
A_D_inplace(nS, d_ApB, d_AmBSq); // X - Y
cudaEventRecord(stop, 0);
cudaEventSynchronize(stop);
cudaEventElapsedTime(&elapsedTime, start, stop);
printf("Time elapsed on final X+Y and X-Y trans = %f msec\n", elapsedTime);
// transfer data to CPU
cudaEventRecord(start, 0);
//int info_gpu = 0;
//check_Cuda_Errors(cudaMemcpy(&info_gpu, d_info, sizeof(int), cudaMemcpyDeviceToHost),
// "cudaMemcpy", __FILE__, __LINE__);
//if (info_gpu != 0) {
// printf("Error: diag_dn_dsyevd returned error code %d\n", info_gpu);
// exit(EXIT_FAILURE);
//}
check_Cuda_Errors(cudaMemcpy(h_XpY, d_, nS2 * sizeof(double), cudaMemcpyDeviceToHost),
check_Cuda_Errors(cudaMemcpy(h_XpY, d_AmB, nS2 * sizeof(double), cudaMemcpyDeviceToHost),
"cudaMemcpy", __FILE__, __LINE__);
check_Cuda_Errors(cudaMemcpy(h_XmY, d_, nS2 * sizeof(double), cudaMemcpyDeviceToHost),
check_Cuda_Errors(cudaMemcpy(h_XmY, d_ApB, nS2 * sizeof(double), cudaMemcpyDeviceToHost),
"cudaMemcpy", __FILE__, __LINE__);
check_Cuda_Errors(cudaMemcpy(h_Omega, d_Omega, nS * sizeof(double), cudaMemcpyDeviceToHost),
"cudaMemcpy", __FILE__, __LINE__);
@ -155,9 +249,13 @@ void ph_drpa_sing(int nO, int nBas, int nS, double *h_eps, double *h_ERI,
cudaEventElapsedTime(&elapsedTime, start, stop);
printf("Time elapsed on GPU -> CPU transfer = %f msec\n", elapsedTime);
check_Cuda_Errors(cudaFree(d_info), "cudaFree", __FILE__, __LINE__);
check_Cuda_Errors(cudaFree(d_A), "cudaFree", __FILE__, __LINE__);
check_Cuda_Errors(cudaFree(d_B), "cudaFree", __FILE__, __LINE__);
check_Cuda_Errors(cudaFree(d_info1), "cudaFree", __FILE__, __LINE__);
check_Cuda_Errors(cudaFree(d_info2), "cudaFree", __FILE__, __LINE__);
check_Cuda_Errors(cudaFree(d_ApB), "cudaFree", __FILE__, __LINE__);
check_Cuda_Errors(cudaFree(d_AmB), "cudaFree", __FILE__, __LINE__);
check_Cuda_Errors(cudaFree(d_AmBSq), "cudaFree", __FILE__, __LINE__);
check_Cuda_Errors(cudaFree(d_AmBSqInv), "cudaFree", __FILE__, __LINE__);
check_Cuda_Errors(cudaFree(d_Omega), "cudaFree", __FILE__, __LINE__);

View File

@ -8,6 +8,9 @@
#include "utils.h"
#include "ph_rpa.h"
#include "my_linalg.h"
/*
*
@ -42,7 +45,6 @@ void ph_drpa_tda_sing(int nO, int nBas, int nS, double *h_eps, double *h_ERI,
check_Cuda_Errors(cudaMalloc((void**)&d_ERI, nBas4 * sizeof(double)),
"cudaMalloc", __FILE__, __LINE__);
printf("CPU->GPU transfer..\n");
cudaEventRecord(start, 0);
check_Cuda_Errors(cudaMemcpy(d_eps, h_eps, nBas * sizeof(double), cudaMemcpyHostToDevice),
"cudaMemcpy", __FILE__, __LINE__);