/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
import org.apache.sysds.runtime.compress.lib.CLALibUtils;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.utils.DMLCompressionStatistics;
import org.apache.sysds.utils.stats.Timing;

public final class CLALibDecompress {
    private static final Log LOG = LogFactory.getLog((String)CLALibDecompress.class.getName());

    private CLALibDecompress() {
    }

    public static MatrixBlock decompress(CompressedMatrixBlock cmb, int k) {
        Timing time = new Timing(true);
        MatrixBlock ret = CLALibDecompress.decompressExecute(cmb, k);
        if (DMLScript.STATISTICS) {
            double t = time.stop();
            DMLCompressionStatistics.addDecompressTime(t, k);
            if (LOG.isTraceEnabled()) {
                LOG.trace((Object)("decompressed block w/ k=" + k + " in " + t + "ms."));
            }
        }
        return ret;
    }

    public static void decompressTo(CompressedMatrixBlock cmb, MatrixBlock ret, int rowOffset, int colOffset, int k, boolean countNNz) {
        CLALibDecompress.decompressTo(cmb, ret, rowOffset, colOffset, k, countNNz, false);
    }

    public static void decompressTo(CompressedMatrixBlock cmb, MatrixBlock ret, int rowOffset, int colOffset, int k, boolean countNNz, boolean reset) {
        Timing time = new Timing(true);
        if (cmb.getNumColumns() + colOffset > ret.getNumColumns() || cmb.getNumRows() + rowOffset > ret.getNumRows()) {
            LOG.warn((Object)"Slow slicing off excess parts for decompressTo because decompression into is implemented for fitting blocks");
            MatrixBlock mbSliced = cmb.slice(Math.min(Math.abs(rowOffset), 0), Math.min(cmb.getNumRows(), ret.getNumRows() - rowOffset) - 1, Math.min(Math.abs(colOffset), 0), Math.min(cmb.getNumColumns(), ret.getNumColumns() - colOffset) - 1);
            mbSliced.putInto(ret, rowOffset, colOffset, false);
            return;
        }
        boolean outSparse = ret.isInSparseFormat();
        if (!cmb.isEmpty()) {
            if (outSparse && (cmb.isOverlapping() || reset)) {
                throw new DMLCompressionException("Not supported decompression into sparse block from overlapping state");
            }
            if (outSparse) {
                CLALibDecompress.decompressToSparseBlock(cmb, ret, rowOffset, colOffset);
            } else {
                CLALibDecompress.decompressToDenseBlock(cmb, ret.getDenseBlock(), rowOffset, colOffset, k, reset);
            }
        }
        if (DMLScript.STATISTICS) {
            double t = time.stop();
            DMLCompressionStatistics.addDecompressToBlockTime(t, k);
            if (LOG.isTraceEnabled()) {
                LOG.trace((Object)("decompressed block w/ k=" + k + " in " + t + "ms."));
            }
        }
        if (countNNz) {
            ret.recomputeNonZeros(k);
        }
    }

    private static void decompressToSparseBlock(CompressedMatrixBlock cmb, MatrixBlock ret, int rowOffset, int colOffset) {
        SparseBlock sb = ret.getSparseBlock();
        List<AColGroup> groups = cmb.getColGroups();
        int nRows = cmb.getNumRows();
        boolean shouldFilter = CLALibUtils.shouldPreFilter(groups);
        if (shouldFilter) {
            MatrixBlock tmp = cmb.getUncompressed("Decompression to put into Sparse Block");
            tmp.putInto(ret, rowOffset, colOffset, false);
        } else {
            for (AColGroup g : groups) {
                g.decompressToSparseBlock(sb, 0, nRows, rowOffset, colOffset);
            }
        }
        sb.sort();
        ret.checkSparseRows();
    }

    private static void decompressToDenseBlock(CompressedMatrixBlock cmb, DenseBlock ret, int rowOffset, int colOffset, int k, boolean reset) {
        List<AColGroup> groups = cmb.getColGroups();
        int nRows = cmb.getNumRows();
        boolean shouldFilter = CLALibUtils.shouldPreFilter(groups);
        if (shouldFilter && !CLALibUtils.alreadyPreFiltered(groups, cmb.getNumColumns())) {
            double[] constV = new double[cmb.getNumColumns()];
            groups = CLALibUtils.filterGroups(groups, constV);
            AColGroup cRet = ColGroupConst.create(constV);
            groups.add(cRet);
        }
        if (k > 1 && nRows > 1000) {
            CLALibDecompress.decompressToDenseBlockParallel(ret, groups, rowOffset, colOffset, nRows, k, reset);
        } else {
            CLALibDecompress.decompressToDenseBlockSingleThread(ret, groups, rowOffset, colOffset, nRows, reset);
        }
    }

    private static void decompressToDenseBlockSingleThread(DenseBlock ret, List<AColGroup> groups, int rowOffset, int colOffset, int nRows, boolean reset) {
        CLALibDecompress.decompressToDenseBlockBlock(ret, groups, rowOffset, colOffset, 0, nRows, reset);
    }

    private static void decompressToDenseBlockBlock(DenseBlock ret, List<AColGroup> groups, int rowOffset, int colOffset, int rl, int ru, boolean reset) {
        if (reset) {
            if (ret.isContiguous()) {
                int nCol = ret.getDim(1);
                ret.fillBlock(0, rl * nCol, ru * nCol, 0.0);
            } else {
                throw new NotImplementedException();
            }
        }
        for (AColGroup g : groups) {
            g.decompressToDenseBlock(ret, rl, ru, rowOffset, colOffset);
        }
    }

    private static void decompressToDenseBlockParallel(DenseBlock ret, List<AColGroup> groups, int rowOffset, int colOffset, int nRows, int k, boolean reset) {
        int blklen = Math.max(nRows / k, 512);
        ExecutorService pool = CommonThreadPool.get(k);
        try {
            ArrayList tasks = new ArrayList(nRows / blklen);
            for (int r = 0; r < nRows; r += blklen) {
                int n = r;
                int end = Math.min(nRows, r + blklen);
                tasks.add(pool.submit(() -> CLALibDecompress.decompressToDenseBlockBlock(ret, groups, rowOffset, colOffset, start, end, reset)));
            }
            for (Future future : tasks) {
                future.get();
            }
        }
        catch (Exception e) {
            throw new DMLCompressionException("Failed parallel decompress to");
        }
        finally {
            pool.shutdown();
        }
    }

    private static MatrixBlock decompressExecute(CompressedMatrixBlock cmb, int k) {
        ArrayList<AColGroup> filteredGroups;
        if (cmb.isEmpty()) {
            return new MatrixBlock(cmb.getNumRows(), cmb.getNumColumns(), true);
        }
        ArrayList<AColGroup> groups = new ArrayList<AColGroup>(cmb.getColGroups());
        int nRows = cmb.getNumRows();
        int nCols = cmb.getNumColumns();
        boolean overlapping = cmb.isOverlapping();
        long nonZeros = cmb.getNonZeros();
        MatrixBlock ret = CLALibDecompress.getUncompressedColGroupAndRemoveFromListOfColGroups(groups, overlapping, nRows, nCols);
        if (ret != null && groups.size() == 0) {
            ret.setNonZeros(ret.recomputeNonZeros(k));
            return ret;
        }
        boolean shouldFilter = CLALibUtils.shouldPreFilterMorphOrRef(groups);
        double[] constV = shouldFilter ? new double[nCols] : null;
        List<AColGroup> list = filteredGroups = shouldFilter ? CLALibUtils.filterGroups(groups, constV) : groups;
        if (ret == null) {
            boolean sparse = !shouldFilter && !overlapping && MatrixBlock.evalSparseFormatInMemory(nRows, nCols, nonZeros);
            ret = new MatrixBlock(nRows, nCols, sparse);
            if (sparse) {
                ret.allocateSparseRowsBlock();
            } else {
                ret.allocateDenseBlock();
            }
            if (MatrixBlock.evalSparseFormatInMemory(nRows, nCols, nonZeros) && !sparse) {
                LOG.warn((Object)("Decompressing into dense but reallocating after to sparse: overlapping - " + overlapping + ", filter - " + shouldFilter));
            }
        } else {
            MatrixBlock tmp = new MatrixBlock();
            tmp.copy(ret);
            ret = tmp;
        }
        int blklen = Math.max(nRows / k, 512);
        if (groups == filteredGroups) {
            constV = null;
        }
        double eps = CLALibDecompress.getEps(constV);
        if (k == 1) {
            if (ret.isInSparseFormat()) {
                CLALibDecompress.decompressSparseSingleThread(ret, filteredGroups, nRows, blklen);
            } else {
                CLALibDecompress.decompressDenseSingleThread(ret, filteredGroups, nRows, blklen, constV, eps, overlapping);
            }
        } else if (ret.isInSparseFormat()) {
            CLALibDecompress.decompressSparseMultiThread(ret, filteredGroups, nRows, blklen, k);
        } else {
            CLALibDecompress.decompressDenseMultiThread(ret, filteredGroups, nRows, blklen, constV, eps, k, overlapping);
        }
        ret.recomputeNonZeros(k);
        ret.examSparsity();
        return ret;
    }

    private static MatrixBlock getUncompressedColGroupAndRemoveFromListOfColGroups(List<AColGroup> colGroups, boolean overlapping, int nRows, int nCols) {
        MatrixBlock ret = null;
        if (overlapping || colGroups.size() == 1) {
            for (int i = 0; i < colGroups.size(); ++i) {
                ColGroupUncompressed guc;
                MatrixBlock gMB;
                AColGroup g = colGroups.get(i);
                if (!(g instanceof ColGroupUncompressed) || (gMB = (guc = (ColGroupUncompressed)g).getData()).getNumColumns() != nCols || gMB.getNumRows() != nRows || gMB.isInSparseFormat() && colGroups.size() != 1) continue;
                colGroups.remove(i);
                LOG.debug((Object)"Using one of the uncompressed ColGroups as base for decompression");
                return gMB;
            }
        }
        return ret;
    }

    private static void decompressSparseSingleThread(MatrixBlock ret, List<AColGroup> filteredGroups, int rlen, int blklen) {
        SparseBlock sb = ret.getSparseBlock();
        for (int i = 0; i < rlen; i += blklen) {
            int rl = i;
            int ru = Math.min(i + blklen, rlen);
            for (AColGroup grp : filteredGroups) {
                grp.decompressToSparseBlock(ret.getSparseBlock(), rl, ru);
            }
            for (int j = rl; j < ru; ++j) {
                if (sb.isEmpty(j)) continue;
                sb.sort(j);
            }
        }
    }

    private static void decompressDenseSingleThread(MatrixBlock ret, List<AColGroup> filteredGroups, int rlen, int blklen, double[] constV, double eps, boolean overlapping) {
        DenseBlock db = ret.getDenseBlock();
        int nCol = ret.getNumColumns();
        for (int i = 0; i < rlen; i += blklen) {
            int rl = i;
            int ru = Math.min(i + blklen, rlen);
            for (AColGroup grp : filteredGroups) {
                grp.decompressToDenseBlock(db, rl, ru);
            }
            if (constV == null) continue;
            CLALibDecompress.addVector(db, nCol, constV, eps, rl, ru);
        }
    }

    public static void decompressDense(MatrixBlock ret, List<AColGroup> groups, double[] constV, double eps, int k, boolean overlapping) {
        Timing time = new Timing(true);
        int nRows = ret.getNumRows();
        int blklen = Math.max(nRows / k, 512);
        if (k > 1) {
            CLALibDecompress.decompressDenseMultiThread(ret, groups, nRows, blklen, constV, eps, k, overlapping);
        } else {
            CLALibDecompress.decompressDenseSingleThread(ret, groups, nRows, blklen, constV, eps, overlapping);
        }
        ret.recomputeNonZeros(k);
        if (DMLScript.STATISTICS) {
            double t = time.stop();
            DMLCompressionStatistics.addDecompressTime(t, k);
            if (LOG.isTraceEnabled()) {
                LOG.trace((Object)("decompressed block w/ k=" + k + " in " + t + "ms."));
            }
        }
    }

    private static void decompressDenseMultiThread(MatrixBlock ret, List<AColGroup> filteredGroups, int rlen, int blklen, double[] constV, double eps, int k, boolean overlapping) {
        ExecutorService pool = CommonThreadPool.get(k);
        try {
            int i;
            ArrayList<Callable<Long>> tasks = new ArrayList<Callable<Long>>();
            if (overlapping || constV != null) {
                for (i = 0; i < rlen; i += blklen) {
                    tasks.add(new DecompressDenseTask(filteredGroups, ret, eps, i, Math.min(i + blklen, rlen), constV));
                }
            } else {
                for (i = 0; i < rlen; i += blklen) {
                    for (AColGroup g : filteredGroups) {
                        tasks.add(new DecompressDenseSingleColTask(g, ret, eps, i, Math.min(i + blklen, rlen), null));
                    }
                }
            }
            long nnz = 0L;
            for (Future rt : pool.invokeAll(tasks)) {
                nnz += ((Long)rt.get()).longValue();
            }
            ret.setNonZeros(nnz);
        }
        catch (InterruptedException | ExecutionException ex) {
            throw new DMLCompressionException("Parallel decompression failed", ex);
        }
        finally {
            pool.shutdown();
        }
    }

    private static void decompressSparseMultiThread(MatrixBlock ret, List<AColGroup> filteredGroups, int rlen, int blklen, int k) {
        ExecutorService pool = CommonThreadPool.get(k);
        try {
            ArrayList<DecompressSparseTask> tasks = new ArrayList<DecompressSparseTask>();
            for (int i = 0; i < rlen; i += blklen) {
                tasks.add(new DecompressSparseTask(filteredGroups, ret, i, Math.min(i + blklen, rlen)));
            }
            for (Future rt : pool.invokeAll(tasks)) {
                rt.get();
            }
        }
        catch (InterruptedException | ExecutionException ex) {
            throw new DMLCompressionException("Parallel decompression failed", ex);
        }
        finally {
            pool.shutdown();
        }
    }

    private static double getEps(double[] constV) {
        if (constV == null) {
            return 0.0;
        }
        double max = -1.7976931348623157E308;
        double min = Double.MAX_VALUE;
        for (double v : constV) {
            if (v > max && Double.isFinite(v)) {
                max = v;
            }
            if (!(v < min) || !Double.isFinite(v)) continue;
            min = v;
        }
        double eps = (max + 1.0E-4 - min) * 1.0E-10;
        return eps;
    }

    private static final void addVector(DenseBlock db, int nCols, double[] rowV, double eps, int rl, int ru) {
        if (eps == 0.0) {
            CLALibDecompress.addVectorEps(db, nCols, rowV, eps, rl, ru);
        } else {
            CLALibDecompress.addVectorNoEps(db, nCols, rowV, eps, rl, ru);
        }
    }

    private static final void addVectorEps(DenseBlock db, int nCols, double[] rowV, double eps, int rl, int ru) {
        if (nCols == 1) {
            CLALibDecompress.addValue(db.values(0), rowV[0], rl, ru);
        } else if (db.isContiguous()) {
            CLALibDecompress.addVectorContiguousNoEps(db.values(0), rowV, nCols, rl, ru);
        } else {
            CLALibDecompress.addVectorNoEps(db, rowV, nCols, rl, ru);
        }
    }

    private static final void addVectorNoEps(DenseBlock db, int nCols, double[] rowV, double eps, int rl, int ru) {
        if (nCols == 1) {
            CLALibDecompress.addValueEps(db.values(0), rowV[0], eps, rl, ru);
        } else if (db.isContiguous()) {
            CLALibDecompress.addVectorContiguousEps(db.values(0), rowV, nCols, eps, rl, ru);
        } else {
            CLALibDecompress.addVectorEps(db, rowV, nCols, eps, rl, ru);
        }
    }

    private static void addValue(double[] retV, double v, int rl, int ru) {
        int off = rl;
        while (off < ru) {
            int n = off++;
            retV[n] = retV[n] + v;
        }
    }

    private static void addValueEps(double[] retV, double v, double eps, int rl, int ru) {
        for (int off = rl; off < ru; ++off) {
            double e = retV[off] + v;
            retV[off] = Math.abs(e) <= eps ? 0.0 : e;
        }
    }

    private static void addVectorContiguousNoEps(double[] retV, double[] rowV, int nCols, int rl, int ru) {
        for (int off = rl * nCols; off < ru * nCols; off += nCols) {
            for (int col = 0; col < nCols; ++col) {
                int out;
                int n = out = off + col;
                retV[n] = retV[n] + rowV[col];
            }
        }
    }

    private static void addVectorContiguousEps(double[] retV, double[] rowV, int nCols, double eps, int rl, int ru) {
        for (int off = rl * nCols; off < ru * nCols; off += nCols) {
            for (int col = 0; col < nCols; ++col) {
                int out;
                int n = out = off + col;
                retV[n] = retV[n] + rowV[col];
                if (!(Math.abs(retV[out]) <= eps)) continue;
                retV[out] = 0.0;
            }
        }
    }

    private static void addVectorNoEps(DenseBlock db, double[] rowV, int nCols, int rl, int ru) {
        for (int row = rl; row < ru; ++row) {
            double[] _retV = db.values(row);
            int off = db.pos(row);
            for (int col = 0; col < nCols; ++col) {
                int n = off + col;
                _retV[n] = _retV[n] + rowV[col];
            }
        }
    }

    private static void addVectorEps(DenseBlock db, double[] rowV, int nCols, double eps, int rl, int ru) {
        for (int row = rl; row < ru; ++row) {
            double[] _retV = db.values(row);
            int off = db.pos(row);
            for (int col = 0; col < nCols; ++col) {
                int out;
                int n = out = off + col;
                _retV[n] = _retV[n] + rowV[col];
                if (!(Math.abs(_retV[out]) <= eps)) continue;
                _retV[out] = 0.0;
            }
        }
    }

    private static class DecompressSparseTask
    implements Callable<Object> {
        private final List<AColGroup> _colGroups;
        private final MatrixBlock _ret;
        private final int _rl;
        private final int _ru;

        protected DecompressSparseTask(List<AColGroup> colGroups, MatrixBlock ret, int rl, int ru) {
            this._colGroups = colGroups;
            this._ret = ret;
            this._rl = rl;
            this._ru = ru;
        }

        @Override
        public Object call() throws Exception {
            try {
                SparseBlock sb = this._ret.getSparseBlock();
                for (AColGroup grp : this._colGroups) {
                    grp.decompressToSparseBlock(this._ret.getSparseBlock(), this._rl, this._ru);
                }
                for (int i = this._rl; i < this._ru; ++i) {
                    if (sb.isEmpty(i)) continue;
                    sb.sort(i);
                }
                return null;
            }
            catch (Exception e) {
                e.printStackTrace();
                throw new DMLRuntimeException(e);
            }
        }
    }

    private static class DecompressDenseSingleColTask
    implements Callable<Long> {
        private final AColGroup _grp;
        private final MatrixBlock _ret;
        private final double _eps;
        private final int _rl;
        private final int _ru;
        private final int _blklen;
        private final double[] _constV;

        protected DecompressDenseSingleColTask(AColGroup grp, MatrixBlock ret, double eps, int rl, int ru, double[] constV) {
            this._grp = grp;
            this._ret = ret;
            this._eps = eps;
            this._rl = rl;
            this._ru = ru;
            this._blklen = Math.max(32768 / ret.getNumColumns(), 128);
            this._constV = constV;
        }

        @Override
        public Long call() {
            try {
                DenseBlock db = this._ret.getDenseBlock();
                int nCol = this._ret.getNumColumns();
                long nnz = 0L;
                for (int b = this._rl; b < this._ru; b += this._blklen) {
                    int e = Math.min(b + this._blklen, this._ru);
                    this._grp.decompressToDenseBlock(db, b, e);
                    if (this._constV == null) continue;
                    CLALibDecompress.addVector(db, nCol, this._constV, this._eps, b, e);
                }
                return nnz;
            }
            catch (Exception e) {
                e.printStackTrace();
                throw new DMLCompressionException("Failed dense decompression", e);
            }
        }
    }

    private static class DecompressDenseTask
    implements Callable<Long> {
        private final List<AColGroup> _colGroups;
        private final MatrixBlock _ret;
        private final double _eps;
        private final int _rl;
        private final int _ru;
        private final int _blklen;
        private final double[] _constV;

        protected DecompressDenseTask(List<AColGroup> colGroups, MatrixBlock ret, double eps, int rl, int ru, double[] constV) {
            this._colGroups = colGroups;
            this._ret = ret;
            this._eps = eps;
            this._rl = rl;
            this._ru = ru;
            this._blklen = Math.max(32768 / ret.getNumColumns(), 128);
            this._constV = constV;
        }

        @Override
        public Long call() {
            try {
                DenseBlock db = this._ret.getDenseBlock();
                int nCol = this._ret.getNumColumns();
                long nnz = 0L;
                for (int b = this._rl; b < this._ru; b += this._blklen) {
                    int e = Math.min(b + this._blklen, this._ru);
                    for (AColGroup grp : this._colGroups) {
                        grp.decompressToDenseBlock(db, b, e);
                    }
                    if (this._constV != null) {
                        CLALibDecompress.addVector(db, nCol, this._constV, this._eps, b, e);
                    }
                    nnz += this._ret.recomputeNonZeros(b, e - 1);
                }
                return nnz;
            }
            catch (Exception e) {
                e.printStackTrace();
                throw new DMLCompressionException("Failed dense decompression", e);
            }
        }
    }
}

