/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.executiongraph;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ScheduledExecutorService;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.flink.runtime.executiongraph.DefaultExecutionGraph;
import org.apache.flink.runtime.executiongraph.EdgeManagerBuildUtil;
import org.apache.flink.runtime.executiongraph.ExecutionGraph;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.executiongraph.ExecutionVertexInputInfo;
import org.apache.flink.runtime.executiongraph.IndexRange;
import org.apache.flink.runtime.executiongraph.IntermediateResult;
import org.apache.flink.runtime.executiongraph.IntermediateResultPartition;
import org.apache.flink.runtime.executiongraph.IntermediateResultPartitionTest;
import org.apache.flink.runtime.executiongraph.JobVertexInputInfo;
import org.apache.flink.runtime.executiongraph.TestingDefaultExecutionGraphBuilder;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups;
import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
import org.apache.flink.runtime.scheduler.strategy.ConsumerVertexGroup;
import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
import org.apache.flink.testutils.TestingUtils;
import org.apache.flink.testutils.executor.TestExecutorExtension;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

class EdgeManagerBuildUtilTest {
    @RegisterExtension
    static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE = TestingUtils.defaultExecutorExtension();

    EdgeManagerBuildUtilTest() {
    }

    @Test
    void testGetMaxNumEdgesToTargetInPointwiseConnection() throws Exception {
        this.testGetMaxNumEdgesToTarget(17, 17, DistributionPattern.POINTWISE);
        this.testGetMaxNumEdgesToTarget(17, 23, DistributionPattern.POINTWISE);
        this.testGetMaxNumEdgesToTarget(17, 34, DistributionPattern.POINTWISE);
        this.testGetMaxNumEdgesToTarget(34, 17, DistributionPattern.POINTWISE);
        this.testGetMaxNumEdgesToTarget(23, 17, DistributionPattern.POINTWISE);
    }

    @Test
    void testGetMaxNumEdgesToTargetInAllToAllConnection() throws Exception {
        this.testGetMaxNumEdgesToTarget(17, 17, DistributionPattern.ALL_TO_ALL);
        this.testGetMaxNumEdgesToTarget(17, 23, DistributionPattern.ALL_TO_ALL);
        this.testGetMaxNumEdgesToTarget(17, 34, DistributionPattern.ALL_TO_ALL);
        this.testGetMaxNumEdgesToTarget(34, 17, DistributionPattern.ALL_TO_ALL);
        this.testGetMaxNumEdgesToTarget(23, 17, DistributionPattern.ALL_TO_ALL);
    }

    @Test
    void testConnectAllToAll() throws Exception {
        int upstream = 3;
        int downstream = 2;
        ExecutionGraph eg = this.setupExecutionGraph(upstream, downstream, DistributionPattern.POINTWISE, true);
        ArrayList<ExecutionVertexInputInfo> executionVertexInputInfos = new ArrayList<ExecutionVertexInputInfo>();
        for (int i = 0; i < downstream; ++i) {
            executionVertexInputInfos.add(new ExecutionVertexInputInfo(i, new IndexRange(0, upstream - 1), new IndexRange(0, 0)));
        }
        JobVertexInputInfo jobVertexInputInfo = new JobVertexInputInfo(executionVertexInputInfos);
        Iterator vertexIterator = eg.getVerticesTopologically().iterator();
        ExecutionJobVertex producer = (ExecutionJobVertex)vertexIterator.next();
        ExecutionJobVertex consumer = (ExecutionJobVertex)vertexIterator.next();
        eg.initializeJobVertex(producer, 1L, Collections.emptyMap(), UnregisteredMetricGroups.createUnregisteredJobManagerJobMetricGroup());
        eg.initializeJobVertex(consumer, 1L, Collections.singletonMap(producer.getProducedDataSets()[0].getId(), jobVertexInputInfo), UnregisteredMetricGroups.createUnregisteredJobManagerJobMetricGroup());
        IntermediateResult result = Objects.requireNonNull(eg.getJobVertex(producer.getJobVertexId())).getProducedDataSets()[0];
        IntermediateResultPartition partition1 = result.getPartitions()[0];
        IntermediateResultPartition partition2 = result.getPartitions()[1];
        IntermediateResultPartition partition3 = result.getPartitions()[2];
        ExecutionVertex vertex1 = consumer.getTaskVertices()[0];
        ExecutionVertex vertex2 = consumer.getTaskVertices()[1];
        ConsumerVertexGroup consumerVertexGroup = (ConsumerVertexGroup)partition1.getConsumerVertexGroups().get(0);
        Assertions.assertThat((Iterable)consumerVertexGroup).containsExactlyInAnyOrder((Object[])new ExecutionVertexID[]{vertex1.getID(), vertex2.getID()});
        Assertions.assertThat((Iterable)((Iterable)partition2.getConsumerVertexGroups().get(0))).isEqualTo((Object)consumerVertexGroup);
        Assertions.assertThat((Iterable)((Iterable)partition3.getConsumerVertexGroups().get(0))).isEqualTo((Object)consumerVertexGroup);
        ConsumedPartitionGroup consumedPartitionGroup = vertex1.getConsumedPartitionGroup(0);
        Assertions.assertThat((Iterable)consumedPartitionGroup).containsExactlyInAnyOrder((Object[])new IntermediateResultPartitionID[]{partition1.getPartitionId(), partition2.getPartitionId(), partition3.getPartitionId()});
        Assertions.assertThat((Iterable)vertex2.getConsumedPartitionGroup(0)).isEqualTo((Object)consumedPartitionGroup);
        Assertions.assertThat((Iterable)consumerVertexGroup.getConsumedPartitionGroup()).isEqualTo((Object)consumedPartitionGroup);
        Assertions.assertThat((Iterable)consumedPartitionGroup.getConsumerVertexGroup()).isEqualTo((Object)consumerVertexGroup);
    }

    @Test
    void testConnectPointwise() throws Exception {
        int upstream = 4;
        int downstream = 4;
        ExecutionGraph eg = this.setupExecutionGraph(upstream, downstream, DistributionPattern.POINTWISE, true);
        List<IndexRange> partitionRanges = Arrays.asList(new IndexRange(0, 0), new IndexRange(0, 0), new IndexRange(1, 2), new IndexRange(3, 3));
        ArrayList<ExecutionVertexInputInfo> executionVertexInputInfos = new ArrayList<ExecutionVertexInputInfo>();
        for (int i = 0; i < downstream; ++i) {
            executionVertexInputInfos.add(new ExecutionVertexInputInfo(i, partitionRanges.get(i), new IndexRange(0, 0)));
        }
        JobVertexInputInfo jobVertexInputInfo = new JobVertexInputInfo(executionVertexInputInfos);
        Iterator vertexIterator = eg.getVerticesTopologically().iterator();
        ExecutionJobVertex producer = (ExecutionJobVertex)vertexIterator.next();
        ExecutionJobVertex consumer = (ExecutionJobVertex)vertexIterator.next();
        eg.initializeJobVertex(producer, 1L, Collections.emptyMap(), UnregisteredMetricGroups.createUnregisteredJobManagerJobMetricGroup());
        eg.initializeJobVertex(consumer, 1L, Collections.singletonMap(producer.getProducedDataSets()[0].getId(), jobVertexInputInfo), UnregisteredMetricGroups.createUnregisteredJobManagerJobMetricGroup());
        IntermediateResult result = Objects.requireNonNull(eg.getJobVertex(producer.getJobVertexId())).getProducedDataSets()[0];
        IntermediateResultPartition partition1 = result.getPartitions()[0];
        IntermediateResultPartition partition2 = result.getPartitions()[1];
        IntermediateResultPartition partition3 = result.getPartitions()[2];
        IntermediateResultPartition partition4 = result.getPartitions()[3];
        ExecutionVertex vertex1 = consumer.getTaskVertices()[0];
        ExecutionVertex vertex2 = consumer.getTaskVertices()[1];
        ExecutionVertex vertex3 = consumer.getTaskVertices()[2];
        ExecutionVertex vertex4 = consumer.getTaskVertices()[3];
        ConsumerVertexGroup consumerVertexGroup1 = (ConsumerVertexGroup)partition1.getConsumerVertexGroups().get(0);
        ConsumerVertexGroup consumerVertexGroup2 = (ConsumerVertexGroup)partition2.getConsumerVertexGroups().get(0);
        ConsumerVertexGroup consumerVertexGroup3 = (ConsumerVertexGroup)partition4.getConsumerVertexGroups().get(0);
        Assertions.assertThat((Iterable)consumerVertexGroup1).containsExactlyInAnyOrder((Object[])new ExecutionVertexID[]{vertex1.getID(), vertex2.getID()});
        Assertions.assertThat((Iterable)consumerVertexGroup2).containsExactlyInAnyOrder((Object[])new ExecutionVertexID[]{vertex3.getID()});
        Assertions.assertThat((Iterable)((Iterable)partition3.getConsumerVertexGroups().get(0))).isEqualTo((Object)consumerVertexGroup2);
        Assertions.assertThat((Iterable)consumerVertexGroup3).containsExactlyInAnyOrder((Object[])new ExecutionVertexID[]{vertex4.getID()});
        ConsumedPartitionGroup consumedPartitionGroup1 = vertex1.getConsumedPartitionGroup(0);
        ConsumedPartitionGroup consumedPartitionGroup2 = vertex3.getConsumedPartitionGroup(0);
        ConsumedPartitionGroup consumedPartitionGroup3 = vertex4.getConsumedPartitionGroup(0);
        Assertions.assertThat((Iterable)consumedPartitionGroup1).containsExactlyInAnyOrder((Object[])new IntermediateResultPartitionID[]{partition1.getPartitionId()});
        Assertions.assertThat((Iterable)vertex2.getConsumedPartitionGroup(0)).isEqualTo((Object)consumedPartitionGroup1);
        Assertions.assertThat((Iterable)consumedPartitionGroup2).containsExactlyInAnyOrder((Object[])new IntermediateResultPartitionID[]{partition2.getPartitionId(), partition3.getPartitionId()});
        Assertions.assertThat((Iterable)consumedPartitionGroup3).containsExactlyInAnyOrder((Object[])new IntermediateResultPartitionID[]{partition4.getPartitionId()});
        Assertions.assertThat((Iterable)consumerVertexGroup1.getConsumedPartitionGroup()).isEqualTo((Object)consumedPartitionGroup1);
        Assertions.assertThat((Iterable)consumedPartitionGroup1.getConsumerVertexGroup()).isEqualTo((Object)consumerVertexGroup1);
        Assertions.assertThat((Iterable)consumerVertexGroup2.getConsumedPartitionGroup()).isEqualTo((Object)consumedPartitionGroup2);
        Assertions.assertThat((Iterable)consumedPartitionGroup2.getConsumerVertexGroup()).isEqualTo((Object)consumerVertexGroup2);
        Assertions.assertThat((Iterable)consumerVertexGroup3.getConsumedPartitionGroup()).isEqualTo((Object)consumedPartitionGroup3);
        Assertions.assertThat((Iterable)consumedPartitionGroup3.getConsumerVertexGroup()).isEqualTo((Object)consumerVertexGroup3);
    }

    private void testGetMaxNumEdgesToTarget(int upstream, int downstream, DistributionPattern pattern) throws Exception {
        int actual;
        Pair<ExecutionJobVertex, ExecutionJobVertex> pair = this.setupExecutionGraph(upstream, downstream, pattern);
        ExecutionJobVertex upstreamEJV = (ExecutionJobVertex)pair.getLeft();
        ExecutionJobVertex downstreamEJV = (ExecutionJobVertex)pair.getRight();
        int calculatedMaxForUpstream = EdgeManagerBuildUtil.computeMaxEdgesToTargetExecutionVertex((int)upstream, (int)downstream, (DistributionPattern)pattern);
        int actualMaxForUpstream = -1;
        for (ExecutionVertex ev : upstreamEJV.getTaskVertices()) {
            Assertions.assertThat((Map)ev.getProducedPartitions()).hasSize(1);
            IntermediateResultPartition partition = (IntermediateResultPartition)ev.getProducedPartitions().values().iterator().next();
            ConsumerVertexGroup consumerVertexGroup = (ConsumerVertexGroup)partition.getConsumerVertexGroups().get(0);
            actual = consumerVertexGroup.size();
            if (actual <= actualMaxForUpstream) continue;
            actualMaxForUpstream = actual;
        }
        Assertions.assertThat((int)actualMaxForUpstream).isEqualTo(calculatedMaxForUpstream);
        int calculatedMaxForDownstream = EdgeManagerBuildUtil.computeMaxEdgesToTargetExecutionVertex((int)downstream, (int)upstream, (DistributionPattern)pattern);
        int actualMaxForDownstream = -1;
        for (ExecutionVertex ev : downstreamEJV.getTaskVertices()) {
            Assertions.assertThat((int)ev.getNumberOfInputs()).isEqualTo(1);
            actual = ev.getConsumedPartitionGroup(0).size();
            if (actual <= actualMaxForDownstream) continue;
            actualMaxForDownstream = actual;
        }
        Assertions.assertThat((int)actualMaxForDownstream).isEqualTo(calculatedMaxForDownstream);
    }

    private Pair<ExecutionJobVertex, ExecutionJobVertex> setupExecutionGraph(int upstream, int downstream, DistributionPattern pattern) throws Exception {
        Iterator jobVertices = this.setupExecutionGraph(upstream, downstream, pattern, false).getVerticesTopologically().iterator();
        return Pair.of(jobVertices.next(), jobVertices.next());
    }

    private ExecutionGraph setupExecutionGraph(int upstream, int downstream, DistributionPattern pattern, boolean isDynamicGraph) throws Exception {
        JobVertex v1 = new JobVertex("vertex1");
        JobVertex v2 = new JobVertex("vertex2");
        v1.setParallelism(upstream);
        v2.setParallelism(downstream);
        v1.setInvokableClass(AbstractInvokable.class);
        v2.setInvokableClass(AbstractInvokable.class);
        v2.connectNewDataSetAsInput(v1, pattern, ResultPartitionType.PIPELINED);
        ArrayList<JobVertex> ordered = new ArrayList<JobVertex>(Arrays.asList(v1, v2));
        TestingDefaultExecutionGraphBuilder builder = TestingDefaultExecutionGraphBuilder.newBuilder().setVertexParallelismStore(IntermediateResultPartitionTest.computeVertexParallelismStoreConsideringDynamicGraph(ordered, isDynamicGraph, 128));
        DefaultExecutionGraph eg = isDynamicGraph ? builder.buildDynamicGraph((ScheduledExecutorService)EXECUTOR_RESOURCE.getExecutor()) : builder.build((ScheduledExecutorService)EXECUTOR_RESOURCE.getExecutor());
        eg.attachJobGraph(ordered, UnregisteredMetricGroups.createUnregisteredJobManagerJobMetricGroup());
        return eg;
    }
}

