(*
  skeg - Sex, Kinematics, Elegance and Glory.
  Copyright (C) 2004 David Baelde and Samuel Mimram.

  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with this program; if not, write to the Free Software
  Foundation, Inc., 59 Temple Place - Suite 330,
  Boston, MA 02111-1307, USA.
*)


(** Operations on matrices.

@author David Baelde, Samuel Mimram *)



(* $Id: matrix.ml,v 1.15 2004/05/14 10:41:40 dbaelde Exp $ *)

type t = float array array

let copy m =
  Array.map Array.copy m

let mult a b =
  let m = Array.length a in
  let n = Array.length b.(0) in
  let p = Array.length a.(0) in
  let c = Array.make_matrix m n 0. in
    for i = 0 to m - 1 do
      for j = 0 to n - 1 do
        for k = 0 to p - 1 do
          c.(i).(j) <- c.(i).(j) +. a.(i).(k) *. b.(k).(j)
        done
      done
    done ;
    c

(**

Internal use

*)


exception Found of int

exception Error of float

let error a b =
  let e = ref 0. in
    for i = 0 to (Array.length a)-1 do
      e := !e +. (a.(i)-.b.(i))*.(a.(i)-.b.(i))
    done ;
    sqrt !e

(** print_sys m b displays the system m * x = b in a very verbose way. *)

let print_sys m b =
  Array.iteri
    (fun i e ->
       Array.iter (Printf.printf "(%f)") e ;
       Printf.printf "=(%f)\n" b.(i)) m

(** print_sys_silhouette m b does as print_sys m b but is more visual. It prints out a '*' for non-zero values and a ' ' for zero. *)

let print_sys_silhouette m b =
  let print_silhouette f =
    print_char (if f = 0. then ' ' else '*')
  in
    Array.iteri
      (fun i e ->
         Array.iter print_silhouette e ; print_char '|' ;
         print_silhouette b.(i) ; print_newline () ) m ;
    print_string ((String.make (Array.length m.(0)) '-')^"+-\n")

(** nearly_zero f tells whether f should be neglected or not. *)

let nearly_zero f =
  let c = classify_float f in
    c = FP_zero || c = FP_subnormal

(**

Conversions from and to vectors

*)


let of_vector_v =
  Array.map (fun (v:float) -> [| v |])

let of_vector_h x = [| x |]

let to_vector_v x =
  Array.map (fun l -> l.(0)) x

let to_vector_h x = x.(0)

(**

Solving linear systems

*)


(** gauss m b transforms m and b such that m is upper triangular, without changing the set of solutions of m*X=b. *)

let gauss a b =
  let m = Array.length a in
  let n = Array.length a.(0) in

  (* Elementary operation on lines *)
  let exchange i1 i2 =
    if i1 <> i2 then
      let (c,d) = (a.(i1),b.(i1)) in
        a.(i1) <- a.(i2) ;
        a.(i2) <- c ;
        b.(i1) <- b.(i2) ;
        b.(i2) <- d
  in
  let mult i c =
    for j = 0 to n -1 do
      a.(i).(j) <- c *. a.(i).(j)
    done ;
    b.(i) <- c *. b.(i)
  in
  let add i1 i2 c =
    for j = 0 to n - 1 do
      a.(i1).(j) <- a.(i1).(j) +. c *. a.(i2).(j)
    done ;
    b.(i1) <- b.(i1) +. c *. b.(i2)
  in

  let score i =
    (* The score of line i is the first j such that c_ij <> 0. *)
    try
      for j = 0 to n-1 do
        if a.(i).(j) <> 0. then raise (Found j)
      done ; n
    with
      | Found j -> j
  in
  let find i =
    (* We extract the line k>=i with the lowest score. *)
    let best_of (i,si) (k,sk) =
      if si<sk then
        (i,si)
      else
        (k,sk)
    in
    let rec find best k =
      if k = m then best else
        find (best_of best (k,(score k))) (k+1)
    in
      find (i,(score i)) i
  in

    for i = 0 to (min m n) - 1 do
      let (ii,col) = find i in
        if col <> n then
          begin
            (* We'll work with equation ii, pivoting with column col. *)
            exchange i ii ;
            
            (* Normalize *)
            mult i (1./.a.(i).(col)) ;

            (* We want c_{k>i,col} = 0. *)
            for k = i + 1 to m - 1 do
              add k i (-. a.(k).(col)) ;
              a.(k).(col) <- 0.
            done
          end
    done

exception Unsolvable

(** Solves a upper triangular system. @raise Unsolvable if the system cannot be solved. *)

let basic_solve a b =
  let m = Array.length a in
  let n = Array.length a.(0) in
  let j i = (* Gives the first j such that c_ij isn't zero *)
    try
      for j = 0 to n-1 do
        if not (nearly_zero a.(i).(j)) then raise (Found j)
      done ; -1
    with Found j -> j
  in
  let x = Array.make n 0. in
    for i = m-1 downto 0 do
      (* Dealing with equation i. We extract the pivot j. *)
      let j = j i in
        if j = -1 then
          (* No freedom: there must be nothing to do. *)
          ( if abs_float b.(i) > 0.1 then
              ( Printf.printf "Error at %d: 0. <> %f\n%!" i b.(i) ;
                raise Unsolvable ) else
                Printf.printf "Admitted error %f at %d.\n%!" b.(i) i )
        else
          ( x.(j) <- b.(i) /. a.(i).(j) ;
            (* Update the constant vector. *)
            for ii = 0 to i do
              b.(ii) <- b.(ii) -. x.(j) *. a.(ii).(j)
            done )
    done ;
    x

(** solve a b solves any system -- at least it tries. a and b can be modified since solve uses gauss. *)

let solve a b =
  (* (* This is the debug version, with a matrix multiplication to check. *)
  let _a = copy a in
  let _b = Array.copy b in
    print_sys_silhouette a b ;
    gauss a b ;
    print_sys_silhouette a b ;
    let x = basic_solve a b in
      Printf.printf "
        (error b
           (Array.make (Array.length b) 0.)) ;
      (* Below is a check of the corectness of the result. *)
      let vx = of_vector_v x in
      let b = _b in
      let ax = to_vector_v (mult _a vx) in
      let e = error ax b in
        Printf.printf " e ;
        if e>0.1 then raise (Error e) ;
        x
  *)

  gauss a b ;
  basic_solve a b