/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.executiongraph;

import java.util.List;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeoutException;
import org.apache.flink.api.common.JobID;
import org.apache.flink.runtime.blob.BlobWriter;
import org.apache.flink.runtime.blob.TestingBlobWriter;
import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutor;
import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutorServiceAdapter;
import org.apache.flink.runtime.concurrent.ManuallyTriggeredScheduledExecutorService;
import org.apache.flink.runtime.deployment.CachedShuffleDescriptors;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactoryTest;
import org.apache.flink.runtime.execution.ExecutionState;
import org.apache.flink.runtime.executiongraph.ExecutionGraph;
import org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.executiongraph.IntermediateResult;
import org.apache.flink.runtime.io.network.partition.NoOpJobMasterPartitionTracker;
import org.apache.flink.runtime.io.network.partition.PartitionNotFoundException;
import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.scheduler.SchedulerBase;
import org.apache.flink.runtime.scheduler.SchedulerTestingUtils;
import org.apache.flink.testutils.TestingUtils;
import org.apache.flink.testutils.executor.TestExecutorExtension;
import org.apache.flink.util.function.RunnableWithException;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

class RemoveCachedShuffleDescriptorTest {
    private static final int PARALLELISM = 4;
    @RegisterExtension
    static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE = TestingUtils.defaultExecutorExtension();
    private ScheduledExecutorService scheduledExecutorService;
    private ComponentMainThreadExecutor mainThreadExecutor;
    private ManuallyTriggeredScheduledExecutorService ioExecutor;

    RemoveCachedShuffleDescriptorTest() {
    }

    @BeforeEach
    void setup() {
        this.scheduledExecutorService = Executors.newSingleThreadScheduledExecutor();
        this.mainThreadExecutor = ComponentMainThreadExecutorServiceAdapter.forSingleThreadExecutor(this.scheduledExecutorService);
        this.ioExecutor = new ManuallyTriggeredScheduledExecutorService();
    }

    @AfterEach
    void teardown() {
        if (this.scheduledExecutorService != null) {
            this.scheduledExecutorService.shutdownNow();
        }
    }

    @Test
    void testRemoveNonOffloadedCacheForAllToAllEdgeAfterFinished() throws Exception {
        this.testRemoveCacheForAllToAllEdgeAfterFinished(new TestingBlobWriter(Integer.MAX_VALUE), 0, 0);
    }

    @Test
    void testRemoveOffloadedCacheForAllToAllEdgeAfterFinished() throws Exception {
        this.testRemoveCacheForAllToAllEdgeAfterFinished(new TestingBlobWriter(0), 4, 3);
    }

    private void testRemoveCacheForAllToAllEdgeAfterFinished(TestingBlobWriter blobWriter, int expectedBefore, int expectedAfter) throws Exception {
        JobID jobId = new JobID();
        JobVertex v1 = ExecutionGraphTestUtils.createNoOpVertex("v1", 4);
        JobVertex v2 = ExecutionGraphTestUtils.createNoOpVertex("v2", 4);
        SchedulerBase scheduler = this.createSchedulerAndDeploy(jobId, v1, v2, DistributionPattern.ALL_TO_ALL, blobWriter);
        ExecutionGraph executionGraph = scheduler.getExecutionGraph();
        this.executionInMainThread(() -> {
            Object[] shuffleDescriptors = TaskDeploymentDescriptorFactoryTest.deserializeShuffleDescriptors(RemoveCachedShuffleDescriptorTest.getConsumedCachedShuffleDescriptor(executionGraph, v2).getAllSerializedShuffleDescriptors(), jobId, blobWriter);
            Assertions.assertThat((Object[])shuffleDescriptors).hasSize(4);
            Assertions.assertThat((int)blobWriter.numberOfBlobs()).isEqualTo(expectedBefore);
        });
        CompletableFuture.runAsync(() -> ExecutionGraphTestUtils.finishJobVertex(executionGraph, v2.getID()), (Executor)this.mainThreadExecutor).join();
        this.ioExecutor.triggerAll();
        this.executionInMainThread(() -> {
            Assertions.assertThat((Object)RemoveCachedShuffleDescriptorTest.getConsumedCachedShuffleDescriptor(executionGraph, v2)).isNull();
            Assertions.assertThat((int)blobWriter.numberOfBlobs()).isEqualTo(expectedAfter);
        });
    }

    @Test
    void testRemoveNonOffloadedCacheForAllToAllEdgeAfterFailover() throws Exception {
        this.testRemoveCacheForAllToAllEdgeAfterFailover(new TestingBlobWriter(Integer.MAX_VALUE), 0, 0);
    }

    @Test
    void testRemoveOffloadedCacheForAllToAllEdgeAfterFailover() throws Exception {
        this.testRemoveCacheForAllToAllEdgeAfterFailover(new TestingBlobWriter(0), 4, 3);
    }

    private void testRemoveCacheForAllToAllEdgeAfterFailover(TestingBlobWriter blobWriter, int expectedBefore, int expectedAfter) throws Exception {
        JobID jobId = new JobID();
        JobVertex v1 = ExecutionGraphTestUtils.createNoOpVertex("v1", 4);
        JobVertex v2 = ExecutionGraphTestUtils.createNoOpVertex("v2", 4);
        SchedulerBase scheduler = this.createSchedulerAndDeploy(jobId, v1, v2, DistributionPattern.ALL_TO_ALL, blobWriter);
        ExecutionGraph executionGraph = scheduler.getExecutionGraph();
        this.executionInMainThread(() -> {
            Object[] shuffleDescriptors = TaskDeploymentDescriptorFactoryTest.deserializeShuffleDescriptors(RemoveCachedShuffleDescriptorTest.getConsumedCachedShuffleDescriptor(executionGraph, v2).getAllSerializedShuffleDescriptors(), jobId, blobWriter);
            Assertions.assertThat((Object[])shuffleDescriptors).hasSize(4);
            Assertions.assertThat((int)blobWriter.numberOfBlobs()).isEqualTo(expectedBefore);
        });
        this.triggerGlobalFailoverAndComplete(scheduler, v1, v2);
        this.ioExecutor.triggerAll();
        this.executionInMainThread(() -> {
            Assertions.assertThat((Object)RemoveCachedShuffleDescriptorTest.getConsumedCachedShuffleDescriptor(executionGraph, v2)).isNull();
            Assertions.assertThat((int)blobWriter.numberOfBlobs()).isEqualTo(expectedAfter);
        });
    }

    @Test
    void testRemoveNonOffloadedCacheForPointwiseEdgeAfterFinished() throws Exception {
        this.testRemoveCacheForPointwiseEdgeAfterFinished(new TestingBlobWriter(Integer.MAX_VALUE), 0, 0);
    }

    @Test
    void testRemoveOffloadedCacheForPointwiseEdgeAfterFinished() throws Exception {
        this.testRemoveCacheForPointwiseEdgeAfterFinished(new TestingBlobWriter(0), 7, 6);
    }

    private void testRemoveCacheForPointwiseEdgeAfterFinished(TestingBlobWriter blobWriter, int expectedBefore, int expectedAfter) throws Exception {
        JobID jobId = new JobID();
        JobVertex v1 = ExecutionGraphTestUtils.createNoOpVertex("v1", 4);
        JobVertex v2 = ExecutionGraphTestUtils.createNoOpVertex("v2", 4);
        SchedulerBase scheduler = this.createSchedulerAndDeploy(jobId, v1, v2, DistributionPattern.POINTWISE, blobWriter);
        ExecutionGraph executionGraph = scheduler.getExecutionGraph();
        this.executionInMainThread(() -> {
            Object[] shuffleDescriptors = TaskDeploymentDescriptorFactoryTest.deserializeShuffleDescriptors(RemoveCachedShuffleDescriptorTest.getConsumedCachedShuffleDescriptor(executionGraph, v2).getAllSerializedShuffleDescriptors(), jobId, blobWriter);
            Assertions.assertThat((Object[])shuffleDescriptors).hasSize(1);
            Assertions.assertThat((int)blobWriter.numberOfBlobs()).isEqualTo(expectedBefore);
        });
        ExecutionVertex ev21 = Objects.requireNonNull(executionGraph.getJobVertex(v2.getID())).getTaskVertices()[0];
        CompletableFuture.runAsync(() -> ExecutionGraphTestUtils.finishExecutionVertex(executionGraph, ev21), (Executor)this.mainThreadExecutor).join();
        this.ioExecutor.triggerAll();
        this.executionInMainThread(() -> {
            Assertions.assertThat((Object)RemoveCachedShuffleDescriptorTest.getConsumedCachedShuffleDescriptor(executionGraph, v2, 0)).isNull();
            Object[] shuffleDescriptorsForOtherVertex = TaskDeploymentDescriptorFactoryTest.deserializeShuffleDescriptors(RemoveCachedShuffleDescriptorTest.getConsumedCachedShuffleDescriptor(executionGraph, v2, 1).getAllSerializedShuffleDescriptors(), jobId, blobWriter);
            Assertions.assertThat((Object[])shuffleDescriptorsForOtherVertex).hasSize(1);
            Assertions.assertThat((int)blobWriter.numberOfBlobs()).isEqualTo(expectedAfter);
        });
    }

    @Test
    void testRemoveNonOffloadedCacheForPointwiseEdgeAfterFailover() throws Exception {
        this.testRemoveCacheForPointwiseEdgeAfterFailover(new TestingBlobWriter(Integer.MAX_VALUE), 0, 0);
    }

    @Test
    void testRemoveOffloadedCacheForPointwiseEdgeAfterFailover() throws Exception {
        this.testRemoveCacheForPointwiseEdgeAfterFailover(new TestingBlobWriter(0), 7, 6);
    }

    private void testRemoveCacheForPointwiseEdgeAfterFailover(TestingBlobWriter blobWriter, int expectedBefore, int expectedAfter) throws Exception {
        JobID jobId = new JobID();
        JobVertex v1 = ExecutionGraphTestUtils.createNoOpVertex("v1", 4);
        JobVertex v2 = ExecutionGraphTestUtils.createNoOpVertex("v2", 4);
        SchedulerBase scheduler = this.createSchedulerAndDeploy(jobId, v1, v2, DistributionPattern.POINTWISE, blobWriter);
        ExecutionGraph executionGraph = scheduler.getExecutionGraph();
        this.executionInMainThread(() -> {
            Object[] shuffleDescriptors = TaskDeploymentDescriptorFactoryTest.deserializeShuffleDescriptors(RemoveCachedShuffleDescriptorTest.getConsumedCachedShuffleDescriptor(executionGraph, v2).getAllSerializedShuffleDescriptors(), jobId, blobWriter);
            Assertions.assertThat((Object[])shuffleDescriptors).hasSize(1);
            Assertions.assertThat((int)blobWriter.numberOfBlobs()).isEqualTo(expectedBefore);
        });
        this.triggerExceptionAndComplete(executionGraph, v1, v2);
        this.ioExecutor.triggerAll();
        this.executionInMainThread(() -> {
            Assertions.assertThat((Object)RemoveCachedShuffleDescriptorTest.getConsumedCachedShuffleDescriptor(executionGraph, v2, 0)).isNull();
            Object[] shuffleDescriptorsForOtherVertex = TaskDeploymentDescriptorFactoryTest.deserializeShuffleDescriptors(RemoveCachedShuffleDescriptorTest.getConsumedCachedShuffleDescriptor(executionGraph, v2, 1).getAllSerializedShuffleDescriptors(), jobId, blobWriter);
            Assertions.assertThat((Object[])shuffleDescriptorsForOtherVertex).hasSize(1);
            Assertions.assertThat((int)blobWriter.numberOfBlobs()).isEqualTo(expectedAfter);
        });
    }

    private SchedulerBase createSchedulerAndDeploy(JobID jobId, JobVertex v1, JobVertex v2, DistributionPattern distributionPattern, BlobWriter blobWriter) throws Exception {
        return SchedulerTestingUtils.createSchedulerAndDeploy(false, jobId, v1, new JobVertex[]{v2}, distributionPattern, blobWriter, this.mainThreadExecutor, this.ioExecutor, NoOpJobMasterPartitionTracker.INSTANCE, (ScheduledExecutorService)EXECUTOR_RESOURCE.getExecutor());
    }

    private void triggerGlobalFailoverAndComplete(SchedulerBase scheduler, JobVertex upstream, JobVertex downstream) throws TimeoutException {
        Exception t = new Exception();
        ExecutionGraph executionGraph = scheduler.getExecutionGraph();
        CompletableFuture.runAsync(() -> {
            scheduler.handleGlobalFailure(t);
            for (ExecutionVertex ev : Objects.requireNonNull(executionGraph.getJobVertex(downstream.getID())).getTaskVertices()) {
                ev.getCurrentExecutionAttempt().completeCancelling();
            }
        }, (Executor)this.mainThreadExecutor).join();
        for (ExecutionVertex ev : Objects.requireNonNull(executionGraph.getJobVertex(upstream.getID())).getTaskVertices()) {
            ExecutionGraphTestUtils.waitUntilExecutionVertexState(ev, ExecutionState.DEPLOYING, 1000L);
        }
    }

    private void triggerExceptionAndComplete(ExecutionGraph executionGraph, JobVertex upstream, JobVertex downstream) throws TimeoutException {
        ExecutionVertex ev11 = Objects.requireNonNull(executionGraph.getJobVertex(upstream.getID())).getTaskVertices()[0];
        ExecutionVertex ev21 = Objects.requireNonNull(executionGraph.getJobVertex(downstream.getID())).getTaskVertices()[0];
        CompletableFuture.runAsync(() -> ev21.markFailed((Throwable)new PartitionNotFoundException(new ResultPartitionID())), (Executor)this.mainThreadExecutor).join();
        ExecutionGraphTestUtils.waitUntilExecutionVertexState(ev11, ExecutionState.DEPLOYING, 1000L);
    }

    private void executionInMainThread(RunnableWithException runnableWithException) {
        CompletableFuture.runAsync(() -> Assertions.assertThatNoException().isThrownBy(() -> ((RunnableWithException)runnableWithException).run()), (Executor)this.mainThreadExecutor).join();
    }

    private static CachedShuffleDescriptors getConsumedCachedShuffleDescriptor(ExecutionGraph executionGraph, JobVertex vertex) {
        return RemoveCachedShuffleDescriptorTest.getConsumedCachedShuffleDescriptor(executionGraph, vertex, 0);
    }

    private static CachedShuffleDescriptors getConsumedCachedShuffleDescriptor(ExecutionGraph executionGraph, JobVertex vertex, int taskNum) {
        ExecutionJobVertex ejv = executionGraph.getJobVertex(vertex.getID());
        List consumedResults = Objects.requireNonNull(ejv).getInputs();
        IntermediateResult consumedResult = (IntermediateResult)consumedResults.get(0);
        return consumedResult.getCachedShuffleDescriptors(ejv.getTaskVertices()[taskNum].getConsumedPartitionGroup(0));
    }
}

