package mobvista.dmp.datasource.device

import java.net.URI
import java.util
import java.util.Map.Entry

import com.google.gson.JsonElement
import mobvista.dmp.common.CommonSparkJob
import mobvista.dmp.format.RCFileInputFormat
import mobvista.dmp.util.BytesRefUtil
import mobvista.prd.datasource.util.{GsonUtil, MRUtils}
import org.apache.commons.cli.Options
import org.apache.commons.lang.StringUtils
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.hive.serde2.columnar.BytesRefArrayWritable
import org.apache.hadoop.io.LongWritable
import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession}

/**
  * 将各个数据源天用户信息合并到该数据源总的用户信息中:
  * 1. 天数据与性别数据和年龄数据join得到,获得年龄和性别
  * 2. 步骤一数据与全量数据合并更新全量数据中的用户信息
  */
class OdsDmpUserInfoV2 extends CommonSparkJob with Serializable {

  val indexSplit = ","

  override protected def run(args: Array[String]): Int = {

    val commandLine = commParser.parse(options, args)
    if (!checkMustOption(commandLine)) {
      printUsage(options)
      return -1
    } else {
      printOptions(commandLine)
    }

    val date = commandLine.getOptionValue("date")
    val dailyPath = commandLine.getOptionValue("dailyPath")
    //  val agePath = commandLine.getOptionValue("agePath")
    //  val genderPath = commandLine.getOptionValue("genderPath")
    val totalPath = commandLine.getOptionValue("totalPath")
    val parallelism = commandLine.getOptionValue("parallelism").toInt
    //  val coalesce = commandLine.getOptionValue("coalesce").toInt
    val dailyFormat = commandLine.getOptionValue("dailyFormat")
    val dailyDidIndex = commandLine.getOptionValue("dailyDidIndex").toInt
    val dailyDidTypeIndex = commandLine.getOptionValue("dailyDidTypeIndex").toInt
    val dailyPltIndex = commandLine.getOptionValue("dailyPltIndex").toInt
    val dailyCountryIndex = commandLine.getOptionValue("dailyCountryIndex").toInt
    val outputPath = commandLine.getOptionValue("outputPath")
    val compression = commandLine.getOptionValue("compression", "zlib")
    val business = commandLine.getOptionValue("business")

    val indices = s"${dailyDidIndex},${dailyDidTypeIndex},${dailyPltIndex},${dailyCountryIndex}"

    val spark = SparkSession.builder()
      .appName("OdsDmpUserInfo")
      .config("spark.rdd.compress", "true")
      .config("spark.io.compression.codec", "snappy")
      .config("spark.default.parallelism", parallelism)
      .config("spark.sql.orc.filterPushdown", "true")
      .config("spark.sql.warehouse.dir", "s3://mob-emr-test/spark-warehouse")
      .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
      .enableHiveSupport()
      .getOrCreate()
    import spark.implicits._
    val sc = spark.sparkContext

    FileSystem.get(new URI(s"s3://mob-emr-test"), spark.sparkContext.hadoopConfiguration).delete(new Path(outputPath), true)
    val logic = new OdsDmpUserInfoLogic(date)
    try {
      // 处理天数据
      var dailyDS: Dataset[OdsDmpUserInfoVOV2] = null
      if ("orc".equalsIgnoreCase(dailyFormat)) {
        dailyDS = spark.read.format("orc").load(dailyPath)
          .map(parseORC(_, indices))
      } else if ("rcfile".equalsIgnoreCase(dailyFormat)) {
        dailyDS = sc.newAPIHadoopFile[LongWritable, BytesRefArrayWritable, RCFileInputFormat[LongWritable, BytesRefArrayWritable]](dailyPath)
          .map(tuple => parseRCFile(tuple._2, indices))
          .toDS()
      } else {
        dailyDS = sc.textFile(dailyPath)
          .map(parseText(_, indices))
          .toDS()
      }

      // 读取全量数据
      var totalDS: Dataset[OdsDmpUserInfoVOV2] = null
      if (StringUtils.isNotEmpty(totalPath)) {
        totalDS = spark.read.orc(totalPath)
          .map(buildUserInfo(_))
      } else {
        val list = new util.ArrayList[OdsDmpUserInfoVOV2]()
        totalDS = spark.createDataset(list)
      }

      dailyDS
        .filter(userInfo => userInfo.device_id.matches(didPtn) || userInfo.device_id.matches(imeiPtn))
        .createOrReplaceTempView("t_daily")

      totalDS.createOrReplaceTempView("t_total")

      spark.udf.register("getAgeRatio", getAgeRatio _)
      /*
      val ageSql =
        s"""
           |select device_id,device_type,split(getAgeRatio(age),'#')[0] as age,split(getAgeRatio(age),'#')[1] as ratio
           |from dwh.dm_device_age_v2
           |where year = '${date.substring(0, 4)}' and month = '${date.substring(5, 7)}'
           |and day = '${date.substring(8, 10)}' and update_date = '${date}'
        """.stripMargin
      */
      spark.sql(logic.getAgeSql())
        .createOrReplaceTempView("tmp_age")

      /*
      val ageRatio =
        """
          |select t.device_id, t.device_type, t.age
          |from (
          |  select device_id, device_type, age,
          |     row_number() over(partition by device_id, device_type, age order by ratio desc)as rk
          |  from tmp_age
          |) t
          |where t.rk = '1'
        """.stripMargin
      */
      spark.sql(Constant.ageRatio)
        .createOrReplaceTempView("t_age")

      /*
      val genderSql =
        s"""
           |select t.device_id, t.device_type, t.gender
           |from (
           |  select device_id, device_type, gender,
           |     row_number() over(partition by device_id, device_type, gender order by ratio desc)as rk
           |  from dwh.dm_device_gender_v2
           |  where year = '${date.substring(0, 4)}' and month = '${date.substring(5, 7)}' and day = '${date.substring(8, 10)}'
           |  and update_date = '${date}'
           |) t
           |where t.rk = '1'
        """.stripMargin
      */
      spark.sql(logic.getGenderSql())
        .createOrReplaceTempView("t_gender")

      val package_name_rdd = spark.sql(logic.getDmInstallListSql(business)).rdd
      val df = logic.getNewInstallList(package_name_rdd).map(MRUtils.SPLITTER.split(_))
        .map(r => InstallList(r(0), r(1), r(2))).toDF
      df.createOrReplaceTempView("t_package")
      /*
      val tagsSql =
        s"""
           |select device_id,device_type,tags from dwh.dm_interest_tag_daily
           |where year = '${date.substring(0, 4)}' and month = '${date.substring(5, 7)}' and day = '${date.substring(8, 10)}'
        """.stripMargin
      spark.sql(tagsSql)
        .createOrReplaceTempView("t_tags")
      */

      /*
      val sql =
        s"""select t.device_id, t.device_type, t.platform,
           |  case when t.country='UK' then 'GB' else t.country end as country,
           |  t.age, t.gender, t.tags, t.package_name, t.first_req_day, t.last_req_day
           |from (
           |   select
           |     coalesce(a.device_id, b.device_id) as device_id,
           |     coalesce(a.device_type, b.device_type) as device_type,
           |     coalesce(a.platform, b.platform) as platform,
           |     coalesce(a.country, b.country, '') as country,
           |     coalesce(a.age, b.age, '') as age,
           |     coalesce(a.gender, b.gender, '') as gender,
           |     '' as tags,
           |     coalesce(a.package_name,b.package_name) as package_name
           |     case when
           |               b.device_id is null
           |          then
           |               '$date'
           |          else
           |               b.first_req_day
           |          end as first_req_day,
           |     case when
           |               a.device_id is null
           |          then
           |               b.last_req_day
           |          else
           |               '$date'
           |          end as last_req_day
           |   from (
           |     select /*+ mapjoin(t)*/ t.device_id, t.device_type, t.platform, t.country, a.age, g.gender, p.package_name
           |         from (
           |            select t.device_id, t.device_type, t.platform, t.country,
           |              row_number() over(partition by t.device_id, t.device_type order by t.country desc ) as rk
           |            from t_daily t
           |            where t.device_id rlike '$didPtn' or t.device_id rlike '$imeiPtn'
           |         ) t
           |         left outer join t_age a on (upper(a.device_id) = upper(t.device_id) and a.device_type = t.device_type)
           |         left outer join t_gender g on (upper(g.device_id) = upper(t.device_id) and g.device_type = t.device_type)
           |         left outer join t_package p on (upper(p.device_id) = upper(t.device_id) and p.device_type = t.device_type)
           |     where t.rk = 1
           |   ) a
           |   full outer join t_total b
           |   on a.device_id = b.device_id and a.device_type = b.device_type
           |) t
        """.stripMargin
      */
      spark.sql(logic.getUserInfoSql())
        .write
        .mode(SaveMode.Overwrite)
        .option("orc.compress", compression)
        .orc(outputPath)
      //  .rdd.saveAsTextFile(outputPath, classOf[GzipCodec])
    } finally {
      sc.stop()
      spark.stop()
    }
    0
  }

  /**
    *
    * @param row
    * @return
    */
  def buildUserInfo(row: Row): OdsDmpUserInfoVOV2 = {
    OdsDmpUserInfoVOV2(
      row.getString(0), // device_id
      row.getString(1), // device_type
      row.getString(2), // platform
      row.getString(3), // country
      row.getString(4), // age
      row.getString(5), // gender
      row.getString(6), // tags
      row.getString(7), // first_req_day
      row.getString(8), //  last_req_day
      "") // package_name
  }

  def getAgeRatio(ageRange: String): String = {
    var age: String = null
    var max: Double = 0.0
    if (ageRange != null) {
      var entry: Entry[String, JsonElement] = null
      val json = GsonUtil.String2JsonObject(ageRange)
      val ageJson = json.get("age_and_proportion").getAsJsonObject()
      val itr = ageJson.entrySet().iterator()
      while (itr.hasNext()) {
        entry = itr.next()
        val temp = entry.getValue().getAsDouble()
        if (temp > max) {
          max = temp
          age = entry.getKey()
        }
      }
      //  return Util.calcLabel(age.toInt)
    }
    return age + "#" + max
  }

  /**
    *
    * @param row
    * @param indices
    * @return
    */
  def parseORC(row: Row, indices: String): OdsDmpUserInfoVOV2 = {
    val idxSplits = splitFun(indices, indexSplit)
    val deviceId = row.getString(idxSplits(0).toInt)
    val deviceType = row.getString(idxSplits(1).toInt)
    val platform = row.getString(idxSplits(2).toInt)
    val country = row.getString(idxSplits(3).toInt)
    new OdsDmpUserInfoVOV2(deviceId, deviceType, platform, country)
  }

  /**
    *
    * @param value
    * @param indices
    * @return
    */
  def parseRCFile(value: BytesRefArrayWritable, indices: String): OdsDmpUserInfoVOV2 = {
    val idxSplits = splitFun(indices, indexSplit)
    val deviceId = BytesRefUtil.BytesRefWritableToString(value.get(idxSplits(0).toInt))
    val deviceType = BytesRefUtil.BytesRefWritableToString(value.get(idxSplits(1).toInt))
    val platform = BytesRefUtil.BytesRefWritableToString(value.get(idxSplits(2).toInt))
    val country = BytesRefUtil.BytesRefWritableToString(value.get(idxSplits(3).toInt))
    new OdsDmpUserInfoVOV2(deviceId, deviceType, platform, country)
  }

  /**
    *
    * @param line
    * @param indices
    * @return
    */
  def parseText(line: String, indices: String): OdsDmpUserInfoVOV2 = {
    val splits = splitFun(line)
    val idxSplits = splitFun(indices, indexSplit)
    val deviceId = splits(idxSplits(0).toInt)
    val deviceType = splits(idxSplits(1).toInt)
    val platform = splits(idxSplits(2).toInt)
    val country = splits(idxSplits(3).toInt)
    new OdsDmpUserInfoVOV2(deviceId, deviceType, platform, country)
  }

  override protected def buildOptions(): Options = {
    val options = new Options
    options.addOption("date", true, "[must] date")
    options.addOption("dailyPath", true, "[must] dailyPath")
    options.addOption("agePath", true, "[must] agePath")
    options.addOption("genderPath", true, "[must] genderPath")
    options.addOption("totalPath", true, "totalPath")
    options.addOption("dailyFormat", true, "[must] dailyFormat orc or text ")
    options.addOption("dailyDidIndex", true, "[must] index of device id")
    options.addOption("dailyDidTypeIndex", true, "[must] index of device id type")
    options.addOption("dailyPltIndex", true, "[must] index of platform")
    options.addOption("dailyCountryIndex", true, "[must] index of country")
    options.addOption("outputPath", true, "[must] outputPath")
    options.addOption("compression", true, "compression type")
    options.addOption("parallelism", true, "parallelism of shuffle operation")
    options.addOption("coalesce", true, "number of output files")
    options.addOption("business", true, "[must] business")
    options
  }
}

object OdsDmpUserInfoV2 {
  def main(args: Array[String]): Unit = {
    new OdsDmpUserInfoV2().run(args)
  }
}


case class OdsDmpUserInfoVOV2(device_id: String, device_type: String, var platform: String, var country: String, var age: String,
                              var gender: String, var tags: String, var first_req_day: String, var last_req_day: String, var install_list: String) {

  def this(device_id: String, device_type: String, platform: String, country: String) = {
    this(device_id, device_type, platform, country, "", "", "", "", "", "")
  }

  override def hashCode() = {
    (this.device_type.hashCode + this.device_id.hashCode)
  }

  override def equals(obj: scala.Any): Boolean = {
    if (obj.isInstanceOf[OdsDmpUserInfoVOV2]) {
      val o = obj.asInstanceOf[OdsDmpUserInfoVOV2]
      this.device_id.equals(o.device_id) && this.device_type.equals(o.device_type)
    } else {
      false
    }
  }
}