package cern.colt.matrix.tfloat.algo.solver;

import cern.colt.list.tint.IntArrayList;
import cern.colt.matrix.tfloat.FloatMatrix1D;
import cern.colt.matrix.tfloat.FloatMatrix2D;
import cern.colt.matrix.tfloat.algo.DenseFloatAlgebra;
import cern.colt.matrix.tfloat.algo.solver.preconditioner.FloatIdentity;
import cern.jet.math.tfloat.FloatFunctions;

/* loaded from: input_file:parallelcolt-0.9.4.jar:cern/colt/matrix/tfloat/algo/solver/FloatMRNSD.class */
public class FloatMRNSD extends AbstractFloatIterativeSolver {
    private static final DenseFloatAlgebra alg = DenseFloatAlgebra.DEFAULT;
    public static final float sqrteps = (float) Math.sqrt(Math.pow(2.0d, -52.0d));

    public FloatMRNSD() {
        this.iter = new MRNSDFloatIterationMonitor();
        ((MRNSDFloatIterationMonitor) this.iter).setRelativeTolerance(-1.0f);
    }

    @Override // cern.colt.matrix.tfloat.algo.solver.FloatIterativeSolver
    public FloatMatrix1D solve(FloatMatrix2D floatMatrix2D, FloatMatrix1D floatMatrix1D, FloatMatrix1D floatMatrix1D2) throws IterativeSolverFloatNotConvergedException {
        FloatMatrix1D zMult;
        float aggregate;
        float sqrt;
        float sqrt2;
        if (!(this.iter instanceof MRNSDFloatIterationMonitor)) {
            this.iter = new MRNSDFloatIterationMonitor();
            ((MRNSDFloatIterationMonitor) this.iter).setRelativeTolerance(-1.0f);
        }
        float f = sqrteps;
        float f2 = floatMatrix1D2.getMinLocation()[0];
        if (f2 < 0.0f) {
            floatMatrix1D2.assign(FloatFunctions.plus((-f2) + f));
        }
        if (((MRNSDFloatIterationMonitor) this.iter).getRelativeTolerance() == -1.0d) {
            ((MRNSDFloatIterationMonitor) this.iter).setRelativeTolerance(sqrteps * alg.norm2(floatMatrix2D.zMult(floatMatrix1D, null, 1.0f, 0.0f, true)));
        }
        FloatMatrix1D zMult2 = floatMatrix2D.zMult(floatMatrix1D2, (FloatMatrix1D) null);
        zMult2.assign(floatMatrix1D, FloatFunctions.plusMultFirst(-1.0f));
        if (this.M instanceof FloatIdentity) {
            zMult = floatMatrix2D.zMult(zMult2, null, 1.0f, 0.0f, true);
            zMult.assign(FloatFunctions.neg);
            aggregate = floatMatrix1D2.aggregate(zMult, FloatFunctions.plus, FloatFunctions.multSquare);
            sqrt = (float) Math.sqrt(aggregate);
        } else {
            zMult = floatMatrix2D.zMult(this.M.transApply(this.M.apply(zMult2, null), null), null, 1.0f, 0.0f, true);
            zMult.assign(FloatFunctions.neg);
            aggregate = floatMatrix1D2.aggregate(zMult, FloatFunctions.plus, FloatFunctions.multSquare);
            sqrt = alg.norm2(zMult);
        }
        IntArrayList intArrayList = new IntArrayList((int) floatMatrix1D.size());
        this.iter.setFirst();
        while (!this.iter.converged(sqrt, floatMatrix1D2)) {
            FloatMatrix1D copy = floatMatrix1D2.copy();
            copy.assign(zMult, FloatFunctions.multNeg);
            FloatMatrix1D zMult3 = floatMatrix2D.zMult(copy, (FloatMatrix1D) null);
            if (!(this.M instanceof FloatIdentity)) {
                zMult3 = this.M.apply(zMult3, null);
            }
            float aggregate2 = aggregate / zMult3.aggregate(FloatFunctions.plus, FloatFunctions.square);
            copy.getNegativeValues(intArrayList, null);
            FloatMatrix1D copy2 = floatMatrix1D2.copy();
            copy2.assign(copy, FloatFunctions.divNeg, intArrayList);
            float min = Math.min(aggregate2, copy2.aggregate(FloatFunctions.min, FloatFunctions.identity, intArrayList));
            floatMatrix1D2.assign(copy, FloatFunctions.plusMultSecond(min));
            if (this.M instanceof FloatIdentity) {
                zMult.assign(floatMatrix2D.zMult(zMult3, null, 1.0f, 0.0f, true), FloatFunctions.plusMultSecond(min));
                aggregate = floatMatrix1D2.aggregate(zMult, FloatFunctions.plus, FloatFunctions.multSquare);
                sqrt2 = (float) Math.sqrt(aggregate);
            } else {
                zMult.assign(floatMatrix2D.zMult(this.M.transApply(zMult3, null), null, 1.0f, 0.0f, true), FloatFunctions.plusMultSecond(min));
                aggregate = floatMatrix1D2.aggregate(zMult, FloatFunctions.plus, FloatFunctions.multSquare);
                sqrt2 = alg.norm2(zMult);
            }
            sqrt = sqrt2;
            this.iter.next();
        }
        return floatMatrix1D2;
    }
}
