特征处理
获取特征类别
【概述】当前部分主要作用时获取
/**
* Examine a schema to identify categorical (Binary and Nominal) features.
*
* @param featuresSchema Schema of the features column.
* If a feature does not have metadata, it is assumed to be continuous.
* If a feature is Nominal, then it must have the number of values
* specified.
* @return Map: feature index to number of categories.
* The map's set of keys will be the set of categorical feature indices.
*/
def getCategoricalFeatures(featuresSchema: StructField): Map[Int, Int] = {
val metadata = AttributeGroup.fromStructField(featuresSchema)
if (metadata.attributes.isEmpty) {
HashMap.empty[Int, Int]
} else {
/*获取每列特征,的基元个数*/
metadata.attributes.get.zipWithIndex.flatMap { case (attr, idx) =>
if (attr == null) {
/*若未经特征工程处理,此处为空,不做处理*/
Iterator()
} else {
/*若讲过ML特征工程处理此处不为空*/
attr match {
/*若当前特征处理标记为NumericAttribute或者Un...该条列不做返回*/
case _: NumericAttribute | UnresolvedAttribute => Iterator()
/*若当前列经过而分类处理,返回当前列id和基元数目(2)*/
case binAttr: BinaryAttribute => Iterator(idx -> 2)
/*若当前列经过标准化处理(例如分桶)返回当前Id和数据元个数即为类别*/
case nomAttr: NominalAttribute =>
nomAttr.getNumValues match {
case Some(numValues: Int) => Iterator(idx -> numValues)
case None => throw new IllegalArgumentException(s"Feature $idx is marked as" +
" Nominal (categorical), but it does not have the number of values specified.")
}
}
}
}.toMap
}
}
/**
*封装参数组,若某列经过特征工程处理,会在StructField.metadata封装相关摘要信息,根据此特性判断,每个特征是否处理过
*
**/
def fromStructField(field: StructField): AttributeGroup = {
require(field.dataType == new VectorUDT)
if (field.metadata.contains(ML_ATTR)) {
fromMetadata(field.metadata.getMetadata(ML_ATTR), field.name)
} else {
new AttributeGroup(field.name)
}
}下图为:经过MLlib特征工程 处理后列的MetaData包含ml_attr和特征处理相关的数据元,未经特征工程处理的MetaData则为空,查看方式df.schema.fields(_.metadata)

训练
override protected def train(dataset: Dataset[_]): DecisionTreeRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val strategy = getOldStrategy(categoricalFeatures)
val instr = Instrumentation.create(this, oldDataset)
instr.logParams(params: _*)
val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
val m = trees.head.asInstanceOf[DecisionTreeRegressionModel]
instr.logSuccess(m)
m
}
版权声明:本文为u010990043原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。