/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.shuffle.celeborn;

import java.io.IOException;
import java.io.OutputStream;
import java.util.concurrent.atomic.LongAdder;
import org.apache.celeborn.client.ShuffleClient;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.exception.CelebornIOException;
import org.apache.celeborn.common.util.Utils;
import org.apache.celeborn.shaded.com.google.common.annotations.VisibleForTesting;
import org.apache.spark.Aggregator;
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
import org.apache.spark.annotation.Private;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.serializer.SerializationStream;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.shuffle.celeborn.CelebornShuffleHandle;
import org.apache.spark.shuffle.celeborn.OpenByteArrayOutputStream;
import org.apache.spark.shuffle.celeborn.SendBufferPool;
import org.apache.spark.shuffle.celeborn.SortBasedPusher;
import org.apache.spark.shuffle.celeborn.SparkUtils;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.execution.UnsafeRowSerializer;
import org.apache.spark.sql.execution.metric.SQLMetric;
import org.apache.spark.storage.BlockManagerId;
import org.apache.spark.unsafe.Platform;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Option;
import scala.Product2;
import scala.collection.Iterator;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;

@Private
public class SortBasedShuffleWriter<K, V, C>
extends ShuffleWriter<K, V> {
    private static final Logger logger = LoggerFactory.getLogger(SortBasedShuffleWriter.class);
    private static final ClassTag<Object> OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object();
    private static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 0x100000;
    private final ShuffleDependency<K, V, C> dep;
    private final Partitioner partitioner;
    private final ShuffleWriteMetricsReporter writeMetrics;
    private final int shuffleId;
    private final int mapId;
    private final TaskContext taskContext;
    private final ShuffleClient shuffleClient;
    private final int numMappers;
    private final int numPartitions;
    private final long pushBufferMaxSize;
    private final SortBasedPusher pusher;
    private long peakMemoryUsedBytes = 0L;
    private final OpenByteArrayOutputStream serBuffer;
    private final SerializationStream serOutputStream;
    private final LongAdder[] mapStatusLengths;
    private long tmpRecordsWritten = 0L;
    private volatile boolean stopping = false;
    private final boolean unsafeRowFastWrite;

    public SortBasedShuffleWriter(int shuffleId, ShuffleDependency<K, V, C> dep, int numMappers, TaskContext taskContext, CelebornConf conf, ShuffleClient client, ShuffleWriteMetricsReporter metrics, SendBufferPool sendBufferPool) throws IOException {
        this(shuffleId, dep, numMappers, taskContext, conf, client, metrics, sendBufferPool, null);
    }

    public SortBasedShuffleWriter(int shuffleId, ShuffleDependency<K, V, C> dep, int numMappers, TaskContext taskContext, CelebornConf conf, ShuffleClient client, ShuffleWriteMetricsReporter metrics, SendBufferPool sendBufferPool, SortBasedPusher pusher) throws IOException {
        this.mapId = taskContext.partitionId();
        this.dep = dep;
        this.shuffleId = shuffleId;
        SerializerInstance serializer = dep.serializer().newInstance();
        this.partitioner = dep.partitioner();
        this.writeMetrics = metrics;
        this.taskContext = taskContext;
        this.numMappers = numMappers;
        this.numPartitions = dep.partitioner().numPartitions();
        this.shuffleClient = client;
        this.unsafeRowFastWrite = conf.clientPushUnsafeRowFastWrite();
        this.serBuffer = new OpenByteArrayOutputStream(0x100000);
        this.serOutputStream = serializer.serializeStream((OutputStream)this.serBuffer);
        this.mapStatusLengths = new LongAdder[this.numPartitions];
        for (int i = 0; i < this.numPartitions; ++i) {
            this.mapStatusLengths[i] = new LongAdder();
        }
        this.pushBufferMaxSize = conf.clientPushBufferMaxSize();
        this.pusher = pusher == null ? new SortBasedPusher(taskContext.taskMemoryManager(), this.shuffleClient, taskContext, shuffleId, this.mapId, taskContext.attemptNumber(), taskContext.taskAttemptId(), numMappers, this.numPartitions, conf, arg_0 -> ((ShuffleWriteMetricsReporter)this.writeMetrics).incBytesWritten(arg_0), this.mapStatusLengths, conf.clientPushSortMemoryThreshold(), sendBufferPool) : pusher;
    }

    public SortBasedShuffleWriter(CelebornShuffleHandle<K, V, C> handle, TaskContext taskContext, CelebornConf conf, ShuffleClient client, ShuffleWriteMetricsReporter metrics, SendBufferPool sendBufferPool) throws IOException {
        this(SparkUtils.celebornShuffleId(client, handle, taskContext, true), handle.dependency(), handle.numMappers(), taskContext, conf, client, metrics, sendBufferPool);
    }

    public SortBasedShuffleWriter(CelebornShuffleHandle<K, V, C> handle, TaskContext taskContext, CelebornConf conf, ShuffleClient client, ShuffleWriteMetricsReporter metrics, SendBufferPool sendBufferPool, SortBasedPusher pusher) throws IOException {
        this(SparkUtils.celebornShuffleId(client, handle, taskContext, true), handle.dependency(), handle.numMappers(), taskContext, conf, client, metrics, sendBufferPool, pusher);
    }

    private void updatePeakMemoryUsed() {
        long mem = this.pusher.getPeakMemoryUsedBytes();
        if (mem > this.peakMemoryUsedBytes) {
            this.peakMemoryUsedBytes = mem;
        }
    }

    public long getPeakMemoryUsedBytes() {
        this.updatePeakMemoryUsed();
        return this.peakMemoryUsedBytes;
    }

    void doWrite(Iterator<Product2<K, V>> records) throws IOException {
        if (this.canUseFastWrite()) {
            this.fastWrite0(records);
        } else if (this.dep.mapSideCombine()) {
            if (this.dep.aggregator().isEmpty()) {
                throw new UnsupportedOperationException("When using map side combine, an aggregator must be specified.");
            }
            this.write0(((Aggregator)this.dep.aggregator().get()).combineValuesByKey(records, this.taskContext));
        } else {
            this.write0(records);
        }
    }

    public void write(Iterator<Product2<K, V>> records) throws IOException {
        this.doWrite(records);
        this.close();
    }

    @VisibleForTesting
    boolean canUseFastWrite() {
        boolean keyIsPartitionId = false;
        if (this.unsafeRowFastWrite && this.dep.serializer() instanceof UnsafeRowSerializer) {
            String partitionerClassName = this.partitioner.getClass().getSimpleName();
            keyIsPartitionId = "PartitionIdPassthrough".equals(partitionerClassName);
        }
        return keyIsPartitionId;
    }

    private void fastWrite0(Iterator iterator) throws IOException {
        Iterator records = iterator;
        SQLMetric dataSize = SparkUtils.getDataSize((UnsafeRowSerializer)this.dep.serializer());
        while (records.hasNext()) {
            Product2 record = (Product2)records.next();
            int partitionId = (Integer)record._1();
            UnsafeRow row = (UnsafeRow)record._2();
            int rowSize = row.getSizeInBytes();
            int serializedRecordSize = 4 + rowSize;
            if (dataSize != null) {
                dataSize.add((long)serializedRecordSize);
            }
            if ((long)serializedRecordSize > this.pushBufferMaxSize) {
                byte[] giantBuffer = new byte[serializedRecordSize];
                Platform.putInt((Object)giantBuffer, (long)Platform.BYTE_ARRAY_OFFSET, (int)Integer.reverseBytes(rowSize));
                Platform.copyMemory((Object)row.getBaseObject(), (long)row.getBaseOffset(), (Object)giantBuffer, (long)(Platform.BYTE_ARRAY_OFFSET + 4), (long)rowSize);
                this.pushGiantRecord(partitionId, giantBuffer, serializedRecordSize);
            } else {
                boolean success = this.pusher.insertRecord(row.getBaseObject(), row.getBaseOffset(), rowSize, partitionId, true);
                if (!success) {
                    this.doPush();
                    success = this.pusher.insertRecord(row.getBaseObject(), row.getBaseOffset(), rowSize, partitionId, true);
                    if (!success) {
                        throw new CelebornIOException("Unable to push after switching pusher!");
                    }
                }
            }
            ++this.tmpRecordsWritten;
        }
    }

    private void doPush() throws IOException {
        long start = System.nanoTime();
        this.pusher.pushData(true);
        this.writeMetrics.incWriteTime(System.nanoTime() - start);
    }

    private void write0(Iterator iterator) throws IOException {
        Iterator records = iterator;
        while (records.hasNext()) {
            Product2 record = (Product2)records.next();
            Object key = record._1();
            int partitionId = this.partitioner.getPartition(key);
            this.serBuffer.reset();
            this.serOutputStream.writeKey(key, OBJECT_CLASS_TAG);
            this.serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG);
            this.serOutputStream.flush();
            int serializedRecordSize = this.serBuffer.size();
            assert (serializedRecordSize > 0);
            if ((long)serializedRecordSize > this.pushBufferMaxSize) {
                this.pushGiantRecord(partitionId, this.serBuffer.getBuf(), serializedRecordSize);
            } else {
                boolean success = this.pusher.insertRecord(this.serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId, false);
                if (!success) {
                    this.doPush();
                    success = this.pusher.insertRecord(this.serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId, false);
                    if (!success) {
                        throw new IOException("Unable to push after switching pusher!");
                    }
                }
            }
            ++this.tmpRecordsWritten;
        }
    }

    private void pushGiantRecord(int partitionId, byte[] buffer, int numBytes) throws IOException {
        logger.debug("Push giant record, size {}.", (Object)Utils.bytesToString(numBytes));
        int bytesWritten = this.shuffleClient.pushData(this.shuffleId, this.mapId, this.taskContext.attemptNumber(), partitionId, buffer, 0, numBytes, this.numMappers, this.numPartitions);
        this.mapStatusLengths[partitionId].add(bytesWritten);
        this.writeMetrics.incBytesWritten((long)bytesWritten);
    }

    private void close() throws IOException {
        logger.info("Memory used {}", (Object)Utils.bytesToString(this.pusher.getUsed()));
        long pushStartTime = System.nanoTime();
        this.pusher.pushData(false);
        this.pusher.close();
        this.shuffleClient.pushMergedData(this.shuffleId, this.mapId, this.taskContext.attemptNumber());
        this.writeMetrics.incWriteTime(System.nanoTime() - pushStartTime);
        this.writeMetrics.incRecordsWritten(this.tmpRecordsWritten);
        long waitStartTime = System.nanoTime();
        this.shuffleClient.mapperEnd(this.shuffleId, this.mapId, this.taskContext.attemptNumber(), this.numMappers);
        this.writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public Option<MapStatus> stop(boolean success) {
        try {
            this.taskContext.taskMetrics().incPeakExecutionMemory(this.getPeakMemoryUsedBytes());
            if (this.stopping) {
                Option option = Option.empty();
                return option;
            }
            this.stopping = true;
            if (success) {
                BlockManagerId bmId = SparkEnv.get().blockManager().shuffleServerId();
                MapStatus mapStatus = SparkUtils.createMapStatus(bmId, SparkUtils.unwrap(this.mapStatusLengths), this.taskContext.taskAttemptId());
                if (mapStatus == null) {
                    throw new IllegalStateException("Cannot call stop(true) without having called write()");
                }
                Option option = Option.apply((Object)mapStatus);
                return option;
            }
            Option option = Option.empty();
            return option;
        }
        finally {
            this.shuffleClient.cleanup(this.shuffleId, this.mapId, this.taskContext.attemptNumber());
        }
    }

    public long[] getPartitionLengths() {
        throw new UnsupportedOperationException("Celeborn is not compatible with push-based shuffle, please set spark.shuffle.push.enabled to false");
    }
}

