diff --git a/Makefile b/Makefile index ca6b8a2..4231fe4 100644 --- a/Makefile +++ b/Makefile @@ -13,9 +13,11 @@ LIB= -mkl=sequential -include irpf90.make export -irpf90.make: qmckl_blas_f.o +#irpf90.make: IRPF90_temp/qmckl_blas_f.o irpf90.make: $(filter-out IRPF90_temp/%, $(wildcard */*.irp.f)) $(wildcard *.irp.f) $(wildcard *.inc.f) Makefile $(IRPF90) +IRPF90_temp/%.c: %.c + IRPF90_temp/%.o: %.c - $(CC) -c $< -o $@ + $(CC) -g -c $< -o $@ diff --git a/el_nuc_el_blas.irp.f b/el_nuc_el_blas.irp.f index 482e8cb..c369ba2 100644 --- a/el_nuc_el_blas.irp.f +++ b/el_nuc_el_blas.irp.f @@ -1,5 +1,6 @@ BEGIN_PROVIDER [ double precision, tmp_c, (nelec_8,nnuc,0:ncord,0:ncord-1) ] &BEGIN_PROVIDER [ double precision, dtmp_c, (nelec_8,4,nnuc,0:ncord,0:ncord-1) ] + use qmckl_blas implicit none BEGIN_DOC ! Calculate the intermediate buffers diff --git a/qmckl_dgemm.c b/qmckl_dgemm.c index 2fdf25c..2935807 100644 --- a/qmckl_dgemm.c +++ b/qmckl_dgemm.c @@ -3,7 +3,7 @@ #include struct dgemm_args { - double alpha; + double alpha; double beta; double* A; double* B; @@ -16,27 +16,68 @@ struct dgemm_args { int ldc; CBLAS_LAYOUT transa; CBLAS_LAYOUT transb; -}; - +}; + + +#define MIN_SIZE 512 +#include static void qmckl_dgemm_rec(struct dgemm_args args) { - cblas_dgemm(CblasColMajor, args.transa, args.transb, - args.m, args.n, args.k, args.alpha, - args.A, args.lda, args.B, args.ldb, - args.beta, args.C, args.ldc); +// printf("%5d %5d\n", args.m, args.n); + + if ( (args.m <= MIN_SIZE) || (args.n <= MIN_SIZE)) { + cblas_dgemm(CblasColMajor, args.transa, args.transb, + args.m, args.n, args.k, args.alpha, + args.A, args.lda, args.B, args.ldb, + args.beta, args.C, args.ldc); + } else { + + int m1 = args.m / 2; + int m2 = args.m - m1; + int n1 = args.n / 2; + int n2 = args.n - n1; + + struct dgemm_args args_1 = args; + args_1.m = m1; + args_1.n = n1; + qmckl_dgemm_rec(args_1); + + // TODO: assuming 'N', 'N' here + struct dgemm_args args_2 = args; + args_2.B = args.B + args.ldb*n1; + args_2.C = args.C + args.ldc*n1; + args_2.m = m1; + args_2.n = n2; + qmckl_dgemm_rec(args_2); + + struct dgemm_args args_3 = args; + args_3.A = args.A + m1; + args_3.C = args.C + m1; + args_3.m = m2; + args_3.n = n1; + qmckl_dgemm_rec(args_3); + + struct dgemm_args args_4 = args; + args_4.A = args.A + m1; + args_4.B = args.B + args.ldb*n1; + args_4.C = args.C + m1 + args.ldc*n1; + args_4.m = m2; + args_4.n = n2; + qmckl_dgemm_rec(args_4); + } } void qmckl_dgemm(char transa, char transb, int m, int n, int k, - double alpha, + double alpha, double* A, int lda, double* B, int ldb, double beta, double* C, int ldc) { - struct dgemm_args args; + struct dgemm_args args; args.alpha = alpha; args.beta = beta ; @@ -51,16 +92,16 @@ void qmckl_dgemm(char transa, char transb, args.ldc = ldc; if (transa == 'T' || transa == 't') { - args.transa = CblasTrans; + args.transa = CblasTrans; } else { - args.transa = CblasNoTrans; + args.transa = CblasNoTrans; } CBLAS_LAYOUT tb; if (transa == 'T' || transa == 't') { - args.transb = CblasTrans; + args.transb = CblasTrans; } else { - args.transb = CblasNoTrans; + args.transb = CblasNoTrans; } qmckl_dgemm_rec(args); diff --git a/qmckl_dgemm.org b/qmckl_dgemm.org index 5f53d81..ee90d5c 100644 --- a/qmckl_dgemm.org +++ b/qmckl_dgemm.org @@ -28,14 +28,14 @@ module qmckl_blas end module qmckl_blas #+END_SRC -* C code +* TODO C code To avoid passing too many arguments to recursive subroutines, we put all the arguments in a struct. #+NAME: dgemm_args #+BEGIN_SRC c struct dgemm_args { - double alpha; + double alpha; double beta; double* A; double* B; @@ -48,24 +48,24 @@ struct dgemm_args { int ldc; CBLAS_LAYOUT transa; CBLAS_LAYOUT transb; -}; - +}; + #+END_SRC The driver routine packs the arguments in the struct and calls the recursive routine. #+NAME: dgemm - #+BEGIN_SRC c + #+BEGIN_SRC c void qmckl_dgemm(char transa, char transb, int m, int n, int k, - double alpha, + double alpha, double* A, int lda, double* B, int ldb, double beta, double* C, int ldc) { - struct dgemm_args args; + struct dgemm_args args; args.alpha = alpha; args.beta = beta ; @@ -80,16 +80,16 @@ void qmckl_dgemm(char transa, char transb, args.ldc = ldc; if (transa == 'T' || transa == 't') { - args.transa = CblasTrans; + args.transa = CblasTrans; } else { - args.transa = CblasNoTrans; + args.transa = CblasNoTrans; } CBLAS_LAYOUT tb; if (transa == 'T' || transa == 't') { - args.transb = CblasTrans; + args.transb = CblasTrans; } else { - args.transb = CblasNoTrans; + args.transb = CblasNoTrans; } qmckl_dgemm_rec(args); @@ -99,13 +99,54 @@ void qmckl_dgemm(char transa, char transb, #+NAME: dgemm_rec - #+BEGIN_SRC c + #+BEGIN_SRC c +#define MIN_SIZE 512 +#include + static void qmckl_dgemm_rec(struct dgemm_args args) { - cblas_dgemm(CblasColMajor, args.transa, args.transb, - args.m, args.n, args.k, args.alpha, - args.A, args.lda, args.B, args.ldb, - args.beta, args.C, args.ldc); +// printf("%5d %5d\n", args.m, args.n); + + if ( (args.m <= MIN_SIZE) || (args.n <= MIN_SIZE)) { + cblas_dgemm(CblasColMajor, args.transa, args.transb, + args.m, args.n, args.k, args.alpha, + args.A, args.lda, args.B, args.ldb, + args.beta, args.C, args.ldc); + } else { + + int m1 = args.m / 2; + int m2 = args.m - m1; + int n1 = args.n / 2; + int n2 = args.n - n1; + + struct dgemm_args args_1 = args; + args_1.m = m1; + args_1.n = n1; + qmckl_dgemm_rec(args_1); + + // TODO: assuming 'N', 'N' here + struct dgemm_args args_2 = args; + args_2.B = args.B + args.ldb*n1; + args_2.C = args.C + args.ldc*n1; + args_2.m = m1; + args_2.n = n2; + qmckl_dgemm_rec(args_2); + + struct dgemm_args args_3 = args; + args_3.A = args.A + m1; + args_3.C = args.C + m1; + args_3.m = m2; + args_3.n = n1; + qmckl_dgemm_rec(args_3); + + struct dgemm_args args_4 = args; + args_4.A = args.A + m1; + args_4.B = args.B + args.ldb*n1; + args_4.C = args.C + m1 + args.ldc*n1; + args_4.m = m2; + args_4.n = n2; + qmckl_dgemm_rec(args_4); + } } #+END_SRC