diff --git a/CI/CI.ml b/CI/CI.ml index 0f425b2..0bd935f 100644 --- a/CI/CI.ml +++ b/CI/CI.ml @@ -490,7 +490,7 @@ let make ?(n_states=1) ?(algo=`Direct) det_space = in let matrix_prod psi = let result = - Matrix.parallel_mm ~transa:`T psi m_H + Matrix.parallel_mm ~transa:`T ~transb:`T psi m_H |> Matrix.transpose in Parallel.broadcast (lazy result) diff --git a/Utils/Matrix.ml b/Utils/Matrix.ml index 53f6c85..0dbfad4 100644 --- a/Utils/Matrix.ml +++ b/Utils/Matrix.ml @@ -201,7 +201,8 @@ let rec mm ?(transa=`N) ?(transb=`N) ?(threshold=epsilon) a b = | `N, `T -> dim2, dim2 in if f a <> f' b then - invalid_arg "Inconsistent dimensions"; + Printf.sprintf "%d %d : Inconsistent dimensions" (f a) (f' b) + |> invalid_arg; (* Dense x sparse *) let mmsp transa transb a b = @@ -560,16 +561,21 @@ let parallel_mm ?(transa=`N) ?(transb=`N) ?(threshold=epsilon) a b = | `N -> dim2 a | `T -> dim1 a in - let n = n / (Parallel.size * 4) in + let n = n / (Parallel.size * 7) in + let b = + match transb with + | `T -> transpose b + | `N -> b + in split_cols n b |> Stream.of_list |> Farm.run ~ordered:true ~f:(fun b -> match a, b with | Computed _, Computed _ -> - mm ~transa ~transb ~threshold a b + mm ~transa ~threshold a b |> sparse_of_computed ~threshold | _ -> - mm ~transa ~transb ~threshold a b + mm ~transa ~threshold a b ) |> Util.stream_to_list |> join_cols