/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.scheduler;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import org.apache.flink.configuration.MemorySize;
import org.apache.flink.runtime.clusterframework.types.ResourceProfile;
import org.apache.flink.runtime.executiongraph.EdgeManagerBuildUtil;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.IntermediateResult;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.IntermediateDataSet;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.JobEdge;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup;
import org.apache.flink.runtime.shuffle.ShuffleMaster;
import org.apache.flink.runtime.shuffle.TaskInputsOutputsDescriptor;
import org.apache.flink.util.Preconditions;

public class SsgNetworkMemoryCalculationUtils {
    static void enrichNetworkMemory(SlotSharingGroup ssg, Function<JobVertexID, ExecutionJobVertex> ejvs, ShuffleMaster<?> shuffleMaster) {
        ResourceProfile original = ssg.getResourceProfile();
        if (original.equals(ResourceProfile.UNKNOWN) || !original.getNetworkMemory().equals((Object)MemorySize.ZERO)) {
            return;
        }
        MemorySize networkMemory = MemorySize.ZERO;
        for (JobVertexID jvId : ssg.getJobVertexIds()) {
            ExecutionJobVertex ejv = ejvs.apply(jvId);
            TaskInputsOutputsDescriptor desc = SsgNetworkMemoryCalculationUtils.buildTaskInputsOutputsDescriptor(ejv, ejvs);
            MemorySize requiredNetworkMemory = shuffleMaster.computeShuffleMemorySizeForTask(desc);
            networkMemory = networkMemory.add(requiredNetworkMemory);
        }
        ResourceProfile enriched = ResourceProfile.newBuilder().setCpuCores(original.getCpuCores()).setTaskHeapMemory(original.getTaskHeapMemory()).setTaskOffHeapMemory(original.getTaskOffHeapMemory()).setManagedMemory(original.getManagedMemory()).setNetworkMemory(networkMemory).setExtendedResources(original.getExtendedResources().values()).build();
        ssg.setResourceProfile(enriched);
    }

    private static TaskInputsOutputsDescriptor buildTaskInputsOutputsDescriptor(ExecutionJobVertex ejv, Function<JobVertexID, ExecutionJobVertex> ejvs) {
        Map<IntermediateDataSetID, Integer> maxInputChannelNums = SsgNetworkMemoryCalculationUtils.getMaxInputChannelNums(ejv);
        Map<IntermediateDataSetID, Integer> maxSubpartitionNums = SsgNetworkMemoryCalculationUtils.getMaxSubpartitionNums(ejv, ejvs);
        JobVertex jv = ejv.getJobVertex();
        Map<IntermediateDataSetID, ResultPartitionType> partitionTypes = SsgNetworkMemoryCalculationUtils.getPartitionTypes(jv);
        return TaskInputsOutputsDescriptor.from(maxInputChannelNums, maxSubpartitionNums, partitionTypes);
    }

    private static Map<IntermediateDataSetID, Integer> getMaxInputChannelNums(ExecutionJobVertex ejv) {
        HashMap<IntermediateDataSetID, Integer> ret = new HashMap<IntermediateDataSetID, Integer>();
        List<JobEdge> inputEdges = ejv.getJobVertex().getInputs();
        for (int i = 0; i < inputEdges.size(); ++i) {
            JobEdge inputEdge = inputEdges.get(i);
            IntermediateResult consumedResult = ejv.getInputs().get(i);
            Preconditions.checkState((boolean)consumedResult.getId().equals(inputEdge.getSourceId()));
            int maxNum = EdgeManagerBuildUtil.computeMaxEdgesToTargetExecutionVertex(ejv.getParallelism(), consumedResult.getNumberOfAssignedPartitions(), inputEdge.getDistributionPattern());
            ret.put(consumedResult.getId(), maxNum);
        }
        return ret;
    }

    private static Map<IntermediateDataSetID, Integer> getMaxSubpartitionNums(ExecutionJobVertex ejv, Function<JobVertexID, ExecutionJobVertex> ejvs) {
        HashMap<IntermediateDataSetID, Integer> ret = new HashMap<IntermediateDataSetID, Integer>();
        List<IntermediateDataSet> producedDataSets = ejv.getJobVertex().getProducedDataSets();
        for (int i = 0; i < producedDataSets.size(); ++i) {
            IntermediateDataSet producedDataSet = producedDataSets.get(i);
            Preconditions.checkState((producedDataSet.getConsumers().size() == 1 ? 1 : 0) != 0, (Object)"Currently a result should have exactly one consumer job vertex.");
            JobEdge outputEdge = producedDataSet.getConsumers().get(0);
            ExecutionJobVertex consumerJobVertex = ejvs.apply(outputEdge.getTarget().getID());
            int maxNum = EdgeManagerBuildUtil.computeMaxEdgesToTargetExecutionVertex(ejv.getParallelism(), consumerJobVertex.getParallelism(), outputEdge.getDistributionPattern());
            ret.put(producedDataSet.getId(), maxNum);
        }
        return ret;
    }

    private static Map<IntermediateDataSetID, ResultPartitionType> getPartitionTypes(JobVertex jv) {
        HashMap<IntermediateDataSetID, ResultPartitionType> ret = new HashMap<IntermediateDataSetID, ResultPartitionType>();
        jv.getProducedDataSets().forEach(ds -> ret.putIfAbsent(ds.getId(), ds.getResultType()));
        return ret;
    }

    private SsgNetworkMemoryCalculationUtils() {
    }
}

