Spark自定义UDAF函数(强类型-DSL语法)

1.继承Aggregator
2.实现方法
3.注册函数
4.通过DataSet数据集获取结果

package com.wxx.bigdata.sql03

import org.apache.spark.sql.{Encoders, SparkSession}
import org.apache.spark.sql.expressions.Aggregator

object CustomerUDAFClassAPP {
  def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder().master("local[2]").appName("CustomerUDFClassAPP").getOrCreate()

    val df = spark.read.json("data/test/user.json")
    import spark.implicits._
    //转换为DateSet
    val ds = df.as[Users]

    //将聚合函数转换为查询的列
    val avgage = CustomerAvg.toColumn.name("avgAge")
    //应用函数
    ds.select(avgage).show()

    spark.stop()
  }

}
//Exception in thread "main" org.apache.spark.sql.AnalysisException: Cannot up cast `age` from bigint to int as it may truncate
//The type path of the target object is:
//- field (class: "scala.Int", name: "age")
//- root class: "com.wxx.bigdata.sql03.Users"
case class Users(name : String, age : BigInt)
case class AvgBuffer(var sum : BigInt, var count : Int)

object CustomerAvg extends Aggregator[Users, AvgBuffer, Double]{
  override def zero = {
    AvgBuffer(0, 0)
  }
  // exector中的合并
  override def reduce(b: AvgBuffer, a: Users) = {
    b.sum =  b.sum + a.age
    b.count = b.count + 1
    b
  }
  //缓冲区的合并操作
  override def merge(b1: AvgBuffer, b2: AvgBuffer) = {
    b1.sum = b1.sum + b2.sum
    b1.count = b1.count + b2.count
    b1
  }

  //完成计算
  override def finish(reduction: AvgBuffer) = {
    reduction.sum.toDouble / reduction.count
  }

  override def bufferEncoder = Encoders.product

  override def outputEncoder = Encoders.scalaDouble
}

 


版权声明:本文为muyingmiao原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。