spark写入pg_SparkSql批量插入或更新 保存数据到Pgsql

在sparksql 中,保存数据到数据,只有 Append , Overwrite , ErrorIfExists, Ignore 四种模式,不满足项目需求 ,现依据 spark save 源码,进行进一步的改造, 批量保存数据,存在则更新 不存在 则插入

**

*测试用例

* 批量保存数据,存在则更新 不存在 则插入

* INSERT INTO test_001 VALUES( ?, ?, ? )

* ON conflict ( ID ) DO

* UPDATE SET id=?,NAME = ?,age = ?;

* @author linzhy

*/

object InsertOrUpdateTest {

def main(args: Array[String]): Unit = {

val spark = SparkSession.builder()

.appName(this.getClass.getSimpleName)

.master("local[2]")

.config("spark.debug.maxToStringFields","100")

.getOrCreate()

var config = ConfigFactory.load()

val ods_url = config.getString("pg.oucloud_ods.url")

val ods_user = config.getString("pg.oucloud_ods.user")

val ods_password = config.getString("pg.oucloud_ods.password")

val test_001 = spark.read.format("jdbc")

.option("url", ods_url)

.option("dbtable", "test_001")

.option("user", ods_user)

.option("password", ods_password)

.load()

test_001.createOrReplaceTempView("test_001")

val sql=

"""

|SELECT * FROM test_001

|""".stripMargin

val dataFrame = spark.sql(sql)

//批量保存数据,存在则更新 不存在 则插入

PgSqlUtil.insertOrUpdateToPgsql(dataFrame,spark.sparkContext,"test_001_copy1","id")

spark.stop();

}

}

insertOrUpdateToPgsql 方法源码

/**

* 批量插入 或更新 数据 ,该方法 借鉴Spark.write.save() 源码

* @param dataFrame

* @param sc

* @param table

* @param id

*/

def insertOrUpdateToPgsql(dataFrame:DataFrame,sc:SparkContext,table:String,id:String): Unit ={

val tableSchema = dataFrame.schema

val columns =tableSchema.fields.map(x => x.name).mkString(",")

val placeholders = tableSchema.fields.map(_ => "?").mkString(",")

val sql = s"INSERT INTO $table ($columns) VALUES ($placeholders) on conflict($id) do update set "

val update = tableSchema.fields.map(x =>

x.name.toString + "=?"

).mkString(",")

val realsql =sql.concat(update)

val conn =connectionPool()

conn.setAutoCommit(false)

val dialect = JdbcDialects.get(conn.getMetaData.getURL)

val broad_ps = sc.broadcast(conn.prepareStatement(realsql))

val numFields = tableSchema.fields.length *2

val nullTypes = tableSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType)

val setters = tableSchema.fields.map(f => makeSetter(conn,f.dataType))

var rowCount = 0

val batchSize = 2000

val updateindex = numFields / 2

try {

dataFrame.foreachPartition(iterator =>{

//遍历批量提交

val ps = broad_ps.value

try{

while (iterator.hasNext) {

val row = iterator.next()

var i = 0

while (i < numFields) {

i < updateindex match {

case true =>{

if (row.isNullAt(i)) {

ps.setNull(i + 1, nullTypes(i))

} else {

setters(i).apply(ps, row, i,0)

}

}

case false =>{

if (row.isNullAt(i-updateindex)) {

ps.setNull(i + 1, nullTypes(i-updateindex))

} else {

setters(i-updateindex).apply(ps, row, i,updateindex)

}

}

}

i = i + 1

}

ps.addBatch()

rowCount += 1

if (rowCount % batchSize == 0) {

ps.executeBatch()

rowCount = 0

}

}

if (rowCount > 0) {

ps.executeBatch()

}

}finally {

ps.close()

}

})

conn.commit()

}catch {

case e: Exception =>

logError("Error in execution of insert. " + e.getMessage)

conn.rollback()

// insertError(connectionPool("OuCloud_ODS"),"insertOrUpdateToPgsql",e.getMessage)

}finally {

conn.close()

}

}


版权声明:本文为weixin_42186387原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。