mirror of
https://github.com/TREX-CoE/irpjast.git
synced 2025-01-07 03:43:28 +01:00
Recursive dgemm OK
This commit is contained in:
parent
a3dbb458fe
commit
b3f287b8fb
6
Makefile
6
Makefile
@ -13,9 +13,11 @@ LIB= -mkl=sequential
|
|||||||
-include irpf90.make
|
-include irpf90.make
|
||||||
export
|
export
|
||||||
|
|
||||||
irpf90.make: qmckl_blas_f.o
|
#irpf90.make: IRPF90_temp/qmckl_blas_f.o
|
||||||
irpf90.make: $(filter-out IRPF90_temp/%, $(wildcard */*.irp.f)) $(wildcard *.irp.f) $(wildcard *.inc.f) Makefile
|
irpf90.make: $(filter-out IRPF90_temp/%, $(wildcard */*.irp.f)) $(wildcard *.irp.f) $(wildcard *.inc.f) Makefile
|
||||||
$(IRPF90)
|
$(IRPF90)
|
||||||
|
|
||||||
|
IRPF90_temp/%.c: %.c
|
||||||
|
|
||||||
IRPF90_temp/%.o: %.c
|
IRPF90_temp/%.o: %.c
|
||||||
$(CC) -c $< -o $@
|
$(CC) -g -c $< -o $@
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
BEGIN_PROVIDER [ double precision, tmp_c, (nelec_8,nnuc,0:ncord,0:ncord-1) ]
|
BEGIN_PROVIDER [ double precision, tmp_c, (nelec_8,nnuc,0:ncord,0:ncord-1) ]
|
||||||
&BEGIN_PROVIDER [ double precision, dtmp_c, (nelec_8,4,nnuc,0:ncord,0:ncord-1) ]
|
&BEGIN_PROVIDER [ double precision, dtmp_c, (nelec_8,4,nnuc,0:ncord,0:ncord-1) ]
|
||||||
|
use qmckl_blas
|
||||||
implicit none
|
implicit none
|
||||||
BEGIN_DOC
|
BEGIN_DOC
|
||||||
! Calculate the intermediate buffers
|
! Calculate the intermediate buffers
|
||||||
|
@ -19,12 +19,53 @@ struct dgemm_args {
|
|||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
#define MIN_SIZE 512
|
||||||
|
#include<stdio.h>
|
||||||
|
|
||||||
static void qmckl_dgemm_rec(struct dgemm_args args) {
|
static void qmckl_dgemm_rec(struct dgemm_args args) {
|
||||||
|
|
||||||
|
// printf("%5d %5d\n", args.m, args.n);
|
||||||
|
|
||||||
|
if ( (args.m <= MIN_SIZE) || (args.n <= MIN_SIZE)) {
|
||||||
cblas_dgemm(CblasColMajor, args.transa, args.transb,
|
cblas_dgemm(CblasColMajor, args.transa, args.transb,
|
||||||
args.m, args.n, args.k, args.alpha,
|
args.m, args.n, args.k, args.alpha,
|
||||||
args.A, args.lda, args.B, args.ldb,
|
args.A, args.lda, args.B, args.ldb,
|
||||||
args.beta, args.C, args.ldc);
|
args.beta, args.C, args.ldc);
|
||||||
|
} else {
|
||||||
|
|
||||||
|
int m1 = args.m / 2;
|
||||||
|
int m2 = args.m - m1;
|
||||||
|
int n1 = args.n / 2;
|
||||||
|
int n2 = args.n - n1;
|
||||||
|
|
||||||
|
struct dgemm_args args_1 = args;
|
||||||
|
args_1.m = m1;
|
||||||
|
args_1.n = n1;
|
||||||
|
qmckl_dgemm_rec(args_1);
|
||||||
|
|
||||||
|
// TODO: assuming 'N', 'N' here
|
||||||
|
struct dgemm_args args_2 = args;
|
||||||
|
args_2.B = args.B + args.ldb*n1;
|
||||||
|
args_2.C = args.C + args.ldc*n1;
|
||||||
|
args_2.m = m1;
|
||||||
|
args_2.n = n2;
|
||||||
|
qmckl_dgemm_rec(args_2);
|
||||||
|
|
||||||
|
struct dgemm_args args_3 = args;
|
||||||
|
args_3.A = args.A + m1;
|
||||||
|
args_3.C = args.C + m1;
|
||||||
|
args_3.m = m2;
|
||||||
|
args_3.n = n1;
|
||||||
|
qmckl_dgemm_rec(args_3);
|
||||||
|
|
||||||
|
struct dgemm_args args_4 = args;
|
||||||
|
args_4.A = args.A + m1;
|
||||||
|
args_4.B = args.B + args.ldb*n1;
|
||||||
|
args_4.C = args.C + m1 + args.ldc*n1;
|
||||||
|
args_4.m = m2;
|
||||||
|
args_4.n = n2;
|
||||||
|
qmckl_dgemm_rec(args_4);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ module qmckl_blas
|
|||||||
end module qmckl_blas
|
end module qmckl_blas
|
||||||
#+END_SRC
|
#+END_SRC
|
||||||
|
|
||||||
* C code
|
* TODO C code
|
||||||
To avoid passing too many arguments to recursive subroutines, we put
|
To avoid passing too many arguments to recursive subroutines, we put
|
||||||
all the arguments in a struct.
|
all the arguments in a struct.
|
||||||
|
|
||||||
@ -100,12 +100,53 @@ void qmckl_dgemm(char transa, char transb,
|
|||||||
|
|
||||||
#+NAME: dgemm_rec
|
#+NAME: dgemm_rec
|
||||||
#+BEGIN_SRC c
|
#+BEGIN_SRC c
|
||||||
|
#define MIN_SIZE 512
|
||||||
|
#include<stdio.h>
|
||||||
|
|
||||||
static void qmckl_dgemm_rec(struct dgemm_args args) {
|
static void qmckl_dgemm_rec(struct dgemm_args args) {
|
||||||
|
|
||||||
|
// printf("%5d %5d\n", args.m, args.n);
|
||||||
|
|
||||||
|
if ( (args.m <= MIN_SIZE) || (args.n <= MIN_SIZE)) {
|
||||||
cblas_dgemm(CblasColMajor, args.transa, args.transb,
|
cblas_dgemm(CblasColMajor, args.transa, args.transb,
|
||||||
args.m, args.n, args.k, args.alpha,
|
args.m, args.n, args.k, args.alpha,
|
||||||
args.A, args.lda, args.B, args.ldb,
|
args.A, args.lda, args.B, args.ldb,
|
||||||
args.beta, args.C, args.ldc);
|
args.beta, args.C, args.ldc);
|
||||||
|
} else {
|
||||||
|
|
||||||
|
int m1 = args.m / 2;
|
||||||
|
int m2 = args.m - m1;
|
||||||
|
int n1 = args.n / 2;
|
||||||
|
int n2 = args.n - n1;
|
||||||
|
|
||||||
|
struct dgemm_args args_1 = args;
|
||||||
|
args_1.m = m1;
|
||||||
|
args_1.n = n1;
|
||||||
|
qmckl_dgemm_rec(args_1);
|
||||||
|
|
||||||
|
// TODO: assuming 'N', 'N' here
|
||||||
|
struct dgemm_args args_2 = args;
|
||||||
|
args_2.B = args.B + args.ldb*n1;
|
||||||
|
args_2.C = args.C + args.ldc*n1;
|
||||||
|
args_2.m = m1;
|
||||||
|
args_2.n = n2;
|
||||||
|
qmckl_dgemm_rec(args_2);
|
||||||
|
|
||||||
|
struct dgemm_args args_3 = args;
|
||||||
|
args_3.A = args.A + m1;
|
||||||
|
args_3.C = args.C + m1;
|
||||||
|
args_3.m = m2;
|
||||||
|
args_3.n = n1;
|
||||||
|
qmckl_dgemm_rec(args_3);
|
||||||
|
|
||||||
|
struct dgemm_args args_4 = args;
|
||||||
|
args_4.A = args.A + m1;
|
||||||
|
args_4.B = args.B + args.ldb*n1;
|
||||||
|
args_4.C = args.C + m1 + args.ldc*n1;
|
||||||
|
args_4.m = m2;
|
||||||
|
args_4.n = n2;
|
||||||
|
qmckl_dgemm_rec(args_4);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
#+END_SRC
|
#+END_SRC
|
||||||
|
Loading…
Reference in New Issue
Block a user