import copy
import numpy as np
import numpy.linalg as nl
import admmSSSC
from spl_clustering import spl_clustering
from mapLabels import mapLabels

def SSSC(X, alpha, beta, rho, n_clusters, maxInter=3):
    N = X[0].shape[1]
    K = len(X)
    Theta = []
    index = []
    lmbda = np.zeros((K,))  # To be modified according to rules
    
    # setting balancing parameters lmbda[k]
    inf=10**10
    for k in range(K):       
        mu_e=np.min(np.max(np.abs(np.dot(X[k].T, X[k]))-np.eye(N)*(inf), axis=1)) # setting balancing parameters
        # if mu_e == 0
        mu_e = np.max([mu_e, 0.1])
        alpha_e = 100 # should be greater than 1
        T = np.dot(np.ones((N,X[k].shape[0])), np.abs(X[k]))
        T = T - np.diag(np.diag(T))
        mu_e = np.min(np.max(T,axis=1), axis=0)
        lmbda[k] = alpha_e*1./mu_e
        index.append(np.zeros((N,)))
        Theta.append(np.zeros((N, N)))
    
    # set iteration parameters
    terminate = False
    converged = np.zeros((K, ), dtype = bool)
    count = 0
    
    # SSSC algorithm 
    while terminate == False and count < maxInter:
        Theta_old = copy.deepcopy(Theta)
        Z = admmSSSC.admmSSSC(X, Theta, lmbda, alpha, beta, rho)
        for k in range(K):
            index[k] = spl_clustering(Z[k], n_clusters)
            index_temp = np.reshape(index[k], (N, 1))
            Theta[k] = (index_temp != index_temp.T)*1
            difference = nl.norm(Theta[k] - Theta_old[k], ord=np.inf)
            if difference < 1:
                converged[k] = True
        terminate = np.prod(converged == True, dtype = bool)
        count += 1
    return Z
