1
0
mirror of https://github.com/TREX-CoE/irpjast.git synced 2025-01-05 19:09:06 +01:00

Recursive dgemm OK

This commit is contained in:
Anthony Scemama 2021-04-23 14:18:59 +02:00
parent a3dbb458fe
commit b3f287b8fb
4 changed files with 116 additions and 31 deletions

View File

@ -13,9 +13,11 @@ LIB= -mkl=sequential
-include irpf90.make -include irpf90.make
export export
irpf90.make: qmckl_blas_f.o #irpf90.make: IRPF90_temp/qmckl_blas_f.o
irpf90.make: $(filter-out IRPF90_temp/%, $(wildcard */*.irp.f)) $(wildcard *.irp.f) $(wildcard *.inc.f) Makefile irpf90.make: $(filter-out IRPF90_temp/%, $(wildcard */*.irp.f)) $(wildcard *.irp.f) $(wildcard *.inc.f) Makefile
$(IRPF90) $(IRPF90)
IRPF90_temp/%.c: %.c
IRPF90_temp/%.o: %.c IRPF90_temp/%.o: %.c
$(CC) -c $< -o $@ $(CC) -g -c $< -o $@

View File

@ -1,5 +1,6 @@
BEGIN_PROVIDER [ double precision, tmp_c, (nelec_8,nnuc,0:ncord,0:ncord-1) ] BEGIN_PROVIDER [ double precision, tmp_c, (nelec_8,nnuc,0:ncord,0:ncord-1) ]
&BEGIN_PROVIDER [ double precision, dtmp_c, (nelec_8,4,nnuc,0:ncord,0:ncord-1) ] &BEGIN_PROVIDER [ double precision, dtmp_c, (nelec_8,4,nnuc,0:ncord,0:ncord-1) ]
use qmckl_blas
implicit none implicit none
BEGIN_DOC BEGIN_DOC
! Calculate the intermediate buffers ! Calculate the intermediate buffers

View File

@ -19,12 +19,53 @@ struct dgemm_args {
}; };
#define MIN_SIZE 512
#include<stdio.h>
static void qmckl_dgemm_rec(struct dgemm_args args) { static void qmckl_dgemm_rec(struct dgemm_args args) {
// printf("%5d %5d\n", args.m, args.n);
if ( (args.m <= MIN_SIZE) || (args.n <= MIN_SIZE)) {
cblas_dgemm(CblasColMajor, args.transa, args.transb, cblas_dgemm(CblasColMajor, args.transa, args.transb,
args.m, args.n, args.k, args.alpha, args.m, args.n, args.k, args.alpha,
args.A, args.lda, args.B, args.ldb, args.A, args.lda, args.B, args.ldb,
args.beta, args.C, args.ldc); args.beta, args.C, args.ldc);
} else {
int m1 = args.m / 2;
int m2 = args.m - m1;
int n1 = args.n / 2;
int n2 = args.n - n1;
struct dgemm_args args_1 = args;
args_1.m = m1;
args_1.n = n1;
qmckl_dgemm_rec(args_1);
// TODO: assuming 'N', 'N' here
struct dgemm_args args_2 = args;
args_2.B = args.B + args.ldb*n1;
args_2.C = args.C + args.ldc*n1;
args_2.m = m1;
args_2.n = n2;
qmckl_dgemm_rec(args_2);
struct dgemm_args args_3 = args;
args_3.A = args.A + m1;
args_3.C = args.C + m1;
args_3.m = m2;
args_3.n = n1;
qmckl_dgemm_rec(args_3);
struct dgemm_args args_4 = args;
args_4.A = args.A + m1;
args_4.B = args.B + args.ldb*n1;
args_4.C = args.C + m1 + args.ldc*n1;
args_4.m = m2;
args_4.n = n2;
qmckl_dgemm_rec(args_4);
}
} }

View File

@ -28,7 +28,7 @@ module qmckl_blas
end module qmckl_blas end module qmckl_blas
#+END_SRC #+END_SRC
* C code * TODO C code
To avoid passing too many arguments to recursive subroutines, we put To avoid passing too many arguments to recursive subroutines, we put
all the arguments in a struct. all the arguments in a struct.
@ -100,12 +100,53 @@ void qmckl_dgemm(char transa, char transb,
#+NAME: dgemm_rec #+NAME: dgemm_rec
#+BEGIN_SRC c #+BEGIN_SRC c
#define MIN_SIZE 512
#include<stdio.h>
static void qmckl_dgemm_rec(struct dgemm_args args) { static void qmckl_dgemm_rec(struct dgemm_args args) {
// printf("%5d %5d\n", args.m, args.n);
if ( (args.m <= MIN_SIZE) || (args.n <= MIN_SIZE)) {
cblas_dgemm(CblasColMajor, args.transa, args.transb, cblas_dgemm(CblasColMajor, args.transa, args.transb,
args.m, args.n, args.k, args.alpha, args.m, args.n, args.k, args.alpha,
args.A, args.lda, args.B, args.ldb, args.A, args.lda, args.B, args.ldb,
args.beta, args.C, args.ldc); args.beta, args.C, args.ldc);
} else {
int m1 = args.m / 2;
int m2 = args.m - m1;
int n1 = args.n / 2;
int n2 = args.n - n1;
struct dgemm_args args_1 = args;
args_1.m = m1;
args_1.n = n1;
qmckl_dgemm_rec(args_1);
// TODO: assuming 'N', 'N' here
struct dgemm_args args_2 = args;
args_2.B = args.B + args.ldb*n1;
args_2.C = args.C + args.ldc*n1;
args_2.m = m1;
args_2.n = n2;
qmckl_dgemm_rec(args_2);
struct dgemm_args args_3 = args;
args_3.A = args.A + m1;
args_3.C = args.C + m1;
args_3.m = m2;
args_3.n = n1;
qmckl_dgemm_rec(args_3);
struct dgemm_args args_4 = args;
args_4.A = args.A + m1;
args_4.B = args.B + args.ldb*n1;
args_4.C = args.C + m1 + args.ldc*n1;
args_4.m = m2;
args_4.n = n2;
qmckl_dgemm_rec(args_4);
}
} }
#+END_SRC #+END_SRC