在sparksql 中,保存数据到数据,只有 Append , Overwrite , ErrorIfExists, Ignore 四种模式,不满足项目需求 ,现依据 spark save 源码,进行进一步的改造, 批量保存数据,存在则更新 不存在 则插入
**
*测试用例
* 批量保存数据,存在则更新 不存在 则插入
* INSERT INTO test_001 VALUES( ?, ?, ? )
* ON conflict ( ID ) DO
* UPDATE SET id=?,NAME = ?,age = ?;
* @author linzhy
*/
object InsertOrUpdateTest {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName(this.getClass.getSimpleName)
.master("local[2]")
.config("spark.debug.maxToStringFields","100")
.getOrCreate()
var config = ConfigFactory.load()
val ods_url = config.getString("pg.oucloud_ods.url")
val ods_user = config.getString("pg.oucloud_ods.user")
val ods_password = config.getString("pg.oucloud_ods.password")
val test_001 = spark.read.format("jdbc")
.option("url", ods_url)
.option("dbtable", "test_001")
.option("user", ods_user)
.option("password", ods_password)
.load()
test_001.createOrReplaceTempView("test_001")
val sql=
"""
|SELECT * FROM test_001
|""".stripMargin
val dataFrame = spark.sql(sql)
//批量保存数据,存在则更新 不存在 则插入
PgSqlUtil.insertOrUpdateToPgsql(dataFrame,spark.sparkContext,"test_001_copy1","id")
spark.stop();
}
}
insertOrUpdateToPgsql 方法源码
/**
* 批量插入 或更新 数据 ,该方法 借鉴Spark.write.save() 源码
* @param dataFrame
* @param sc
* @param table
* @param id
*/
def insertOrUpdateToPgsql(dataFrame:DataFrame,sc:SparkContext,table:String,id:String): Unit ={
val tableSchema = dataFrame.schema
val columns =tableSchema.fields.map(x => x.name).mkString(",")
val placeholders = tableSchema.fields.map(_ => "?").mkString(",")
val sql = s"INSERT INTO $table ($columns) VALUES ($placeholders) on conflict($id) do update set "
val update = tableSchema.fields.map(x =>
x.name.toString + "=?"
).mkString(",")
val realsql =sql.concat(update)
val conn =connectionPool()
conn.setAutoCommit(false)
val dialect = JdbcDialects.get(conn.getMetaData.getURL)
val broad_ps = sc.broadcast(conn.prepareStatement(realsql))
val numFields = tableSchema.fields.length *2
val nullTypes = tableSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType)
val setters = tableSchema.fields.map(f => makeSetter(conn,f.dataType))
var rowCount = 0
val batchSize = 2000
val updateindex = numFields / 2
try {
dataFrame.foreachPartition(iterator =>{
//遍历批量提交
val ps = broad_ps.value
try{
while (iterator.hasNext) {
val row = iterator.next()
var i = 0
while (i < numFields) {
i < updateindex match {
case true =>{
if (row.isNullAt(i)) {
ps.setNull(i + 1, nullTypes(i))
} else {
setters(i).apply(ps, row, i,0)
}
}
case false =>{
if (row.isNullAt(i-updateindex)) {
ps.setNull(i + 1, nullTypes(i-updateindex))
} else {
setters(i-updateindex).apply(ps, row, i,updateindex)
}
}
}
i = i + 1
}
ps.addBatch()
rowCount += 1
if (rowCount % batchSize == 0) {
ps.executeBatch()
rowCount = 0
}
}
if (rowCount > 0) {
ps.executeBatch()
}
}finally {
ps.close()
}
})
conn.commit()
}catch {
case e: Exception =>
logError("Error in execution of insert. " + e.getMessage)
conn.rollback()
// insertError(connectionPool("OuCloud_ODS"),"insertOrUpdateToPgsql",e.getMessage)
}finally {
conn.close()
}
}