/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.state.v2;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.concurrent.CompletableFuture;
import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.api.common.operators.MailboxExecutor;
import org.apache.flink.api.common.state.v2.State;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.asyncprocessing.AsyncExecutionController;
import org.apache.flink.runtime.asyncprocessing.MockStateRequestContainer;
import org.apache.flink.runtime.asyncprocessing.StateExecutor;
import org.apache.flink.runtime.asyncprocessing.StateRequest;
import org.apache.flink.runtime.asyncprocessing.StateRequestContainer;
import org.apache.flink.runtime.asyncprocessing.StateRequestHandler;
import org.apache.flink.runtime.asyncprocessing.StateRequestType;
import org.apache.flink.runtime.mailbox.SyncMailboxExecutor;
import org.apache.flink.runtime.state.v2.AbstractAggregatingState;
import org.apache.flink.runtime.state.v2.AbstractKeyedStateTestBase;
import org.apache.flink.runtime.state.v2.AggregatingStateDescriptor;
import org.apache.flink.runtime.state.v2.internal.InternalPartitionedState;
import org.assertj.core.api.AssertionsForClassTypes;
import org.junit.jupiter.api.Test;

class AbstractAggregatingStateTest
extends AbstractKeyedStateTestBase {
    AbstractAggregatingStateTest() {
    }

    @Test
    public void testAggregating() {
        SumAggregator aggregator = new SumAggregator(1);
        AggregatingStateDescriptor descriptor = new AggregatingStateDescriptor("testAggState", (AggregateFunction)aggregator, (TypeInformation)BasicTypeInfo.INT_TYPE_INFO);
        AbstractAggregatingState state = new AbstractAggregatingState((StateRequestHandler)this.aec, descriptor);
        this.aec.setCurrentContext(this.aec.buildContext((Object)"test", (Object)"test"));
        state.asyncClear();
        this.validateRequestRun((State)state, StateRequestType.CLEAR, null, 0);
        state.asyncGet();
        this.validateRequestRun((State)state, StateRequestType.AGGREGATING_GET, null, 0);
        state.asyncAdd((Object)1);
        this.validateRequestRun((State)state, StateRequestType.AGGREGATING_GET, null, 1);
        this.validateRequestRun((State)state, StateRequestType.AGGREGATING_ADD, 2, 0);
        state.asyncAdd((Object)5);
        this.validateRequestRun((State)state, StateRequestType.AGGREGATING_GET, null, 1);
        this.validateRequestRun((State)state, StateRequestType.AGGREGATING_ADD, 6, 0);
    }

    @Test
    public void testMergeNamespace() throws Exception {
        SumAggregator aggregator = new SumAggregator(0);
        AggregatingStateDescriptor descriptor = new AggregatingStateDescriptor("testState", (AggregateFunction)aggregator, (TypeInformation)BasicTypeInfo.INT_TYPE_INFO);
        AsyncExecutionController aec = new AsyncExecutionController((MailboxExecutor)new SyncMailboxExecutor(), (a, b) -> {}, (StateExecutor)new AggregatingStateExecutor(), 1, 100, 10000L, 1, null);
        AbstractAggregatingState aggregatingState = new AbstractAggregatingState((StateRequestHandler)aec, descriptor);
        aec.setCurrentContext(aec.buildContext((Object)"test", (Object)"test"));
        aec.setCurrentNamespaceForState((InternalPartitionedState)aggregatingState, (Object)"1");
        aggregatingState.asyncAdd((Object)1);
        aec.drainInflightRecords(0);
        AssertionsForClassTypes.assertThat((int)AggregatingStateExecutor.hashMap.size()).isEqualTo(1);
        AssertionsForClassTypes.assertThat((Integer)AggregatingStateExecutor.hashMap.get(Tuple2.of((Object)"test", (Object)"1"))).isEqualTo(1);
        aec.setCurrentNamespaceForState((InternalPartitionedState)aggregatingState, (Object)"2");
        aggregatingState.asyncAdd((Object)2);
        aec.drainInflightRecords(0);
        AssertionsForClassTypes.assertThat((int)AggregatingStateExecutor.hashMap.size()).isEqualTo(2);
        AssertionsForClassTypes.assertThat((Integer)AggregatingStateExecutor.hashMap.get(Tuple2.of((Object)"test", (Object)"1"))).isEqualTo(1);
        AssertionsForClassTypes.assertThat((Integer)AggregatingStateExecutor.hashMap.get(Tuple2.of((Object)"test", (Object)"2"))).isEqualTo(2);
        aec.setCurrentNamespaceForState((InternalPartitionedState)aggregatingState, (Object)"3");
        aggregatingState.asyncAdd((Object)3);
        aec.drainInflightRecords(0);
        AssertionsForClassTypes.assertThat((int)AggregatingStateExecutor.hashMap.size()).isEqualTo(3);
        AssertionsForClassTypes.assertThat((Integer)AggregatingStateExecutor.hashMap.get(Tuple2.of((Object)"test", (Object)"1"))).isEqualTo(1);
        AssertionsForClassTypes.assertThat((Integer)AggregatingStateExecutor.hashMap.get(Tuple2.of((Object)"test", (Object)"2"))).isEqualTo(2);
        AssertionsForClassTypes.assertThat((Integer)AggregatingStateExecutor.hashMap.get(Tuple2.of((Object)"test", (Object)"3"))).isEqualTo(3);
        ArrayList<String> sources = new ArrayList<String>(Arrays.asList("1", "2", "3"));
        aggregatingState.asyncMergeNamespaces((Object)"0", sources);
        aec.drainInflightRecords(0);
        AssertionsForClassTypes.assertThat((int)AggregatingStateExecutor.hashMap.size()).isEqualTo(1);
        AssertionsForClassTypes.assertThat((Integer)AggregatingStateExecutor.hashMap.get(Tuple2.of((Object)"test", (Object)"0"))).isEqualTo(6);
        AssertionsForClassTypes.assertThat((Integer)AggregatingStateExecutor.hashMap.get(Tuple2.of((Object)"test", (Object)"1"))).isNull();
        AssertionsForClassTypes.assertThat((Integer)AggregatingStateExecutor.hashMap.get(Tuple2.of((Object)"test", (Object)"2"))).isNull();
        AssertionsForClassTypes.assertThat((Integer)AggregatingStateExecutor.hashMap.get(Tuple2.of((Object)"test", (Object)"3"))).isNull();
        aec.setCurrentNamespaceForState((InternalPartitionedState)aggregatingState, (Object)"4");
        aggregatingState.asyncAdd((Object)4);
        aec.drainInflightRecords(0);
        AssertionsForClassTypes.assertThat((int)AggregatingStateExecutor.hashMap.size()).isEqualTo(2);
        AssertionsForClassTypes.assertThat((Integer)AggregatingStateExecutor.hashMap.get(Tuple2.of((Object)"test", (Object)"0"))).isEqualTo(6);
        AssertionsForClassTypes.assertThat((Integer)AggregatingStateExecutor.hashMap.get(Tuple2.of((Object)"test", (Object)"4"))).isEqualTo(4);
        ArrayList<String> sources1 = new ArrayList<String>(Arrays.asList("4"));
        aggregatingState.asyncMergeNamespaces((Object)"0", sources1);
        aec.drainInflightRecords(0);
        AssertionsForClassTypes.assertThat((int)AggregatingStateExecutor.hashMap.size()).isEqualTo(1);
        AssertionsForClassTypes.assertThat((Integer)AggregatingStateExecutor.hashMap.get(Tuple2.of((Object)"test", (Object)"0"))).isEqualTo(10);
        AssertionsForClassTypes.assertThat((Integer)AggregatingStateExecutor.hashMap.get(Tuple2.of((Object)"test", (Object)"1"))).isNull();
        AssertionsForClassTypes.assertThat((Integer)AggregatingStateExecutor.hashMap.get(Tuple2.of((Object)"test", (Object)"2"))).isNull();
        AssertionsForClassTypes.assertThat((Integer)AggregatingStateExecutor.hashMap.get(Tuple2.of((Object)"test", (Object)"3"))).isNull();
        AssertionsForClassTypes.assertThat((Integer)AggregatingStateExecutor.hashMap.get(Tuple2.of((Object)"test", (Object)"4"))).isNull();
    }

    static class AggregatingStateExecutor
    implements StateExecutor {
        private static final HashMap<Tuple2<String, String>, Integer> hashMap = new HashMap();

        AggregatingStateExecutor() {
        }

        public CompletableFuture<Void> executeBatchRequests(StateRequestContainer stateRequestContainer) {
            for (StateRequest<?, ?, ?, ?> stateRequest : ((MockStateRequestContainer)stateRequestContainer).getStateRequestList()) {
                String key = (String)stateRequest.getRecordContext().getKey();
                String namespace = (String)stateRequest.getNamespace();
                if (stateRequest.getRequestType() == StateRequestType.AGGREGATING_ADD) {
                    hashMap.put((Tuple2<String, String>)Tuple2.of((Object)key, (Object)namespace), (Integer)stateRequest.getPayload());
                    stateRequest.getFuture().complete(null);
                    continue;
                }
                if (stateRequest.getRequestType() == StateRequestType.AGGREGATING_GET) {
                    Integer val = hashMap.get(Tuple2.of((Object)key, (Object)namespace));
                    stateRequest.getFuture().complete((Object)val);
                    continue;
                }
                if (stateRequest.getRequestType() == StateRequestType.AGGREGATING_REMOVE) {
                    hashMap.remove(Tuple2.of((Object)key, (Object)namespace));
                    stateRequest.getFuture().complete(null);
                    continue;
                }
                throw new UnsupportedOperationException("Unsupported type");
            }
            CompletableFuture<Void> future = new CompletableFuture<Void>();
            future.complete(null);
            return future;
        }

        public StateRequestContainer createStateRequestContainer() {
            return new MockStateRequestContainer();
        }

        public boolean fullyLoaded() {
            return false;
        }

        public void shutdown() {
        }
    }

    static class SumAggregator
    implements AggregateFunction<Integer, Integer, Integer> {
        private final int init;

        public SumAggregator(int init) {
            this.init = init;
        }

        public Integer createAccumulator() {
            return this.init;
        }

        public Integer add(Integer value, Integer accumulator) {
            return accumulator + value;
        }

        public Integer getResult(Integer accumulator) {
            return accumulator;
        }

        public Integer merge(Integer a, Integer b) {
            return a + b;
        }
    }
}

