mirror of
https://github.com/TREX-CoE/irpjast.git
synced 2025-01-10 13:08:29 +01:00
Prepared for recursive
This commit is contained in:
parent
c532c8b6d8
commit
412dba6b92
@ -10,36 +10,35 @@
|
||||
! dtmp_c:
|
||||
! dr_{ij}^k . R_{ja}^l -> dtmp_c_{ia}^{kl}
|
||||
END_DOC
|
||||
integer :: k, l, m, icount
|
||||
integer*8 :: gemms(6*ncord*(ncord+1))
|
||||
integer :: k, l, m
|
||||
integer*8 :: tasks(100000), ntasks
|
||||
|
||||
icount = 0
|
||||
ntasks = 0_8
|
||||
! r_{ij}^k . R_{ja}^l -> tmp_c_{ia}^{kl}
|
||||
do k=0,ncord-1
|
||||
do l=0,ncord
|
||||
icount += 1
|
||||
call qmckl_dgemm('N','N', nelec, nnuc, nelec, 1.d0, &
|
||||
rescale_een_e(1,1,k), size(rescale_een_e,1), &
|
||||
rescale_een_n(1,1,l), size(rescale_een_n,1), 0.d0, &
|
||||
tmp_c(1,1,l,k), size(tmp_c,1), gemms(icount))
|
||||
tmp_c(1,1,l,k), size(tmp_c,1), tasks, ntasks)
|
||||
enddo
|
||||
enddo
|
||||
|
||||
! dr_{ij}^k . R_{ja}^l -> dtmp_c_{ia}^{kl}
|
||||
do k=0,ncord-1
|
||||
do l=0,ncord
|
||||
icount += 1
|
||||
call qmckl_dgemm('N','N', nelec_8*4, nnuc, nelec, 1.d0, &
|
||||
rescale_een_e_deriv_e(1,1,1,k), &
|
||||
size(rescale_een_e_deriv_e,1)*size(rescale_een_e_deriv_e,2), &
|
||||
rescale_een_n(1,1,l), &
|
||||
size(rescale_een_n,1), 0.d0, &
|
||||
dtmp_c(1,1,1,l,k), size(dtmp_c,1)*size(dtmp_c,2), &
|
||||
gemms(icount))
|
||||
tasks, ntasks)
|
||||
enddo
|
||||
enddo
|
||||
|
||||
call qmckl_tasks_run(gemms, icount)
|
||||
print *, ntasks, ' tasks'
|
||||
call qmckl_tasks_run(tasks, ntasks)
|
||||
|
||||
END_PROVIDER
|
||||
|
||||
|
@ -5,23 +5,24 @@ module qmckl_blas
|
||||
|
||||
interface
|
||||
subroutine qmckl_dgemm(transa, transb, m, n, k, &
|
||||
alpha, A, lda, B, ldb, beta, C, ldc, res) bind(C)
|
||||
alpha, A, lda, B, ldb, beta, C, ldc, tasks, ntasks) bind(C)
|
||||
use :: iso_c_binding
|
||||
implicit none
|
||||
character(kind=c_char ), value :: transa, transb
|
||||
integer (kind=c_int ), value :: m, n, k, lda, ldb, ldc
|
||||
real (kind=c_double), value :: alpha, beta
|
||||
real (kind=c_double) :: A(lda,*), B(ldb,*), C(ldc,*)
|
||||
integer (kind=c_int64_t) :: res
|
||||
integer (kind=c_int64_t) :: tasks(*)
|
||||
integer (kind=c_int64_t) :: ntasks
|
||||
end subroutine qmckl_dgemm
|
||||
end interface
|
||||
|
||||
interface
|
||||
subroutine qmckl_tasks_run(gemms, ngemms) bind(C)
|
||||
subroutine qmckl_tasks_run(tasks, ntasks) bind(C)
|
||||
use :: iso_c_binding
|
||||
implicit none
|
||||
integer (kind=c_int32_t), value :: ngemms
|
||||
integer (kind=c_int64_t) :: gemms(ngemms)
|
||||
integer (kind=c_int64_t), value :: ntasks
|
||||
integer (kind=c_int64_t) :: tasks(ntasks)
|
||||
end subroutine qmckl_tasks_run
|
||||
end interface
|
||||
|
||||
|
@ -59,18 +59,16 @@ void qmckl_dgemm_cl(struct dgemm_args args, double* A, double* B, double* C) {
|
||||
args.beta, C, args.ldc);
|
||||
}
|
||||
|
||||
void qmckl_dgemm(char transa, char transb,
|
||||
static struct dgemm_args* qmckl_dgemm_to_struct(char transa, char transb,
|
||||
int m, int n, int k,
|
||||
double alpha,
|
||||
double* A, int lda,
|
||||
double* B, int ldb,
|
||||
double beta,
|
||||
double* C, int ldc,
|
||||
int64_t* result)
|
||||
double* C, int ldc)
|
||||
{
|
||||
struct dgemm_args* args = (struct dgemm_args*) malloc (sizeof(struct dgemm_args));
|
||||
assert (args != NULL);
|
||||
*result = (int64_t) args;
|
||||
|
||||
args->alpha = alpha;
|
||||
args->beta = beta ;
|
||||
@ -95,14 +93,29 @@ void qmckl_dgemm(char transa, char transb,
|
||||
} else {
|
||||
args->transb = CblasNoTrans;
|
||||
}
|
||||
return args;
|
||||
|
||||
}
|
||||
|
||||
void qmckl_dgemm(char transa, char transb,
|
||||
int m, int n, int k,
|
||||
double alpha,
|
||||
double* A, int lda,
|
||||
double* B, int ldb,
|
||||
double beta,
|
||||
double* C, int ldc,
|
||||
int64_t* tasks, int64_t* ntasks)
|
||||
{
|
||||
tasks[*ntasks] = (int64_t) qmckl_dgemm_to_struct (transa, transb,
|
||||
m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
|
||||
*ntasks += 1L;
|
||||
}
|
||||
|
||||
void qmckl_tasks_run(struct dgemm_args** gemms, int ngemms)
|
||||
{
|
||||
starpu_init(NULL);
|
||||
|
||||
int rc = starpu_init(NULL);
|
||||
assert (rc == 0);
|
||||
|
||||
starpu_data_handle_t matrix_handle[ngemms][3];
|
||||
for (int i=0 ; i<ngemms ; ++i)
|
||||
{
|
||||
@ -138,7 +151,8 @@ void qmckl_tasks_run(struct dgemm_args** gemms, int ngemms)
|
||||
task->handles[0] = matrix_handle[i][0];
|
||||
task->handles[1] = matrix_handle[i][1];
|
||||
task->handles[2] = matrix_handle[i][2];
|
||||
starpu_task_submit(task);
|
||||
rc = starpu_task_submit(task);
|
||||
assert (rc == 0);
|
||||
}
|
||||
starpu_task_wait_for_all();
|
||||
|
||||
@ -150,3 +164,65 @@ void qmckl_tasks_run(struct dgemm_args** gemms, int ngemms)
|
||||
}
|
||||
starpu_shutdown();
|
||||
}
|
||||
|
||||
|
||||
#define MIN_SIZE 512
|
||||
static void qmckl_dgemm_rec(struct dgemm_args args) {
|
||||
|
||||
// printf("%5d %5d\n", args.m, args.n);
|
||||
|
||||
if ( (args.m <= MIN_SIZE) || (args.n <= MIN_SIZE)) {
|
||||
#pragma omp task
|
||||
{
|
||||
printf("BLAS %5d %5d %5d\n", args.m, args.n, args.k);
|
||||
cblas_dgemm(CblasColMajor, args.transa, args.transb,
|
||||
args.m, args.n, args.k, args.alpha,
|
||||
args.A, args.lda, args.B, args.ldb,
|
||||
args.beta, args.C, args.ldc);
|
||||
}
|
||||
} else {
|
||||
|
||||
int m1 = args.m / 2;
|
||||
int m2 = args.m - m1;
|
||||
int n1 = args.n / 2;
|
||||
int n2 = args.n - n1;
|
||||
|
||||
{
|
||||
struct dgemm_args args_1 = args;
|
||||
args_1.m = m1;
|
||||
args_1.n = n1;
|
||||
qmckl_dgemm_rec(args_1);
|
||||
}
|
||||
|
||||
{
|
||||
// TODO: assuming 'N', 'N' here
|
||||
struct dgemm_args args_2 = args;
|
||||
args_2.B = args.B + args.ldb*n1;
|
||||
args_2.C = args.C + args.ldc*n1;
|
||||
args_2.m = m1;
|
||||
args_2.n = n2;
|
||||
qmckl_dgemm_rec(args_2);
|
||||
}
|
||||
|
||||
{
|
||||
struct dgemm_args args_3 = args;
|
||||
args_3.A = args.A + m1;
|
||||
args_3.C = args.C + m1;
|
||||
args_3.m = m2;
|
||||
args_3.n = n1;
|
||||
qmckl_dgemm_rec(args_3);
|
||||
}
|
||||
|
||||
{
|
||||
struct dgemm_args args_4 = args;
|
||||
args_4.A = args.A + m1;
|
||||
args_4.B = args.B + args.ldb*n1;
|
||||
args_4.C = args.C + m1 + args.ldc*n1;
|
||||
args_4.m = m2;
|
||||
args_4.n = n2;
|
||||
qmckl_dgemm_rec(args_4);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user