package mobvista.dmp.datasource.dm

import java.net.URI
import java.text.SimpleDateFormat
import java.util
import java.util.Date

import com.google.gson.{JsonArray, JsonObject}
import mobvista.dmp.common.CommonSparkJob
import mobvista.dmp.util.MRUtils
import mobvista.prd.datasource.util.GsonUtil
import org.apache.commons.cli.Options
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.sql.{SaveMode, SparkSession}

/**
  * @author wangjf
  */
class DmDeviceTagAll extends CommonSparkJob with Serializable {
  override protected def run(args: Array[String]): Int = {
    val commandLine = commParser.parse(options, args)
    if (!checkMustOption(commandLine)) {
      printUsage(options)
      printOptions(commandLine)
      return 1
    } else {
      printOptions(commandLine)
    }

    val date = commandLine.getOptionValue("date")
    val ga_date = commandLine.getOptionValue("ga_date")
    val input = commandLine.getOptionValue("input")
    val output = commandLine.getOptionValue("output")
    val coalesce = commandLine.getOptionValue("coalesce")

    val spark = SparkSession.builder()
      .appName(s"DmDeviceTagAll.${date}")
      .config("spark.rdd.compress", "true")
      .config("spark.io.compression.codec", "lz4")
      .config("spark.io.compression.lz4.blockSize", "64k")
      .config("spark.sql.orc.filterPushdown", "true")
      .config("spark.sql.autoBroadcastJoinThreshold", "104857600")
      .config("spark.sql.warehouse.dir", "s3://mob-emr-test/spark-warehouse")
      .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
      .config("spark.kryo.registrator", "mobvista.dmp.datasource.dm.MyRegisterKryo")
      .enableHiveSupport()
      .getOrCreate()
    val sc = spark.sparkContext
    //  sc.getConf.registerKryoClasses(Array(classOf[DmInterestTag], classOf[DmDeviceTag], classOf[util.HashMap[String, (String, String, String)]]))
    //  sc.getConf.registerKryoClasses(Array(classOf[MyRegisterKryo]))

    FileSystem.get(new URI(s"s3://mob-emr-test"), sc.hadoopConfiguration).delete(new Path(output), true)

    try {
      spark.udf.register("check_deviceId", Constant.check_deviceId _)
      spark.udf.register("str2Json", Constant.str2Json _)

      val sdf1 = new SimpleDateFormat("yyyy/MM/dd")
      val sdf2 = new SimpleDateFormat("yyyyMMdd")
      val date_path = sdf1.format(sdf2.parse(date))
      val ga_date_path = sdf1.format(sdf2.parse(ga_date))

      //      val bMap = sc.broadcast(spark.sql(Constant.old2new_sql).rdd.map(r => {
      //        (r.getAs("tag_code").toString, r.getAs("tag_id").toString)
      //      }).collectAsMap())

      val device_tag_sql = Constant.device_tag_sql_1.replace("@date", date).replace("@ga_date", ga_date)
        .replace("@str2Json", "str2Json")

      val newTagDF = spark.sql(device_tag_sql).rdd.map(r => {
        DmDeviceTag(r.getAs("device_id"), r.getAs("device_type"), r.getAs("platform"), r.getAs("package_name"), r.getAs("tags"), r.getAs("update_date"))
      }).map(ir => {
        (ir.device_id, ("new", MRUtils.JOINER.join(ir.package_name, ir.device_type, ir.platform, ir.update_date, ir.tags)))
      })

      val interest_tag_sql = Constant.interest_tag_sql.replace("@date", date)

      val oldTagDF = spark.sql(interest_tag_sql).rdd
        .map(r => {
          DmInterestTag(r.getAs("device_id"), r.getAs("device_type"), r.getAs("platform"), r.getAs("tags"))
        }).map(ir => {
        (ir.device_id, ("old", MRUtils.JOINER.join(ir.device_type, ir.platform, ir.tags)))
      })
      import spark.implicits._
      val df = oldTagDF.union(newTagDF).combineByKey(
        (v: (String, String)) => Iterable(v),
        (c: Iterable[(String, String)], v: (String, String)) => c ++ Seq(v),
        (c1: Iterable[(String, String)], c2: Iterable[(String, String)]) => c1 ++ c2
      ).mapPartitions(v => new CustomerIterator(v))
        .toDF
      //  .mapPartitions(v => new CustomerIteratorV2(v, bMap))
      //  .toDF
      df.repartition(coalesce.toInt).write.mode(SaveMode.Overwrite)
        .option("orc.compress", "zlib")
        .orc(output)
      /*
      .mapPartitions(v => new CustomerIterator(v))
      .toDF

    df.repartition(coalesce.toInt).write.mode(SaveMode.Overwrite)
      .option("orc.compress", "zlib")
      .orc(output)
      */
      /*
      val bMap = sc.broadcast(spark.sql(Constant.old2new_sql).rdd.map(r => {
        (r.getAs("tag_code").toString, r.getAs("tag_id").toString)
      }).collectAsMap())

      //  val newDF = spark.read.schema(Constant.dm_device_tag_schema).orc(buildPath(input, date_path, "3s"), buildPath(input, date_path, "adn_request_sdk"),
      //  buildPath(input, ga_date_path, "ga"))
      //  val device_tag_sql = Constant.device_tag_sql_2.replace("@date", date).replace("@ga_date", ga_date)
      //  .replace("@check_deviceId", "check_deviceId(device_id)")
      val newDF = spark.read.schema(Constant.dm_device_tag_schema).orc(buildPath(input, date_path, "3s"))
        .rdd.map(r => {
        DeviceTag(r.getAs("device_id"), r.getAs("device_type"), r.getAs("platform"), r.getAs("package_name"), r.getAs("tag_type"), r.getAs("first_tag"),
          r.getAs("second_tag"), r.getAs("update_date"))
      }).mapPartitions(Constant.packageMapPartition(_, bMap))
        .combineByKey(
          (v: (String, String, String, String)) => Iterable(v),
          (c: Iterable[(String, String, String, String)], v: (String, String, String, String)) => c ++ Seq(v),
          (c1: Iterable[(String, String, String, String)], c2: Iterable[(String, String, String, String)]) => c1 ++ c2
        ).mapPartitions(Constant.groupByPackagePartition)

      newDF.saveAsTextFile(output + "/b", classOf[GzipCodec])
      */
      /*
      val interest_tag_sql = Constant.interest_tag_sql.replace("@date", date)
      val oldDF = spark.sql(interest_tag_sql).rdd
        .map(r => {
          DmInterestTag(r.getAs("device_id"), r.getAs("device_type"), r.getAs("platform"), r.getAs("tags"))
        })
        .mapPartitions(iters => {
          val res = new util.ArrayList[(String, (String, String))]()
          while (iters.hasNext) {
            val ir = iters.next
            res.add((ir.device_id, ("old", MRUtils.JOINER.join(ir.device_type, ir.platform, ir.tags))))
          }
          res.asScala.iterator
        })

      newDF.union(oldDF).combineByKey(
        (v: (String, String)) => Iterable(v),
        (c: Iterable[(String, String)], v: (String, String)) => c ++ Seq(v),
        (c1: Iterable[(String, String)], c2: Iterable[(String, String)]) => c1 ++ c2
      ).mapPartitions(Constant.bigJoinMapPart)
      */

      /*
      oldDF.union(newDF).combineByKey(
        (v: (String, String)) => Iterable(v),
        (c: Iterable[(String, String)], v: (String, String)) => c ++ Seq(v),
        (c1: Iterable[(String, String)], c2: Iterable[(String, String)]) => c1 ++ c2
      ).mapPartitions(Constant.bigJoinMapPartition)
      */

      /*
      .mapPartitions(irs => {

      val res = new util.ArrayList[(String, (String, String))]()
      while (irs.hasNext) {
        val ir = irs.next
        val device_id = ir._1
        var device_type_new = ""
        var platform_new = ""
        val map = new util.HashMap[(String, (Date, util.Set[String]))]()
        val rs = ir._2.iterator
        while (rs.hasNext) {
          val r = rs.next
        }
        /*
        var newFlag = false
        if (ir._1.equals("old")) {
          oldFlag = true
        }
        if (ir._1.equals("new")) {
          newFlag = true
        }
        if (oldFlag && !newFlag) {

        } else if (!oldFlag && newFlag) {

        } else {

        }
        */
      }
      res.asScala.iterator
    })
    */

      /*
      df.toDF
        .write.mode(SaveMode.Overwrite)
        .option("orc.compress", "zlib")
        .orc(output)
        */

    } finally {
      if (spark != null) {
        spark.stop()
      }
    }
    0
  }

  /**
    * @desc
    * UDF 合并 jsonArray 并，将 date 置为最新的日期，减少 shuffle 操作和磁盘 io
    * @param tags
    * jsonArray
    * @return
    */
  def combineJsonArray(tags: String): String = {
    val jsonArray = new JsonArray
    val map: java.util.Map[String, (Date, JsonObject)] = new util.HashMap[String, (Date, JsonObject)]()
    val sdf: SimpleDateFormat = new SimpleDateFormat("yyyy-MM-dd")
    tags.split(";").foreach(tag => {
      val jsonNode = GsonUtil.String2JsonArray(tag)
      for (i <- 0 until jsonNode.size) {
        val json = jsonNode.get(i).getAsJsonObject
        if (!json.has("package_name") || !json.has("date")) {
          jsonArray.add(json)
        } else if (map.keySet.contains(json.get("package_name").getAsString)) {
          if (map.get(json.get("package_name").getAsString)._1.before(sdf.parse(json.get("date").getAsString))) {
            map.put(json.get("package_name").getAsString, (sdf.parse(json.get("date").getAsString), json))
          }
        } else {
          map.put(json.get("package_name").getAsString, (sdf.parse(json.get("date").getAsString), json))
        }
      }
    })
    import scala.collection.JavaConversions._
    for (key <- map.keySet()) {
      jsonArray.add(map.get(key)._2)
    }
    jsonArray.toString
  }

  def toJsonArraySize(tags: String): Int = {
    GsonUtil.String2JsonArray(tags).size()
  }

  def buildPath(input: String, date_path: String, business: String): String = {
    s"$input/$date_path/*/$business"
  }

  override protected def buildOptions(): Options = {
    val options = new Options
    options.addOption("date", true, "[must] date")
    options.addOption("ga_date", true, "[must] ga_date")
    options.addOption("input", true, "[must] input")
    options.addOption("output", true, "[must] output")
    options.addOption("coalesce", true, "[must] coalesce")
    options
  }
}

object DmDeviceTagAll {
  def main(args: Array[String]): Unit = {
    new DmDeviceTagAll().run(args)
  }
}

case class DeviceTag(device_id: String, device_type: String, platform: String, package_name: String, tag_type: String, first_tag: String, second_tag: String, update_date: String)

case class DmDeviceTag(device_id: String, device_type: String, platform: String, package_name: String, tags: String, update_date: String)