I want to alter transform method in StringIndexer to a new Class SafeStringIndexer :
sample code:
package org.apache.spark.ml.feature
import org.apache.spark.SparkException
import org.apache.spark.annotation.Since
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.StringType
class SafeStringIndexer extends StringIndexer {
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
if (!dataset.schema.fieldNames.contains($(inputCol))) {
logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " +
"Skip StringIndexerModel.")
return dataset.toDF
}
transformSchema(dataset.schema, logging = true)
val filteredLabels = getHandleInvalid match {
case StringIndexer.KEEP_INVALID => labels :+ "__unknown"
case _ => labels
}
val metadata = NominalAttribute.defaultAttr
.withName($(outputCol)).withValues(filteredLabels).toMetadata()
// If we are skipping invalid records, filter them out.
val (filteredDataset, keepInvalid) = $(handleInvalid) match {
case StringIndexer.SKIP_INVALID =>
val filterer = udf { label: String =>
labelToIndex.contains(label)
}
(dataset.na.drop(Array($(inputCol))).where(filterer(dataset($(inputCol)))), false)
case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_INVALID)
}
val indexer = udf { label: String =>f
if (label == null) {
if (keepInvalid) {
labels.length
} else {
throw new SparkException("StringIndexer encountered NULL value. To handle or skip " +
"NULLS, try setting StringIndexer.handleInvalid.")
}
} else {
if (labelToIndex.contains(label)) {
labelToIndex(label)
} else if (keepInvalid) {
labels.length
} else {
throw new SparkException(s"Unseen label: $label. To handle unseen labels, " +
s"set Param handleInvalid to ${StringIndexer.KEEP_INVALID}.")
}
}
}.asNondeterministic()
filteredDataset.select(col("*"),
indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata))
}
}
source code : https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
I am quite new to scala(from python), official document doesn't contain an example as complex as spark StringIndexer, it is not so easy:
labels is defined in StringIndexerModel, how do I access it?
@Since("1.4.0")
class StringIndexerModel (
@Since("1.4.0") override val uid: String,
@Since("1.5.0") val labels: Array[String])
extends Model[StringIndexerModel] with StringIndexerBase with MLWritable {
First, tried from https://stackoverflow.com/a/20330282/1637673
class SafeStringIndexer extends StringIndexer {
import StringIndexerModel._
Then I want to extend StringIndexModel at first, block again
Now I am quite confused with scala inheritance. I just want to extend (use the least code to reuse) Spark StringIndexer (the transform method) gracefully.