mirror of
https://github.com/triqs/dft_tools
synced 2025-01-12 05:58:18 +01:00
arrays: h5 read/write for arrays of complex types
- array of complex type (not fundamental) can now be saved/loaded to h5 - with a test with array<gf<...>>
This commit is contained in:
parent
7aedaef945
commit
1c9d6dacfa
40
test/triqs/gfs/array_gf.cpp
Normal file
40
test/triqs/gfs/array_gf.cpp
Normal file
@ -0,0 +1,40 @@
|
||||
#include <triqs/gfs.hpp>
|
||||
#define TEST(X) std::cout << BOOST_PP_STRINGIZE((X)) << " ---> " << (X) << std::endl << std::endl;
|
||||
|
||||
using namespace triqs::gfs;
|
||||
using namespace triqs::arrays;
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
|
||||
try {
|
||||
|
||||
triqs::clef::placeholder<0> w_;
|
||||
auto agf = array<gf<imfreq>, 2>{2, 3};
|
||||
auto bgf = array<gf<imfreq>, 2>{2, 3};
|
||||
agf() = gf<imfreq>{{10.0, Fermion}, {1, 1}};
|
||||
agf(0, 0)(w_) << 1 / (w_ + 2);
|
||||
|
||||
array<double,2> A(2,2); A()=0;A(0,0) = 1.3; A(1,1) = -8.2;
|
||||
array<array<double,2>, 1> aa(2), bb;
|
||||
aa(0) = A;
|
||||
aa(1) = 2 * A;
|
||||
bb = aa;
|
||||
|
||||
{
|
||||
H5::H5File file("ess_array_gf.h5", H5F_ACC_TRUNC);
|
||||
h5_write(file, "Agf", agf);
|
||||
h5_write(file, "aa", aa);
|
||||
}
|
||||
{
|
||||
H5::H5File file("ess_array_gf.h5", H5F_ACC_RDONLY);
|
||||
h5_read(file, "Agf", bgf);
|
||||
h5_read(file, "aa", bb);
|
||||
}
|
||||
{
|
||||
H5::H5File file("ess_array_gf2.h5", H5F_ACC_TRUNC);
|
||||
h5_write(file, "Agf", bgf);
|
||||
h5_write(file, "aa", bb);
|
||||
}
|
||||
}
|
||||
TRIQS_CATCH_AND_ABORT;
|
||||
}
|
@ -92,7 +92,7 @@ int main() {
|
||||
C = extract<decltype(C)>(P["B"]);
|
||||
std::cout << "C" << C << std::endl;
|
||||
|
||||
array<array<int,2>, 1> aa(3);
|
||||
array<array<int,2>, 1> aa(2);
|
||||
aa(0) = A; aa(1) = 2*A;
|
||||
P["aa"] = aa;
|
||||
|
||||
@ -143,6 +143,7 @@ int main() {
|
||||
|
||||
parameters P4;
|
||||
parameters::register_type<triqs::arrays::array<double,1>>();
|
||||
//parameters::register_type<triqs::arrays::array<int,2>>();
|
||||
std::cout << "P4 before : "<< P4<< std::endl ;
|
||||
{
|
||||
H5::H5File file( "ess.h5", H5F_ACC_RDONLY );
|
||||
|
@ -147,7 +147,7 @@ namespace arrays {
|
||||
V = res;
|
||||
}
|
||||
|
||||
} // namespace h5impl
|
||||
} // namespace h5_impl
|
||||
|
||||
// a trait to detect if A::value_type exists and is a scalar or a string
|
||||
// used to exclude array<array<..>>
|
||||
@ -192,8 +192,89 @@ namespace arrays {
|
||||
template <typename ArrayType>
|
||||
ENABLE_IFC(is_amv_value_or_view_class<ArrayType>::value&& has_scalar_or_string_value_type<ArrayType>::value)
|
||||
h5_write(h5::group g, std::string const& name, ArrayType const& A) {
|
||||
if (A.is_empty()) TRIQS_RUNTIME_ERROR << " Can not save an empty array into hdf5";
|
||||
h5_impl::write_array(g, name, array_const_view<typename ArrayType::value_type, ArrayType::rank>(A));
|
||||
}
|
||||
|
||||
// details for generic save/read of arrays.
|
||||
namespace h5_impl {
|
||||
inline std::string _h5_name() { return ""; }
|
||||
|
||||
template <typename T0, typename... Ts> std::string _h5_name(T0 const& t0, Ts const&... ts) {
|
||||
auto r = std::to_string(t0);
|
||||
auto r1 = _h5_name(ts...);
|
||||
if (r1 != "") r += "_" + r1;
|
||||
return r;
|
||||
}
|
||||
|
||||
#ifndef __cpp_generic_lambdas
|
||||
template <typename ArrayType> struct _save_lambda {
|
||||
ArrayType const& a;
|
||||
h5::group g;
|
||||
template <typename... Is> void operator()(Is const&... is) const { h5_write(g, _h5_name(is...), a(is...)); }
|
||||
};
|
||||
|
||||
template <typename ArrayType> struct _load_lambda {
|
||||
ArrayType& a;
|
||||
h5::group g;
|
||||
template <typename... Is> void operator()(Is const&... is) { h5_read(g, _h5_name(is...), a(is...)); }
|
||||
};
|
||||
#endif
|
||||
} // details
|
||||
|
||||
/*
|
||||
* Write an array or a view into an hdf5 file when type is not fundamental
|
||||
* ArrayType The type of the array/matrix/vector, etc..
|
||||
* g The h5 group
|
||||
* name The name of the hdf5 array in the file/group where the stack will be stored
|
||||
* A The array to be stored
|
||||
* The HDF5 exceptions will be caught and rethrown as TRIQS_RUNTIME_ERROR (with a full stackstrace, cf triqs doc).
|
||||
*/
|
||||
template <typename ArrayType>
|
||||
std::c14::enable_if_t<is_amv_value_or_view_class<ArrayType>::value && !has_scalar_or_string_value_type<ArrayType>::value>
|
||||
h5_write(h5::group gr, std::string name, ArrayType const& a) {
|
||||
if (a.is_empty()) TRIQS_RUNTIME_ERROR << " Can not save an empty array into hdf5";
|
||||
auto gr2 = gr.create_group(name);
|
||||
gr2.write_triqs_hdf5_data_scheme(a);
|
||||
// save the shape
|
||||
array<int, 1> sha(ArrayType::rank);
|
||||
for (int u = 0; u < ArrayType::rank; ++u) sha(u) = a.shape()[u];
|
||||
h5_write(gr2, "shape", sha);
|
||||
#ifndef __cpp_generic_lambdas
|
||||
foreach(a, h5_impl::_save_lambda<ArrayType>{a, gr2});
|
||||
#else
|
||||
foreach(a, [&](auto... is) { h5_write(gr2, h5_impl::_h5_name(is...), a(is...)); });
|
||||
#endif
|
||||
}
|
||||
|
||||
/*
|
||||
* Read an array or a view from an hdf5 file when type is not fundamental
|
||||
* ArrayType The type of the array/matrix/vector, etc..
|
||||
* g The h5 group
|
||||
* name The name of the hdf5 array in the file/group where the stack will be stored
|
||||
* A The array to be stored
|
||||
* The HDF5 exceptions will be caught and rethrown as TRIQS_RUNTIME_ERROR (with a full stackstrace, cf triqs doc).
|
||||
*/
|
||||
template <typename ArrayType>
|
||||
std::c14::enable_if_t<is_amv_value_or_view_class<ArrayType>::value && !has_scalar_or_string_value_type<ArrayType>::value>
|
||||
h5_read(h5::group gr, std::string name, ArrayType& a) {
|
||||
static_assert(!std::is_const<ArrayType>::value, "Can not read in const object");
|
||||
auto gr2 = gr.open_group(name);
|
||||
// TODO checking scheme...
|
||||
// load the shape
|
||||
auto sha2 = a.shape();
|
||||
array<int, 1> sha;
|
||||
h5_read(gr2, "shape", sha);
|
||||
if (first_dim(sha) != sha2.size())
|
||||
TRIQS_RUNTIME_ERROR << " array<array<...>> load : rank mismatch. Expected " << sha2.size()<< " Got " << first_dim(sha);
|
||||
for (int u = 0; u < sha2.size(); ++u) sha2[u] = sha(u);
|
||||
if (a.shape() != sha2) a.resize(sha2);
|
||||
#ifndef __cpp_generic_lambdas
|
||||
foreach(a, h5_impl::_load_lambda<ArrayType>{a, gr2});
|
||||
#else
|
||||
foreach(a, [&](auto... is) { h5_read(gr2, h5_impl::_h5_name(is...), a(is...)); });
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
@ -48,7 +48,8 @@ namespace triqs { namespace utility {
|
||||
size_t type_hash_element = type_hash;
|
||||
auto it2= _object::code_element_rank_to_code_array.find(std::make_pair(type_hash_element,rank));
|
||||
if (it2 == _object::code_element_rank_to_code_array.end())
|
||||
TRIQS_RUNTIME_ERROR << " code_element_rank_to_code_array : type not found" << rank;
|
||||
TRIQS_RUNTIME_ERROR << " code_element_rank_to_code_array : type not found : " << name << " " << type_hash_element << " "
|
||||
<< rank;
|
||||
type_hash = it2->second;
|
||||
}
|
||||
}
|
||||
|
@ -41,7 +41,7 @@ namespace triqs { namespace utility {
|
||||
public :
|
||||
|
||||
typedef T value_type;
|
||||
|
||||
|
||||
mini_vector(){init();}
|
||||
|
||||
#define AUX(z,p,unused) _data[p] = x_##p;
|
||||
@ -69,6 +69,8 @@ namespace triqs { namespace utility {
|
||||
|
||||
friend void swap(mini_vector & a, mini_vector & b) { std::swap(a._data, b._data);}
|
||||
|
||||
int size() const { return Rank;}
|
||||
|
||||
template<typename T2>
|
||||
mini_vector & operator=(const mini_vector<T2,Rank> & x){ for (int i=0;i<Rank; ++i) _data[i] = x[i]; return *this;}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user