blog
blog copied to clipboard
Spark整合HBase(自定义HBase DataSource)
背景
Spark支持多种数据源,但是Spark对HBase 的读写都没有相对优雅的api,但spark和HBase整合的场景又比较多,故通过spark的DataSource API自己实现了一套比较方便操作HBase的API。
写 HBase
写HBase会根据Dataframe的schema写入对应数据类型的数据到Hbase,先上使用示例:
import spark.implicits._
import org.apache.hack.spark._
val df = spark.createDataset(Seq(("ufo", "play"), ("yy", ""))).toDF("name", "like")
// 方式一
val options = Map(
"hbase.table.rowkey.field" -> "name",
"hbase.table.numReg" -> "12",
"hbase.table.rowkey.prefix" -> "00",
"bulkload.enable" -> "false"
)
df.saveToHbase("hbase_table", Some("XXX:2181"), options)
// 方式二
df1.write.format("org.apache.spark.sql.execution.datasources.hbase")
.options(Map(
"hbase.table.rowkey.field" -> "name",
"hbase.table.name" -> "hbase_table",
"hbase.zookeeper.quorum" -> "XXX:2181",
"hbase.table.rowkey.prefix" -> "00",
"hbase.table.numReg" -> "12",
"bulkload.enable" -> "false"
)).save()
上面两种方式实现的效果是一样的,下面解释一下每个参数的含义:
- hbase.zookeeper.quorum:zookeeper地址
- hbase.table.rowkey.field:spark临时表的哪个字段作为hbase的rowkey,默认第一个字段
- bulkload.enable:是否启动bulkload,默认不启动,当要插入的hbase表只有一列rowkey时,必需启动
- hbase.table.name:Hbase表名
- hbase.table.family:列族名,默认info
- hbase.table.startKey:预分区开始key,当hbase表不存在时,会自动创建Hbase表,不带一下三个参数则只有一个分区
- hbase.table.endKey:预分区开始key
- hbase.table.numReg:分区个数
- hbase.table.rowkey.prefix: 当rowkey是数字开头,预分区需要指明前缀的formate形式,如 00
- hbase.check_table: 写入hbase表时,是否需要检查表是否存在,默认 false
读 HBase
示例代码如下:
// 方式一
import org.apache.hack.spark._
val options = Map(
"spark.table.schema" -> "appid:String,appstoreid:int,firm:String",
"hbase.table.schema" -> ":rowkey,info:appStoreId,info:firm"
)
spark.hbaseTableAsDataFrame("hbase_table", Some("XXX:2181")).show(false)
// 方式二
spark.read.format("org.apache.spark.sql.execution.datasources.hbase").
options(Map(
"spark.table.schema" -> "appid:String,appstoreid:int,firm:String",
"hbase.table.schema" -> ":rowkey,info:appStoreId,info:firm",
"hbase.zookeeper.quorum" -> "XXX:2181",
"hbase.table.name" -> "hbase_table"
)).load.show(false)
spark和hbase表的schema映射关系指定不是必须的,默认会生成rowkey和content两个字段,content是由所有字段组成的json字符串,可通过field.type.fieldname
对单个字段设置数据类型,默认都是StringType。这样映射出来还得通过spark程序转一下才是你想要的样子,而且所有字段都会去扫描,相对来说不是特别高效。
故我们可自定义schema映射来获取数据:
- hbase.zookeeper.quorum:zookeeper地址
- spark.table.schema:Spark临时表对应的schema eg: "ID:String,appname:String,age:Int"
- hbase.table.schema:Hbase表对应schema eg: ":rowkey,info:appname,info:age"
- hbase.table.name:Hbase表名
- spark.rowkey.view.name:rowkey对应的dataframe创建的tempview名(设置了该值后,只获取rowkey对应的数据)
注意这两个schema是一一对应的,Hbase只会扫描hbase.table.schema
对应的列。
核心代码
写 HBase
class DataFrameFunctions(data: DataFrame) extends Logging with Serializable {
def saveToHbase(tableName: String, zkUrl: Option[String] = None,
options: Map[String, String] = new HashMap[String, String]): Unit = {
val wrappedConf = {
implicit val formats = DefaultFormats
val hc = HBaseConfiguration.create()
hc.set("hbase.zookeeper.quorum", zkUrl.getOrElse("127.0.0.1:2181"))
new SerializableConfiguration(hc)
}
val hbaseConf = wrappedConf.value
val rowkey = options.getOrElse("rowkey.field", data.schema.head.name)
val family = options.getOrElse("family", "info")
val numReg = options.getOrElse("numReg", -1).toString.toInt
val startKey = options.getOrElse("startKey", null)
val endKey = options.getOrElse("endKey", null)
val rdd = data.rdd
val f = family
val tName = TableName.valueOf(tableName)
val connection = ConnectionFactory.createConnection(hbaseConf)
val admin = connection.getAdmin
if (!admin.isTableAvailable(tName)) {
HBaseUtils.createTable(connection, tName, family, startKey, endKey, numReg)
}
connection.close()
if (hbaseConf.get("mapreduce.output.fileoutputformat.outputdir") == null) {
hbaseConf.set("mapreduce.output.fileoutputformat.outputdir", "/tmp")
}
val jobConf = new JobConf(hbaseConf, this.getClass)
jobConf.set(TableOutputFormat.OUTPUT_TABLE, tableName)
val job = Job.getInstance(jobConf)
job.setOutputKeyClass(classOf[ImmutableBytesWritable])
job.setOutputValueClass(classOf[Result])
job.setOutputFormatClass(classOf[TableOutputFormat[ImmutableBytesWritable]])
val fields = data.schema.toArray
val rowkeyIndex = fields.zipWithIndex.filter(f => f._1.name == rowkey).head._2
val otherFields = fields.zipWithIndex.filter(f => f._1.name != rowkey)
lazy val setters = otherFields.map(r => HBaseUtils.makeHbaseSetter(r))
lazy val setters_bulkload = otherFields.map(r => HBaseUtils.makeHbaseSetter_bulkload(r))
options.getOrElse("bulkload.enable", "true") match {
case "true" =>
val tmpPath = s"/tmp/bulkload/${tableName}" + System.currentTimeMillis()
def convertToPut_bulkload(row: Row) = {
val rk = Bytes.toBytes(row.getString(rowkeyIndex))
setters_bulkload.map(_.apply(rk, row, f))
}
rdd.flatMap(convertToPut_bulkload)
.saveAsNewAPIHadoopFile(tmpPath, classOf[ImmutableBytesWritable], classOf[KeyValue], classOf[HFileOutputFormat2], job.getConfiguration)
val bulkLoader: LoadIncrementalHFiles = new LoadIncrementalHFiles(hbaseConf)
bulkLoader.doBulkLoad(new Path(tmpPath), new HTable(hbaseConf, tableName))
case "false" =>
def convertToPut(row: Row) = {
val put = new Put(Bytes.toBytes(row.getString(rowkeyIndex)))
setters.foreach(_.apply(put, row, f))
(new ImmutableBytesWritable, put)
}
rdd.map(convertToPut).saveAsNewAPIHadoopDataset(job.getConfiguration)
}
}
}
读Hbase
class SparkSqlContextFunctions(@transient val spark: SparkSession) extends Serializable {
private val SPARK_TABLE_SCHEMA: String = "spark.table.schema"
private val HBASE_TABLE_SCHEMA: String = "hbase.table.schema"
def hbaseTableAsDataFrame(table: String, zkUrl: Option[String] = None,
options:Map[String, String] = new HashMap[String, String]
): DataFrame = {
val wrappedConf = {
val hc = HBaseConfiguration.create()
hc.set("hbase.zookeeper.quorum", zkUrl.getOrElse("127.0.0.1:2181"))
hc.set(TableInputFormat.INPUT_TABLE, table)
if (options.contains(HBASE_TABLE_SCHEMA)) {
var str = ArrayBuffer[String]()
options(HBASE_TABLE_SCHEMA)
.split(",", -1).map(field =>
if (!field.startsWith(":")) {
str += field
}
)
if (str.length > 1) hc.set(TableInputFormat.SCAN_COLUMNS, str.mkString(" "))
}
Array(SPARK_TABLE_SCHEMA,HBASE_TABLE_SCHEMA,TableInputFormat.SCAN_ROW_START,TableInputFormat.SCAN_ROW_STOP).foldLeft((hc,options)) {
case ((_hc,_options),pram) => if(_options.contains(pram)) _hc.set(pram,_options(pram))
(_hc,_options)
}
new SerializableConfiguration(hc)
}
def hbaseConf = wrappedConf.value
def schema: StructType = {
import org.apache.spark.sql.types._
Option(hbaseConf.get(SPARK_TABLE_SCHEMA)) match {
case Some(schema) => HBaseUtils.registerSparkTableSchema(schema)
case None =>
StructType(
Array(
StructField("rowkey", StringType, nullable = false),
StructField("content", StringType)
)
)
}
}
Option(hbaseConf.get(SPARK_TABLE_SCHEMA)) match {
case Some(s) =>
require(hbaseConf.get(HBASE_TABLE_SCHEMA).nonEmpty, "Because the parameter spark.table.schema has been set, hbase.table.schema also needs to be set.")
val sparkTableSchemas = schema.fields.map(f => SparkTableSchema(f.name, f.dataType))
val hBaseTableSchemas = HBaseUtils.registerHbaseTableSchema(hbaseConf.get(HBASE_TABLE_SCHEMA))
require(sparkTableSchemas.length == hBaseTableSchemas.length, "The length of the parameter spark.table.schema must be the same as the parameter hbase.table.schema.")
val schemas = sparkTableSchemas.zip(hBaseTableSchemas)
val setters = schemas.map(schema => HBaseUtils.makeHbaseGetter(schema))
val hBaseRDD = spark.sparkContext.newAPIHadoopRDD(hbaseConf, classOf[TableInputFormat], classOf[ImmutableBytesWritable], classOf[Result])
.map { case (_, result) => Row.fromSeq(setters.map(r => r.apply(result)).toSeq) }
spark.createDataFrame(hBaseRDD, schema)
case None =>
val hBaseRDD = spark.sparkContext.newAPIHadoopRDD(hbaseConf, classOf[TableInputFormat], classOf[ImmutableBytesWritable], classOf[Result])
.map { line =>
val rowKey = Bytes.toString(line._2.getRow)
implicit val formats = Serialization.formats(NoTypeHints)
val content = line._2.getMap.navigableKeySet().flatMap { f =>
line._2.getFamilyMap(f).map { c =>
val columnName = Bytes.toString(f) + ":" + Bytes.toString(c._1)
options.get("field.type." + columnName) match {
case Some(i) =>
val value = i match {
case "LongType" => Bytes.toLong(c._2)
case "FloatType" => Bytes.toFloat(c._2)
case "DoubleType" => Bytes.toDouble(c._2)
case "IntegerType" => Bytes.toInt(c._2)
case "BooleanType" => Bytes.toBoolean(c._2)
case "BinaryType" => c._2
case "TimestampType" => new Timestamp(Bytes.toLong(c._2))
case "DateType" => new java.sql.Date(Bytes.toLong(c._2))
case _ => Bytes.toString(c._2)
}
(columnName, value)
case None => (columnName, Bytes.toString(c._2))
}
}
}.toMap
val contentStr = Serialization.write(content)
Row.fromSeq(Seq(rowKey,contentStr))
}
spark.createDataFrame(hBaseRDD, schema)
}
}
}
扩展的DataSource都需要是名为DefaultSource 的类
class DefaultSource extends CreatableRelationProvider with RelationProvider with DataSourceRegister {
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation =
HBaseRelation(parameters, None)(sqlContext)
override def shortName(): String = "hbase"
override def createRelation(sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], data: DataFrame): BaseRelation = {
val relation = InsertHBaseRelation(data, parameters)(sqlContext)
relation.insert(data, false)
relation
}
}
private[sql] case class InsertHBaseRelation(
dataFrame: DataFrame,
parameters: Map[String, String]
)(@transient val sqlContext: SQLContext)
extends BaseRelation with InsertableRelation with Logging {
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
def getZkURL: String = parameters.getOrElse("zk", parameters.getOrElse("hbase.zookeeper.quorum", sys.error("You must specify parameter zkurl...")))
def getOutputTableName: String = parameters.getOrElse("outputTableName", sys.error("You must specify parameter outputTableName..."))
import org.apache.hack.spark._
data.saveToHbase(getOutputTableName, Some(getZkURL), parameters)
}
override def schema: StructType = dataFrame.schema
}
private[sql] case class HBaseRelation(
parameters: Map[String, String],
userSpecifiedschema: Option[StructType]
)(@transient val sqlContext: SQLContext)
extends BaseRelation with TableScan with Logging {
def getZkURL: String = parameters.getOrElse("zk", parameters.getOrElse("hbase.zookeeper.quorum", sys.error("You must specify parameter zkurl...")))
def getInputTableName: String = parameters.getOrElse("inputTableName", sys.error("You must specify parameter imputTableName..."))
def buildScan(): RDD[Row] = {
import org.apache.hack.spark._
sqlContext.sparkSession.hbaseTableAsDataFrame(getInputTableName, Some(getZkURL), parameters).rdd
}
override def schema: StructType = {
import org.apache.hack.spark._
sqlContext.sparkSession.hbaseTableAsDataFrame(getInputTableName, Some(getZkURL), parameters).schema
}
}
参考
为什么df可以调用 那个方法
@SteveYanzhi
import org.apache.hack.spark._
这里面有隐式转换
implicit def toSparkSqlContextFunctions(spark: SparkSession): SparkSqlContextFunctions = {
new SparkSqlContextFunctions(spark)
}
implicit def toDataFrameFunctions(data: DataFrame): DataFrameFunctions = {
new DataFrameFunctions(data)
}
完整的整合代码参考 https://github.com/teeyog/IQL/tree/master/iql-spark/src/main/scala/org/apache/spark/sql/execution/datasources/hbase
@SteveYanzhi 没有,后面发的是最新的,功能更完善。
代码真的写的很棒