10
1
mirror of https://github.com/pfloos/quack synced 2025-01-03 01:55:57 +01:00

optim A x D x A.T and A x Dinv x A.T on GPU

This commit is contained in:
AbdAmmar 2024-11-30 00:38:32 +01:00
parent fcb11d662c
commit da4f8df0f8
10 changed files with 360 additions and 69 deletions

View File

@ -2,12 +2,18 @@
#define MY_LINALG
extern void A_plus_B_in_A(int n, double *A, double *B);
extern void A_minus_twoB_in_B(int n, double *A, double *B);
extern void A_D_At(int n, double *A, double *D, double *R);
extern void A_Dinv_At(int n, double *A, double *D, double *R);
extern void A_D_inplace(int n, double *A, double *D);
extern void A_Dinv_inplace(int n, double *A, double *D);
extern void A_D_in_B(int n, double *A, double *D, double *B);
extern void A_Dinv_in_B(int n, double *A, double *D, double *B);
extern void elementwise_dsqrt(int nS, double *A, double *A_Sq);
extern void elementwise_dsqrt_inplace(int nS, double *A);

57
src/cuda/src/a_d_in_b.cu Normal file
View File

@ -0,0 +1,57 @@
#include <stdio.h>
__global__ void A_D_in_B_kernel(int n, double *A, double *D, double *B) {
int i, j;
int in, ji;
double tmp;
i = blockIdx.x * blockDim.x + threadIdx.x;
j = blockIdx.y * blockDim.y + threadIdx.y;
while(i < n) {
in = i * n;
tmp = D[i];
while(j < n) {
ji = in + j;
B[ji] = A[ji] * tmp;
j += blockDim.y * gridDim.y;
} // j
i += blockDim.x * gridDim.x;
} // i
}
extern "C" void A_D_in_B(int n, double *A, double *D, double *B) {
int sBlocks = 32;
int nBlocks = (n + sBlocks - 1) / sBlocks;
dim3 dimGrid(nBlocks, nBlocks, 1);
dim3 dimBlock(sBlocks, sBlocks, 1);
printf("lunching A_D_in_B_kernel with %dx%d blocks and %dx%d threads/block\n",
nBlocks, nBlocks, sBlocks, sBlocks);
A_D_in_B_kernel<<<dimGrid, dimBlock>>>(n, A, D, B);
}

View File

@ -0,0 +1,57 @@
#include <stdio.h>
__global__ void A_Dinv_in_B_kernel(int n, double *A, double *D, double *B) {
int i, j;
int in, ji;
double tmp;
i = blockIdx.x * blockDim.x + threadIdx.x;
j = blockIdx.y * blockDim.y + threadIdx.y;
while(i < n) {
in = i * n;
tmp = 1.0 / D[i];
while(j < n) {
ji = in + j;
B[ji] = A[ji] * tmp;
j += blockDim.y * gridDim.y;
} // j
i += blockDim.x * gridDim.x;
} // i
}
extern "C" void A_Dinv_in_B(int n, double *A, double *D, double *B) {
int sBlocks = 32;
int nBlocks = (n + sBlocks - 1) / sBlocks;
dim3 dimGrid(nBlocks, nBlocks, 1);
dim3 dimBlock(sBlocks, sBlocks, 1);
printf("lunching A_Dinv_in_B_kernel with %dx%d blocks and %dx%d threads/block\n",
nBlocks, nBlocks, sBlocks, sBlocks);
A_Dinv_in_B_kernel<<<dimGrid, dimBlock>>>(n, A, D, B);
}

View File

@ -0,0 +1,52 @@
#include <stdio.h>
__global__ void A_minus_twoB_in_B_kernel(int n, double *A, double *B) {
int i, j;
int in, ji;
i = blockIdx.x * blockDim.x + threadIdx.x;
j = blockIdx.y * blockDim.y + threadIdx.y;
while(i < n) {
in = i * n;
while(j < n) {
ji = in + j;
B[ji] = A[ji] - 2.0 * B[ji];
j += blockDim.y * gridDim.y;
} // j
i += blockDim.x * gridDim.x;
} // i
}
extern "C" void A_minus_twoB_in_B(int n, double *A, double *B) {
int sBlocks = 32;
int nBlocks = (n + sBlocks - 1) / sBlocks;
dim3 dimGrid(nBlocks, nBlocks, 1);
dim3 dimBlock(sBlocks, sBlocks, 1);
printf("lunching A_minus_twoB_in_B_kernel with %dx%d blocks and %dx%d threads/block\n",
nBlocks, nBlocks, sBlocks, sBlocks);
A_minus_twoB_in_B_kernel<<<dimGrid, dimBlock>>>(n, A, B);
}

View File

@ -0,0 +1,52 @@
#include <stdio.h>
__global__ void A_plus_B_in_A_kernel(int n, double *A, double *B) {
int i, j;
int in, ji;
i = blockIdx.x * blockDim.x + threadIdx.x;
j = blockIdx.y * blockDim.y + threadIdx.y;
while(i < n) {
in = i * n;
while(j < n) {
ji = in + j;
A[ji] = A[ji] + B[ji];
j += blockDim.y * gridDim.y;
} // j
i += blockDim.x * gridDim.x;
} // i
}
extern "C" void A_plus_B_in_A(int n, double *A, double *B) {
int sBlocks = 32;
int nBlocks = (n + sBlocks - 1) / sBlocks;
dim3 dimGrid(nBlocks, nBlocks, 1);
dim3 dimBlock(sBlocks, sBlocks, 1);
printf("lunching A_plus_B_in_A_kernel with %dx%d blocks and %dx%d threads/block\n",
nBlocks, nBlocks, sBlocks, sBlocks);
A_plus_B_in_A_kernel<<<dimGrid, dimBlock>>>(n, A, B);
}

View File

@ -5,17 +5,18 @@ __global__ void ph_dRPA_A_sing_kernel(int nO, int nV, int nBas, int nS, double *
int i, j, a, b;
int aa, bb;
int nVS;
int nBas2, nBas3;
int i_A0, i_A1, i_A2;
int i_I0, i_I1, i_I2;
long long nVS;
long long nBas2, nBas3;
long long i_A0, i_A1, i_A2, i_A3;
long long i_I0, i_I1, i_I2, i_I3;
bool a_eq_b;
nVS = nV * nS;
nVS = (long long) nV * (long long) nS;
nBas2 = nBas * nBas;
nBas3 = nBas2 * nBas;
nBas2 = (long long) nBas * (long long) nBas;
nBas3 = nBas2 * (long long) nBas;
aa = blockIdx.x * blockDim.x + threadIdx.x;
bb = blockIdx.y * blockDim.y + threadIdx.y;
@ -23,29 +24,32 @@ __global__ void ph_dRPA_A_sing_kernel(int nO, int nV, int nBas, int nS, double *
while(aa < nV) {
a = aa + nO;
i_A0 = aa * nS;
i_I0 = a * nBas2;
i_A0 = (long long) aa * (long long) nS;
i_I0 = (long long) a * nBas2;
while(bb < nV) {
b = bb + nO;
a_eq_b = a == b;
i_A1 = i_A0 + bb;
i_I1 = i_I0 + b * nBas;
i_A1 = i_A0 + (long long) bb;
i_I1 = i_I0 + (long long) b * (long long) nBas;
i = 0;
while(i < nO) {
i_A2 = i_A1 + i * nVS;
i_I2 = i_I1 + i;
i_A2 = i_A1 + (long long) i * nVS;
i_I2 = i_I1 + (long long) i;
j = 0;
while(j < nO) {
A[i_A2 + j * nV] = 2.0 * ERI[i_I2 + j * nBas3];
i_A3 = i_A2 + (long long) j * (long long) nV;
i_I3 = i_I2 + (long long) j * nBas3;
A[i_A3] = 2.0 * ERI[i_I3];
if(a_eq_b && (i==j)) {
A[i_A2 + j * nV] += eps[a] - eps[i];
A[i_A3] += eps[a] - eps[i];
}
j ++;

View File

@ -1,22 +1,26 @@
#include <stdio.h>
__global__ void ph_dRPA_AmB_sing_kernel(int nO, int nV, int nBas, int nS, double *eps, double *ERI, double *AmB) {
__global__ void ph_dRPA_AmB_sing_kernel(int nO, int nV, int nBas, int nS,
double *eps, double *ERI, double *AmB) {
int i, j, a, b;
int aa, bb;
int nVS;
int nBas2, nBas3;
int i_A0, i_A1, i_A2;
int i_I0, i_I1, i_I2;
int i_J1, i_J2;
long long i_A0, i_A1, i_A2, i_A3;
long long i_I0, i_I1, i_I2, i_I3;
long long i_J1, i_J2, i_J3;
long long nVS;
long long nBas2, nBas3;
bool a_eq_b;
nVS = nV * nS;
nVS = (long long) nV * (long long) nS;
nBas2 = nBas * nBas;
nBas3 = nBas2 * nBas;
nBas2 = (long long) nBas * (long long) nBas;
nBas3 = nBas2 * (long long) nBas;
aa = blockIdx.x * blockDim.x + threadIdx.x;
bb = blockIdx.y * blockDim.y + threadIdx.y;
@ -24,31 +28,35 @@ __global__ void ph_dRPA_AmB_sing_kernel(int nO, int nV, int nBas, int nS, double
while(aa < nV) {
a = aa + nO;
i_A0 = aa * nS;
i_I0 = a * nBas2;
i_A0 = (long long) aa * (long long) nS;
i_I0 = (long long) a * nBas2;
while(bb < nV) {
b = bb + nO;
a_eq_b = a == b;
i_A1 = i_A0 + bb;
i_I1 = i_I0 + b * nBas;
i_J1 = i_I0 + b * nBas3;
i_A1 = i_A0 + (long long) bb;
i_I1 = i_I0 + (long long) b * (long long) nBas;
i_J1 = i_I0 + (long long) b * nBas3;
i = 0;
while(i < nO) {
i_A2 = i_A1 + i * nVS;
i_I2 = i_I1 + i;
i_J2 = i_J1 + i;
i_A2 = i_A1 + (long long) i * nVS;
i_I2 = i_I1 + (long long) i;
i_J2 = i_J1 + (long long) i;
j = 0;
while(j < nO) {
AmB[i_A2 + j * nV] = 2.0 * (ERI[i_I2 + j * nBas3] - ERI[i_J2 + j * nBas]);
i_A3 = i_A2 + (long long) j * nV;
i_I3 = i_I2 + (long long) j * nBas3;
i_J3 = i_J2 + (long long) j * (long long) nBas;
AmB[i_A3] = 2.0 * (ERI[i_I3] - ERI[i_J3]);
if(a_eq_b && (i==j)) {
AmB[i_A2 + j * nV] += eps[a] - eps[i];
AmB[i_A3] += eps[a] - eps[i];
}
j ++;

View File

@ -1,22 +1,29 @@
#include <stdio.h>
__global__ void ph_dRPA_ApB_sing_kernel(int nO, int nV, int nBas, int nS, double *eps, double *ERI, double *ApB) {
__global__ void ph_dRPA_ApB_sing_kernel(int nO, int nV, int nBas, int nS,
double *eps, double *ERI, double *ApB) {
int i, j, a, b;
int aa, bb;
int nVS;
int nBas2, nBas3;
int i_A0, i_A1, i_A2;
long i, j, a, b;
long aa, bb;
int i_A0, i_A1, i_A2, i_A3;
int i_I0, i_I1, i_I2;
int i_J1, i_J2;
int nVS;
int nBas2;
long long i_I3, i_J3;
long long nBas3;
bool a_eq_b;
nVS = nV * nS;
nBas2 = nBas * nBas;
nBas3 = nBas2 * nBas;
nBas3 = (long long) nBas2 * (long long) nBas;
aa = blockIdx.x * blockDim.x + threadIdx.x;
bb = blockIdx.y * blockDim.y + threadIdx.y;
@ -34,21 +41,25 @@ __global__ void ph_dRPA_ApB_sing_kernel(int nO, int nV, int nBas, int nS, double
i_A1 = i_A0 + bb;
i_I1 = i_I0 + b * nBas;
i_J1 = i_I0 + b * nBas3;
i_J1 = a + b * nBas;
i = 0;
while(i < nO) {
i_A2 = i_A1 + i * nVS;
i_I2 = i_I1 + i;
i_J2 = i_J1 + i;
i_J2 = i_J1 + i * nBas2;
j = 0;
while(j < nO) {
ApB[i_A2 + j * nV] = 2.0 * (ERI[i_I2 + j * nBas3] + ERI[i_J2 + j * nBas]);
i_A3 = i_A2 + j * nV;
i_I3 = i_I2 + (long long) j * nBas3;
i_J3 = i_J2 + (long long) j * nBas3;
ApB[i_A3] = 2.0 * (ERI[i_I3] + ERI[i_J3]);
if(a_eq_b && (i==j)) {
ApB[i_A2 + j * nV] += eps[a] - eps[i];
ApB[i_A3] += eps[a] - eps[i];
}
j ++;

View File

@ -5,15 +5,17 @@ __global__ void ph_dRPA_B_sing_kernel(int nO, int nV, int nBas, int nS, double *
int i, j, a, b;
int aa, bb;
int nVS;
int nBas2, nBas3;
int i_B0, i_B1, i_B2;
int i_I0, i_I1, i_I2;
nVS = nV * nS;
long long nVS;
long long nBas2, nBas3;
long long i_B0, i_B1, i_B2, i_B3;
long long i_I0, i_I1, i_I2, i_I3;
nBas2 = nBas * nBas;
nBas3 = nBas2 * nBas;
nVS = (long long) nV * (long long) nS;
nBas2 = (long long) nBas * (long long) nBas;
nBas3 = nBas2 * (long long) nBas;
aa = blockIdx.x * blockDim.x + threadIdx.x;
bb = blockIdx.y * blockDim.y + threadIdx.y;
@ -21,25 +23,28 @@ __global__ void ph_dRPA_B_sing_kernel(int nO, int nV, int nBas, int nS, double *
while(aa < nV) {
a = aa + nO;
i_B0 = aa * nS;
i_I0 = a * nBas2;
i_B0 = (long long) aa * (long long) nS;
i_I0 = (long long) a * nBas2;
while(bb < nV) {
b = bb + nO;
i_B1 = i_B0 + bb;
i_I1 = i_I0 + b * nBas3;
i_B1 = i_B0 + (long long) bb;
i_I1 = i_I0 + (long long) b * nBas3;
i = 0;
while(i < nO) {
i_B2 = i_B1 + i * nVS;
i_I2 = i_I1 + i;
i_B2 = i_B1 + (long long) i * nVS;
i_I2 = i_I1 + (long long) i;
j = 0;
while(j < nO) {
B[i_B2 + j * nV] = 2.0 * ERI[i_I2 + j * nBas];
i_B3 = i_B2 + (long long) j * (long long) nV;
i_I3 = i_I2 + (long long) j * (long long) nBas;
B[i_B3] = 2.0 * ERI[i_I3];
j ++;
} // j

View File

@ -65,6 +65,12 @@ void ph_drpa_sing(int nO, int nBas, int nS, double *h_eps, double *h_ERI,
cudaEventRecord(start, 0);
ph_dRPA_ApB_sing(nO, nV, nBas, nS, d_eps, d_ERI, d_ApB);
ph_dRPA_AmB_sing(nO, nV, nBas, nS, d_eps, d_ERI, d_AmB);
//ph_dRPA_A_sing(nO, nV, nBas, nS, d_eps, d_ERI, d_ApB);
//ph_dRPA_B_sing(nO, nV, nBas, nS, d_ERI, d_AmB);
//check_Cuda_Errors(cudaDeviceSynchronize(), "cudaDeviceSynchronize", __FILE__, __LINE__);
//A_plus_B_in_A(nS, d_ApB, d_AmB);
//check_Cuda_Errors(cudaDeviceSynchronize(), "cudaDeviceSynchronize", __FILE__, __LINE__);
//A_minus_twoB_in_B(nS, d_ApB, d_AmB);
check_Cuda_Errors(cudaGetLastError(), "cudaGetLastError", __FILE__, __LINE__);
cudaEventRecord(stop, 0);
cudaEventSynchronize(stop);
@ -73,6 +79,7 @@ void ph_drpa_sing(int nO, int nBas, int nS, double *h_eps, double *h_ERI,
// free memory
check_Cuda_Errors(cudaDeviceSynchronize(), "cudaDeviceSynchronize", __FILE__, __LINE__);
check_Cuda_Errors(cudaFree(d_eps), "cudaFree", __FILE__, __LINE__);
check_Cuda_Errors(cudaFree(d_ERI), "cudaFree", __FILE__, __LINE__);
@ -105,30 +112,62 @@ void ph_drpa_sing(int nO, int nBas, int nS, double *h_eps, double *h_ERI,
// d_AmBSq = d_AmB (d_Omega)^{+0.5} (d_AmB)^T
double *d_AmBSq = NULL;
check_Cuda_Errors(cudaMalloc((void**)&d_AmBSq, nS2 * sizeof(double)),
"cudaMalloc", __FILE__, __LINE__);
// d_AmBSqInv = d_AmB (d_Omega)^{-0.5} (d_AmB)^T
double *d_AmBSq = NULL;
double *d_AmBSqInv = NULL;
check_Cuda_Errors(cudaMalloc((void**)&d_AmBSqInv, nS2 * sizeof(double)),
"cudaMalloc", __FILE__, __LINE__);
double *d_tmp1 = NULL;
double *d_tmp2 = NULL;
check_Cuda_Errors(cudaMalloc((void**)&d_AmBSq, nS2 * sizeof(double)), "cudaMalloc", __FILE__, __LINE__);
check_Cuda_Errors(cudaMalloc((void**)&d_AmBSqInv, nS2 * sizeof(double)), "cudaMalloc", __FILE__, __LINE__);
check_Cuda_Errors(cudaMalloc((void**)&d_tmp1, nS2 * sizeof(double)), "cudaMalloc", __FILE__, __LINE__);
check_Cuda_Errors(cudaMalloc((void**)&d_tmp2, nS2 * sizeof(double)), "cudaMalloc", __FILE__, __LINE__);
check_Cublas_Errors(cublasCreate(&handle), "cublasCreate", __FILE__, __LINE__);
cudaEventRecord(start, 0);
A_D_At(nS, d_AmB, d_Omega, d_AmBSq);
A_Dinv_At(nS, d_AmB, d_Omega, d_AmBSqInv);
// naive way
//A_D_At(nS, d_AmB, d_Omega, d_AmBSq);
//A_Dinv_At(nS, d_AmB, d_Omega, d_AmBSqInv);
A_D_in_B(nS, d_AmB, d_Omega, d_tmp1);
A_Dinv_in_B(nS, d_AmB, d_Omega, d_tmp2);
check_Cuda_Errors(cudaDeviceSynchronize(), "cudaDeviceSynchronize", __FILE__, __LINE__);
check_Cublas_Errors(cublasDgemm(handle,
CUBLAS_OP_N, CUBLAS_OP_T,
nS, nS, nS,
&alpha,
d_tmp1, nS,
d_AmB, nS,
&beta,
d_AmBSq, nS),
"cublasDgemm", __FILE__, __LINE__);
check_Cublas_Errors(cublasDgemm(handle,
CUBLAS_OP_N, CUBLAS_OP_T,
nS, nS, nS,
&alpha,
d_tmp2, nS,
d_AmB, nS,
&beta,
d_AmBSqInv, nS),
"cublasDgemm", __FILE__, __LINE__);
check_Cuda_Errors(cudaGetLastError(), "cudaGetLastError", __FILE__, __LINE__);
cudaEventRecord(stop, 0);
cudaEventSynchronize(stop);
cudaEventElapsedTime(&elapsedTime, start, stop);
printf("Time elapsed on d_AmBSq & d_AmBSqInv = %f msec\n", elapsedTime);
check_Cuda_Errors(cudaDeviceSynchronize(), "cudaDeviceSynchronize", __FILE__, __LINE__);
check_Cuda_Errors(cudaFree(d_tmp1), "cudaFree", __FILE__, __LINE__);
check_Cuda_Errors(cudaFree(d_tmp2), "cudaFree", __FILE__, __LINE__);
// Dgemm
cudaEventRecord(start, 0);
check_Cublas_Errors(cublasCreate(&handle), "cublasCreate", __FILE__, __LINE__);
// X + Y
check_Cublas_Errors(cublasDgemm(handle,