package mobvista.dmp.datasource.device

import java.net.URI
import java.util

import mobvista.dmp.common.CommonSparkJob
import mobvista.dmp.format.RCFileInputFormat
import mobvista.dmp.util.BytesRefUtil
import org.apache.commons.cli.Options
import org.apache.commons.lang.StringUtils
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.hive.serde2.columnar.BytesRefArrayWritable
import org.apache.hadoop.io.LongWritable
import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession}

/**
  * 将各个数据源天用户信息合并到该数据源总的用户信息中:
  * 1. 天数据与性别数据和年龄数据join得到,获得年龄和性别
  * 2. 步骤一数据与全量数据合并更新全量数据中的用户信息
  */
class OdsDmpUserInfo extends CommonSparkJob with Serializable {

  val indexSplit = ","

  override protected def run(args: Array[String]): Int = {

    val commandLine = commParser.parse(options, args)
    if (!checkMustOption(commandLine)) {
      printUsage(options)
      return -1
    } else {
      printOptions(commandLine)
    }

    val date = commandLine.getOptionValue("date")
    val dailyPath = commandLine.getOptionValue("dailyPath")
    val agePath = commandLine.getOptionValue("agePath")
    val genderPath = commandLine.getOptionValue("genderPath")
    val totalPath = commandLine.getOptionValue("totalPath")
    val parallelism = commandLine.getOptionValue("parallelism").toInt
    val coalesce = commandLine.getOptionValue("coalesce").toInt
    val dailyFormat = commandLine.getOptionValue("dailyFormat")
    val dailyDidIndex = commandLine.getOptionValue("dailyDidIndex").toInt
    val dailyDidTypeIndex = commandLine.getOptionValue("dailyDidTypeIndex").toInt
    val dailyPltIndex = commandLine.getOptionValue("dailyPltIndex").toInt
    val dailyCountryIndex = commandLine.getOptionValue("dailyCountryIndex").toInt
    val outputPath = commandLine.getOptionValue("outputPath")
    val compression = commandLine.getOptionValue("compression", "zlib")

    val indices = s"${dailyDidIndex},${dailyDidTypeIndex},${dailyPltIndex},${dailyCountryIndex}"

    val spark = SparkSession.builder()
      .appName("OdsDmpUserInfo")
      .config("spark.rdd.compress", "true")
      .config("spark.io.compression.codec", "snappy")
      .config("spark.default.parallelism", parallelism)
      .config("spark.sql.orc.filterPushdown", "true")
      .config("spark.sql.warehouse.dir", "s3://mob-emr-test/spark-warehouse")
      .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
      .getOrCreate()
    import spark.implicits._
    val sc = spark.sparkContext

    try {
      import org.apache.spark.sql.functions._

      // 处理天数据
      var dailyDS: Dataset[OdsDmpUserInfoVO] = null
      if ("orc".equalsIgnoreCase(dailyFormat)) {
        dailyDS = spark.read.format("orc").load(dailyPath)
          .map(parseORC(_, indices))
      } else if ("rcfile".equalsIgnoreCase(dailyFormat)) {
        dailyDS = sc.newAPIHadoopFile[LongWritable, BytesRefArrayWritable, RCFileInputFormat[LongWritable, BytesRefArrayWritable]](dailyPath)
          .map(tuple => parseRCFile(tuple._2, indices))
          .toDS()
      } else {
        dailyDS = sc.textFile(dailyPath)
          .map(parseText(_, indices))
          .filter(o => {
            mobvista.dmp.common.MobvistaConstant.checkDeviceId(o.device_id)
          })
          .toDS()
      }

      // 读取全量数据
      var totalDS: Dataset[OdsDmpUserInfoVO] = null
      val test_rdd = sc.textFile(totalPath)
      if (StringUtils.isNotEmpty(totalPath) && !test_rdd.isEmpty()) {
        totalDS = spark.read.format("orc").load(totalPath)
          .map(buildUserInfo)
      } else {
        val list = new util.ArrayList[OdsDmpUserInfoVO]()
        totalDS = spark.createDataset(list)
      }

      dailyDS
        .filter(userInfo => userInfo.device_id.matches(didPtn) || userInfo.device_id.matches(imeiPtn) || userInfo.device_id.matches(andriodIdPtn) || userInfo.device_id.matches(imeiMd5Ptn) || ( "oaid".equalsIgnoreCase(userInfo.device_type) && userInfo.device_id.matches(oaidAnotherPtn) ) )
        .distinct().createOrReplaceTempView("t_daily")

      totalDS.distinct().createOrReplaceTempView("t_total")

      val sql =
        s"""
           |select t.device_id, t.device_type, t.platform,
           |  case when upper(t.country)='GB' then 'UK' else t.country end as country,
           |  t.age, t.gender, t.tags, t.first_req_day, t.last_req_day
           |from (
           |   select
           |     coalesce(a.device_id, b.device_id) as device_id,
           |     coalesce(a.device_type, b.device_type) as device_type,
           |     coalesce(a.platform, b.platform) as platform,
           |     coalesce(a.country, b.country) as country,
           |     '' as age,
           |     '' as gender,
           |     '' as tags,
           |     case when
           |               b.device_id is null
           |          then
           |               '$date'
           |          else
           |               b.first_req_day
           |          end as first_req_day,
           |     case when
           |               a.device_id is null
           |          then
           |               b.last_req_day
           |          else
           |               '$date'
           |          end as last_req_day
           |   from (
           |     select t.device_id, t.device_type, t.platform, t.country
           |     from (
           |        select t.device_id, t.device_type, t.platform, t.country,
           |          row_number() over(partition by t.device_id, t.device_type order by t.country desc ) as rk
           |        from t_daily t
           |        where t.device_id rlike '$didPtn' or t.device_id rlike '$imeiPtn'  or t.device_id rlike '$andriodIdPtn' or t.device_id rlike '$imeiMd5Ptn' or (t.device_id rlike '$oaidAnotherPtn' and lower(t.device_type)='oaid')
           |     ) t
           |     where t.rk = 1
           |   ) a
           |   full outer join t_total b
           |   on a.device_id=b.device_id and a.device_type=b.device_type
           |) t
           |distribute by t.device_id
           |sort by t.device_id asc
        """.stripMargin
      FileSystem.get(new URI(s"s3://mob-emr-test"), sc.hadoopConfiguration).delete(new Path(outputPath), true)

      spark.sql(sql)
        .write
        .mode(SaveMode.Overwrite)
        .option("orc.compress", compression)
        .orc(outputPath)
    } finally {
      sc.stop()
      spark.stop()
    }
    0
  }

  /**
    *
    * @param row
    * @return
    */
  def buildUserInfo(row: Row): OdsDmpUserInfoVO = {
    OdsDmpUserInfoVO(
      row.getString(0), // device_id
      row.getString(1), // device_type
      row.getString(2), // platform
      row.getString(3), // country
      row.getString(4), // age
      row.getString(5), // gender
      row.getString(6), // tags
      row.getString(7), // first_req_day
      row.getString(8)) // last_req_day
  }

  /**
    *
    * @param row
    * @param indices
    * @return
    */
  def parseORC(row: Row, indices: String): OdsDmpUserInfoVO = {
    val idxSplits = splitFun(indices, indexSplit)
    val deviceId = row.getString(idxSplits(0).toInt)
    val deviceType = row.getString(idxSplits(1).toInt)
    val platform = row.getString(idxSplits(2).toInt)
    val country = row.getString(idxSplits(3).toInt)
    new OdsDmpUserInfoVO(deviceId, deviceType, platform, country)
  }

  /**
    *
    * @param value
    * @param indices
    * @return
    */
  def parseRCFile(value: BytesRefArrayWritable, indices: String): OdsDmpUserInfoVO = {
    val idxSplits = splitFun(indices, indexSplit)
    val deviceId = BytesRefUtil.BytesRefWritableToString(value.get(idxSplits(0).toInt))
    val deviceType = BytesRefUtil.BytesRefWritableToString(value.get(idxSplits(1).toInt))
    val platform = BytesRefUtil.BytesRefWritableToString(value.get(idxSplits(2).toInt))
    val country = BytesRefUtil.BytesRefWritableToString(value.get(idxSplits(3).toInt))
    new OdsDmpUserInfoVO(deviceId, deviceType, platform, country)
  }

  /**
    *
    * @param line
    * @param indices
    * @return
    */
  def parseText(line: String, indices: String): OdsDmpUserInfoVO = {
    val splits = splitFun(line)
    val idxSplits = splitFun(indices, indexSplit)
    if (splits.length >= idxSplits(3).toInt) {
      val deviceId = splits(idxSplits(0).toInt)
      val deviceType = splits(idxSplits(1).toInt)
      val platform = splits(idxSplits(2).toInt)
      val country = if (splits.length > idxSplits(3).toInt) {
        splits(idxSplits(3).toInt)
      } else {
        ""
      }
      new OdsDmpUserInfoVO(deviceId, deviceType, platform, country)
    } else {
      new OdsDmpUserInfoVO("", "", "", "")
    }
  }

  override protected def buildOptions(): Options = {
    val options = new Options
    options.addOption("date", true, "[must] date")
    options.addOption("dailyPath", true, "[must] dailyPath")
    options.addOption("agePath", true, "[must] agePath")
    options.addOption("genderPath", true, "[must] genderPath")
    options.addOption("totalPath", true, "totalPath")
    options.addOption("dailyFormat", true, "[must] dailyFormat orc or text ")
    options.addOption("dailyDidIndex", true, "[must] index of device id")
    options.addOption("dailyDidTypeIndex", true, "[must] index of device id type")
    options.addOption("dailyPltIndex", true, "[must] index of platform")
    options.addOption("dailyCountryIndex", true, "[must] index of country")
    options.addOption("outputPath", true, "[must] outputPath")
    options.addOption("compression", true, "compression type")
    options.addOption("parallelism", true, "parallelism of shuffle operation")
    options.addOption("coalesce", true, "number of output files")
    options
  }
}

object OdsDmpUserInfo {
  def main(args: Array[String]): Unit = {
    new OdsDmpUserInfo().run(args)
  }
}


case class OdsDmpUserInfoVO(device_id: String, device_type: String, var platform: String, var country: String, var age: String,
                            var gender: String, var tags: String, var first_req_day: String, var last_req_day: String) {

  def this(device_id: String, device_type: String, platform: String, country: String) = {
    this(device_id, device_type, platform, country, "", "", "", "", "")
  }

  override def hashCode() = {
    (this.device_type.hashCode + this.device_id.hashCode)
  }

  override def equals(obj: scala.Any): Boolean = {
    if (obj.isInstanceOf[OdsDmpUserInfoVO]) {
      val o = obj.asInstanceOf[OdsDmpUserInfoVO]
      this.device_id.equals(o.device_id) && this.device_type.equals(o.device_type)
    } else {
      false
    }
  }
}