package mobvista.dmp.datasource.age_gender

import java.net.URI

import mobvista.dmp.common.CommonSparkJob
import mobvista.dmp.util.{DateUtil, MRUtils}
import org.apache.commons.cli.Options
import org.apache.commons.lang3.StringUtils
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.sql.{SaveMode, SparkSession}

/**
  * @author wangjf
  */
class CalcDeviceAge extends CommonSparkJob with Serializable {

  override protected def run(args: Array[String]): Int = {

    val commandLine = commParser.parse(options, args)
    if (!checkMustOption(commandLine)) {
      printUsage(options)
      return -1
    } else {
      printOptions(commandLine)
    }

    val date = commandLine.getOptionValue("date")
    val merge_input = commandLine.getOptionValue("merge_input")
    val dict_input = commandLine.getOptionValue("dict_input")
    val output = commandLine.getOptionValue("output")
    val parallelism = commandLine.getOptionValue("parallelism")

    val spark = SparkSession.builder()
      .appName("CalcDeviceAge")
      .config("spark.rdd.compress", "true")
      .config("spark.io.compression.codec", "snappy")
      .config("spark.sql.orc.filterPushdown", "true")
      .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
      .config("spark.kryo.registrationRequired", "false")
      .config("spark.kryo.registrator", "mobvista.dmp.datasource.age_gender.MyRegisterKryo")
      .enableHiveSupport()
      .getOrCreate()


    val sc = spark.sparkContext

    FileSystem.get(new URI(s"s3://mob-emr-test"), spark.sparkContext.hadoopConfiguration).delete(new Path(output), true)

    try {

      val packageMap = sc.textFile(dict_input).map(_.split("\t"))
        .map(r => (r(0), MRUtils.JOINER.join(r(1).substring(1), r(2)))).collectAsMap()

      val bPackageMap = sc.broadcast(packageMap)

      val update_date = DateUtil.format(DateUtil.parse(date, "yyyyMMdd"), "yyyy-MM-dd")
      val df = spark.read.schema(Constant.merge_schema).orc(merge_input)
        .where(s"update_date = '$update_date'")
        .rdd
        .mapPartitions(Logic.calcDeviceAgeLogicJson(_, bPackageMap))
        .filter(d => {
          StringUtils.isNotBlank(d.age)
        })

      /*
      df.repartition(parallelism.toInt)
        .map(r => {
          MRUtils.JOINER.join(r.device_id, r.device_type, r.package_names, r.age, r.tag)
        }).saveAsTextFile(output, classOf[GzipCodec])
      */
      import spark.implicits._
      df.coalesce(numPartitions = parallelism.toInt, shuffle = true)
        .toDF
        .write.mode(SaveMode.Overwrite)
        .option("orc.compress", "zlib")
        .orc(output)

    } finally {
      sc.stop()
      spark.stop()
    }
    0
  }

  override protected def buildOptions(): Options = {
    val options = new Options
    options.addOption("date", true, "[must] date")
    options.addOption("merge_input", true, "[must] merge_input")
    options.addOption("dict_input", true, "[must] dict_input")
    options.addOption("output", true, "[must] output")
    options.addOption("parallelism", true, "[must] parallelism")
    options
  }
}

object CalcDeviceAge {
  def main(args: Array[String]): Unit = {
    new CalcDeviceAge().run(args)
  }
}