/*! \file
    \brief Plain Euclid algorithm with 2x2 matrices
*/

#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <assert.h>

#include "gmp.h"
#include "gmp-impl.h"

#include "gcd_utils.h"
#include "gcd_matrix.h"
#include "qseq.h"
#include "regular.h"

#include "gcd_euclid.h"

#define DEBUG_EUCLID   0
#define DEBUG_PTR      0
#define TIMINGS        0
#define COUNT_STEPS    0

int mpn_is_power_of_base(mp_ptr a, mp_size_t an)
{
    mp_size_t m;

    if(an == 0) return -1;
    for(m = 0; m < an-1; m++)
	if(a[m] != 0)
	    return -1;
    return (a[an-1] == 1 ? (an-1) : -1);
}

int mpn_gcd_is_norm_lt(mp_ptr a, mp_size_t an, mp_size_t m)
{
    int n;
    
#if DEBUG_EUCLID >= 3
    printf("norm_lt: a:="); MPN_PRINT(a, an); printf("; m=%lu\n", m);
#endif
    if((n = mpn_is_power_of_base(a, an)) != -1)
        /* strange writing caused by BASE^0 */
	return n < m;
    else
	return an <= m;
}

int mpn_gcd_is_norm_le(mp_ptr a, mp_size_t an, mp_size_t m)
{
    int n;
    
#if DEBUG_EUCLID >= 3
    printf("norm_le: a:="); MPN_PRINT(a, an); printf("; m=%lu\n", m);
#endif
    if((n = mpn_is_power_of_base(a, an)) != -1)
        /* strange writing caused by BASE^0 */
	return n <= m;
    else
	return an <= m;
}

/* OUTPUT: ceiling(log_base(a)). */
mp_size_t mpn_gcd_ceiling_norm(mp_ptr a, mp_size_t an)
{
    int n;
    
    if((n = mpn_is_power_of_base(a, an)) != -1)
	/* strange writing caused by BASE^0 */
	return n;
    else /* a < base^an */
	return an;
}

/* PRECONDITION: R->Q == R->lq
   OUTPUT: ||R->Q|| > emin or base^m? 
*/
static int is_norm_Q_gt(regular_t R, mp_size_t m, mp_ptr emin, mp_size_t eminn)
{
    mp_size_t a1n = R->Q->n, i;

#if DEBUG_EUCLID >= 3
    printf("Q=");
    hgcd_matrix_print(R->Q);
    printf("\n");
#endif
    /* ||Q|| = Q->Q[0][0] => find real size for comparing */
    MPN_NORMALIZE(R->Q->p[0][0], a1n);
    if(m == 0)
	/* do we have ||Q|| > emin? */
	return (a1n > eminn)
	    || (a1n == eminn && mpn_cmp(R->Q->p[0][0], emin, eminn) > 0);
    else{
	/* do we have ||Q|| > base^m? */
	if(a1n <= m)
	    /* base^m has m+1 digits... */
	    return 0;
	if(a1n > m+1)
	    return 1;
	/* case a1n == m+1 */
	if(R->Q->p[0][0][m] > 1)
	    return 1;
	/* case a1 = BASE^m + ... */
	for(i = 0; i < m; i++)
	    if(R->Q->p[0][0][i] != 0)
		return 1;
	/* case a1 == BASE^m */
	return 0;
    }
}

/* if qgcd == 0:
       if m = 0 stop when b < emin, which implies emin >= 1;
       if m > 0 stop when ||b|| < m or ||b|| <= m depending on use_eq; 
            in that case, we can stand emin == NULL and eminn = 0.
   elif qgcd == 1:
       if m = 0, stop when ||Q_{i+1}|| > emin
       if m > 0, stop when ||Q_{i+1}|| > base^m
*/
inline int gcd_stopping_condition(mp_ptr b, mp_size_t bn, mp_size_t m,
				  mp_ptr emin, mp_size_t eminn,
				  int use_eq, regular_t R, int qgcd)
{
    assert(m != 0 || eminn != 0);
    if(qgcd == 0){
	if(m == 0){
	    assert(eminn > 0); /* FIXME: why? */
	    return (bn < eminn) || (bn == eminn && mpn_cmp(b,emin,eminn) < 0);
	}
	else
	    return (use_eq ? mpn_gcd_is_norm_le(b, bn, m)
		           : mpn_gcd_is_norm_lt(b, bn, m));
    }
    else{
#if DEBUG_EUCLID >= 1
	printf("bn:=%lu;\n", bn);
	printf("||Q||[%lu]:=", R->Q->n); MPN_PRINT(R->Q->p[0][0], R->Q->n);
	printf(";\n");
	printf("emin[%lu]:=", eminn); MPN_PRINT(emin, eminn); printf(";\n");
#endif
	return (bn == 0 || is_norm_Q_gt(R, m, emin, eminn));
    }
}

/* if m = 0 stop when a > rmin
   if m > 0 stop when ||a|| >= m */
int gcd_stopping_condition_ge(mp_ptr a, mp_size_t an, mp_size_t m,
			      mp_ptr rmin, mp_size_t rminn)
{
#if DEBUG_EUCLID >= 2
    printf("   a[%lu]:=", an); MPN_PRINT(a, an); printf(";\n");
    printf("rmin[%lu]:=", rminn); MPN_PRINT(rmin, rminn); printf(";\n");
#endif
    if(m == 0)
	return ((an > rminn) || (an == rminn && mpn_cmp(a, rmin, rminn) > 0));
    else
	return mpn_gcd_is_norm_ge(a, an, m);
}

/* INPUT: an >= bn; size(q), size(r) >= an, large enough;
   b[bn..an[ should be 0.
   SIDE-EFFECT: a = q*b+r, 0 <= r < b.
                R is updated by multiplication. TODO: possible batch.
   REM: stands r = a, but not q = tp because of mpn_hgcd_matrix_update_q...!!!
   IDEA: try q = 1 beforehand (41% chance of success in the context).

   FIXME: check gmp for an equivalent function.

 */
inline mp_size_t euclidean_step(qseq_t lq, mp_ptr q, mp_ptr r,
				mp_ptr a, mp_size_t an,
				mp_ptr b, mp_size_t bn,
				mp_ptr tp, mp_size_t tp_alloc)
{
    mp_size_t qn;

    qn = euclidean_div_rem(q, r, a, an, b, bn, tp, tp_alloc);
    qseq_add_last_mpn(lq, q, qn);
    return qn;
}

/* INPUT: q must have size >= an
   PRECONDITION: an >= bn >= rminn
   SIDE-EFFECT: (tp, a) <- (a div b, a mod b);
                lq is updated.
*/
static inline mp_size_t euclid_gcd_aux(qseq_t lq,
				       mp_ptr a, mp_size_t an,
				       mp_ptr b, mp_size_t bn,
				       mp_ptr q,
				       mp_ptr tp, mp_size_t tp_alloc)
{
    mp_size_t qn;
    
    assert(an >= bn);
    /* clear b[bn..an[ */
    MPN_ZERO(b+bn, an-bn);
    qn = euclidean_step(lq, q, a, a, an, b, bn, tp, tp_alloc);
    //    hgcd_matrix_mul_q(R->Q, q, qn, tp, tp_alloc);
    /* a[0..bn-1] is filled but a[bn..an[ is left untouched! */
    MPN_ZERO(a+bn, an-bn);
#if DEBUG_EUCLID >= 1
    printf("r:="); MPN_PRINT(a, bn); printf(";\n");
    printf("printf \"CHECK: %%o\\n\", a-b*q-r;\n");
    printf("Q:=Matrix([[q, 1], [1, 0]]);\n");
#endif
    return qn;
}

/** INPUT: a > b with an+1 words, 
           size(tp) >= an+1 (should be enough, not checked);
    PRECONDITION:
           R->Q == R->lq
	   if qgcd == 0:
	       a > b >= emin > 0
	   elif qgcd == 1:
	       ||R->Q|| <= emin or base^m
    SIDE-EFFECT: a[an] (resp. b[an]) contain the number of digits 
                 of a (resp. b);
		 R is updated by accumulation.
    POSTCONDITION:
           R->Q == R->lq
           if qgcd == 0:
	       a >= emin[0..eminn[ > b; 
	       or ||a|| >= m > ||b||
	   elif qgcd == 1:
	       ||R->Q = Q_i|| <= emin or base^m < ||Q_{i+1}|| <= base^(m+1)
*/
void euclid_gcd(regular_t R,
		mp_ptr a, mp_size_t an,
		mp_ptr b, mp_size_t bn,
		mp_size_t m, mp_ptr emin, mp_size_t eminn,
		mp_ptr tp, mp_size_t tp_alloc, int qgcd)
{
    struct hgcd_matrix M;
    mp_ptr xim1 = NULL, q = NULL;
    mp_size_t n = an, tmpn, xim1n = 0;
    int nsteps = 0; /* for checking PRECONDITION a posteriori */
    int Rfirst, Rfirst_large, det;

#if DEBUG_EUCLID >= 1
    int lqfirst = R->lq->first;

    R->lq->first = R->lq->last;
    if(m == 0){
	printf("emin:="); MPN_PRINT(emin, eminn); printf(";\n");
    }
    else
	printf("m:=%lu;\n", m);
    mp_size_t Aorign = an, Borign = bn;
    mp_ptr Aorig = (mp_ptr)malloc(an * sizeof(mp_limb_t));
    mp_ptr Borig = (mp_ptr)malloc(bn * sizeof(mp_limb_t));
    MPN_COPY(Aorig, a, an);
    MPN_COPY(Borig, b, bn);
#endif
    if(bn == 0){
	/* stop at once */
	b[n] = 0;
	a[n] = an;
	return;
    }
    regular_save(&Rfirst, &Rfirst_large, &det, R);
    q = (mp_ptr)malloc((an+1) * sizeof(mp_limb_t));
    /* qgcd == 0: while a > b >= rmin 
       qgcd == 1: while ||Q|| <= emin or base^m
     */
    if(qgcd == 1){
	/* FIXME */
	hgcd_matrix_init(&M, R->Q->alloc);
	xim1 = (mp_ptr)malloc(an * sizeof(mp_limb_t));
    }
    ASSERT(R->det == qseq_determinant(R->lq));
    while(1){
	/* unrolling to keep pointers a and b safe */
	ASSERT(R->det == qseq_determinant(R->lq));
#if DEBUG_EUCLID >= 1
	printf("x%d:=", nsteps); MPN_PRINT(a, an); printf(";\n");
	printf("x%d:=", nsteps+1); MPN_PRINT(b, bn); printf(";\n");
	printf("Q%d:=", nsteps); hgcd_matrix_print(R->Q); printf(";\n");
#endif
#if DEBUG_EUCLID >= 2
	/* warning: cannot work with gcd_small if Rorig non empty */
	assert(regular_are_norms_correct(R));
	assert(regular_check_inequality(R, a, an, Aorig, Aorign));
#endif
	if(gcd_stopping_condition(b, bn, m, emin, eminn, 0, R, qgcd) != 0)
	    break;
	nsteps++;
	/* (tp, a) <- (a div b, a mod b) */
	if(qgcd == 1){
	    /* FIXME */
	    /* copy for future use */
	    hgcd_matrix_set(&M, R->Q);
	    MPN_COPY(xim1, a, an);
	    xim1n = an;
	}
	euclid_gcd_aux(R->lq, a, an, b, bn, q, tp, tp_alloc);
	MPN_NORMALIZE(a, an);
#if DEBUG_EUCLID >= 1
	printf("x%d:=", nsteps); MPN_PRINT(b, bn); printf(";\n");
	printf("x%d:=", nsteps+1); MPN_PRINT(a, an); printf(";\n");
	printf("Q%d:=", nsteps); hgcd_matrix_print(R->Q); printf("\n");
#endif
#if DEBUG_EUCLID >= 2
	/* warning: cannot work with gcd_small if Rorig non empty */
	assert(regular_are_norms_correct(R));
	assert(regular_check_inequality(R, b, bn, Aorig, Aorign));
#endif
	if(gcd_stopping_condition(a, an, m, emin, eminn, 0, R, qgcd) != 0){
	    /* we have an <= bn */
	    MPN_SWAP(a, b, bn);
	    tmpn = an; an = bn; bn = tmpn;
	    break;
	}
	nsteps++;
	/* (tp, b) <- (b div a, b mod a) */
	if(qgcd == 1){
	    hgcd_matrix_set(&M, R->Q);
	    MPN_COPY(xim1, b, bn);
	    xim1n = bn;
	}
	euclid_gcd_aux(R->lq, b, bn, a, an, q, tp, tp_alloc);
	MPN_NORMALIZE(b, bn);
    }
    ASSERT(R->det == qseq_determinant(R->lq));
    /* we could have b = 0, in which case we have to stop */
    /* checking precondition a posteriori: really useful (when b = 0)? */
    /* if qgcd == 0: a >= rmin > b, postcondition met */
    if(qgcd == 1 && is_norm_Q_gt(R, m, emin, eminn)){
	/* we have ||Q|| > emin or base^m, we must roll back */
	assert(nsteps > 0); /* we cannot roll back before the call */
#if DEBUG_EUCLID >= 1
	printf("Qi_qeuc:="); hgcd_matrix_print(&M); printf(";\n");
	printf("Qip1_qeuc:="); hgcd_matrix_print(R->Q); printf("\n");
#endif
	hgcd_matrix_set(R->Q, &M);
	// DET	R->det = -R->det;
	hgcd_matrix_clear(&M);
	if(qseq_is_used(R->lq) != 0){
	    qseq_remove_last(R->lq);
	    ASSERT(R->det == qseq_determinant(R->lq));
	}
	/* xim1 -> a -> b; make (a, b) <- (xim1, a) */
	MPN_COPY(b, a, an);
	bn = an;
	MPN_COPY(a, xim1, xim1n);
	an = xim1n;
	free(xim1);
    }
    /* now flush: R[0..R->lq->large[ *= lq[Rfirst..R->lq->large[ */
    regular_flush_lq(R, R->lq->first);
    regular_restore(R, Rfirst, Rfirst_large, det);
#if DEBUG_EUCLID >= 100
    check_Q_from_lq(R->Q, R->lq);
#endif
    a[n] = an;
    b[n] = bn;
    free(q);
#if DEBUG_EUCLID >= 1
    if(R->lq != NULL && qseq_is_empty(R->lq) == 0)
	qseq_check(R->lq, Aorig, Aorign, Borig, Borign, a, an, b, bn);
    R->lq->first = lqfirst;
#endif
}

/* INPUT: R->lq = <q1, ..., qk>; 
   SIDE-EFFECT: (a, b) <- (qk * a + b, a)
                (Q, lq) <- (Q / <qk>, lq minus <qk>)

   FIXME: too many MPN_COPY.

*/
void euclid_step_back(regular_t R,
		      mp_ptr a, mp_size_t *an,
		      mp_ptr b, mp_size_t *bn,
		      mp_ptr tp, mp_size_t tp_alloc)
{
    mp_size_t tpn = *bn;
    
    MPN_COPY(tp, b, tpn);
    /* tp <- b + a * q */
    mpn_qseq_addmul_last(tp, &tpn, a, *an, R->lq);
    assert(tpn <= tp_alloc);
    MPN_COPY(b, a, *an);
    *bn = *an;
    MPN_COPY(a, tp, tpn);
    *an = tpn;
    regular_oslash(R, tp, tp_alloc);
}

/** INPUT: R contains the partial quotients that should be applied to (a*, b*)
           ending at (a, b).
    OUTPUT: size of a* which is >= size of b*.
    SIDE-EFFECT: (a, b) <- (q*a+b, a); R is updated.
    POSTCONDITION: 
        if qgcd == 0:
	    a >= rmin > b (so a == a* and b == b*) if rmin != NULL
	    or ||a|| >= m > ||b|| if m > 0
	elif qgcd == 1:
	    ||R->Q|| <= emin or base^m

    FIXME: pass a buffer tp/tp_alloc?
*/
mp_size_t euclid_gcd_backwards(regular_t R,
			       mp_ptr a, mp_size_t an,
			       mp_ptr b, mp_size_t bn,
			       mp_size_t m, mp_ptr emin, mp_size_t eminn,
			       int imax)
{
#if TIMINGS >= 1
    double tt = runtime();
#endif
    mp_size_t tpn = 2*an+3; /* FIXME: overshooting? */
    mp_ptr tp = (mp_ptr)malloc(tpn * sizeof(mp_limb_t));
    int i = -1;

    while(regular_is_empty(R) == 0){
	i++;
#if DEBUG_EUCLID >= 1
	printf("a:="); MPN_PRINT(a, an); 
	printf(";\nb:="); MPN_PRINT(b, bn);
	printf(";\n");
	if(emin != NULL){
	    printf("emin:="); MPN_PRINT(emin, eminn);
	    printf(";\n");
	}
#endif
#if DEBUG_EUCLID >= 1
	printf("gcd_backup: q%d:=", i);
	qseq_print_cell(R->lq, R->lq->last-1);
	printf(";\n");
#endif
#if 0
	/* previous remainder is q*a+b */
	tpn = qseq_nl(R->lq, R->lq->last-1);
	tpn = 2 + max(an + (tpn == 0 ? 1 : tpn), bn);
	tp = realloc(tp, tpn * sizeof(mp_limb_t));
	MPN_COPY(tp, b, bn);
	MPN_ZERO(tp+bn, tpn-bn);
	tpn = bn;
	mpn_qseq_addmul_last(tp, &tpn, a, an, R->lq);
	qseq_remove_last(R->lq);
	MPN_COPY(b, a, an);
	bn = an;
	MPN_COPY(a, tp, tpn);
	an = tpn;
#else
	regular_oslash3(R, a, &an, b, &bn, tp, tpn);
#endif
	/* we want a >= emin or ||a|| >= m */
	if(gcd_stopping_condition_ge(a, an, m, emin, eminn) != 0)
	    break;
    }
    assert(i < imax);
    /* be very cautious */
    MPN_ZERO(b+bn, an-bn);
#if TIMINGS >= 1
    fprintf(stderr, "#T# backup: %d %lf\n", i, runtime()-tt);
#endif
    /* FIXME: could it be that lq is empty at the end?? */
#if DEBUG_EUCLID >= 1
    printf("final values As:="); MPN_PRINT(a, an); 
    printf(";\nBs:="); MPN_PRINT(b, bn);
    printf(";\n");
    if(emin != NULL){
	printf("emin:="); MPN_PRINT(emin, eminn);
	printf(";\n");
    }
#endif
    free(tp);
    return an;
}

