package org.apache.flink.runtime.checkpoint.channel;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
import org.apache.flink.runtime.checkpoint.RescaleMappings;
import org.apache.flink.runtime.checkpoint.channel.RecoveredChannelStateHandler;
import org.apache.flink.runtime.io.network.api.SubtaskConnectionDescriptor;
import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.partition.consumer.InputChannel;
import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
import org.apache.flink.runtime.io.network.partition.consumer.RecoveredInputChannel;

/* compiled from: RecoveredChannelStateHandler.java */
/* loaded from: input_file:org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandler.class */
class InputChannelRecoveredStateHandler implements RecoveredChannelStateHandler<InputChannelInfo, Buffer> {
    private final InputGate[] inputGates;
    private final InflightDataRescalingDescriptor channelMapping;
    private final Map<InputChannelInfo, List<RecoveredInputChannel>> rescaledChannels = new HashMap();
    private final Map<Integer, RescaleMappings> oldToNewMappings = new HashMap();

    /* JADX INFO: Access modifiers changed from: package-private */
    public InputChannelRecoveredStateHandler(InputGate[] inputGateArr, InflightDataRescalingDescriptor inflightDataRescalingDescriptor) {
        this.inputGates = inputGateArr;
        this.channelMapping = inflightDataRescalingDescriptor;
    }

    @Override // org.apache.flink.runtime.checkpoint.channel.RecoveredChannelStateHandler
    public RecoveredChannelStateHandler.BufferWithContext<Buffer> getBuffer(InputChannelInfo inputChannelInfo) throws IOException, InterruptedException {
        Buffer requestBufferBlocking = getMappedChannels(inputChannelInfo).get(0).requestBufferBlocking();
        return new RecoveredChannelStateHandler.BufferWithContext<>(ChannelStateByteBuffer.wrap(requestBufferBlocking), requestBufferBlocking);
    }

    @Override // org.apache.flink.runtime.checkpoint.channel.RecoveredChannelStateHandler
    public void recover(InputChannelInfo inputChannelInfo, int i, Buffer buffer) throws IOException {
        try {
            if (buffer.readableBytes() > 0) {
                for (RecoveredInputChannel recoveredInputChannel : getMappedChannels(inputChannelInfo)) {
                    recoveredInputChannel.onRecoveredStateBuffer(EventSerializer.toBuffer(new SubtaskConnectionDescriptor(i, inputChannelInfo.getInputChannelIdx()), false));
                    recoveredInputChannel.onRecoveredStateBuffer(buffer.retainBuffer());
                }
            }
        } finally {
            buffer.recycleBuffer();
        }
    }

    @Override // java.lang.AutoCloseable
    public void close() throws IOException {
        for (InputGate inputGate : this.inputGates) {
            inputGate.finishReadRecoveredState();
        }
    }

    private RecoveredInputChannel getChannel(int i, int i2) {
        InputChannel channel = this.inputGates[i].getChannel(i2);
        if (channel instanceof RecoveredInputChannel) {
            return (RecoveredInputChannel) channel;
        }
        throw new IllegalStateException("Cannot restore state to a non-recovered input channel: " + channel);
    }

    private List<RecoveredInputChannel> getMappedChannels(InputChannelInfo inputChannelInfo) {
        return this.rescaledChannels.computeIfAbsent(inputChannelInfo, this::calculateMapping);
    }

    private List<RecoveredInputChannel> calculateMapping(InputChannelInfo inputChannelInfo) {
        List<RecoveredInputChannel> list = (List) Arrays.stream(this.oldToNewMappings.computeIfAbsent(Integer.valueOf(inputChannelInfo.getGateIdx()), num -> {
            return this.channelMapping.getChannelMapping(num.intValue()).invert();
        }).getMappedIndexes(inputChannelInfo.getInputChannelIdx())).mapToObj(i -> {
            return getChannel(inputChannelInfo.getGateIdx(), i);
        }).collect(Collectors.toList());
        if (list.isEmpty()) {
            throw new IllegalStateException("Recovered a buffer from old " + inputChannelInfo + " that has no mapping in " + this.channelMapping.getChannelMapping(inputChannelInfo.getGateIdx()));
        }
        return list;
    }
}
