1
0
mirror of https://github.com/TREX-CoE/qmckl.git synced 2025-01-03 10:06:09 +01:00

Add qmckl_dgemm_safe proxy function

This commit is contained in:
q-posev 2022-08-22 12:05:28 +02:00
parent 0e8161ca1f
commit e05b589e79

View File

@ -1084,6 +1084,225 @@ qmckl_exit_code test_qmckl_dgemm(qmckl_context context);
assert(QMCKL_SUCCESS == test_qmckl_dgemm(context)); assert(QMCKL_SUCCESS == test_qmckl_dgemm(context));
#+end_src #+end_src
** ~qmckl_dgemm_safe~
"Size-safe" is a proxy function with the same functionality as ~qmckl_dgemm~
but with 3 additional arguments. These arguments ~size_max_M~ (where ~M~ is a matix)
are required primarily for the Python API, where compatibility with
NumPy arrays implies that sizes of the input and output arrays are provided.
#+NAME: qmckl_dgemm_safe_args
| Variable | Type | In/Out | Description |
|--------------+-----------------+--------+---------------------------------------|
| ~context~ | ~qmckl_context~ | in | Global state |
| ~TransA~ | ~char~ | in | 'T' is transposed |
| ~TransB~ | ~char~ | in | 'T' is transposed |
| ~m~ | ~int64_t~ | in | Number of rows of the input matrix |
| ~n~ | ~int64_t~ | in | Number of columns of the input matrix |
| ~k~ | ~int64_t~ | in | Number of columns of the input matrix |
| ~alpha~ | ~double~ | in | \alpha |
| ~A~ | ~double[][lda]~ | in | Array containing the matrix $A$ |
| ~size_max_A~ | ~int64_t~ | in | Size of the matrix $A$ |
| ~lda~ | ~int64_t~ | in | Leading dimension of array ~A~ |
| ~B~ | ~double[][ldb]~ | in | Array containing the matrix $B$ |
| ~size_max_B~ | ~int64_t~ | in | Size of the matrix $B$ |
| ~ldb~ | ~int64_t~ | in | Leading dimension of array ~B~ |
| ~beta~ | ~double~ | in | \beta |
| ~C~ | ~double[][ldc]~ | out | Array containing the matrix $C$ |
| ~size_max_C~ | ~int64_t~ | in | Size of the matrix $C$ |
| ~ldc~ | ~int64_t~ | in | Leading dimension of array ~C~ |
Requirements:
- ~context~ is not ~QMCKL_NULL_CONTEXT~
- ~m > 0~
- ~n > 0~
- ~k > 0~
- ~lda >= m~
- ~ldb >= n~
- ~ldc >= n~
- ~A~ is allocated with at least $m \times k \times 8$ bytes
- ~B~ is allocated with at least $k \times n \times 8$ bytes
- ~C~ is allocated with at least $m \times n \times 8$ bytes
- ~size_max_A >= m * k~
- ~size_max_B >= k * n~
- ~size_max_C >= m * n~
#+CALL: generate_c_header(table=qmckl_dgemm_safe_args,rettyp="qmckl_exit_code",fname="qmckl_dgemm_safe")
#+RESULTS:
#+BEGIN_src c :tangle (eval h_func) :comments org
qmckl_exit_code qmckl_dgemm_safe (
const qmckl_context context,
const char TransA,
const char TransB,
const int64_t m,
const int64_t n,
const int64_t k,
const double alpha,
const double* A,
const int64_t size_max_A,
const int64_t lda,
const double* B,
const int64_t size_max_B,
const int64_t ldb,
const double beta,
double* const C,
const int64_t size_max_C,
const int64_t ldc );
#+END_src
#+begin_src f90 :tangle (eval f) :exports none
integer function qmckl_dgemm_safe_f(context, TransA, TransB, &
m, n, k, alpha, A, size_A, LDA, B, size_B, LDB, beta, C, size_C, LDC) &
result(info)
use qmckl
implicit none
integer(qmckl_context), intent(in) :: context
character , intent(in) :: TransA, TransB
integer*8 , intent(in) :: m, n, k
double precision , intent(in) :: alpha, beta
integer*8 , intent(in) :: lda
integer*8 , intent(in) :: size_A
double precision , intent(in) :: A(lda,*)
integer*8 , intent(in) :: ldb
integer*8 , intent(in) :: size_B
double precision , intent(in) :: B(ldb,*)
integer*8 , intent(in) :: ldc
integer*8 , intent(in) :: size_C
double precision , intent(out) :: C(ldc,*)
info = QMCKL_SUCCESS
if (context == QMCKL_NULL_CONTEXT) then
info = QMCKL_INVALID_CONTEXT
return
endif
if (m <= 0_8) then
info = QMCKL_INVALID_ARG_4
return
endif
if (n <= 0_8) then
info = QMCKL_INVALID_ARG_5
return
endif
if (k <= 0_8) then
info = QMCKL_INVALID_ARG_6
return
endif
if (LDA <= 0) then
info = QMCKL_INVALID_ARG_10
return
endif
if (LDB <= 0) then
info = QMCKL_INVALID_ARG_13
return
endif
if (LDC <= 0) then
info = QMCKL_INVALID_ARG_17
return
endif
if (size_A <= 0) then
info = QMCKL_INVALID_ARG_9
return
endif
if (size_B <= 0) then
info = QMCKL_INVALID_ARG_12
return
endif
if (size_C <= 0) then
info = QMCKL_INVALID_ARG_16
return
endif
call dgemm(transA, transB, int(m,4), int(n,4), int(k,4), &
alpha, A, int(LDA,4), B, int(LDB,4), beta, C, int(LDC,4))
end function qmckl_dgemm_safe_f
#+end_src
*** C interface :noexport:
#+CALL: generate_c_interface(table=qmckl_dgemm_safe_args,rettyp="qmckl_exit_code",fname="qmckl_dgemm_safe")
#+RESULTS:
#+begin_src f90 :tangle (eval f) :comments org :exports none
integer(c_int32_t) function qmckl_dgemm_safe &
(context, TransA, TransB, m, n, k, alpha, A, size_A, lda, B, size_B, ldb, beta, C, size_C, ldc) &
bind(C) result(info)
use, intrinsic :: iso_c_binding
implicit none
integer (c_int64_t) , intent(in) , value :: context
character , intent(in) , value :: TransA
character , intent(in) , value :: TransB
integer (c_int64_t) , intent(in) , value :: m
integer (c_int64_t) , intent(in) , value :: n
integer (c_int64_t) , intent(in) , value :: k
real (c_double ) , intent(in) , value :: alpha
real (c_double ) , intent(in) :: A(lda,*)
integer (c_int64_t) , intent(in) , value :: size_A
integer (c_int64_t) , intent(in) , value :: lda
real (c_double ) , intent(in) :: B(ldb,*)
integer (c_int64_t) , intent(in) , value :: size_B
integer (c_int64_t) , intent(in) , value :: ldb
real (c_double ) , intent(in) , value :: beta
real (c_double ) , intent(out) :: C(ldc,*)
integer (c_int64_t) , intent(in) , value :: size_C
integer (c_int64_t) , intent(in) , value :: ldc
integer(c_int32_t), external :: qmckl_dgemm_safe_f
info = qmckl_dgemm_safe_f &
(context, TransA, TransB, m, n, k, alpha, A, size_A, lda, B, size_B, ldb, beta, C, size_C, ldc)
end function qmckl_dgemm_safe
#+end_src
#+CALL: generate_f_interface(table=qmckl_dgemm_safe_args,rettyp="qmckl_exit_code",fname="qmckl_dgemm_safe")
#+RESULTS:
#+begin_src f90 :tangle (eval fh_func) :comments org :exports none
interface
integer(c_int32_t) function qmckl_dgemm_safe &
(context, TransA, TransB, m, n, k, alpha, A, size_A, lda, B, size_B, ldb, beta, C, size_C, ldc) &
bind(C)
use, intrinsic :: iso_c_binding
import
implicit none
integer (c_int64_t) , intent(in) , value :: context
character , intent(in) , value :: TransA
character , intent(in) , value :: TransB
integer (c_int64_t) , intent(in) , value :: m
integer (c_int64_t) , intent(in) , value :: n
integer (c_int64_t) , intent(in) , value :: k
real (c_double ) , intent(in) , value :: alpha
real (c_double ) , intent(in) :: A(lda,*)
integer (c_int64_t) , intent(in) , value :: size_A
integer (c_int64_t) , intent(in) , value :: lda
real (c_double ) , intent(in) :: B(ldb,*)
integer (c_int64_t) , intent(in) , value :: size_B
integer (c_int64_t) , intent(in) , value :: ldb
real (c_double ) , intent(in) , value :: beta
real (c_double ) , intent(out) :: C(ldc,*)
integer (c_int64_t) , intent(in) , value :: size_C
integer (c_int64_t) , intent(in) , value :: ldc
end function qmckl_dgemm_safe
end interface
#+end_src
** ~qmckl_matmul~ ** ~qmckl_matmul~
Matrix multiplication using the =qmckl_matrix= data type: Matrix multiplication using the =qmckl_matrix= data type: