0

I want to alter transform method in StringIndexer to a new Class SafeStringIndexer :

sample code:

package org.apache.spark.ml.feature

import org.apache.spark.SparkException
import org.apache.spark.annotation.Since
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.StringType

class SafeStringIndexer extends StringIndexer {
  @Since("2.0.0")
  override def transform(dataset: Dataset[_]): DataFrame = {
    if (!dataset.schema.fieldNames.contains($(inputCol))) {
      logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " +
        "Skip StringIndexerModel.")
      return dataset.toDF
    }
    transformSchema(dataset.schema, logging = true)

    val filteredLabels = getHandleInvalid match {
      case StringIndexer.KEEP_INVALID => labels :+ "__unknown"
      case _ => labels
    }

    val metadata = NominalAttribute.defaultAttr
      .withName($(outputCol)).withValues(filteredLabels).toMetadata()
    // If we are skipping invalid records, filter them out.
    val (filteredDataset, keepInvalid) = $(handleInvalid) match {
      case StringIndexer.SKIP_INVALID =>
        val filterer = udf { label: String =>
          labelToIndex.contains(label)
        }
        (dataset.na.drop(Array($(inputCol))).where(filterer(dataset($(inputCol)))), false)
      case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_INVALID)
    }

    val indexer = udf { label: String =>f
      if (label == null) {
        if (keepInvalid) {
          labels.length
        } else {
          throw new SparkException("StringIndexer encountered NULL value. To handle or skip " +
            "NULLS, try setting StringIndexer.handleInvalid.")
        }
      } else {
        if (labelToIndex.contains(label)) {
          labelToIndex(label)
        } else if (keepInvalid) {
          labels.length
        } else {
          throw new SparkException(s"Unseen label: $label.  To handle unseen labels, " +
            s"set Param handleInvalid to ${StringIndexer.KEEP_INVALID}.")
        }
      }
    }.asNondeterministic()

    filteredDataset.select(col("*"),
      indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata))
  }
}

source code : https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala

I am quite new to scala(from python), official document doesn't contain an example as complex as spark StringIndexer, it is not so easy:

enter image description here

labels is defined in StringIndexerModel, how do I access it?

@Since("1.4.0")
class StringIndexerModel (
    @Since("1.4.0") override val uid: String,
    @Since("1.5.0") val labels: Array[String])
  extends Model[StringIndexerModel] with StringIndexerBase with MLWritable {

First, tried from https://stackoverflow.com/a/20330282/1637673

class SafeStringIndexer extends StringIndexer {

  import StringIndexerModel._

Doesn't make labels resolved enter image description here

Then I want to extend StringIndexModel at first, block again

enter image description here

Now I am quite confused with scala inheritance. I just want to extend (use the least code to reuse) Spark StringIndexer (the transform method) gracefully.

illuminator3
  • 102
  • 3
  • 10
Mithril
  • 11,666
  • 17
  • 90
  • 135

0 Answers0