spark-decisionTreeRegressor(DTR回归决策树)源码解析

 

特征处理

 

获取特征类别

【概述】当前部分主要作用时获取

  /**
   * Examine a schema to identify categorical (Binary and Nominal) features.
   *
   * @param featuresSchema  Schema of the features column.
   *                        If a feature does not have metadata, it is assumed to be continuous.
   *                        If a feature is Nominal, then it must have the number of values
   *                        specified.
   * @return  Map: feature index to number of categories.
   *          The map's set of keys will be the set of categorical feature indices.
   */
  def getCategoricalFeatures(featuresSchema: StructField): Map[Int, Int] = {
    val metadata = AttributeGroup.fromStructField(featuresSchema)
    if (metadata.attributes.isEmpty) {
      HashMap.empty[Int, Int]
    } else {
      /*获取每列特征,的基元个数*/
      metadata.attributes.get.zipWithIndex.flatMap { case (attr, idx) =>
        if (attr == null) {
          /*若未经特征工程处理,此处为空,不做处理*/
          Iterator()
        } else {
          /*若讲过ML特征工程处理此处不为空*/
          attr match {
            /*若当前特征处理标记为NumericAttribute或者Un...该条列不做返回*/
            case _: NumericAttribute | UnresolvedAttribute => Iterator()
            /*若当前列经过而分类处理,返回当前列id和基元数目(2)*/
            case binAttr: BinaryAttribute => Iterator(idx -> 2)
            /*若当前列经过标准化处理(例如分桶)返回当前Id和数据元个数即为类别*/
            case nomAttr: NominalAttribute =>
              nomAttr.getNumValues match {
                case Some(numValues: Int) => Iterator(idx -> numValues)
                case None => throw new IllegalArgumentException(s"Feature $idx is marked as" +
                  " Nominal (categorical), but it does not have the number of values specified.")
              }
          }
        }
      }.toMap
    }
  }

  /**
   *封装参数组,若某列经过特征工程处理,会在StructField.metadata封装相关摘要信息,根据此特性判断,每个特征是否处理过
   *
   **/
  def fromStructField(field: StructField): AttributeGroup = {
    require(field.dataType == new VectorUDT)
    if (field.metadata.contains(ML_ATTR)) {
      fromMetadata(field.metadata.getMetadata(ML_ATTR), field.name)
    } else {
      new AttributeGroup(field.name)
    }
  }

下图为:经过MLlib特征工程 处理后列的MetaData包含ml_attr和特征处理相关的数据元,未经特征工程处理的MetaData则为空,查看方式df.schema.fields(_.metadata)

训练

  override protected def train(dataset: Dataset[_]): DecisionTreeRegressionModel = {
    val categoricalFeatures: Map[Int, Int] =
      MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
    val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
    val strategy = getOldStrategy(categoricalFeatures)

    val instr = Instrumentation.create(this, oldDataset)
    instr.logParams(params: _*)

    val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
      seed = $(seed), instr = Some(instr), parentUID = Some(uid))

    val m = trees.head.asInstanceOf[DecisionTreeRegressionModel]
    instr.logSuccess(m)
    m
  }

 


版权声明:本文为u010990043原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。