10
1
mirror of https://gitlab.com/scemama/QCaml.git synced 2025-01-18 00:21:40 +01:00

map optimizations

This commit is contained in:
Anthony Scemama 2020-03-26 17:43:11 +01:00
parent 9660eed87f
commit ffcccba188
17 changed files with 103 additions and 87 deletions

View File

@ -72,24 +72,34 @@ let test_case name t =
let check_matrix title a r = let check_matrix title a r =
let a = Mat.to_array a in let a = Mat.to_array a in
Array.iteri (fun i x -> Mat.to_array r
let message = |> Array.iteri (fun i x ->
Printf.sprintf "%s line %d" title i let message =
in Printf.sprintf "%s line %d" title i
Alcotest.(check (array (float 1.e-10))) message a.(i) x in
) (Mat.to_array r) Alcotest.(check (array (float 1.e-10))) message a.(i) x
)
in in
let check_eri title a r = let check_eri title a r =
let f { ERI.i_r1 ; j_r2 ; k_r1 ; l_r2 ; value } = let f { ERI.i_r1 ; j_r2 ; k_r1 ; l_r2 ; value } =
(i_r1, (j_r2, (k_r1, (l_r2, value)))) (i_r1, (j_r2, (k_r1, (l_r2, value))))
in in
let a = ERI.to_list a |> List.map f let a = ERI.to_list a |> List.rev_map f |> List.rev in
and r = ERI.to_list r |> List.map f let r = ERI.to_list r |> List.rev_map f |> List.rev in
in Printf.eprintf "test \n%!";
Alcotest.(check (list (pair int (pair int (pair int (pair int (float 1.e-12))))))) "ERI" a r Alcotest.(check (list (pair int (pair int (pair int (pair int (float 1.e-12))))))) "ERI" a r
in in
let check_eri_lr title a r =
let f { ERI_lr.i_r1 ; j_r2 ; k_r1 ; l_r2 ; value } =
(i_r1, (j_r2, (k_r1, (l_r2, value))))
in
let a = ERI_lr.to_list a |> List.rev_map f |> List.rev in
let r = ERI_lr.to_list r |> List.rev_map f |> List.rev in
Alcotest.(check (list (pair int (pair int (pair int (pair int (float 1.e-12))))))) "ERI_lr" a r
in
let test_overlap () = let test_overlap () =
let reference = let reference =
sym_matrix_of_file ("test_files/"^name^"_overlap.ref") sym_matrix_of_file ("test_files/"^name^"_overlap.ref")
@ -133,13 +143,13 @@ let test_case name t =
let test_ee_lr_ints () = let test_ee_lr_ints () =
let reference = let reference =
ERI.of_file ("test_files/"^name^"_eri_lr.ref") ~sparsity:`Dense ERI_lr.of_file ("test_files/"^name^"_eri_lr.ref") ~sparsity:`Dense
~size:(Basis.size t.basis) ~size:(Basis.size t.basis)
in in
let ee_ints = let ee_lr_ints =
Lazy.force t.ee_ints Lazy.force t.ee_lr_ints
in in
check_eri "ee_lr_ints" ee_ints reference check_eri_lr "ee_lr_ints" ee_lr_ints reference
in in
[ [

View File

@ -26,7 +26,7 @@ let make ?(cutoff=Constants.epsilon) atomic_shell_a atomic_shell_b =
in in
let contracted_shell_pairs = let contracted_shell_pairs =
List.map (fun s_a -> List.concat_map (fun s_a ->
List.map (fun s_b -> List.map (fun s_b ->
if Cs.index s_b <= Cs.index s_a then if Cs.index s_b <= Cs.index s_a then
Csp.make ~cutoff s_a s_b Csp.make ~cutoff s_a s_b
@ -34,7 +34,6 @@ let make ?(cutoff=Constants.epsilon) atomic_shell_a atomic_shell_b =
None None
) l_b ) l_b
) l_a ) l_a
|> List.concat
|> list_some |> list_some
in in
match contracted_shell_pairs with match contracted_shell_pairs with

View File

@ -28,12 +28,11 @@ let make ?(cutoff=Constants.epsilon) atomic_shell_pair_p atomic_shell_pair_q =
and atomic_shell_d = Asp.atomic_shell_b atomic_shell_pair_q and atomic_shell_d = Asp.atomic_shell_b atomic_shell_pair_q
in in
let contracted_shell_pair_couples = let contracted_shell_pair_couples =
List.map (fun ap_ab -> List.concat_map (fun ap_ab ->
List.map (fun ap_cd -> List.map (fun ap_cd ->
ContractedShellPairCouple.make ~cutoff ap_ab ap_cd ContractedShellPairCouple.make ~cutoff ap_ab ap_cd
) (Asp.contracted_shell_pairs atomic_shell_pair_q) ) (Asp.contracted_shell_pairs atomic_shell_pair_q)
) (Asp.contracted_shell_pairs atomic_shell_pair_p) ) (Asp.contracted_shell_pairs atomic_shell_pair_p)
|> List.concat
|> list_some |> list_some
in in
match contracted_shell_pair_couples with match contracted_shell_pair_couples with

View File

@ -29,14 +29,13 @@ let make ?(cutoff=Constants.epsilon) shell_pair_p shell_pair_q =
in in
let cutoff = 1.e-3 *. cutoff in let cutoff = 1.e-3 *. cutoff in
let coefs_and_shell_pair_couples = let coefs_and_shell_pair_couples =
List.map (fun (c_ab, sp_ab) -> List.concat_map (fun (c_ab, sp_ab) ->
List.map (fun (c_cd, sp_cd) -> List.map (fun (c_cd, sp_cd) ->
let coef_prod = c_ab *. c_cd in let coef_prod = c_ab *. c_cd in
if abs_float coef_prod < cutoff then None if abs_float coef_prod < cutoff then None
else Some (coef_prod, Pspc.make sp_ab sp_cd) else Some (coef_prod, Pspc.make sp_ab sp_cd)
) (Csp.coefs_and_shell_pairs shell_pair_q) ) (Csp.coefs_and_shell_pairs shell_pair_q)
) (Csp.coefs_and_shell_pairs shell_pair_p) ) (Csp.coefs_and_shell_pairs shell_pair_p)
|> List.concat
|> list_some |> list_some
in in
match coefs_and_shell_pair_couples with match coefs_and_shell_pair_couples with

View File

@ -24,7 +24,7 @@ module Make(T : TwoEI_structure) = struct
let class_of_contracted_shell_pair_couple = T.class_of_contracted_shell_pair_couple let class_of_contracted_shell_pair_couple = T.class_of_contracted_shell_pair_couple
let filter_contracted_shell_pairs ?(cutoff=integrals_cutoff) shell_pairs = let filter_contracted_shell_pairs ?(cutoff=integrals_cutoff) shell_pairs =
List.map (fun pair -> List.rev_map (fun pair ->
match Cspc.make ~cutoff pair pair with match Cspc.make ~cutoff pair pair with
| Some cspc -> | Some cspc ->
let cls = class_of_contracted_shell_pair_couple cspc in let cls = class_of_contracted_shell_pair_couple cspc in
@ -33,20 +33,20 @@ module Make(T : TwoEI_structure) = struct
| None -> (pair, -1.) | None -> (pair, -1.)
) shell_pairs ) shell_pairs
|> List.filter (fun (_, schwartz_p_max) -> schwartz_p_max >= cutoff) |> List.filter (fun (_, schwartz_p_max) -> schwartz_p_max >= cutoff)
|> List.map fst |> List.rev_map fst
(* TODO (* TODO
let filter_contracted_shell_pair_couples let filter_contracted_shell_pair_couples
?(cutoff=integrals_cutoff) shell_pair_couples = ?(cutoff=integrals_cutoff) shell_pair_couples =
List.map (fun pair -> List.rev_map (fun pair ->
let cls = let cls =
class_of_contracted_shell_pairs pair pair class_of_contracted_shell_pairs pair pair
in in
(pair, Zmap.fold (fun key value accu -> max (abs_float value) accu) cls 0. ) (pair, Zmap.fold (fun key value accu -> max (abs_float value) accu) cls 0. )
) shell_pairs ) shell_pairs
|> List.filter (fun (_, schwartz_p_max) -> schwartz_p_max >= cutoff) |> List.filter (fun (_, schwartz_p_max) -> schwartz_p_max >= cutoff)
|> List.map fst |> List.rev_map fst
*) *)

View File

@ -97,10 +97,12 @@ let create_matrix_arbitrary f det_space =
in in
let singles = let singles =
List.filter (fun (i,d,det_j) -> d < 2) doubles List.filter (fun (i,d,det_j) -> d < 2) doubles
|> List.map (fun (i,_,det_j) -> (i,det_j)) |> List.rev_map (fun (i,_,det_j) -> (i,det_j))
|> List.rev
in in
let doubles = let doubles =
List.map (fun (i,_,det_j) -> (i,det_j)) doubles List.rev_map (fun (i,_,det_j) -> (i,det_j)) doubles
|> List.rev
in in
(singles, doubles) (singles, doubles)
) det_beta ) det_beta
@ -262,10 +264,12 @@ let create_matrix_spin ?(nmax=2) f det_space =
in in
let singles = let singles =
List.filter (fun (i,d,det_j) -> d < 2) doubles List.filter (fun (i,d,det_j) -> d < 2) doubles
|> List.map (fun (i,_,det_j) -> (i,det_j)) |> List.rev_map (fun (i,_,det_j) -> (i,det_j))
|> List.rev
in in
let doubles = let doubles =
List.map (fun (i,_,det_j) -> (i,det_j)) doubles List.rev_map (fun (i,_,det_j) -> (i,det_j)) doubles
|> List.rev
in in
(singles, doubles, triples) (singles, doubles, triples)
) b ) b
@ -292,13 +296,16 @@ let create_matrix_spin ?(nmax=2) f det_space =
in in
let triples = let triples =
List.map (fun (i,_,det_j) -> (i,det_j)) triples List.rev_map (fun (i,_,det_j) -> (i,det_j)) triples
|> List.rev
in in
let doubles = let doubles =
List.map (fun (i,_,det_j) -> (i,det_j)) doubles List.rev_map (fun (i,_,det_j) -> (i,det_j)) doubles
|> List.rev
in in
let singles = let singles =
List.map (fun (i,_,det_j) -> (i,det_j)) singles List.rev_map (fun (i,_,det_j) -> (i,det_j)) singles
|> List.rev
in in
(singles, doubles, triples) (singles, doubles, triples)
) b ) b
@ -769,7 +776,8 @@ let second_order_sum { det_space ; m_H ; m_S2 ; eigensystem ; n_states }
in in
let psi_filtered = let psi_filtered =
List.map (fun i -> psi0.(i)) psi_filtered_idx List.rev_map (fun i -> psi0.(i)) psi_filtered_idx
|> List.rev
in in
let psi_h_alfa alfa = let psi_h_alfa alfa =
@ -896,39 +904,34 @@ let second_order_sum2 { det_space ; m_H ; m_S2 ; eigensystem ; n_states }
Ds.determinants_array det_space Ds.determinants_array det_space
|> Array.to_list |> Array.to_list
|> List.map (fun det_i -> |> List.concat_map (fun det_i ->
[ Spin.Alfa ; Spin.Beta ] [ Spin.Alfa ; Spin.Beta ]
|> List.map (fun spin -> |> List.concat_map (fun spin ->
List.map (fun particle -> List.concat_map (fun particle ->
List.map (fun hole -> List.map (fun hole ->
[ [ Determinant.single_excitation spin hole particle det_i ] ; [ [ Determinant.single_excitation spin hole particle det_i ] ;
List.map (fun particle' -> List.concat_map (fun particle' ->
List.map (fun hole' -> List.map (fun hole' ->
Determinant.double_excitation Determinant.double_excitation
spin hole particle spin hole particle
spin hole' particle' det_i spin hole' particle' det_i
) list_holes ) list_holes
) list_particles ) list_particles
|> List.concat
; ;
List.map (fun particle' -> List.concat_map (fun particle' ->
List.map (fun hole' -> List.map (fun hole' ->
Determinant.double_excitation Determinant.double_excitation
spin hole particle spin hole particle
(Spin.other spin) hole' particle' det_i (Spin.other spin) hole' particle' det_i
) list_holes ) list_holes
) list_particles ) list_particles
|> List.concat
] ]
|> List.concat |> List.concat
) list_holes ) list_holes
) list_particles ) list_particles
|> List.concat
) )
|> List.concat
) )
|> List.concat |> List.concat
|> List.concat
|> List.filter (fun alfa -> not (Determinant.is_none alfa)) |> List.filter (fun alfa -> not (Determinant.is_none alfa))
|> List.sort_uniq compare |> List.sort_uniq compare
in in

View File

@ -323,10 +323,11 @@ let fci_f12_of_mo_basis mo_basis ~frozen_core mo_num =
in in
{ r with mo_class = { r with mo_class =
MOClass.to_list r.mo_class MOClass.to_list r.mo_class
|> List.map (fun i -> |> List.rev_map (fun i ->
match i with match i with
| MOClass.Virtual i when i > mo_num -> MOClass.Auxiliary i | MOClass.Virtual i when i > mo_num -> MOClass.Auxiliary i
| i -> i) | i -> i)
|> List.rev
|> MOClass.of_list } |> MOClass.of_list }
@ -339,10 +340,11 @@ let cas_f12_of_mo_basis mo_basis ~frozen_core n m mo_num =
in in
{ r with mo_class = { r with mo_class =
MOClass.to_list r.mo_class MOClass.to_list r.mo_class
|> List.map (fun i -> |> List.rev_map (fun i ->
match i with match i with
| MOClass.Virtual i when i > mo_num -> MOClass.Auxiliary i | MOClass.Virtual i when i > mo_num -> MOClass.Auxiliary i
| i -> i) | i -> i)
|> List.rev
|> MOClass.of_list |> MOClass.of_list
} }

View File

@ -76,7 +76,7 @@ let multiple_of_spindet t t' =
else else
Phase.Neg Phase.Neg
in in
(phase, List.map2 (fun hole particle -> (hole, particle)) holes (List.rev particles) ) (phase, List.rev @@ List.rev_map2 (fun hole particle -> (hole, particle)) holes (List.rev particles) )
let double_of_spindet t t' = let double_of_spindet t t' =
@ -99,8 +99,8 @@ let multiple_of_det t t' =
in in
let phase = Phase.add pa pb in let phase = Phase.add pa pb in
Multiple (phase, List.concat [ Multiple (phase, List.concat [
List.map (fun (hole, particle) -> { hole ; particle ; spin=Spin.Alfa }) a ; List.rev @@ List.rev_map (fun (hole, particle) -> { hole ; particle ; spin=Spin.Alfa }) a ;
List.map (fun (hole, particle) -> { hole ; particle ; spin=Spin.Beta }) b ]) List.rev @@ List.rev_map (fun (hole, particle) -> { hole ; particle ; spin=Spin.Beta }) b ])
let double_of_det t t' = let double_of_det t t' =

View File

@ -93,7 +93,8 @@ let holes_particles_of t t' =
let holes = Bitstring.logand (bitstring t) x |> Bitstring.to_list let holes = Bitstring.logand (bitstring t) x |> Bitstring.to_list
and particles = Bitstring.logand (bitstring t') x |> Bitstring.to_list and particles = Bitstring.logand (bitstring t') x |> Bitstring.to_list
in in
List.map2 (fun h p -> (h,p)) holes particles List.rev_map2 (fun h p -> (h,p)) holes particles
|> List.rev
let set_phase p = function let set_phase p = function

View File

@ -30,8 +30,8 @@ let fci_of_mo_basis ~frozen_core mo_basis elec_num =
let spin_determinants = let spin_determinants =
Bitstring.permtutations elec_num mo_num Bitstring.permtutations elec_num mo_num
|> List.filter (fun b -> Bitstring.logand neg_active_mask b = occ_mask) |> List.filter (fun b -> Bitstring.logand neg_active_mask b = occ_mask)
|> List.map (fun b -> Spindeterminant.of_bitstring b)
|> Array.of_list |> Array.of_list
|> Array.map (fun b -> Spindeterminant.of_bitstring b)
in in
{ elec_num ; mo_basis ; mo_class ; spin_determinants } { elec_num ; mo_basis ; mo_class ; spin_determinants }
@ -54,8 +54,8 @@ let cas_of_mo_basis mo_basis ~frozen_core elec_num n m =
let spin_determinants = let spin_determinants =
Bitstring.permtutations elec_num mo_num Bitstring.permtutations elec_num mo_num
|> List.filter (fun b -> Bitstring.logand neg_active_mask b = occ_mask) |> List.filter (fun b -> Bitstring.logand neg_active_mask b = occ_mask)
|> List.map (fun b -> Spindeterminant.of_bitstring b)
|> Array.of_list |> Array.of_list
|> Array.map (fun b -> Spindeterminant.of_bitstring b)
in in
{ elec_num ; mo_basis ; mo_class ; spin_determinants } { elec_num ; mo_basis ; mo_class ; spin_determinants }

View File

@ -67,10 +67,11 @@ let array_4_init d1 d2 d3 d4 fx =
SharedMemory.create Bigarray.Float64 [| d1;d2;d3;d4 |] SharedMemory.create Bigarray.Float64 [| d1;d2;d3;d4 |]
in in
Util.list_range 1 d4 Util.list_range 1 d4
|> List.map (fun l -> |> List.rev_map (fun l ->
Util.list_range 1 d3 Util.list_range 1 d3
|> List.map (fun k -> (k,l)) ) |> List.rev_map (fun k -> (k,l)) )
|> List.concat |> List.concat
|> List.rev
|> Stream.of_list |> Stream.of_list
|> Farm.run ~f ~ordered:false |> Farm.run ~f ~ordered:false
|> Stream.iter (fun (k,l,x) -> |> Stream.iter (fun (k,l,x) ->
@ -133,10 +134,11 @@ let array_5_init d1 d2 d3 d4 d5 fx =
SharedMemory.create Bigarray.Float64 [| d1;d2;d3;d4;d5 |] SharedMemory.create Bigarray.Float64 [| d1;d2;d3;d4;d5 |]
in in
Util.list_range 1 d5 Util.list_range 1 d5
|> List.map (fun m -> |> List.rev_map (fun m ->
Util.list_range 1 d4 Util.list_range 1 d4
|> List.map (fun l -> (l,m)) ) |> List.rev_map (fun l -> (l,m)) )
|> List.concat |> List.concat
|> List.rev
|> Stream.of_list |> Stream.of_list
|> Farm.run ~f ~ordered:false |> Farm.run ~f ~ordered:false
|> Stream.iter (fun (l,m,x) -> |> Stream.iter (fun (l,m,x) ->

View File

@ -34,51 +34,45 @@ let to_list t = t
let core_mos t = let core_mos t =
List.map (fun x -> List.filter_map (fun x ->
match x with match x with
| Core i -> Some i | Core i -> Some i
| _ -> None) t | _ -> None) t
|> Util.list_some
let inactive_mos t = let inactive_mos t =
List.map (fun x -> List.filter_map (fun x ->
match x with match x with
| Inactive i -> Some i | Inactive i -> Some i
| _ -> None ) t | _ -> None ) t
|> Util.list_some
let active_mos t = let active_mos t =
List.map (fun x -> List.filter_map (fun x ->
match x with match x with
| Active i -> Some i | Active i -> Some i
| _ -> None ) t | _ -> None ) t
|> Util.list_some
let virtual_mos t = let virtual_mos t =
List.map (fun x -> List.filter_map (fun x ->
match x with match x with
| Virtual i -> Some i | Virtual i -> Some i
| _ -> None ) t | _ -> None ) t
|> Util.list_some
let deleted_mos t = let deleted_mos t =
List.map (fun x -> List.filter_map (fun x ->
match x with match x with
| Deleted i -> Some i | Deleted i -> Some i
| _ -> None ) t | _ -> None ) t
|> Util.list_some
let auxiliary_mos t = let auxiliary_mos t =
List.map (fun x -> List.filter_map (fun x ->
match x with match x with
| Auxiliary i -> Some i | Auxiliary i -> Some i
| _ -> None ) t | _ -> None ) t
|> Util.list_some
let mo_class_array t = let mo_class_array t =

View File

@ -107,19 +107,19 @@ let zkey_array a =
begin begin
match a with match a with
| Singlet l1 -> | Singlet l1 ->
List.map (fun x -> Zkey.of_powers_three x) (keys_1d @@ to_int l1) List.rev_map (fun x -> Zkey.of_powers_three x) (keys_1d @@ to_int l1)
| Doublet (l1, l2) -> | Doublet (l1, l2) ->
List.map (fun a -> List.rev_map (fun a ->
List.map (fun b -> Zkey.of_powers_six a b) (keys_1d @@ to_int l2) List.rev_map (fun b -> Zkey.of_powers_six a b) (keys_1d @@ to_int l2)
) (keys_1d @@ to_int l1) ) (keys_1d @@ to_int l1)
|> List.concat |> List.concat
| Triplet (l1, l2, l3) -> | Triplet (l1, l2, l3) ->
List.map (fun a -> List.rev_map (fun a ->
List.map (fun b -> List.rev_map (fun b ->
List.map (fun c -> List.rev_map (fun c ->
Zkey.of_powers_nine a b c) (keys_1d @@ to_int l3) Zkey.of_powers_nine a b c) (keys_1d @@ to_int l3)
) (keys_1d @@ to_int l2) ) (keys_1d @@ to_int l2)
|> List.concat |> List.concat
@ -128,10 +128,10 @@ let zkey_array a =
| Quartet (l1, l2, l3, l4) -> | Quartet (l1, l2, l3, l4) ->
List.map (fun a -> List.rev_map (fun a ->
List.map (fun b -> List.rev_map (fun b ->
List.map (fun c -> List.rev_map (fun c ->
List.map (fun d -> List.rev_map (fun d ->
Zkey.of_powers_twelve a b c d) (keys_1d @@ to_int l4) Zkey.of_powers_twelve a b c d) (keys_1d @@ to_int l4)
) (keys_1d @@ to_int l3) ) (keys_1d @@ to_int l3)
|> List.concat |> List.concat
@ -140,6 +140,7 @@ let zkey_array a =
) (keys_1d @@ to_int l1) ) (keys_1d @@ to_int l1)
|> List.concat |> List.concat
end end
|> List.rev
|> Array.of_list |> Array.of_list
in in
Hashtbl.add zkey_array_memo a result; Hashtbl.add zkey_array_memo a result;

View File

@ -117,7 +117,7 @@ let make
in in
let residual_norms = List.map nrm2 u_proposed in let residual_norms = List.rev @@ List.rev_map nrm2 u_proposed in
let residual_norm = let residual_norm =
List.fold_left (fun accu i -> accu +. i *. i) 0. residual_norms List.fold_left (fun accu i -> accu +. i *. i) 0. residual_norms
|> sqrt |> sqrt

View File

@ -80,11 +80,12 @@ let sparse_of_computed ?(threshold=epsilon) = function
| Computed {m ; n ; f} -> | Computed {m ; n ; f} ->
Sparse { m ; n ; v=Array.init n (fun j -> Sparse { m ; n ; v=Array.init n (fun j ->
Util.list_range 1 m Util.list_range 1 m
|> List.map (fun i -> |> List.rev_map (fun i ->
let x = f i (j+1) in let x = f i (j+1) in
if abs_float x > threshold then Some (i, x) if abs_float x > threshold then Some (i, x)
else None) else None)
|> Util.list_some |> Util.list_some
|> List.rev
|> Vector.sparse_of_assoc_list m |> Vector.sparse_of_assoc_list m
) } ) }
| _ -> invalid_arg "Expected a computed matrix" | _ -> invalid_arg "Expected a computed matrix"
@ -176,7 +177,7 @@ let outer_product ?(threshold=epsilon) v1 v2 =
in in
let v = let v =
Array.init (Vector.dim v2) (fun j -> Array.init (Vector.dim v2) (fun j ->
List.map (fun (i, x) -> List.rev_map (fun (i, x) ->
let z = x *. v'.{j+1} in let z = x *. v'.{j+1} in
if abs_float z < threshold then if abs_float z < threshold then
None None
@ -184,6 +185,7 @@ let outer_product ?(threshold=epsilon) v1 v2 =
Some (i, z) Some (i, z)
) v ) v
|> Util.list_some |> Util.list_some
|> List.rev
|> Vector.sparse_of_assoc_list (Vector.dim v1) |> Vector.sparse_of_assoc_list (Vector.dim v1)
) )
in in
@ -500,22 +502,23 @@ let split_cols nrows = function
Mat.to_col_vecs a Mat.to_col_vecs a
|> Array.to_list |> Array.to_list
|> Util.list_pack nrows |> Util.list_pack nrows
|> List.map (fun l -> |> List.rev_map (fun l ->
Dense (Mat.of_col_vecs @@ Array.of_list l) ) Dense (Mat.of_col_vecs @@ Array.of_list l) )
|> List.rev
end end
| Sparse a -> | Sparse a ->
begin begin
Array.to_list a.v Array.to_list a.v
|> Util.list_pack nrows |> Util.list_pack nrows
|> List.map Array.of_list |> List.rev_map Array.of_list
|> List.map (fun v -> Sparse { m=a.m ; n= Array.length v ; v }) |> List.rev_map (fun v -> Sparse { m=a.m ; n= Array.length v ; v })
end end
| Computed a -> | Computed a ->
begin begin
Util.list_range 0 (a.n-1) Util.list_range 0 (a.n-1)
|> Util.list_pack nrows |> Util.list_pack nrows
|> List.map Array.of_list |> List.rev_map Array.of_list
|> List.map (fun v -> Computed { m=a.m ; n= Array.length v ; f = (fun i j -> a.f i (j+v.(0)) ) }) |> List.rev_map (fun v -> Computed { m=a.m ; n= Array.length v ; f = (fun i j -> a.f i (j+v.(0)) ) })
end end
@ -534,7 +537,9 @@ let join_cols l =
| [] -> Sparse { m=0 ; n=0 ; v=[| |] } | [] -> Sparse { m=0 ; n=0 ; v=[| |] }
| (Dense a) :: rest -> aux_dense [] ((Dense a) :: rest) | (Dense a) :: rest -> aux_dense [] ((Dense a) :: rest)
| (Sparse a) :: rest -> aux_sparse 0 0 [] ((Sparse a) :: rest) | (Sparse a) :: rest -> aux_sparse 0 0 [] ((Sparse a) :: rest)
| (Computed a) :: rest -> aux_sparse 0 0 [] (List.map sparse_of_computed ( (Computed a) :: rest )) | (Computed a) :: rest -> aux_sparse 0 0 []
(List.rev_map sparse_of_computed ( (Computed a) :: rest )
|> List.rev)
in aux (List.rev l) in aux (List.rev l)

View File

@ -244,7 +244,8 @@ let of_some = function
let list_some l = let list_some l =
List.filter (function None -> false | _ -> true) l List.filter (function None -> false | _ -> true) l
|> List.map (function Some x -> x | _ -> assert false) |> List.rev_map (function Some x -> x | _ -> assert false)
|> List.rev
let list_range first last = let list_range first last =

View File

@ -91,8 +91,8 @@ let sparse_of_vec ?(threshold=epsilon) v =
let sparse_of_assoc_list n v = let sparse_of_assoc_list n v =
Sparse { n ; Sparse { n ;
v = List.map (fun (index, value) -> {index ; value}) v v = Array.of_list v
|> Array.of_list |> Array.map (fun (index, value) -> {index ; value})
} }