diff --git a/Utils/FourIdxStorage.ml b/Utils/FourIdxStorage.ml index cea66ad..191f702 100644 --- a/Utils/FourIdxStorage.ml +++ b/Utils/FourIdxStorage.ml @@ -44,6 +44,13 @@ let dense_index i j size = (j-1)*size + i +let sym_index i j = + if i < j then + (j*(j-1))/2 + i + else + (i*(i-1))/2 + j + + let unsafe_get_four_index ~r1 ~r2 t = let open Bigarray.Array2 in @@ -78,7 +85,7 @@ let unsafe_get_four_index ~r1 ~r2 t = else match t.four_index with - | Dense a -> unsafe_get a (dense_index i j t.size) (dense_index k l t.size) + | Dense a -> unsafe_get a (dense_index i k t.size) (sym_index j l) | Sparse a -> let key = key_of_indices ~r1 ~r2 in try Hashtbl.find a key with Not_found -> 0. @@ -152,24 +159,18 @@ let unsafe_set_four_index ~r1 ~r2 ~value t = else match t.four_index with - | Dense a -> let ij = (dense_index i j t.size) - and kl = (dense_index k l t.size) - and il = (dense_index i l t.size) - and kj = (dense_index k j t.size) - and ji = (dense_index j i t.size) - and lk = (dense_index l k t.size) - and li = (dense_index l i t.size) - and jk = (dense_index j k t.size) + | Dense a -> let ik = (dense_index i k t.size) + and jl = (dense_index j l t.size) + and ki = (dense_index k i t.size) + and lj = (dense_index l j t.size) + and ik_s = (sym_index i k) + and jl_s = (sym_index j l) in begin - unsafe_set a ij kl value; - unsafe_set a kj il value; - unsafe_set a il kj value; - unsafe_set a kl ij value; - unsafe_set a ji lk value; - unsafe_set a li jk value; - unsafe_set a jk li value; - unsafe_set a lk ji value + unsafe_set a ik jl_s value; + unsafe_set a ki jl_s value; + unsafe_set a jl ik_s value; + unsafe_set a lj ik_s value; end | Sparse a -> let key = key_of_indices ~r1 ~r2 in Hashtbl.replace a key value @@ -233,7 +234,7 @@ let create ~size ?(temp_dir="/dev/shm") sparsity = let four_index = match sparsity with | `Dense -> let result = - SharedMemory.create ~temp_dir Float64 [| size*size ; size*size |] + SharedMemory.create ~temp_dir Float64 [| size*size ; (size*(size+1))/2 |] |> Bigarray.array2_of_genarray in Dense result