package mobvista.dmp.utils.clickhouse import mobvista.dmp.utils.clickhouse.ClickHouseResultSetExt._ import mobvista.dmp.utils.clickhouse.Utils._ import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.types._ import ru.yandex.clickhouse.ClickHouseDataSource import scala.collection.mutable /** * @package: mobvista.dmp.utils.clickhouse * @author: wangjf * @date: 2019-07-16 * @time: 18:28 * @email: jinfeng.wang@mobvista.com * @phone: 152-1062-7698 */ case class DataFrameExt(df: org.apache.spark.sql.DataFrame) extends Serializable { /** * @desc execute sql * @param sql * @param clusterNameO * @param ds */ def executeSQL(sql: String, clusterNameO: Option[String] = None)(implicit ds: ClickHouseDataSource): Unit = { val client = ClickHouseClient(clusterNameO)(ds) clusterNameO match { case None => client.query(sql) case Some(x) => client.queryCluster(sql) } } def dropClickHouseDb(dbName: String, clusterNameO: Option[String] = None) (implicit ds: ClickHouseDataSource) { val client = ClickHouseClient(clusterNameO)(ds) clusterNameO match { case None => client.dropDb(dbName) case Some(x) => client.dropDbCluster(dbName) } } def dropClickHouseTable(dbName: String, tbName: String, clusterNameO: Option[String] = None) (implicit ds: ClickHouseDataSource) { val client = ClickHouseClient(clusterNameO)(ds) clusterNameO match { case None => client.dropTable(dbName, tbName) case Some(x) => client.dropTableCluster(dbName, tbName) } } def createClickHouseDb(dbName: String, clusterNameO: Option[String] = None) (implicit ds: ClickHouseDataSource) { val client = ClickHouseClient(clusterNameO)(ds) clusterNameO match { case None => client.createDb(dbName) case Some(x) => client.createDbCluster(dbName) } } def createClickHouseTable(dbName: String, tableName: String, partitionColumnNames: Seq[String], indexColumns: Seq[String], orderColumnNames: Seq[String], clusterNameO: Option[String] = None) (implicit ds: ClickHouseDataSource) { val client = ClickHouseClient(clusterNameO)(ds) val sqlStmt = createClickHouseTableDefinitionSQL(dbName, tableName, partitionColumnNames, indexColumns, orderColumnNames) clusterNameO match { case None => client.query(sqlStmt) case Some(clusterName) => // create local table on every node client.queryCluster(sqlStmt) // create distrib table (view) on every node val sqlStmt2 = s"CREATE TABLE IF NOT EXISTS ${dbName}.${tableName}_all AS ${dbName}.${tableName} ENGINE = Distributed($clusterName, $dbName, $tableName, rand());" client.queryCluster(sqlStmt2) } } /** * @desc alter table tableName detach partition partition_expr * @param dbName * @param tableName * @param userDefinedPartitionExpr * @param clusterNameO * @param ds */ def detachPartition(dbName: String, tableName: String, userDefinedPartitionExpr: String, clusterNameO: Option[String] = None)(implicit ds: ClickHouseDataSource) { val client = ClickHouseClient(clusterNameO)(ds) /** * default detach last partition,when userDefinedPartitionExpr is not null,detach userDefinedPartitionExpr */ val partitionExpr = if (StringUtils.isBlank(userDefinedPartitionExpr)) { val lastPartStmt = getLastPartitionSQL(dbName, tableName) val rs = client.query(lastPartStmt) val r = rs.map(x => x.getString("partition")) r.last } else { userDefinedPartitionExpr } val detachStmt = s""" |ALTER TABLE ${dbName}.${tableName} DETACH PARTITION ${partitionExpr} """.stripMargin clusterNameO match { case None => client.query(detachStmt) case Some(clusterName) => // detach partition on every node client.queryCluster(detachStmt) } } /** * @desc alter table tableName drop partition partition_expr * @param dbName * @param tableName * @param userDefinedPartitionExpr * @param clusterNameO * @param ds */ def dropPartition(dbName: String, tableName: String, userDefinedPartitionExpr: String, clusterNameO: Option[String] = None)(implicit ds: ClickHouseDataSource) { val client = ClickHouseClient(clusterNameO)(ds) /** * default drop last partition,when userDefinedPartitionExpr is not null,drop userDefinedPartitionExpr */ val partitionExpr = if (StringUtils.isBlank(userDefinedPartitionExpr)) { val lastPartStmt = getLastPartitionSQL(dbName, tableName) val rs = client.query(lastPartStmt) val r = rs.map(x => x.getString("partition")) r.last } else { userDefinedPartitionExpr } val dropStmt = s""" |ALTER TABLE ${dbName}.${tableName} DROP PARTITION ${partitionExpr} """.stripMargin clusterNameO match { case None => client.query(dropStmt) case Some(clusterName) => // drop partition on every node client.queryCluster(dropStmt) } } def getLastPartitionSQL(dbName: String, tableName: String): String = { s""" |SELECT partition FROM system.parts WHERE database = '${dbName}' AND table = '${tableName}' GROUP BY partition ORDER BY partition DESC LIMIT 1,1 """.stripMargin } def saveToClickHouse(dbName: String, tableName: String, partitionVals: Seq[String], partitionColumnNames: Seq[String], clusterNameO: Option[String] = None, batchSize: Int = 100000) (implicit ds: ClickHouseDataSource) = { val defaultHost = ds.getHost val defaultPort = ds.getPort val (clusterTableName, clickHouseHosts) = clusterNameO match { case Some(clusterName) => // get nodes from cluster val client = ClickHouseClient(clusterNameO)(ds) (s"${tableName}_all", client.getClusterNodes()) case None => (tableName, Seq(defaultHost)) } val schema = df.schema // following code is going to be run on executors val insertResults = df.rdd.mapPartitions((partition: Iterator[org.apache.spark.sql.Row]) => { val rnd = scala.util.Random.nextInt(clickHouseHosts.length) val targetHost = clickHouseHosts(rnd) val targetHostDs = ClickHouseConnectionFactory.get(targetHost, defaultPort) targetHostDs.getProperties.setSocketTimeout(120000) targetHostDs.getProperties.setConnectionTimeout(120000) // explicit closing using(targetHostDs.getConnection) { conn => val descSql = s"desc $dbName.$tableName" val descStatement = conn.prepareStatement(descSql) val results = descStatement.executeQuery() val columns = results.map(x => x.getString("name")).reverse val insertStatementSql = generateInsertStatment(schema, dbName, clusterTableName, columns) val statement = conn.prepareStatement(insertStatementSql) var totalInsert = 0 var counter = 0 while (partition.hasNext) { counter += 1 val row = partition.next() val offSet = partitionColumnNames.size + 1 // create mock date for (i <- partitionVals.indices) { if (i == 0) statement.setDate(i + 1, java.sql.Date.valueOf(partitionVals(i))) else statement.setString(i + 1, partitionVals(i)) } // map fields schema.foreach { f => val fieldName = f.name val fieldIdx = row.fieldIndex(fieldName) val fieldVal = row.get(fieldIdx) if (fieldVal != null) { val obj = if (fieldVal.isInstanceOf[mutable.WrappedArray[_]]) fieldVal.asInstanceOf[mutable.WrappedArray[_]].array else fieldVal statement.setObject(fieldIdx + offSet, obj) } else { val defVal = defaultNullValue(f.dataType, fieldVal) statement.setObject(fieldIdx + offSet, defVal) } } statement.addBatch() if (counter >= batchSize) { val r = statement.executeBatch() totalInsert += r.sum counter = 0 } } // end: while if (counter > 0) { val r = statement.executeBatch() totalInsert += r.sum counter = 0 } // return: Seq((host, insertCount)) List((targetHost, totalInsert)).toIterator } }).collect() // aggr insert results by hosts insertResults.groupBy(_._1) .map(x => (x._1, x._2.map(_._2).sum)) } private def generateInsertStatment(schema: org.apache.spark.sql.types.StructType, dbName: String, tableName: String, columns: Seq[String]) = { val vals = 1 to columns.length map (i => "?") s"INSERT INTO $dbName.$tableName (${columns.mkString(",")}) VALUES (${vals.mkString(",")})" } private def defaultNullValue(sparkType: org.apache.spark.sql.types.DataType, v: Any) = sparkType match { case DoubleType => 0 case LongType => 0 case FloatType => 0 case IntegerType => 0 case StringType => null case BooleanType => false case _ => null } private def createClickHouseTableDefinitionSQL(dbName: String, tableName: String, partitionColumnNames: Seq[String], indexColumns: Seq[String], orderColumnNames: Seq[String]) = { val header = s""" CREATE TABLE IF NOT EXISTS $dbName.$tableName( """ var columns: scala.List[String] = scala.List() var dt = "" var parts: scala.List[String] = scala.List() for (i <- partitionColumnNames.indices) { if (i == 0) { dt = partitionColumnNames(i) columns = columns.::(s"${partitionColumnNames(i)} Date") } else { parts = parts.::(partitionColumnNames(i)) val timeset = Set("hour", "minute", "second", "hh", "mm", "ss") val partType = if (timeset.contains(partitionColumnNames(i).toLowerCase)) { "FixedString(2)" } else { "String" } columns = columns.::(s"${partitionColumnNames(i)} ${partType}") } } columns = columns.reverse columns = df.schema.map { f => Seq(f.name, sparkType2ClickHouseType(f.dataType)).mkString(" ") }.toList.:::(columns) /* val columns = s"$partitionColumnName Date" :: df.schema.map { f => Seq(f.name, sparkType2ClickHouseType(f.dataType)).mkString(" ") }.toList */ val columnsStr = columns.mkString(",\n") val partitioner = if (parts.nonEmpty) { if (orderColumnNames.isEmpty) { s""" |) ENGINE = ReplacingMergeTree() PARTITION BY (toYYYYMMDD(${dt}),${parts.mkString(",")}) ORDER BY (${dt},${parts.mkString(",")},${indexColumns.mkString(",")}) """.stripMargin } else { s""" |) ENGINE = ReplacingMergeTree() PARTITION BY (toYYYYMMDD(${dt}),${parts.mkString(",")}) ORDER BY (${dt},${parts.mkString(",")},${orderColumnNames.mkString(",")}) """.stripMargin } } else { if (orderColumnNames.isEmpty) { s""" |) ENGINE = ReplacingMergeTree() PARTITION BY toYYYYMMDD(${dt}) ORDER BY (${dt},${indexColumns.mkString(",")}) """.stripMargin } else { s""" |) ENGINE = ReplacingMergeTree() PARTITION BY toYYYYMMDD(${dt}) ORDER BY (${dt},${orderColumnNames.mkString(",")}) """.stripMargin } } Seq(header, columnsStr, partitioner).mkString("\n") } private def sparkType2ClickHouseType(sparkType: org.apache.spark.sql.types.DataType) = sparkType match { case LongType => "Int64" case DateType => "Date" case DoubleType => "Float64" case FloatType => "Float32" case IntegerType => "Int32" case StringType => "String" case BooleanType => "UInt8" case ArrayType(IntegerType, true) => "Array(Int32)" case ArrayType(IntegerType,false) => "Array(Int32)" case ArrayType(StringType, true) => "Array(String)" case ArrayType(StringType, false) => "Array(String)" case _ => "unknown" println("unknownType ==>> " + sparkType) } }