/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.streaming.runtime.partitioner;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ScheduledExecutorService;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.java.tuple.Tuple;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.JobException;
import org.apache.flink.runtime.executiongraph.DefaultExecutionGraph;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.executiongraph.IntermediateResultPartition;
import org.apache.flink.runtime.executiongraph.InternalExecutionGraphAccessor;
import org.apache.flink.runtime.executiongraph.TestingDefaultExecutionGraphBuilder;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups;
import org.apache.flink.runtime.scheduler.SchedulerBase;
import org.apache.flink.streaming.api.datastream.DataStreamSource;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.source.ParallelSourceFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner;
import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
import org.apache.flink.streaming.runtime.partitioner.StreamPartitionerTest;
import org.apache.flink.testutils.TestingUtils;
import org.apache.flink.testutils.executor.TestExecutorExtension;
import org.apache.flink.util.Collector;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

class RescalePartitionerTest
extends StreamPartitionerTest {
    @RegisterExtension
    private static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE = TestingUtils.defaultExecutorExtension();

    RescalePartitionerTest() {
    }

    @Override
    StreamPartitioner<Tuple> createPartitioner() {
        RescalePartitioner partitioner = new RescalePartitioner();
        Assertions.assertThat((boolean)partitioner.isBroadcast()).isFalse();
        return partitioner;
    }

    @Test
    void testSelectChannelsInterval() {
        this.streamPartitioner.setup(3);
        this.assertSelectedChannel(0);
        this.assertSelectedChannel(1);
        this.assertSelectedChannel(2);
        this.assertSelectedChannel(0);
    }

    @Test
    void testExecutionGraphGeneration() throws Exception {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        env.setParallelism(4);
        DataStreamSource text = env.addSource((SourceFunction)new ParallelSourceFunction<String>(){
            private static final long serialVersionUID = 7772338606389180774L;

            public void run(SourceFunction.SourceContext<String> ctx) throws Exception {
            }

            public void cancel() {
            }
        }).setParallelism(2);
        SingleOutputStreamOperator counts = text.rescale().flatMap((FlatMapFunction)new FlatMapFunction<String, Tuple2<String, Integer>>(){
            private static final long serialVersionUID = -5255930322161596829L;

            public void flatMap(String value, Collector<Tuple2<String, Integer>> out) throws Exception {
            }
        });
        counts.rescale().print().setParallelism(2);
        JobGraph jobGraph = env.getStreamGraph().getJobGraph();
        List jobVertices = jobGraph.getVerticesSortedTopologicallyFromSources();
        JobVertex sourceVertex = (JobVertex)jobVertices.get(0);
        JobVertex mapVertex = (JobVertex)jobVertices.get(1);
        JobVertex sinkVertex = (JobVertex)jobVertices.get(2);
        Assertions.assertThat((int)sourceVertex.getParallelism()).isEqualTo(2);
        Assertions.assertThat((int)mapVertex.getParallelism()).isEqualTo(4);
        Assertions.assertThat((int)sinkVertex.getParallelism()).isEqualTo(2);
        DefaultExecutionGraph eg = TestingDefaultExecutionGraphBuilder.newBuilder().setVertexParallelismStore(SchedulerBase.computeVertexParallelismStore((JobGraph)jobGraph)).build((ScheduledExecutorService)EXECUTOR_RESOURCE.getExecutor());
        try {
            eg.attachJobGraph(jobVertices, UnregisteredMetricGroups.createUnregisteredJobManagerJobMetricGroup());
        }
        catch (JobException e) {
            Assertions.fail((String)"Building ExecutionGraph failed", (Throwable)e);
        }
        ExecutionJobVertex execSourceVertex = eg.getJobVertex(sourceVertex.getID());
        ExecutionJobVertex execMapVertex = eg.getJobVertex(mapVertex.getID());
        ExecutionJobVertex execSinkVertex = eg.getJobVertex(sinkVertex.getID());
        Assertions.assertThat((List)execSourceVertex.getInputs()).isEmpty();
        Assertions.assertThat((List)execMapVertex.getInputs()).hasSize(1);
        Assertions.assertThat((int)execMapVertex.getParallelism()).isEqualTo(4);
        ExecutionVertex[] mapTaskVertices = execMapVertex.getTaskVertices();
        HashMap<Integer, Integer> mapInputPartitionCounts = new HashMap<Integer, Integer>();
        for (ExecutionVertex mapTaskVertex : mapTaskVertices) {
            Assertions.assertThat((int)mapTaskVertex.getNumberOfInputs()).isOne();
            Assertions.assertThat((Iterable)mapTaskVertex.getConsumedPartitionGroup(0)).hasSize(1);
            IntermediateResultPartitionID consumedPartitionId = mapTaskVertex.getConsumedPartitionGroup(0).getFirst();
            Assertions.assertThat((Comparable)mapTaskVertex.getExecutionGraphAccessor().getResultPartitionOrThrow(consumedPartitionId).getProducer().getJobvertexId()).isEqualTo((Object)sourceVertex.getID());
            int inputPartition = consumedPartitionId.getPartitionNumber();
            if (!mapInputPartitionCounts.containsKey(inputPartition)) {
                mapInputPartitionCounts.put(inputPartition, 1);
                continue;
            }
            mapInputPartitionCounts.put(inputPartition, (Integer)mapInputPartitionCounts.get(inputPartition) + 1);
        }
        Assertions.assertThat(mapInputPartitionCounts).hasSize(2);
        Iterator iterator = mapInputPartitionCounts.values().iterator();
        while (iterator.hasNext()) {
            int count = (Integer)iterator.next();
            Assertions.assertThat((int)count).isEqualTo(2);
        }
        Assertions.assertThat((List)execSinkVertex.getInputs()).hasSize(1);
        Assertions.assertThat((int)execSinkVertex.getParallelism()).isEqualTo(2);
        ExecutionVertex[] sinkTaskVertices = execSinkVertex.getTaskVertices();
        InternalExecutionGraphAccessor executionGraphAccessor = execSinkVertex.getGraph();
        HashSet<Integer> mapSubpartitions = new HashSet<Integer>();
        for (ExecutionVertex sinkTaskVertex : sinkTaskVertices) {
            Assertions.assertThat((int)sinkTaskVertex.getNumberOfInputs()).isOne();
            Assertions.assertThat((Iterable)sinkTaskVertex.getConsumedPartitionGroup(0)).hasSize(2);
            for (IntermediateResultPartitionID consumedPartitionId : sinkTaskVertex.getConsumedPartitionGroup(0)) {
                IntermediateResultPartition consumedPartition = executionGraphAccessor.getResultPartitionOrThrow(consumedPartitionId);
                Assertions.assertThat((Comparable)consumedPartition.getProducer().getJobvertexId()).isEqualTo((Object)mapVertex.getID());
                int partitionNumber = consumedPartition.getPartitionNumber();
                Assertions.assertThat(mapSubpartitions).doesNotContain((Object[])new Integer[]{partitionNumber});
                mapSubpartitions.add(partitionNumber);
            }
        }
        Assertions.assertThat(mapSubpartitions).hasSize(4);
    }
}

