/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.state.metrics;

import java.util.Arrays;
import java.util.Collection;
import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.api.common.state.AggregatingStateDescriptor;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.state.StateDescriptor;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.metrics.MetricGroup;
import org.apache.flink.metrics.groups.UnregisteredMetricsGroup;
import org.apache.flink.runtime.state.internal.InternalAggregatingState;
import org.apache.flink.runtime.state.internal.InternalKvState;
import org.apache.flink.runtime.state.internal.InternalListState;
import org.apache.flink.runtime.state.internal.InternalMapState;
import org.apache.flink.runtime.state.internal.InternalReducingState;
import org.apache.flink.runtime.state.internal.InternalValueState;
import org.apache.flink.runtime.state.metrics.LatencyTrackingAggregatingState;
import org.apache.flink.runtime.state.metrics.LatencyTrackingListState;
import org.apache.flink.runtime.state.metrics.LatencyTrackingMapState;
import org.apache.flink.runtime.state.metrics.LatencyTrackingReducingState;
import org.apache.flink.runtime.state.metrics.LatencyTrackingStateConfig;
import org.apache.flink.runtime.state.metrics.LatencyTrackingStateFactory;
import org.apache.flink.runtime.state.metrics.LatencyTrackingValueState;
import org.apache.flink.testutils.junit.extensions.parameterized.Parameter;
import org.apache.flink.testutils.junit.extensions.parameterized.ParameterizedTestExtension;
import org.apache.flink.testutils.junit.extensions.parameterized.Parameters;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.TestTemplate;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mockito;

@ExtendWith(value={ParameterizedTestExtension.class})
public class LatencyTrackingStateFactoryTest {
    @Parameter
    public boolean enableLatencyTracking;

    @Parameters(name="enable latency tracking: {0}")
    public static Collection<Boolean> enabled() {
        return Arrays.asList(true, false);
    }

    private LatencyTrackingStateConfig getLatencyTrackingStateConfig() {
        UnregisteredMetricsGroup metricsGroup = new UnregisteredMetricsGroup();
        return LatencyTrackingStateConfig.newBuilder().setEnabled(this.enableLatencyTracking).setMetricGroup((MetricGroup)metricsGroup).build();
    }

    @TestTemplate
    <K, N> void testTrackValueState() throws Exception {
        InternalValueState valueState = (InternalValueState)Mockito.mock(InternalValueState.class);
        ValueStateDescriptor valueStateDescriptor = new ValueStateDescriptor("value", String.class);
        InternalKvState latencyTrackingState = LatencyTrackingStateFactory.createStateAndWrapWithLatencyTrackingIfEnabled((InternalKvState)valueState, (StateDescriptor)valueStateDescriptor, (LatencyTrackingStateConfig)this.getLatencyTrackingStateConfig());
        if (this.enableLatencyTracking) {
            Assertions.assertThat((Object)latencyTrackingState).isInstanceOf(LatencyTrackingValueState.class);
        } else {
            Assertions.assertThat((Object)latencyTrackingState).isEqualTo((Object)valueState);
        }
    }

    @TestTemplate
    <K, N> void testTrackListState() throws Exception {
        InternalListState listState = (InternalListState)Mockito.mock(InternalListState.class);
        ListStateDescriptor listStateDescriptor = new ListStateDescriptor("list", String.class);
        InternalKvState latencyTrackingState = LatencyTrackingStateFactory.createStateAndWrapWithLatencyTrackingIfEnabled((InternalKvState)listState, (StateDescriptor)listStateDescriptor, (LatencyTrackingStateConfig)this.getLatencyTrackingStateConfig());
        if (this.enableLatencyTracking) {
            Assertions.assertThat((Object)latencyTrackingState).isInstanceOf(LatencyTrackingListState.class);
        } else {
            Assertions.assertThat((Object)latencyTrackingState).isEqualTo((Object)listState);
        }
    }

    @TestTemplate
    <K, N> void testTrackMapState() throws Exception {
        InternalMapState mapState = (InternalMapState)Mockito.mock(InternalMapState.class);
        MapStateDescriptor mapStateDescriptor = new MapStateDescriptor("map", String.class, Long.class);
        InternalKvState latencyTrackingState = LatencyTrackingStateFactory.createStateAndWrapWithLatencyTrackingIfEnabled((InternalKvState)mapState, (StateDescriptor)mapStateDescriptor, (LatencyTrackingStateConfig)this.getLatencyTrackingStateConfig());
        if (this.enableLatencyTracking) {
            Assertions.assertThat((Object)latencyTrackingState).isInstanceOf(LatencyTrackingMapState.class);
        } else {
            Assertions.assertThat((Object)latencyTrackingState).isEqualTo((Object)mapState);
        }
    }

    @TestTemplate
    <K, N> void testTrackReducingState() throws Exception {
        InternalReducingState reducingState = (InternalReducingState)Mockito.mock(InternalReducingState.class);
        ReducingStateDescriptor reducingStateDescriptor = new ReducingStateDescriptor("reducing", Long::sum, Long.class);
        InternalKvState latencyTrackingState = LatencyTrackingStateFactory.createStateAndWrapWithLatencyTrackingIfEnabled((InternalKvState)reducingState, (StateDescriptor)reducingStateDescriptor, (LatencyTrackingStateConfig)this.getLatencyTrackingStateConfig());
        if (this.enableLatencyTracking) {
            Assertions.assertThat((Object)latencyTrackingState).isInstanceOf(LatencyTrackingReducingState.class);
        } else {
            Assertions.assertThat((Object)latencyTrackingState).isEqualTo((Object)reducingState);
        }
    }

    @TestTemplate
    <K, N> void testTrackAggregatingState() throws Exception {
        InternalAggregatingState aggregatingState = (InternalAggregatingState)Mockito.mock(InternalAggregatingState.class);
        AggregatingStateDescriptor aggregatingStateDescriptor = new AggregatingStateDescriptor("aggregate", (AggregateFunction)new AggregateFunction<Long, Long, Long>(){
            private static final long serialVersionUID = 1L;

            public Long createAccumulator() {
                return 0L;
            }

            public Long add(Long value, Long accumulator) {
                return value + accumulator;
            }

            public Long getResult(Long accumulator) {
                return accumulator;
            }

            public Long merge(Long a, Long b) {
                return a + b;
            }
        }, Long.class);
        InternalKvState latencyTrackingState = LatencyTrackingStateFactory.createStateAndWrapWithLatencyTrackingIfEnabled((InternalKvState)aggregatingState, (StateDescriptor)aggregatingStateDescriptor, (LatencyTrackingStateConfig)this.getLatencyTrackingStateConfig());
        if (this.enableLatencyTracking) {
            Assertions.assertThat((Object)latencyTrackingState).isInstanceOf(LatencyTrackingAggregatingState.class);
        } else {
            Assertions.assertThat((Object)latencyTrackingState).isEqualTo((Object)aggregatingState);
        }
    }
}

