package org.apache.spark.shuffle.writer;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.common.util.concurrent.Uninterruptibles;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import org.apache.spark.Aggregator;
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.shuffle.RssShuffleHandle;
import org.apache.spark.shuffle.RssShuffleManager;
import org.apache.spark.shuffle.RssSparkConfig;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.storage.BlockManagerId;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.storage.util.StorageType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Function1;
import scala.Option;
import scala.Product2;
import scala.collection.Iterator;

/* loaded from: input_file:org/apache/spark/shuffle/writer/RssShuffleWriter.class */
public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
    private static final Logger LOG = LoggerFactory.getLogger(RssShuffleWriter.class);
    private static final String DUMMY_HOST = "dummy_host";
    private static final int DUMMY_PORT = 99999;
    private final String appId;
    private final int shuffleId;
    private WriteBufferManager bufferManager;
    private final String taskId;
    private final int numMaps;
    private final ShuffleDependency<K, V, C> shuffleDependency;
    private final Partitioner partitioner;
    private final RssShuffleManager shuffleManager;
    private final boolean shouldPartition;
    private final long sendCheckTimeout;
    private final long sendCheckInterval;
    private final int bitmapSplitNum;
    private final Map<Integer, Set<Long>> partitionToBlockIds;
    private final ShuffleWriteClient shuffleWriteClient;
    private final Map<Integer, List<ShuffleServerInfo>> partitionToServers;
    private final Set<ShuffleServerInfo> shuffleServersForData;
    private final long[] partitionLengths;
    private final boolean isMemoryShuffleEnabled;
    private final Function<String, Boolean> taskFailureCallback;
    private final Set<Long> blockIds;
    protected final long taskAttemptId;
    protected final ShuffleWriteMetrics shuffleWriteMetrics;
    private final BlockingQueue<Object> finishEventQueue;

    @VisibleForTesting
    public RssShuffleWriter(String str, int i, String str2, long j, WriteBufferManager writeBufferManager, ShuffleWriteMetrics shuffleWriteMetrics, RssShuffleManager rssShuffleManager, SparkConf sparkConf, ShuffleWriteClient shuffleWriteClient, RssShuffleHandle<K, V, C> rssShuffleHandle) {
        this(str, i, str2, j, shuffleWriteMetrics, rssShuffleManager, sparkConf, shuffleWriteClient, rssShuffleHandle, (Function<String, Boolean>) str3 -> {
            return true;
        });
        this.bufferManager = writeBufferManager;
    }

    private RssShuffleWriter(String str, int i, String str2, long j, ShuffleWriteMetrics shuffleWriteMetrics, RssShuffleManager rssShuffleManager, SparkConf sparkConf, ShuffleWriteClient shuffleWriteClient, RssShuffleHandle<K, V, C> rssShuffleHandle, Function<String, Boolean> function) {
        this.blockIds = Sets.newConcurrentHashSet();
        this.finishEventQueue = new LinkedBlockingQueue();
        LOG.warn("RssShuffle start write taskAttemptId data" + j);
        this.shuffleManager = rssShuffleManager;
        this.appId = str;
        this.shuffleId = i;
        this.taskId = str2;
        this.taskAttemptId = j;
        this.numMaps = rssShuffleHandle.getNumMaps();
        this.shuffleWriteMetrics = shuffleWriteMetrics;
        this.shuffleDependency = rssShuffleHandle.getDependency();
        this.partitioner = this.shuffleDependency.partitioner();
        this.shouldPartition = this.partitioner.numPartitions() > 1;
        this.sendCheckTimeout = ((Long) sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS)).longValue();
        this.sendCheckInterval = ((Long) sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS)).longValue();
        this.bitmapSplitNum = ((Integer) sparkConf.get(RssSparkConfig.RSS_CLIENT_BITMAP_SPLIT_NUM)).intValue();
        this.partitionToBlockIds = Maps.newHashMap();
        this.shuffleWriteClient = shuffleWriteClient;
        this.shuffleServersForData = rssShuffleHandle.getShuffleServersForData();
        this.partitionLengths = new long[this.partitioner.numPartitions()];
        Arrays.fill(this.partitionLengths, 0L);
        this.partitionToServers = rssShuffleHandle.getPartitionToServers();
        this.isMemoryShuffleEnabled = isMemoryShuffleEnabled(sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()));
        this.taskFailureCallback = function;
    }

    public RssShuffleWriter(String str, int i, String str2, long j, ShuffleWriteMetrics shuffleWriteMetrics, RssShuffleManager rssShuffleManager, SparkConf sparkConf, ShuffleWriteClient shuffleWriteClient, RssShuffleHandle<K, V, C> rssShuffleHandle, Function<String, Boolean> function, TaskContext taskContext) {
        this(str, i, str2, j, shuffleWriteMetrics, rssShuffleManager, sparkConf, shuffleWriteClient, rssShuffleHandle, function);
        this.bufferManager = new WriteBufferManager(i, str2, j, new BufferManagerOptions(sparkConf), rssShuffleHandle.getDependency().serializer(), rssShuffleHandle.getPartitionToServers(), taskContext.taskMemoryManager(), shuffleWriteMetrics, RssSparkConfig.toRssConf(sparkConf), this::processShuffleBlockInfos);
    }

    private boolean isMemoryShuffleEnabled(String str) {
        return StorageType.withMemory(StorageType.valueOf(str));
    }

    public void write(Iterator<Product2<K, V>> iterator) throws IOException {
        try {
            writeImpl(iterator);
        } catch (Exception e) {
            this.taskFailureCallback.apply(this.taskId);
            throw e;
        }
    }

    private void writeImpl(Iterator<Product2<K, V>> iterator) {
        List<ShuffleBlockInfo> addRecord;
        boolean mapSideCombine = this.shuffleDependency.mapSideCombine();
        Function1 function1 = null;
        if (mapSideCombine) {
            function1 = ((Aggregator) this.shuffleDependency.aggregator().get()).createCombiner();
        }
        while (iterator.hasNext()) {
            checkIfBlocksFailed();
            Product2 product2 = (Product2) iterator.next();
            int partition = getPartition(product2._1());
            if (mapSideCombine) {
                addRecord = this.bufferManager.addRecord(partition, product2._1(), function1.apply(product2._2()));
            } else {
                addRecord = this.bufferManager.addRecord(partition, product2._1(), product2._2());
            }
            if (addRecord != null && !addRecord.isEmpty()) {
                processShuffleBlockInfos(addRecord);
            }
        }
        long currentTimeMillis = System.currentTimeMillis();
        List<ShuffleBlockInfo> clear = this.bufferManager.clear();
        if (clear != null && !clear.isEmpty()) {
            processShuffleBlockInfos(clear);
        }
        long currentTimeMillis2 = System.currentTimeMillis();
        checkBlockSendResult(this.blockIds);
        long currentTimeMillis3 = System.currentTimeMillis();
        long j = currentTimeMillis3 - currentTimeMillis2;
        if (!this.isMemoryShuffleEnabled) {
            sendCommit();
        }
        long writeTime = this.bufferManager.getWriteTime() + (System.currentTimeMillis() - currentTimeMillis);
        this.shuffleWriteMetrics.incWriteTime(TimeUnit.MILLISECONDS.toNanos(writeTime));
        LOG.info("Finish write shuffle for appId[" + this.appId + "], shuffleId[" + this.shuffleId + "], taskId[" + this.taskId + "] with write " + writeTime + " ms, include checkSendResult[" + j + "], commit[" + (System.currentTimeMillis() - currentTimeMillis3) + "], " + this.bufferManager.getManagerCostInfo());
    }

    public long[] getPartitionLengths() {
        return new long[0];
    }

    @VisibleForTesting
    protected List<CompletableFuture<Long>> processShuffleBlockInfos(List<ShuffleBlockInfo> list) {
        if (list == null || list.isEmpty()) {
            return Collections.emptyList();
        }
        list.forEach(shuffleBlockInfo -> {
            long blockId = shuffleBlockInfo.getBlockId();
            this.blockIds.add(Long.valueOf(blockId));
            int partitionId = shuffleBlockInfo.getPartitionId();
            this.partitionToBlockIds.computeIfAbsent(Integer.valueOf(partitionId), num -> {
                return Sets.newHashSet();
            }).add(Long.valueOf(blockId));
            long[] jArr = this.partitionLengths;
            jArr[partitionId] = jArr[partitionId] + shuffleBlockInfo.getLength();
        });
        return postBlockEvent(list);
    }

    protected List<CompletableFuture<Long>> postBlockEvent(List<ShuffleBlockInfo> list) {
        ArrayList arrayList = new ArrayList();
        for (AddBlockEvent addBlockEvent : this.bufferManager.buildBlockEvents(list)) {
            addBlockEvent.addCallback(() -> {
                if (this.finishEventQueue.add(new Object())) {
                    return;
                }
                LOG.error("Add event " + addBlockEvent + " to finishEventQueue fail");
            });
            arrayList.add(this.shuffleManager.sendData(addBlockEvent));
        }
        return arrayList;
    }

    @VisibleForTesting
    protected void checkBlockSendResult(Set<Long> set) {
        boolean z = false;
        try {
            long currentTimeMillis = System.currentTimeMillis() + this.sendCheckTimeout;
            while (true) {
                try {
                    this.finishEventQueue.clear();
                    checkIfBlocksFailed();
                    set.removeAll(this.shuffleManager.getSuccessBlockIds(this.taskId));
                } catch (InterruptedException e) {
                    z = true;
                }
                if (!set.isEmpty()) {
                    if (this.finishEventQueue.isEmpty()) {
                        if (this.finishEventQueue.poll(Math.max(currentTimeMillis - System.currentTimeMillis(), 0L), TimeUnit.MILLISECONDS) == null) {
                            break;
                        }
                    }
                } else {
                    break;
                }
            }
            if (set.isEmpty()) {
            } else {
                String str = "Timeout: Task[" + this.taskId + "] failed because " + set.size() + " blocks can't be sent to shuffle server in " + this.sendCheckTimeout + " ms.";
                LOG.error(str);
                throw new RssException(str);
            }
        } finally {
            if (z) {
                Thread.currentThread().interrupt();
            }
        }
    }

    private void checkIfBlocksFailed() {
        Set<Long> failedBlockIds = this.shuffleManager.getFailedBlockIds(this.taskId);
        if (failedBlockIds.isEmpty()) {
            return;
        }
        String str = "Send failed: Task[" + this.taskId + "] failed because " + failedBlockIds.size() + " blocks can't be sent to shuffle server.";
        LOG.error(str);
        throw new RssException(str);
    }

    @VisibleForTesting
    protected void sendCommit() {
        ExecutorService newSingleThreadExecutor = Executors.newSingleThreadExecutor();
        Future submit = newSingleThreadExecutor.submit(() -> {
            return Boolean.valueOf(this.shuffleWriteClient.sendCommit(this.shuffleServersForData, this.appId, this.shuffleId, this.numMaps));
        });
        int i = 200;
        long currentTimeMillis = System.currentTimeMillis();
        while (!submit.isDone()) {
            try {
                LOG.info("Wait commit to shuffle server for task[" + this.taskAttemptId + "] cost " + (System.currentTimeMillis() - currentTimeMillis) + " ms");
                Uninterruptibles.sleepUninterruptibly(i, TimeUnit.MILLISECONDS);
                i = Math.min(i * 2, 5000);
            } catch (Throwable th) {
                newSingleThreadExecutor.shutdown();
                throw th;
            }
        }
        try {
            try {
                if (!((Boolean) submit.get()).booleanValue()) {
                    throw new RssException("Failed to commit task to shuffle server");
                }
                newSingleThreadExecutor.shutdown();
            } catch (InterruptedException e) {
                LOG.warn("Ignore the InterruptedException which should be caused by internal killed");
                newSingleThreadExecutor.shutdown();
            }
        } catch (Exception e2) {
            throw new RssException("Exception happened when get commit status", e2);
        }
    }

    @VisibleForTesting
    protected <T> int getPartition(T t) {
        int i = 0;
        if (this.shouldPartition) {
            i = this.partitioner.getPartition(t);
        }
        return i;
    }

    public Option<MapStatus> stop(boolean z) {
        try {
            if (!z) {
                Option<MapStatus> empty = Option.empty();
                if (this.bufferManager != null) {
                    this.bufferManager.freeAllMemory();
                }
                if (this.shuffleManager != null) {
                    this.shuffleManager.clearTaskMeta(this.taskId);
                }
                return empty;
            }
            HashMap newHashMap = Maps.newHashMap();
            for (Map.Entry<Integer, Set<Long>> entry : this.partitionToBlockIds.entrySet()) {
                newHashMap.put(entry.getKey(), Lists.newArrayList(entry.getValue()));
            }
            long currentTimeMillis = System.currentTimeMillis();
            this.shuffleWriteClient.reportShuffleResult(this.partitionToServers, this.appId, this.shuffleId, this.taskAttemptId, newHashMap, this.bitmapSplitNum);
            LOG.info("Report shuffle result for task[{}] with bitmapNum[{}] cost {} ms", new Object[]{Long.valueOf(this.taskAttemptId), Integer.valueOf(this.bitmapSplitNum), Long.valueOf(System.currentTimeMillis() - currentTimeMillis)});
            Option<MapStatus> apply = Option.apply(MapStatus.apply(BlockManagerId.apply(this.appId + "_" + this.taskId, DUMMY_HOST, DUMMY_PORT, Option.apply(Long.toString(this.taskAttemptId))), this.partitionLengths, this.taskAttemptId));
            if (this.bufferManager != null) {
                this.bufferManager.freeAllMemory();
            }
            if (this.shuffleManager != null) {
                this.shuffleManager.clearTaskMeta(this.taskId);
            }
            return apply;
        } catch (Throwable th) {
            if (this.bufferManager != null) {
                this.bufferManager.freeAllMemory();
            }
            if (this.shuffleManager != null) {
                this.shuffleManager.clearTaskMeta(this.taskId);
            }
            throw th;
        }
    }

    @VisibleForTesting
    Map<Integer, Set<Long>> getPartitionToBlockIds() {
        return this.partitionToBlockIds;
    }

    @VisibleForTesting
    public WriteBufferManager getBufferManager() {
        return this.bufferManager;
    }
}
