package mobvista.dmp.datasource.age_gender

import java.net.URI
import java.util
import java.util.regex.Pattern

import mobvista.dmp.common.CommonSparkJob
import mobvista.dmp.util.MRUtils
import org.apache.commons.cli.Options
import org.apache.commons.lang.StringUtils
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.sql.SparkSession

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

/**
  * @author wangjf
  */
class CalcPackageAge extends CommonSparkJob with Serializable {

  private val wellSplit = Pattern.compile("#")
  private val dollarSplit = Pattern.compile("\\$")

  override protected def run(args: Array[String]): Int = {

    val commandLine = commParser.parse(options, args)
    if (!checkMustOption(commandLine)) {
      printUsage(options)
      return -1
    } else {
      printOptions(commandLine)
    }

    val input = commandLine.getOptionValue("inputPath")
    val output = commandLine.getOptionValue("outputPath")
    val parallelism = commandLine.getOptionValue("parallelism")
    val low = commandLine.getOptionValue("low").toInt
    val high = commandLine.getOptionValue("high").toInt
    val unbelievable = commandLine.getOptionValue("unbelievable").toDouble

    val spark = SparkSession.builder()
      .appName("CalcPackageAge")
      .config("spark.rdd.compress", "true")
      .config("spark.io.compression.codec", "snappy")
      .config("spark.sql.orc.filterPushdown", "true")
      .config("spark.sql.warehouse.dir", "s3://mob-emr-test/spark-warehouse")
      .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
      .enableHiveSupport()
      .getOrCreate()
    val sc = spark.sparkContext

    FileSystem.get(new URI(s"s3://mob-emr-test"), spark.sparkContext.hadoopConfiguration).delete(new Path(output), true)
    try {

      val merge_device_age_rdd = spark.read.schema(Constant.merge_schema).orc(input).rdd

      val rdd = merge_device_age_rdd.map(r => {
        val pkg_ages = new ArrayBuffer[(String, Int)]()
        val package_names = r.getAs("package_names").toString
        val age_business = r.getAs("label").toString
        val pkgs = wellSplit.split(package_names, -1)
        val ages = dollarSplit.split(age_business, -1)
        for (pkg <- pkgs) {
          for (age <- ages) {
            if (!age.equals("null") && StringUtils.isNotBlank(age)) {
              pkg_ages += ((MRUtils.JOINER.join(pkg, wellSplit.split(age, -1)(0)), 1))
            }
          }
        }
        pkg_ages
      }).flatMap(l => l)
        .reduceByKey(_ + _)
        .mapPartitions(mapMergeFun)

      rdd.combineByKey(
        (v: String) => Iterable(v),
        (c: Iterable[String], v: String) => c ++ Seq(v),
        (c1: Iterable[String], c2: Iterable[String]) => c1 ++ c2
      ).map(r => {
        val ageNum: mutable.HashMap[String, Int] = scala.collection.mutable.HashMap.empty[String, Int]
        var valid: Int = 0 //  有年龄标签的数量
        val key = r._1
        for (age <- r._2) {
          val age_num = wellSplit.split(age, -1)
          ageNum += (age_num(0) -> age_num(1).toInt)
          valid += age_num(1).toInt //  将 package 下有年龄标签的合并在一起
        }
        (key, (valid, ageNum))
      }).map(r => {
        val pkg = r._1
        val valid = r._2._1
        val ageNum = r._2._2
        val set = ageNum.keySet
        var maxRatio: Double = 0.0
        val value: StringBuilder = new StringBuilder
        for (name <- set) {
          if (!name.equals("null")) {
            val num = ageNum.get(name).mkString
            var abnRatio = 0.0
            if (valid != 0) abnRatio = num.toDouble / valid
            if (abnRatio > maxRatio) maxRatio = abnRatio
            value.append("|") //  每个包下的年龄标签拼接
            value.append(name)
            value.append(":")
            value.append(abnRatio)
          }
        }
        if (value.isEmpty) { //当没有年龄标签时
          value.append(0.0)
        }
        MRUtils.JOINER.join(pkg, valid.toString, maxRatio.toString, value)
      }).mapPartitions(mapDictFun(_, low, high, unbelievable))
        .sortByKey()
        .map(l => {
          MRUtils.JOINER.join(l._1, l._2)
        }).coalesce(parallelism.toInt).saveAsTextFile(output)

    } finally {
      sc.stop()
      spark.stop()
    }
    0
  }

  override protected def buildOptions(): Options = {
    val options = new Options
    options.addOption("inputPath", true, "[must] inputPath")
    options.addOption("outputPath", true, "[must] outputPath")
    options.addOption("parallelism", true, "[must] parallelism")
    options.addOption("low", true, "[must] low")
    options.addOption("high", true, "[must] high")
    options.addOption("unbelievable", true, "[must] unbelievable")
    options
  }

  def mapMergeFun(iter: Iterator[(String, Int)]): Iterator[(String, String)] = {
    val res = new util.ArrayList[(String, String)]()
    while (iter.hasNext) {
      val ir = iter.next
      val fields = MRUtils.SPLITTER.split(ir._1, -1)
      val pkg = fields(0)
      val age = fields(1)
      val num = ir._2
      res.add((pkg, age + wellSplit + num))
    }
    res.asScala.iterator
  }

  def mapFun(iter: Iterator[String]): Iterator[(String, Int)] = {
    val res = new util.ArrayList[(String, Int)]()
    while (iter.hasNext) {
      val cur = MRUtils.SPLITTER.split(iter.next, -1)
      res.add((MRUtils.JOINER.join(cur(0), cur(1)), 1))
    }
    res.asScala.iterator
  }

  def mapDictFun(iter: Iterator[String], low: Int, high: Int, unbelievable: Double): Iterator[(String, String)] = {
    //  val res = new util.ArrayList[String]()
    val res = new util.ArrayList[(String, String)]()
    while (iter.hasNext) {
      val fields = MRUtils.SPLITTER.split(iter.next, -1)
      if (fields(3) != "0.0" && fields(1).toInt >= low) {
        var tag = ""
        if (fields(2).toFloat >= unbelievable) {
          tag = "unbelievable"
          //  res.add(MRUtils.JOINER.join(fields(0), fields(3), tag))
          res.add((fields(0), MRUtils.JOINER.join(fields(3), tag)))
        }
        if (fields(1).toInt >= high) {
          tag = "confirm"
        } else {
          tag = "calc"
        }
        res.add((fields(0), MRUtils.JOINER.join(fields(3), tag)))
        //  res.add(MRUtils.JOINER.join(fields(0), fields(3), tag))
      }
    }
    res.asScala.iterator
  }

  def reduceFun(red_1: (Int, Map[String, Int]), red_2: (Int, Map[String, Int])): (Int, Map[String, Int]) = {
    val redMap = red_1._2 ++ red_2._2
      .map(t => t._1 -> (t._2 + red_1._2.getOrElse(t._1, 0)))
    (red_1._1 + red_2._1, redMap)
  }
}

object CalcPackageAge {
  def main(args: Array[String]): Unit = {
    new CalcPackageAge().run(args)
  }
}