/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.checkpoint;

import java.io.Serializable;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import org.apache.flink.runtime.checkpoint.KeyGroupState;
import org.apache.flink.runtime.checkpoint.SubtaskState;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.state.StateHandle;
import org.apache.flink.util.SerializedValue;

public class TaskState
implements Serializable {
    private static final long serialVersionUID = -4845578005863201810L;
    private final JobVertexID jobVertexID;
    private final Map<Integer, SubtaskState> subtaskStates;
    private final Map<Integer, KeyGroupState> kvStates;
    private final int parallelism;

    public TaskState(JobVertexID jobVertexID, int parallelism) {
        this.jobVertexID = jobVertexID;
        this.subtaskStates = new HashMap<Integer, SubtaskState>(parallelism);
        this.kvStates = new HashMap<Integer, KeyGroupState>();
        this.parallelism = parallelism;
    }

    public JobVertexID getJobVertexID() {
        return this.jobVertexID;
    }

    public void putState(int subtaskIndex, SubtaskState subtaskState) {
        if (subtaskIndex < 0 || subtaskIndex >= this.parallelism) {
            throw new IndexOutOfBoundsException("The given sub task index " + subtaskIndex + " exceeds the maximum number of sub tasks " + this.subtaskStates.size());
        }
        this.subtaskStates.put(subtaskIndex, subtaskState);
    }

    public SubtaskState getState(int subtaskIndex) {
        if (subtaskIndex < 0 || subtaskIndex >= this.parallelism) {
            throw new IndexOutOfBoundsException("The given sub task index " + subtaskIndex + " exceeds the maximum number of sub tasks " + this.subtaskStates.size());
        }
        return this.subtaskStates.get(subtaskIndex);
    }

    public Collection<SubtaskState> getStates() {
        return this.subtaskStates.values();
    }

    public long getStateSize() {
        long result = 0L;
        for (SubtaskState subtaskState : this.subtaskStates.values()) {
            result += subtaskState.getStateSize();
        }
        for (KeyGroupState keyGroupState : this.kvStates.values()) {
            result += keyGroupState.getStateSize();
        }
        return result;
    }

    public int getNumberCollectedStates() {
        return this.subtaskStates.size();
    }

    public int getParallelism() {
        return this.parallelism;
    }

    public void putKvState(int keyGroupId, KeyGroupState keyGroupState) {
        this.kvStates.put(keyGroupId, keyGroupState);
    }

    public KeyGroupState getKvState(int keyGroupId) {
        return this.kvStates.get(keyGroupId);
    }

    public Map<Integer, SerializedValue<StateHandle<?>>> getUnwrappedKvStates(Set<Integer> keyGroupPartition) {
        HashMap result = new HashMap(keyGroupPartition.size());
        for (Integer keyGroupId : keyGroupPartition) {
            KeyGroupState keyGroupState = this.kvStates.get(keyGroupId);
            if (keyGroupState == null) continue;
            result.put(keyGroupId, this.kvStates.get(keyGroupId).getKeyGroupState());
        }
        return result;
    }

    public int getNumberCollectedKvStates() {
        return this.kvStates.size();
    }

    public void discard(ClassLoader classLoader) throws Exception {
        for (SubtaskState subtaskState : this.subtaskStates.values()) {
            subtaskState.discard(classLoader);
        }
        for (KeyGroupState keyGroupState : this.kvStates.values()) {
            keyGroupState.discard(classLoader);
        }
    }

    public boolean equals(Object obj) {
        if (obj instanceof TaskState) {
            TaskState other = (TaskState)obj;
            return this.jobVertexID.equals((Object)other.jobVertexID) && this.parallelism == other.parallelism && this.subtaskStates.equals(other.subtaskStates) && this.kvStates.equals(other.kvStates);
        }
        return false;
    }

    public int hashCode() {
        return this.parallelism + 31 * Objects.hash(new Object[]{this.jobVertexID, this.subtaskStates, this.kvStates});
    }
}

