import numpy as np
import numpy.linalg as nl

def shrinkage(v, p):
    '''
    This is a shrinkage-thresholding operator acting on each element of the given matrix
    Shrinkage_p(v) with p>0, p and v can be matrix
    '''
    return (v-p)*((v-p)>0.)-(-v-p)*((-v-p)>0.)

def admmSSSC(X, Theta, lmbda, alpha, beta, rho):
    '''
    Use ADMM approach to solving SSSC-AMF optimization problem.
    '''
    ### Initialize ###
    K = len(X)
    # n, N are dimension and number of data points, respectively.
    n = []
    N = X[0].shape[1]
    J = []
    C = []
    E = []
    Z = []
    Y_1 = []
    Y_2 = []
    Y_3 = []
    Y_4 = []
    U = np.zeros((N,N))
    QQ = np.zeros((K, N*N))
    ZZ = np.zeros((K, N*N))
    
    # Pre-computation
    XtX = []
    P_m = []
    
    for k in range(K):
        dim = X[k].shape[0]
        n.append(dim) 
        J.append(np.zeros((N, N)))
        C.append(np.zeros((N, N)))
        E.append(np.zeros((n[k], N)))
        Z.append(np.zeros((N, N)))
        Y_1.append(np.zeros((n[k], N)))
        Y_2.append(np.zeros((N,N)))
        Y_3.append(np.zeros((N,N)))
        Y_4.append(np.zeros((N,1)))
    
    # set hyper-parameter
    mu = 1.0
    
    # set iteration parameters
    ABSTOL = 10**(-6)
    maxIter = 1000
    terminate = False
    converged = np.zeros((K, ), dtype = bool)
    count = 0
    
    # Pre-computation
    one=np.ones((N, 1))
    ones=np.ones((N, N))
    for k in range(K):   
        XtX.append(np.dot(X[k].T,X[k]))
        P_m.append(nl.inv(XtX[k]+np.eye(N)+ones))
    
    while terminate != True and count < maxIter:
        for k in range(K):
            # Update J
            U = C[k]+1./mu*Y_2[k]
            J[k] = 0.5*shrinkage(U+Z[k]-1./mu*Y_3[k], 1./mu*(1.+alpha*Theta[k]))    # i!=j
            temp = shrinkage(np.diag(Z[k])-1./mu*np.diag(Y_3[k]), 1./mu*(1.+alpha*np.diag(Theta[k])))  # i==j
            J[k] = J[k]*(~np.eye(J[k].shape[0], dtype=bool))+np.diag(temp)
            
            # Update C
            C[k] = np.dot(P_m[k], np.dot(X[k].T, X[k]-E[k]+1./mu*Y_1[k])
            +J[k]-np.diag(np.diag(J[k]))-1./mu*Y_2[k]
            -1./mu*np.dot(one, Y_4[k].T)+ones)      
            
            # Update E
            V = X[k]-np.dot(X[k], C[k])+1./mu*Y_1[k]
            E[k] = shrinkage(V, lmbda[k]/mu)
          
        # Update Z_concat
        coef = beta/mu
        for k in range(K):
            QQ[k, :] = np.ravel(J[k]+Y_3[k]/mu)
        
        QQ_norm = nl.norm(QQ, ord = 2, axis = 0)
        QQ_norm_ = QQ_norm-coef

        with np.errstate(divide = 'ignore', invalid = 'ignore'):
            ZZ = (QQ_norm_>0)*QQ_norm_/QQ_norm*QQ
        ZZ = np.nan_to_num(ZZ)
            
        # Update Z
        for k in range(K):
            Z[k] = np.reshape(ZZ[k, :], (N, N))
        
        # Update Y_1, Y_2, Y_3 and Y_4
        for k in range(K):
            Y_1[k] += mu*(X[k]-np.dot(X[k], C[k])-E[k])
            Y_2[k] += mu*(C[k]-J[k]+np.diag(np.diag(J[k])))
            Y_3[k] += mu*(J[k]-Z[k])
            Y_4[k] += mu*(np.dot(C[k].T, one)-one)
            
        # Update mu
        mu *= rho
        ''' '''
        # Update counter
        count = count+1
        
        
        for k in range(K):
            err1 = nl.norm(X[k]-np.dot(X[k], C[k])-E[k], ord=np.inf)
            if (err1 < ABSTOL):
                converged[k] = True
        print 'err1: '+np.str(err1)+', iter: '+np.str(count)+'\n'
        terminate = np.prod(converged == True, dtype = bool)
    return Z
    