package mobvista.dmp.datasource.age_gender

import java.net.URI
import java.util
import java.util.regex.Pattern

import mobvista.dmp.common.CommonSparkJob
import mobvista.dmp.util.MRUtils
import org.apache.commons.cli.Options
import org.apache.commons.lang.StringUtils
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.sql.SparkSession

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

/**
  * @author wangjf
  */
class CalcPackageGender extends CommonSparkJob with Serializable {

  private val wellSplit = Pattern.compile("#")
  private val dollarSplit = Pattern.compile("\\$")

  override protected def run(args: Array[String]): Int = {

    val commandLine = commParser.parse(options, args)
    if (!checkMustOption(commandLine)) {
      printUsage(options)
      return -1
    } else {
      printOptions(commandLine)
    }

    val input = commandLine.getOptionValue("inputPath")
    val output = commandLine.getOptionValue("outputPath")
    val parallelism = commandLine.getOptionValue("parallelism")
    val lowThreshold = commandLine.getOptionValue("lowThreshold").toInt
    val highThreshold = commandLine.getOptionValue("highThreshold").toInt
    val mRatio = commandLine.getOptionValue("mRatio").toDouble
    val fRatio = commandLine.getOptionValue("fRatio").toDouble

    val spark = SparkSession.builder()
      .appName("CalcPackageGender")
      .config("spark.rdd.compress", "true")
      .config("spark.io.compression.codec", "snappy")
      .config("spark.sql.orc.filterPushdown", "true")
      .config("spark.sql.warehouse.dir", "s3://mob-emr-test/spark-warehouse")
      .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
      .enableHiveSupport()
      .getOrCreate()
    val sc = spark.sparkContext

    FileSystem.get(new URI(s"s3://mob-emr-test"), spark.sparkContext.hadoopConfiguration).delete(new Path(output), true)
    try {

      val merge_device_gender_rdd = spark.read.schema(Constant.merge_schema).orc(input).rdd

      val rdd = merge_device_gender_rdd.map(r => {
        val pkg_genders = new ArrayBuffer[(String, Int)]()
        val package_names = r.getAs("package_names").toString
        val gender_business = r.getAs("label").toString
        val pkgs = wellSplit.split(package_names, -1)
        val genders = dollarSplit.split(gender_business, -1)
        for (pkg <- pkgs) {
          for (gender <- genders) {
            if (!gender.equals("null") && StringUtils.isNotBlank(gender)) {
              pkg_genders += ((MRUtils.JOINER.join(pkg, wellSplit.split(gender, -1)(0)), 1))
            }
          }
        }
        pkg_genders
      }).flatMap(l => l)
        .reduceByKey(_ + _)
        .mapPartitions(mapMergeFun)

      rdd.combineByKey(
        (v: String) => Iterable(v),
        (c: Iterable[String], v: String) => c ++ Seq(v),
        (c1: Iterable[String], c2: Iterable[String]) => c1 ++ c2
      ).map(r => {
        val genderNum: mutable.HashMap[String, Integer] = scala.collection.mutable.HashMap.empty[String, Integer]
        var valid: Int = 0 //  有性别标签的数量
        val key = r._1
        for (gender <- r._2) {
          val gender_num = wellSplit.split(gender, -1)
          genderNum += (gender_num(0) -> gender_num(1).toInt)
          valid += gender_num(1).toInt //  将 package 下有性别标签的合并在一起
        }
        (key, (valid, genderNum))
      }).map(r => {
        val pkg = r._1
        val valid = r._2._1
        val genderNum = r._2._2
        val set = genderNum.keySet
        val value: StringBuilder = new StringBuilder
        var mRatio = 0.0
        var mGender = 0
        var fGender = 0
        for (name <- set) {
          if (!name.equals("null")) {
            val num = genderNum(name)
            value.append("|") //  每个包下的性别标签拼接
            value.append(name)
            value.append(":")
            value.append(num)
          }
        }
        if (valid != 0) { //当没有标签时
          mGender = if (genderNum.keySet.contains("m")) {
            genderNum("m")
          } else {
            0
          }
          fGender = if (genderNum.keySet.contains("f")) {
            genderNum("f")
          } else {
            0
          }
          mRatio = mGender.toDouble / valid
        }
        MRUtils.JOINER.join(pkg, valid, mGender.toString, fGender.toString, mRatio.toString, value)
      }).mapPartitions(mapDictFun(_, lowThreshold, highThreshold, mRatio, fRatio))
        .coalesce(parallelism.toInt).saveAsTextFile(output)

    } finally {
      sc.stop()
      spark.stop()
    }
    0
  }

  override protected def buildOptions(): Options = {
    val options = new Options
    options.addOption("inputPath", true, "[must] inputPath")
    options.addOption("outputPath", true, "[must] outputPath")
    options.addOption("parallelism", true, "[must] parallelism")
    options.addOption("lowThreshold", true, "[must] lowThreshold")
    options.addOption("highThreshold", true, "[must] highThreshold")
    options.addOption("mRatio", true, "[must] mRatio")
    options.addOption("fRatio", true, "[must] fRatio")
    options
  }

  def mapMergeFun(iter: Iterator[(String, Int)]): Iterator[(String, String)] = {
    val res = new util.ArrayList[(String, String)]()
    while (iter.hasNext) {
      val ir = iter.next
      val fields = MRUtils.SPLITTER.split(ir._1, -1)
      val pkg = fields(0)
      val tag = fields(1)
      val num = ir._2
      res.add((pkg, tag + wellSplit + num))
    }
    res.asScala.iterator
  }

  def mapFun(iter: Iterator[String]): Iterator[(String, Int)] = {
    val res = new util.ArrayList[(String, Int)]()
    while (iter.hasNext) {
      val cur = MRUtils.SPLITTER.split(iter.next, -1)
      res.add((MRUtils.JOINER.join(cur(0), cur(1)), 1))
    }
    res.asScala.iterator
  }

  def mapDictFun(iter: Iterator[String], lowThreshold: Int, highThreshold: Int, mRatio: Double, fRatio: Double): Iterator[String] = {
    val res = new util.ArrayList[String]()
    while (iter.hasNext) {
      val fields = MRUtils.SPLITTER.split(iter.next, -1)
      if (fields(4) != "0.0" && fields(1).toInt >= lowThreshold) {
        val tag = if (fields(1).toInt >= highThreshold && (fields(4).toDouble >= mRatio || fields(4).toDouble <= fRatio)) {
          "confirm"
        } else {
          "calc"
        }
        //  package_name  male_ratio  label_type
        res.add(MRUtils.JOINER.join(fields(0), fields(4), tag))
      }
    }
    res.asScala.iterator
  }

  def reduceFun(red_1: (Int, Map[String, Int]), red_2: (Int, Map[String, Int])): (Int, Map[String, Int]) = {
    val redMap = red_1._2 ++ red_2._2
      .map(t => t._1 -> (t._2 + red_1._2.getOrElse(t._1, 0)))
    (red_1._1 + red_2._1, redMap)
  }
}

object CalcPackageGender {
  def main(args: Array[String]): Unit = {
    new CalcPackageGender().run(args)
  }
}