#!/usr/bin/env python3

## FILL IN WHERE IT SAYS "MISSING CODE"

## compressed sensing on costly communication channels

import sys
import os
import math
import numpy as np
import cvxpy as cp
from math import sqrt
import time

################### configurable params #####################

myZero = 1e-9

######################### functions #########################

# generate a componentwise Normal(0,1) matrix
def normalmatrix(m, n):
    return np.random.normal([0],[[1]],(m,n))

# generate a componentwise Uniform(0,1) matrix
def uniformmatrix01(m, n):
    return np.random.rand(m,n)

# generate a componentwise Uniform(-1,1) matrix
def uniformmatrix(m, n):
    return 2*np.random.rand(m,n)-1

def AAIeval(AA, m):
    ## check orthogonality 
    print("checking orthogonality")
    nA = AA.shape[0]
    ell2err = np.linalg.norm(np.subtract(AA, np.eye(nA))) / nA**2
    ell2errdiag = np.linalg.norm(np.subtract(np.diag(AA),np.ones(nA))) / nA
    ell1err = np.linalg.norm(np.subtract(AA, np.eye(nA)),1) / nA**2
    ell1errdiag = np.linalg.norm(np.subtract(np.diag(AA),np.ones(nA)),1) / nA
    print("  avg ||AA - I||_1 = ", ell1err)
    print("  avg ||AA - I||_2 = ", ell2err)
    print("  avg ||diag(AA) - diag(I)||_1 =", ell1errdiag)
    print("  avg ||diag(AA) - diag(I)||_2 = ", ell2errdiag)
    print("  ell1err/sqrt(m)={0:.4g}, ell2err={1:.4g}".format(ell1err/sqrt(m), ell2err))
    print("  ell1errdiag/sqrt(m)={0:.4g}, ell2errdiag={1:.4g}".format(ell1errdiag / sqrt(m), ell2errdiag))
    return(ell1err, ell1errdiag, ell2err, ell2errdiag)

## this function returns the support of x as a list of integers (indices to nonzero components in x)
def supp(x):
    S = []
    ##### MISSING CODE
    return S

####################### MAIN #######################

## read command line
if len(sys.argv) < 5:
    print("syntax: " + sys.argv[0] + " m n dens distrib")
    print("  m = sample size")
    print("  n = signal size")
    print("  dens = signal density")
    print("  distrib = encoder distribution (uniform, normal)")
    exit(1)

t0 = time.time()

## read command line
m = int(sys.argv[1])
n = int(sys.argv[2])
dens = float(sys.argv[3])
distrib = sys.argv[4]
print("read m =", m, "n =", n, "density =", dens, distrib, "distribution")

## generate LP data
print("making instance data")
# generate xhat with support size round(dens*n)
##### MISSING CODE
print("  generated signal (xhat) with support size", support_size)

# generate A
if distrib == "uniform":
    ubnd = -2.0*m / 3.0
    ##### MISSING CODE
    print("  generated encoding matrix (A) from", ubnd, "* U(-1,1)")
else:
    ##### MISSING CODE
    print("  generated encoding matrix (A) from N(0,1)")

# compute b
##### MISSING CODE
print("  computed sample (b) with ||b||_2 =", np.linalg.norm(b))

## evaluate A
As = A / math.sqrt(m) 
AA = np.dot(As.T,As)
# evaluate orthogonality
errs = AAIeval(AA, m)

## formulate and solve LP
print("solving basis pursuit LP")
x = cp.Variable(n)
obj = cp.Minimize(cp.norm(x, 1))
constrs = [A*x - b == 0]
bp = cp.Problem(obj, constrs)
result = bp.solve(solver = cp.GLPK)
objfunval = bp.value
xstar = np.array(x.value)
T = supp(xstar)
decoderr1 = np.linalg.norm(np.subtract(xstar, xhat), 1)
decoderr2 = np.linalg.norm(np.subtract(xstar, xhat))
print("  ||xhat||_1={0:.4g}, optobjfunval={1:.4g}".format(np.linalg.norm(xhat,1), objfunval))
print("  |supp(xhat)|={0:d}, ||supp(xstar)|={1:d}".format(len(S), len(T)))
print("  ||xstar - xhat||_1={0:.4g}, ||xstar - xhat||_2={1:.4g}".format(decoderr1, decoderr2))
