package org.apache.flink.runtime.checkpoint;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.stream.Collectors;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.common.JobID;
import org.apache.flink.runtime.execution.ExecutionState;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.executiongraph.InternalExecutionGraphAccessor;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.jobgraph.JobEdge;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/runtime/checkpoint/DefaultCheckpointPlanCalculator.class */
public class DefaultCheckpointPlanCalculator implements CheckpointPlanCalculator {
    private final JobID jobId;
    private final CheckpointPlanCalculatorContext context;
    private final List<ExecutionJobVertex> jobVerticesInTopologyOrder = new ArrayList();
    private final List<ExecutionVertex> allTasks = new ArrayList();
    private final List<ExecutionVertex> sourceTasks = new ArrayList();
    private boolean allowCheckpointsAfterTasksFinished;

    public DefaultCheckpointPlanCalculator(JobID jobID, CheckpointPlanCalculatorContext checkpointPlanCalculatorContext, Iterable<ExecutionJobVertex> iterable) {
        this.jobId = (JobID) Preconditions.checkNotNull(jobID);
        this.context = (CheckpointPlanCalculatorContext) Preconditions.checkNotNull(checkpointPlanCalculatorContext);
        Preconditions.checkNotNull(iterable);
        iterable.forEach(executionJobVertex -> {
            this.jobVerticesInTopologyOrder.add(executionJobVertex);
            this.allTasks.addAll(Arrays.asList(executionJobVertex.getTaskVertices()));
            if (executionJobVertex.getJobVertex().isInputVertex()) {
                this.sourceTasks.addAll(Arrays.asList(executionJobVertex.getTaskVertices()));
            }
        });
    }

    public void setAllowCheckpointsAfterTasksFinished(boolean z) {
        this.allowCheckpointsAfterTasksFinished = z;
    }

    @Override // org.apache.flink.runtime.checkpoint.CheckpointPlanCalculator
    public CompletableFuture<CheckpointPlan> calculateCheckpointPlan() {
        return CompletableFuture.supplyAsync(() -> {
            try {
                if (this.context.hasFinishedTasks() && !this.allowCheckpointsAfterTasksFinished) {
                    throw new CheckpointException(String.format("some tasks of job %s has been finished, abort the checkpoint", this.jobId), CheckpointFailureReason.NOT_ALL_REQUIRED_TASKS_RUNNING);
                }
                checkAllTasksInitiated();
                CheckpointPlan calculateAfterTasksFinished = this.context.hasFinishedTasks() ? calculateAfterTasksFinished() : calculateWithAllTasksRunning();
                checkTasksStarted(calculateAfterTasksFinished.getTasksToTrigger());
                return calculateAfterTasksFinished;
            } catch (Throwable th) {
                throw new CompletionException(th);
            }
        }, this.context.getMainExecutor());
    }

    private void checkAllTasksInitiated() throws CheckpointException {
        for (ExecutionVertex executionVertex : this.allTasks) {
            if (executionVertex.getCurrentExecutionAttempt() == null) {
                throw new CheckpointException(String.format("task %s of job %s is not being executed at the moment. Aborting checkpoint.", executionVertex.getTaskNameWithSubtaskIndex(), this.jobId), CheckpointFailureReason.NOT_ALL_REQUIRED_TASKS_RUNNING);
            }
        }
    }

    private void checkTasksStarted(List<Execution> list) throws CheckpointException {
        for (Execution execution : list) {
            if (execution.getState() == ExecutionState.CREATED || execution.getState() == ExecutionState.SCHEDULED || execution.getState() == ExecutionState.DEPLOYING) {
                throw new CheckpointException(String.format("Checkpoint triggering task %s of job %s has not being executed at the moment. Aborting checkpoint.", execution.getVertex().getTaskNameWithSubtaskIndex(), this.jobId), CheckpointFailureReason.NOT_ALL_REQUIRED_TASKS_RUNNING);
            }
        }
    }

    private CheckpointPlan calculateWithAllTasksRunning() {
        return new CheckpointPlan(Collections.unmodifiableList((List) this.sourceTasks.stream().map((v0) -> {
            return v0.getCurrentExecutionAttempt();
        }).collect(Collectors.toList())), Collections.unmodifiableList(createTaskToWaitFor(this.allTasks)), Collections.unmodifiableList(this.allTasks), Collections.emptyList(), Collections.emptyList());
    }

    private CheckpointPlan calculateAfterTasksFinished() {
        Map<JobVertexID, BitSet> collectTaskRunningStatus = collectTaskRunningStatus();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        ArrayList arrayList5 = new ArrayList();
        for (ExecutionJobVertex executionJobVertex : this.jobVerticesInTopologyOrder) {
            BitSet bitSet = collectTaskRunningStatus.get(executionJobVertex.getJobVertexId());
            if (bitSet.cardinality() == 0) {
                arrayList5.add(executionJobVertex);
                for (ExecutionVertex executionVertex : executionJobVertex.getTaskVertices()) {
                    arrayList4.add(executionVertex.getCurrentExecutionAttempt());
                }
            } else {
                List<JobEdge> inputs = executionJobVertex.getJobVertex().getInputs();
                boolean someTasksMustBeTriggered = someTasksMustBeTriggered(collectTaskRunningStatus, inputs);
                for (int i = 0; i < executionJobVertex.getTaskVertices().length; i++) {
                    ExecutionVertex executionVertex2 = executionJobVertex.getTaskVertices()[i];
                    if (bitSet.get(executionVertex2.getParallelSubtaskIndex())) {
                        arrayList2.add(executionVertex2.getCurrentExecutionAttempt());
                        arrayList3.add(executionVertex2);
                        if (someTasksMustBeTriggered && !hasRunningPrecedentTasks(executionVertex2, inputs, collectTaskRunningStatus)) {
                            arrayList.add(executionVertex2.getCurrentExecutionAttempt());
                        }
                    } else {
                        arrayList4.add(executionVertex2.getCurrentExecutionAttempt());
                    }
                }
            }
        }
        return new CheckpointPlan(Collections.unmodifiableList(arrayList), Collections.unmodifiableList(arrayList2), Collections.unmodifiableList(arrayList3), Collections.unmodifiableList(arrayList4), Collections.unmodifiableList(arrayList5));
    }

    private boolean someTasksMustBeTriggered(Map<JobVertexID, BitSet> map, List<JobEdge> list) {
        for (JobEdge jobEdge : list) {
            if (hasActiveUpstreamVertex(jobEdge.getDistributionPattern(), map.get(jobEdge.getSource().getProducer().getID()))) {
                return false;
            }
        }
        return true;
    }

    private boolean hasActiveUpstreamVertex(DistributionPattern distributionPattern, BitSet bitSet) {
        return (distributionPattern == DistributionPattern.ALL_TO_ALL && bitSet.cardinality() > 0) || (distributionPattern == DistributionPattern.POINTWISE && bitSet.cardinality() == bitSet.size());
    }

    private boolean hasRunningPrecedentTasks(ExecutionVertex executionVertex, List<JobEdge> list, Map<JobVertexID, BitSet> map) {
        InternalExecutionGraphAccessor executionGraphAccessor = executionVertex.getExecutionGraphAccessor();
        for (int i = 0; i < list.size(); i++) {
            if (list.get(i).getDistributionPattern() == DistributionPattern.POINTWISE) {
                Iterator<IntermediateResultPartitionID> it = executionVertex.getConsumedPartitions(i).iterator();
                while (it.hasNext()) {
                    ExecutionVertex producer = executionGraphAccessor.getResultPartitionOrThrow(it.next()).getProducer();
                    if (map.get(producer.getJobvertexId()).get(producer.getParallelSubtaskIndex())) {
                        return true;
                    }
                }
            }
        }
        return false;
    }

    @VisibleForTesting
    Map<JobVertexID, BitSet> collectTaskRunningStatus() {
        HashMap hashMap = new HashMap();
        for (ExecutionJobVertex executionJobVertex : this.jobVerticesInTopologyOrder) {
            BitSet bitSet = new BitSet(executionJobVertex.getTaskVertices().length);
            for (int i = 0; i < executionJobVertex.getTaskVertices().length; i++) {
                if (!executionJobVertex.getTaskVertices()[i].getCurrentExecutionAttempt().isFinished()) {
                    bitSet.set(i);
                }
            }
            hashMap.put(executionJobVertex.getJobVertexId(), bitSet);
        }
        return hashMap;
    }

    private List<Execution> createTaskToWaitFor(List<ExecutionVertex> list) {
        ArrayList arrayList = new ArrayList(list.size());
        Iterator<ExecutionVertex> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().getCurrentExecutionAttempt());
        }
        return arrayList;
    }
}
