/*! \file
    \brief Utilities for gcd algorithms
*/

#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <math.h>

#include <assert.h>

#include "gmp.h"
#include "gmp-impl.h"

#include "gcd_utils.h"

#define DEBUG_UTILS 0

#if COUNT_QEQ1 == 1
long nbqeq1 = 0, nbdeltagt1 = 0, nbdeltaeq1 = 0, nbdeltaeq0 = 0, nbxor = 0;
#endif
 
#if defined(hp) || defined(i486)
#include <sys/time.h>
#include <sys/times.h>
#include <sys/param.h>
#else
#include <sys/time.h>
#include <sys/resource.h>
#endif

/** Estimates the CPU time spent since starting the program */
double runtime(void) 
{
#if defined(hp) || defined(i486)
        struct tms t;

        times(&t);
        return(t.tms_utime*1000./HZ);
#else
	struct rusage r;
	struct timeval t;

	getrusage(0, &r);
	t = r.ru_utime;
	return(1000*t.tv_sec + (t.tv_usec/1000.));
#endif
}

/* anti-clobbering (slow) code!
   B^n <= a < B^{n+1} => log_10(a) < (n+1)*log_10(B), a has n+1 digits.
   => number of digits is <= 1 + (n+1)*64*log(2).
*/
void MPN_PRINT(mp_ptr a, mp_size_t an)
{
    mp_ptr tmp = (mp_ptr)malloc(an * sizeof(mp_limb_t));
    unsigned char *buf, small[10000];
    size_t ndig = (size_t)(1 + an * log(2.0) * GMP_NUMB_BITS);
    size_t j, nc;

    MPN_COPY(tmp, a, an);
    if(ndig <= 10000){
	nc = mpn_get_str(small, 10, tmp, an);
	assert(nc <= ndig);
	if(nc == 0)
	    printf("0");
	else
	    for(j = 0; j < nc; j++)
		printf("%c", '0'+small[j]);
    }
    else{
	buf = (unsigned char*)malloc(ndig * sizeof(unsigned char));
	nc = mpn_get_str(buf, 10, tmp, an);
	assert(nc <= ndig);
	if(nc == 0)
	    printf("0");
	else
	    for(j = 0; j < nc; j++)
		printf("%c", '0'+buf[j]);
	free(buf);
    }
    free(tmp);
}

MAYBE_UNUSED
void MPN_PRINT2(mp_ptr a, mp_size_t an)
{
    unsigned char buf[GMP_NUMB_BITS];
    int i, j, nc;
    
    for(i = an-1; i >= 0; i--){
	nc = mpn_get_str(buf, 2, a+i, 1);
	for(j = nc; j < GMP_NUMB_BITS; j++)
	    printf("0");
	for(j = 0; j < nc; j++)
	    printf("%c", '0'+buf[j]);
	printf(":");
    }
}

/* anti-clobbering (slow) code!
   B^n <= a < B^{n+1} => log_10(a) < (n+1)*log_10(B), a has n+1 digits.
   => number of digits is <= 1 + (n+1)*64*log(2).
*/
void MPN_PRINT_BASE(mp_ptr a, mp_size_t an)
{
    mp_size_t i;

    printf("[");
    for(i = 0; i < an; i++){
	printf("%lu", a[i]);
	if(i < an-1)
	    printf(", ");
    }
    printf("]");
}

/********** borrowed from CADO: begin **********/

void MPZ_SET_MPN(mpz_ptr DST, const mp_limb_t * SRC, size_t NLIMBS)
{
    /* MPZ_GROW_ALLOC(DST, NLIMBS); */
    {
        if (ALLOC(DST) < (int) (NLIMBS)) {
            ALLOC(DST) = (NLIMBS);
            PTR(DST)=(mp_limb_t *) realloc(PTR(DST),
                    (NLIMBS) * sizeof(mp_limb_t));
        }
    }
    SIZ(DST) = (NLIMBS);
    memcpy(PTR(DST),(SRC),(NLIMBS) * sizeof(mp_limb_t));
    MPN_NORMALIZE(PTR(DST),SIZ(DST));
}

void MPN_SET_MPZ(mp_limb_t * DST, size_t NLIMBS, mpz_srcptr SRC)
{
    mp_size_t r = MIN((size_t) ABS(SIZ(SRC)), NLIMBS);
    memcpy((DST),PTR(SRC),r * sizeof(mp_limb_t));
    memset((DST)+ABS(SIZ(SRC)),0,((NLIMBS)-r) * sizeof(mp_limb_t));
}

/********** borrowed from CADO: end **********/

void MPN_SWAP(mp_ptr a, mp_ptr b, mp_size_t size)
{
    mp_limb_t tmp;
    mp_size_t i;

    for(i = 0; i < size; i++, a++, b++){
	tmp = *a;
	*a = *b;
	*b = tmp;
    }
}

/* x += y, with enough place in the result. */
void gcd_incr(mp_ptr x, mp_size_t *xn, mp_ptr y, mp_size_t yn)
{
    if(*xn >= yn){
	if(mpn_add(x, x, *xn, y, yn) == 1){
	    x[*xn] = 1;
	    *xn += 1;
	}
	else
	    MPN_NORMALIZE(x, *xn);
    }
    else if(*xn == 0){
	MPN_COPY(x, y, yn);
	*xn = yn;
    }
    else{
	MPN_ZERO(x+(*xn), yn-(*xn));
	if(mpn_add(x, x, yn, y, yn) == 1)
	    x[yn++] = 1;
	*xn = yn;
    }
}

/* INPUT: yn >= zn, tmpn */
static inline void mpn_mul_aux(mp_ptr tmp, mp_size_t tmpn,
			       mp_ptr y, mp_size_t yn,
			       mp_ptr z, mp_size_t zn)
{
    if(zn == 0)
	MPN_ZERO(tmp, yn);
    else if(zn == 1){
	MPN_ZERO(tmp+yn, tmpn-yn);
	if(z[0] == 1){
	    MPN_COPY(tmp, y, yn);
	}
	else
	    tmp[yn] = mpn_mul_1(tmp, y, yn, z[0]);
    }
    else
	mpn_mul(tmp, y, yn, z, zn);
}

/** x += y * z.
    INPUT: size(x) must be > *xn, yn+zn.
    TODO: z = 1, z small.
*/
void gcd_addmul(mp_ptr x, mp_size_t *xn,
		mp_ptr y, mp_size_t yn,
		mp_ptr z, mp_size_t zn)
{
    mp_ptr tmp;
    mp_size_t tmpn = yn+zn;

#if DEBUG_UTILS >= 2
    printf("[addmul]\nx:="); MPN_PRINT(x, *xn);
    printf(";\ny:="); MPN_PRINT(y, yn);
    printf(";\nz:="); MPN_PRINT(z, zn);
    printf(";\n");
#endif
    /* tmp <- y*z */
    tmp = (mp_ptr)malloc(tmpn * sizeof(mp_limb_t));
    tmp[tmpn-1] = 0;
    if(zn >= yn)
	mpn_mul_aux(tmp, tmpn, z, zn, y, yn);
    else
	mpn_mul_aux(tmp, tmpn, y, yn, z, zn);
    MPN_NORMALIZE(tmp, tmpn);
#if DEBUG_UTILS >= 2
    printf("yz:="); MPN_PRINT(tmp, tmpn); printf(";\n");
#endif
    gcd_incr(x, xn, tmp, tmpn);
    free(tmp);
}

/* (q, r) <- (a div b, a mod b). 
   OUTPUT: qn
*/
mp_size_t euclidean_div_rem(mp_ptr q, mp_ptr r, mp_ptr a, mp_size_t an,
			    mp_ptr b, mp_size_t bn,
			    mp_ptr tp, mp_size_t tp_alloc)
{
    mp_size_t qn, rn;
    int cond, delta = (int)(an-bn);

    /* few cases with an=bn+1: 
       a[an-1]=1 and b[an-2] has msb quite large (63)? */
#if 1
    if(delta > 1){
	cond = 0;
#if COUNT_QEQ1 == 1
	nbdeltagt1++;
#endif
    }
    else if(delta == 1){
	mp_limb_t mask = ((mp_limb_t)1) << (mp_bits_per_limb-1);
	cond = (a[an-1] == (mp_limb_t)1) && (b[bn-1] & mask);
#if COUNT_QEQ1 == 1
	nbdeltaeq1++;
	if(cond) nbxor++;
#endif
    }
    else /* delta == 0 */{
	cond = (a[an-1] ^ b[an-1]) < b[an-1]; /* if yes, msb ranks are equal */
#if COUNT_QEQ1 == 1
	nbdeltaeq0++;
#endif
    }
#else
    cond = (delta == 0);
#endif
#if DEBUG_UTILS >= 1
    printf("a:="); MPN_PRINT(a, an);
    printf(";\nb:="); MPN_PRINT(b, bn);
    printf(";\n");
#endif
    if(cond != 0){
	qn = an-bn+1;
	/* perhaps q = 1: substract and see what happens */
	mpn_sub_n(r, a, b, an);
	if(mpn_cmp(r, b, an) < 0){
	    /* newa=a-b < b => q=1 */
#if COUNT_QEQ1 == 1
	    nbqeq1++;
#endif
	    q[0] = 1;
	    if(qn == 2){
		q[1] = 0;
		qn = 1;
	    }
#if DEBUG_UTILS >= 1
	    printf("q:=1;\n");
#endif
	}
	else{
	    rn = an;
	    MPN_NORMALIZE(r, rn);
	    qn = rn-bn+1; /* <= 2 */
	    mpn_tdiv_qr(q, r, 0, r, rn, b, bn);
	    if(qn == 2 && q[qn-1] == 0)
		qn--;
	    /* TODO: case where q[1] becomes 1 or q[1] > 0? */
	    q[qn] = mpn_add_1(q, q, 1, 1);
	    if(q[qn] != 0)
		qn++;
#if DEBUG_UTILS >= 1
	    printf("q:="); MPN_PRINT(q, qn); printf(";\n");
#endif
	}
    }
    else{
	mpn_tdiv_qr(q, r, 0, a, an, b, bn);
	qn = an-bn+1;
	MPN_NORMALIZE(q, qn);
	assert(qn != 0);
#if DEBUG_UTILS >= 1
	printf("q:="); MPN_PRINT(q, qn); printf(";\n");
#endif
    }
    return qn;
}

