1
0
mirror of https://github.com/TREX-CoE/qmckl.git synced 2024-11-19 20:42:50 +01:00

Improved tensors in qmcalk_blas:

This commit is contained in:
Anthony Scemama 2023-03-16 16:52:01 +01:00
parent c0131d5da4
commit e10c7584ff

View File

@ -52,7 +52,10 @@
#include "config.h"
#endif
#include "qmckl_memory_private_type.h"
#include "qmckl_blas_private_type.h"
#include "qmckl_memory_private_func.h"
#include "qmckl_blas_private_func.h"
int main() {
@ -406,6 +409,7 @@ qmckl_tensor_of_vector(const qmckl_vector vector,
}
assert (prod_size == vector.size);
result.order = order;
result.data = vector.data;
return result;
@ -604,7 +608,7 @@ qmckl_tensor
qmckl_tensor_set(qmckl_tensor tensor, double value)
{
qmckl_vector vector = qmckl_vector_of_tensor(tensor);
for (int32_t i=0 ; i<vector.size ; ++i) {
for (int64_t i=0 ; i<vector.size ; ++i) {
qmckl_vec(vector, i) = value;
}
return qmckl_tensor_of_vector(vector, tensor.order, tensor.size);
@ -788,6 +792,74 @@ qmckl_tensor_of_double(const qmckl_context context,
}
#+end_src
** Allocate and copy to ~double*~
#+begin_src c :comments org :tangle (eval h_private_func)
double* qmckl_alloc_double_of_vector(const qmckl_context context,
const qmckl_vector vector);
#+end_src
#+begin_src c :comments org :tangle (eval c) :exports none
double* qmckl_alloc_double_of_vector(const qmckl_context context,
const qmckl_vector vector)
{
/* Always true by construction */
assert (qmckl_context_check(context) != QMCKL_NULL_CONTEXT);
assert (vector.size > (int64_t) 0);
qmckl_memory_info_struct mem_info = qmckl_memory_info_struct_zero;
mem_info.size = vector.size * sizeof(double);
double* target = (double*) qmckl_malloc(context, mem_info);
if (target == NULL) {
return NULL;
}
qmckl_exit_code rc;
rc = qmckl_double_of_vector(context, vector, target, vector.size);
assert (rc == QMCKL_SUCCESS);
if (rc != QMCKL_SUCCESS) {
rc = qmckl_free(context, target);
target = NULL;
}
return target;
}
#+end_src
#+begin_src c :comments org :tangle (eval h_private_func)
double* qmckl_alloc_double_of_matrix(const qmckl_context context,
const qmckl_matrix matrix);
#+end_src
#+begin_src c :comments org :tangle (eval c) :exports none
double* qmckl_alloc_double_of_matrix(const qmckl_context context,
const qmckl_matrix matrix)
{
qmckl_vector vector = qmckl_vector_of_matrix(matrix);
return qmckl_alloc_double_of_vector(context, vector);
}
#+end_src
#+begin_src c :comments org :tangle (eval h_private_func)
double* qmckl_alloc_double_of_tensor(const qmckl_context context,
const qmckl_tensor tensor);
#+end_src
#+begin_src c :comments org :tangle (eval c) :exports none
double* qmckl_alloc_double_of_tensor(const qmckl_context context,
const qmckl_tensor tensor)
{
qmckl_vector vector = qmckl_vector_of_tensor(tensor);
return qmckl_alloc_double_of_vector(context, vector);
}
#+end_src
** Tests :noexport:
#+begin_src c :comments link :tangle (eval c_test) :exports none
@ -803,6 +875,8 @@ qmckl_tensor_of_double(const qmckl_context context,
for (int64_t i=0 ; i<p ; ++i)
assert( vec.data[i] == (double) i );
printf("qmckl_vector ok\n");
qmckl_matrix mat = qmckl_matrix_of_vector(vec, m, n);
assert (mat.size[0] == m);
assert (mat.size[1] == n);
@ -812,13 +886,28 @@ qmckl_tensor_of_double(const qmckl_context context,
for (int64_t i=0 ; i<m ; ++i)
assert ( qmckl_mat(mat, i, j) == qmckl_vec(vec, i+j*m)) ;
printf("qmckl_matrix_of_vector ok\n");
qmckl_vector vec2 = qmckl_vector_of_matrix(mat);
assert (vec2.size == p);
assert (vec2.data == vec.data);
for (int64_t i=0 ; i<p ; ++i)
assert ( qmckl_vec(vec2, i) == qmckl_vec(vec, i) ) ;
printf("qmckl_vector_of_matrix ok\n");
double* dbl = qmckl_alloc_double_of_matrix(context, mat);
for (int64_t i=0 ; i<p ; ++i)
assert ( dbl[i] == qmckl_vec(vec, i) ) ;
printf("qmckl_double_of_matrix ok\n");
qmckl_exit_code rc = qmckl_free(context, dbl);
assert (rc == QMCKL_SUCCESS);
printf("qmckl_free ok\n");
qmckl_vector_free(context, &vec);
printf("qmckl_vector_free ok\n");
}
#+end_src
@ -981,7 +1070,7 @@ integer function qmckl_dgemm_f(context, TransA, TransB, &
do j=1,n
do i=1,m
C(i,j) = alpha*C1(j,i) + beta*C(i,j)
transpose C(i,j) = alpha*C1(j,i) + beta*C(i,j)
end do
end do
@ -1105,7 +1194,8 @@ integer(qmckl_exit_code) function test_qmckl_dgemm(context) bind(C)
end do
end do
test_qmckl_dgemm = qmckl_dgemm(context, TransA, TransB, m, n, k, alpha, A, LDA, B, LDB, beta, C, LDC)
test_qmckl_dgemm = qmckl_dgemm(context, TransA, TransB, m, n, k, &
alpha, A, LDA, B, LDB, beta, C, LDC)
if (test_qmckl_dgemm /= QMCKL_SUCCESS) return
@ -1133,6 +1223,7 @@ end function test_qmckl_dgemm
#+begin_src c :comments link :tangle (eval c_test) :exports none
qmckl_exit_code test_qmckl_dgemm(qmckl_context context);
assert(QMCKL_SUCCESS == test_qmckl_dgemm(context));
printf("qmckl_dgemm ok\n");
#+end_src
** ~qmckl_dgemm_safe~
@ -1586,28 +1677,32 @@ print(C.T)
58., 136., 214.,
59., 141., 223. };
double cnew[15];
qmckl_exit_code rc;
qmckl_matrix A = qmckl_matrix_alloc(context, 3, 4);
rc = qmckl_matrix_of_double(context, a, 12, &A);
assert(rc == QMCKL_SUCCESS);
printf("A ok\n");
qmckl_matrix B = qmckl_matrix_alloc(context, 4, 5);
rc = qmckl_matrix_of_double(context, b, 20, &B);
assert(rc == QMCKL_SUCCESS);
printf("B ok\n");
qmckl_matrix C = qmckl_matrix_alloc(context, 3, 5);
rc = qmckl_matmul(context, 'N', 'N', 0.5, A, B, 0., &C);
printf("C ok\n");
assert(rc == QMCKL_SUCCESS);
rc = qmckl_double_of_matrix(context, C, cnew, 15);
double cnew[15];
rc = qmckl_double_of_matrix(context, C, &(cnew[0]), 15);
assert(rc == QMCKL_SUCCESS);
printf("cnew ok\n");
for (int i=0 ; i<15 ; ++i) {
printf("%f %f\n", cnew[i], c[i]);
assert (c[i] == cnew[i]);
}
printf("qmckl_matmul ok\n");
}
#+end_src
** ~qmckl_adjugate~
@ -2491,6 +2586,7 @@ end function test_qmckl_adjugate
#+begin_src c :comments link :tangle (eval c_test)
qmckl_exit_code test_qmckl_adjugate(qmckl_context context);
assert(QMCKL_SUCCESS == test_qmckl_adjugate(context));
printf("qmckl_adjugate ok\n");
#+end_src
** ~qmckl_adjugate_safe~
@ -2724,6 +2820,7 @@ qmckl_transpose (qmckl_context context,
for (int i=0 ; i<2 ; ++i)
assert (qmckl_mat(A, i, j) == qmckl_mat(At, j, i));
printf("qmckl_transpose ok\n");
qmckl_matrix_free(context, &A);
qmckl_matrix_free(context, &At);
}