/*
 * Decompiled with CFR 0.152.
 */
package no.uib.cipr.matrix.distributed;

import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.Exchanger;
import no.uib.cipr.matrix.distributed.Communicator;
import no.uib.cipr.matrix.distributed.Reduction;

public class CollectiveCommunications {
    final int size;
    private final List<List<Exchanger<Communicator.SendRecv>>> ex;
    private final CyclicBarrier barrier;
    private final Broadcast broadcast;
    private final Gather gather;
    private final Scatter scatter;
    private final AllGather allGather;
    private final AllToAll allToAll;
    private final Reduce reduce;

    public CollectiveCommunications(int size) {
        if (size < 1) {
            throw new IllegalArgumentException("size < 1");
        }
        this.size = size;
        this.barrier = new CyclicBarrier(size);
        this.broadcast = new Broadcast();
        this.gather = new Gather();
        this.scatter = new Scatter();
        this.allGather = new AllGather();
        this.allToAll = new AllToAll();
        this.reduce = new Reduce();
        this.ex = new ArrayList<List<Exchanger<Communicator.SendRecv>>>();
        for (int i = 0; i < size; ++i) {
            int j;
            ArrayList iex = new ArrayList();
            for (j = 0; j < i; ++j) {
                iex.add(this.ex.get(j).get(i));
            }
            iex.add(null);
            for (j = i + 1; j < size; ++j) {
                iex.add(new Exchanger());
            }
            this.ex.add(iex);
        }
    }

    public int size() {
        return this.size;
    }

    public Communicator createCommunicator(int rank) {
        if (rank < 0 || rank >= this.size) {
            throw new IllegalArgumentException("rank < 0 || rank >= size");
        }
        return new Communicator(rank, this.ex.get(rank), this);
    }

    static void await(CyclicBarrier barrier) {
        try {
            barrier.await();
        }
        catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
        catch (BrokenBarrierException e) {
            throw new RuntimeException(e);
        }
    }

    void barrier() {
        CollectiveCommunications.await(this.barrier);
    }

    void broadcast(Object buffer, int root, int rank) {
        this.broadcast.buffer[rank] = buffer;
        if (rank == root) {
            this.broadcast.root = root;
        }
        CollectiveCommunications.await(this.broadcast.barrier);
    }

    void gather(Object sendbuf, Object[] recvbuf, int root, int rank) {
        this.gather.setSend(sendbuf, rank);
        if (rank == root) {
            this.gather.recvbuf = recvbuf;
        }
        CollectiveCommunications.await(this.gather.barrier);
    }

    void scatter(Object[] sendbuf, Object recvbuf, int root, int rank) {
        this.scatter.setRecv(recvbuf, rank);
        if (rank == root) {
            this.scatter.sendbuf = sendbuf;
        }
        CollectiveCommunications.await(this.scatter.barrier);
    }

    void allGather(Object sendbuf, Object[] recvbuf, int rank) {
        this.allGather.setSendRecv(sendbuf, recvbuf, rank);
        CollectiveCommunications.await(this.allGather.barrier);
    }

    void allToAll(Object[] sendbuf, Object[] recvbuf, int rank) {
        this.allToAll.setSendRecv(sendbuf, recvbuf, rank);
        CollectiveCommunications.await(this.allToAll.barrier);
    }

    void reduce(Object sendbuf, Object recvbuf, Reduction op, int root, int rank) {
        this.reduce.sendbuf[rank] = sendbuf;
        if (rank == root) {
            this.reduce.op = op;
            this.reduce.recvbuf = recvbuf;
        }
        CollectiveCommunications.await(this.reduce.barrier);
    }

    void allReduce(Object sendbuf, Object recvbuf, Reduction op, int rank) {
        this.reduce(sendbuf, recvbuf, op, 0, rank);
        this.broadcast(recvbuf, 0, rank);
    }

    private class Reduce
    implements Runnable {
        CyclicBarrier barrier;
        Reduction op;
        Object[] sendbuf;
        Object recvbuf;

        private Reduce() {
            this.barrier = new CyclicBarrier(CollectiveCommunications.this.size, this);
            this.sendbuf = new Object[CollectiveCommunications.this.size];
        }

        public void run() {
            this.op.init(this.recvbuf);
            for (int i = 0; i < CollectiveCommunications.this.size; ++i) {
                this.op.op(this.recvbuf, this.sendbuf[i]);
            }
        }
    }

    private class AllToAll
    implements Runnable {
        final CyclicBarrier barrier;
        private final Object[][] sendbuf;
        private final Object[][] recvbuf;
        private final int[][] length;

        private AllToAll() {
            this.barrier = new CyclicBarrier(CollectiveCommunications.this.size, this);
            this.sendbuf = new Object[CollectiveCommunications.this.size][CollectiveCommunications.this.size];
            this.recvbuf = new Object[CollectiveCommunications.this.size][CollectiveCommunications.this.size];
            this.length = new int[CollectiveCommunications.this.size][CollectiveCommunications.this.size];
        }

        public void setSendRecv(Object[] send, Object[] recv, int rank) {
            this.sendbuf[rank] = send;
            this.recvbuf[rank] = recv;
            for (int i = 0; i < CollectiveCommunications.this.size; ++i) {
                this.length[rank][i] = Array.getLength(send[i]);
            }
        }

        public void run() {
            for (int i = 0; i < CollectiveCommunications.this.size; ++i) {
                for (int j = 0; j < CollectiveCommunications.this.size; ++j) {
                    System.arraycopy(this.sendbuf[i][j], 0, this.recvbuf[j][i], 0, this.length[i][j]);
                }
            }
        }
    }

    private class AllGather
    implements Runnable {
        final CyclicBarrier barrier;
        private final Object[] sendbuf;
        private final Object[][] recvbuf;
        private final int[] length;

        private AllGather() {
            this.barrier = new CyclicBarrier(CollectiveCommunications.this.size, this);
            this.sendbuf = new Object[CollectiveCommunications.this.size];
            this.recvbuf = new Object[CollectiveCommunications.this.size][CollectiveCommunications.this.size];
            this.length = new int[CollectiveCommunications.this.size];
        }

        public void setSendRecv(Object send, Object[] recv, int rank) {
            this.sendbuf[rank] = send;
            this.recvbuf[rank] = recv;
            this.length[rank] = Array.getLength(send);
        }

        public void run() {
            for (int i = 0; i < CollectiveCommunications.this.size; ++i) {
                for (int j = 0; j < CollectiveCommunications.this.size; ++j) {
                    System.arraycopy(this.sendbuf[i], 0, this.recvbuf[j][i], 0, this.length[i]);
                }
            }
        }
    }

    private class Scatter
    implements Runnable {
        final CyclicBarrier barrier;
        Object[] sendbuf;
        Object[] recvbuf;
        private final int[] length;

        private Scatter() {
            this.barrier = new CyclicBarrier(CollectiveCommunications.this.size, this);
            this.sendbuf = new Object[CollectiveCommunications.this.size];
            this.recvbuf = new Object[CollectiveCommunications.this.size];
            this.length = new int[CollectiveCommunications.this.size];
        }

        public void setRecv(Object recvbuf, int rank) {
            this.recvbuf[rank] = recvbuf;
            this.length[rank] = Array.getLength(recvbuf);
        }

        public void run() {
            for (int i = 0; i < CollectiveCommunications.this.size; ++i) {
                System.arraycopy(this.sendbuf[i], 0, this.recvbuf[i], 0, this.length[i]);
            }
        }
    }

    private class Gather
    implements Runnable {
        CyclicBarrier barrier;
        Object[] recvbuf;
        Object[] sendbuf;
        private final int[] length;

        private Gather() {
            this.barrier = new CyclicBarrier(CollectiveCommunications.this.size, this);
            this.recvbuf = new Object[CollectiveCommunications.this.size];
            this.sendbuf = new Object[CollectiveCommunications.this.size];
            this.length = new int[CollectiveCommunications.this.size];
        }

        public void setSend(Object sendbuf, int rank) {
            this.sendbuf[rank] = sendbuf;
            this.length[rank] = Array.getLength(sendbuf);
        }

        public void run() {
            for (int i = 0; i < CollectiveCommunications.this.size; ++i) {
                System.arraycopy(this.sendbuf[i], 0, this.recvbuf[i], 0, this.length[i]);
            }
        }
    }

    private class Broadcast
    implements Runnable {
        final CyclicBarrier barrier;
        int root;
        final Object[] buffer;

        private Broadcast() {
            this.barrier = new CyclicBarrier(CollectiveCommunications.this.size, this);
            this.buffer = new Object[CollectiveCommunications.this.size];
        }

        public void run() {
            int length = Array.getLength(this.buffer[this.root]);
            for (int i = 0; i < CollectiveCommunications.this.size; ++i) {
                System.arraycopy(this.buffer[this.root], 0, this.buffer[i], 0, length);
            }
        }
    }
}

