package org.apache.spark.ml.util;

import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.SparseVector;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row$;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DoubleType;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.FloatType;
import org.apache.spark.sql.types.IntegerType$;
import org.apache.spark.sql.types.StringType$;
import scala.Array$;
import scala.Function1;
import scala.MatchError;
import scala.Option;
import scala.Predef$;
import scala.Some;
import scala.collection.SeqLike;
import scala.collection.immutable.$colon;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.reflect.api.Mirror;
import scala.reflect.api.TypeCreator;
import scala.reflect.api.TypeTags;
import scala.reflect.api.Types;
import scala.reflect.api.Universe;
import scala.reflect.runtime.package$;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichDouble$;

/* compiled from: DatasetUtils.scala */
/* loaded from: input_file:org/apache/spark/ml/util/DatasetUtils$.class */
public final class DatasetUtils$ {
    public static DatasetUtils$ MODULE$;
    private UserDefinedFunction validateVector;
    private volatile boolean bitmap$0;

    static {
        new DatasetUtils$();
    }

    public Column checkNonNanValues(String str, String str2) {
        Column cast = functions$.MODULE$.col(str).cast(DoubleType$.MODULE$);
        return functions$.MODULE$.when(cast.isNull().$bar$bar(cast.isNaN()), functions$.MODULE$.raise_error(functions$.MODULE$.lit(new StringBuilder(24).append(str2).append(" MUST NOT be Null or NaN").toString()))).when(cast.$eq$eq$eq(BoxesRunTime.boxToDouble(Double.NEGATIVE_INFINITY)).$bar$bar(cast.$eq$eq$eq(BoxesRunTime.boxToDouble(Double.POSITIVE_INFINITY))), functions$.MODULE$.raise_error(functions$.MODULE$.concat(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.lit(new StringBuilder(31).append(str2).append(" MUST NOT be Infinity, but got ").toString()), cast})))).otherwise(cast);
    }

    public Column checkRegressionLabels(String str) {
        return checkNonNanValues(str, "Labels");
    }

    public Column checkClassificationLabels(String str, Option<Object> option) {
        Column otherwise;
        Column cast = functions$.MODULE$.col(str).cast(DoubleType$.MODULE$);
        if ((option instanceof Some) && 2 == BoxesRunTime.unboxToInt(((Some) option).value())) {
            otherwise = functions$.MODULE$.when(cast.isNull().$bar$bar(cast.isNaN()), functions$.MODULE$.raise_error(functions$.MODULE$.lit("Labels MUST NOT be Null or NaN"))).when(cast.$eq$bang$eq(BoxesRunTime.boxToInteger(0)).$amp$amp(cast.$eq$bang$eq(BoxesRunTime.boxToInteger(1))), functions$.MODULE$.raise_error(functions$.MODULE$.concat(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.lit("Labels MUST be in {0, 1}, but got "), cast})))).otherwise(cast);
        } else {
            int unboxToInt = BoxesRunTime.unboxToInt(option.getOrElse(() -> {
                return Integer.MAX_VALUE;
            }));
            Predef$.MODULE$.require(0 < unboxToInt && unboxToInt <= Integer.MAX_VALUE);
            otherwise = functions$.MODULE$.when(cast.isNull().$bar$bar(cast.isNaN()), functions$.MODULE$.raise_error(functions$.MODULE$.lit("Labels MUST NOT be Null or NaN"))).when(cast.$less(BoxesRunTime.boxToInteger(0)).$bar$bar(cast.$greater$eq(BoxesRunTime.boxToInteger(unboxToInt))), functions$.MODULE$.raise_error(functions$.MODULE$.concat(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.lit(new StringBuilder(33).append("Labels MUST be in [0, ").append(unboxToInt).append("), but got ").toString()), cast})))).when(cast.$eq$bang$eq(cast.cast(IntegerType$.MODULE$)), functions$.MODULE$.raise_error(functions$.MODULE$.concat(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.lit("Labels MUST be Integers, but got "), cast})))).otherwise(cast);
        }
        return otherwise;
    }

    public Column checkNonNegativeWeights(String str) {
        Column cast = functions$.MODULE$.col(str).cast(DoubleType$.MODULE$);
        return functions$.MODULE$.when(cast.isNull().$bar$bar(cast.isNaN()), functions$.MODULE$.raise_error(functions$.MODULE$.lit("Weights MUST NOT be Null or NaN"))).when(cast.$less(BoxesRunTime.boxToInteger(0)).$bar$bar(cast.$eq$eq$eq(BoxesRunTime.boxToDouble(Double.POSITIVE_INFINITY))), functions$.MODULE$.raise_error(functions$.MODULE$.concat(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.lit("Weights MUST NOT be Negative or Infinity, but got "), cast})))).otherwise(cast);
    }

    public Column checkNonNegativeWeights(Option<String> option) {
        Column lit;
        if (option instanceof Some) {
            String str = (String) ((Some) option).value();
            if (new StringOps(Predef$.MODULE$.augmentString(str)).nonEmpty()) {
                lit = checkNonNegativeWeights(str);
                return lit;
            }
        }
        lit = functions$.MODULE$.lit(BoxesRunTime.boxToDouble(1.0d));
        return lit;
    }

    public Column checkNonNanVectors(Column column) {
        return functions$.MODULE$.when(column.isNull(), functions$.MODULE$.raise_error(functions$.MODULE$.lit("Vectors MUST NOT be Null"))).when(validateVector().apply(Predef$.MODULE$.wrapRefArray(new Column[]{column})).unary_$bang(), functions$.MODULE$.raise_error(functions$.MODULE$.concat(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.lit("Vector values MUST NOT be NaN or Infinity, but got "), column.cast(StringType$.MODULE$)})))).otherwise(column);
    }

    public Column checkNonNanVectors(String str) {
        return checkNonNanVectors(functions$.MODULE$.col(str));
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v0 */
    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Throwable] */
    /* JADX WARN: Type inference failed for: r0v8, types: [org.apache.spark.ml.util.DatasetUtils$] */
    private UserDefinedFunction validateVector$lzycompute() {
        ?? r0 = this;
        synchronized (r0) {
            if (!this.bitmap$0) {
                functions$ functions_ = functions$.MODULE$;
                Function1 function1 = vector -> {
                    return BoxesRunTime.boxToBoolean($anonfun$validateVector$1(vector));
                };
                TypeTags.TypeTag Boolean = package$.MODULE$.universe().TypeTag().Boolean();
                TypeTags universe = package$.MODULE$.universe();
                this.validateVector = functions_.udf(function1, Boolean, universe.TypeTag().apply(package$.MODULE$.universe().runtimeMirror(getClass().getClassLoader()), new TypeCreator() { // from class: org.apache.spark.ml.util.DatasetUtils$$typecreator1$1
                    public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                        mirror.universe();
                        return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
                    }
                }));
                r0 = this;
                r0.bitmap$0 = true;
            }
        }
        return this.validateVector;
    }

    private UserDefinedFunction validateVector() {
        return !this.bitmap$0 ? validateVector$lzycompute() : this.validateVector;
    }

    public Column columnToVector(Dataset<?> dataset, String str) {
        UserDefinedFunction udf;
        Column apply;
        ArrayType dataType = dataset.schema().apply(str).dataType();
        if (dataType instanceof VectorUDT) {
            apply = functions$.MODULE$.col(str);
        } else {
            if (!(dataType instanceof ArrayType)) {
                throw new IllegalArgumentException(new StringBuilder(32).append(dataType).append(" column cannot be cast to Vector").toString());
            }
            DataType elementType = dataType.elementType();
            if (elementType instanceof FloatType) {
                udf = functions$.MODULE$.udf(seq -> {
                    double[] dArr = (double[]) Array$.MODULE$.ofDim(seq.size(), ClassTag$.MODULE$.Double());
                    seq.indices().foreach$mVc$sp(i -> {
                        dArr[i] = BoxesRunTime.unboxToFloat(seq.apply(i));
                    });
                    return Vectors$.MODULE$.dense(dArr);
                }, package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(getClass().getClassLoader()), new TypeCreator() { // from class: org.apache.spark.ml.util.DatasetUtils$$typecreator1$2
                    public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                        mirror.universe();
                        return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
                    }
                }), package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(getClass().getClassLoader()), new TypeCreator() { // from class: org.apache.spark.ml.util.DatasetUtils$$typecreator2$1
                    public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                        Universe universe = mirror.universe();
                        return universe.internal().reificationSupport().TypeRef(universe.internal().reificationSupport().SingleType(universe.internal().reificationSupport().SingleType(universe.internal().reificationSupport().thisPrefix(mirror.RootClass()), mirror.staticPackage("scala")), mirror.staticModule("scala.package")), universe.internal().reificationSupport().selectType(mirror.staticModule("scala.package").asModule().moduleClass(), "Seq"), new $colon.colon(mirror.staticClass("scala.Float").asType().toTypeConstructor(), Nil$.MODULE$));
                    }
                }));
            } else {
                if (!(elementType instanceof DoubleType)) {
                    throw new IllegalArgumentException(new StringBuilder(39).append("Array[").append(elementType).append("] column cannot be cast to Vector").toString());
                }
                udf = functions$.MODULE$.udf(seq2 -> {
                    return Vectors$.MODULE$.dense((double[]) seq2.toArray(ClassTag$.MODULE$.Double()));
                }, package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(getClass().getClassLoader()), new TypeCreator() { // from class: org.apache.spark.ml.util.DatasetUtils$$typecreator3$1
                    public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                        mirror.universe();
                        return mirror.staticClass("org.apache.spark.ml.linalg.Vector").asType().toTypeConstructor();
                    }
                }), package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(getClass().getClassLoader()), new TypeCreator() { // from class: org.apache.spark.ml.util.DatasetUtils$$typecreator4$1
                    public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                        Universe universe = mirror.universe();
                        return universe.internal().reificationSupport().TypeRef(universe.internal().reificationSupport().SingleType(universe.internal().reificationSupport().SingleType(universe.internal().reificationSupport().thisPrefix(mirror.RootClass()), mirror.staticPackage("scala")), mirror.staticModule("scala.package")), universe.internal().reificationSupport().selectType(mirror.staticModule("scala.package").asModule().moduleClass(), "Seq"), new $colon.colon(mirror.staticClass("scala.Double").asType().toTypeConstructor(), Nil$.MODULE$));
                    }
                }));
            }
            apply = udf.apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(str)}));
        }
        return apply;
    }

    public RDD<Vector> columnToOldVector(Dataset<?> dataset, String str) {
        return dataset.select(Predef$.MODULE$.wrapRefArray(new Column[]{columnToVector(dataset, str)})).rdd().map(row -> {
            Some unapplySeq = Row$.MODULE$.unapplySeq(row);
            if (!unapplySeq.isEmpty() && unapplySeq.get() != null && ((SeqLike) unapplySeq.get()).lengthCompare(1) == 0) {
                Object apply = ((SeqLike) unapplySeq.get()).apply(0);
                if (apply instanceof org.apache.spark.ml.linalg.Vector) {
                    return org.apache.spark.mllib.linalg.Vectors$.MODULE$.fromML((org.apache.spark.ml.linalg.Vector) apply);
                }
            }
            throw new MatchError(row);
        }, ClassTag$.MODULE$.apply(Vector.class));
    }

    public static final /* synthetic */ boolean $anonfun$validateVector$1(org.apache.spark.ml.linalg.Vector vector) {
        boolean forall;
        if (vector instanceof DenseVector) {
            forall = new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(((DenseVector) vector).values())).forall(d -> {
                return (Double.isNaN(d) || RichDouble$.MODULE$.isInfinity$extension(Predef$.MODULE$.doubleWrapper(d))) ? false : true;
            });
        } else {
            if (!(vector instanceof SparseVector)) {
                throw new MatchError(vector);
            }
            forall = new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(((SparseVector) vector).values())).forall(d2 -> {
                return (Double.isNaN(d2) || RichDouble$.MODULE$.isInfinity$extension(Predef$.MODULE$.doubleWrapper(d2))) ? false : true;
            });
        }
        return forall;
    }

    private DatasetUtils$() {
        MODULE$ = this;
    }
}
