package mobvista.dmp.datasource.dm

import java.net.URI

import mobvista.dmp.common.CommonSparkJob
import mobvista.dmp.datasource.dm.entity.{ParentCondition, SqlCondition, TagCondition}
import mobvista.prd.datasource.util.GsonUtil
import org.apache.commons.cli.{BasicParser, Options}
import org.apache.commons.lang.StringUtils
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
import org.apache.spark.storage.StorageLevel

import scala.collection.mutable

/**
  * @package: mobvista.dmp.datasource.dm
  * @author: wangjf
  * @date: 2019/3/20
  * @time: 上午11:26
  * @email: jinfeng.wang@mobvista.com
  * @phone: 152-1062-7698
  */
class TagQuery extends CommonSparkJob with Serializable {
  def commandOptions(): Options = {
    val options = new Options()
    options.addOption("json", true, "json")
    options.addOption("output", true, "output")
    options.addOption("coalesce", true, "coalesce")
    options
  }

  val stack = new mutable.Stack[DataFrame]

  val common_sql = "SELECT device_id FROM active_tag WHERE"

  override protected def run(args: Array[String]): Int = {
    val parser = new BasicParser()
    val options = commandOptions()
    val commandLine = parser.parse(options, args)
    val json = commandLine.getOptionValue("json")
    val output = commandLine.getOptionValue("output")
    val coalesce = commandLine.getOptionValue("coalesce")

    val spark = SparkSession
      .builder()
      .appName("TagQuery")
      .config("spark.rdd.compress", "true")
      .config("spark.sql.orc.filterPushdown", "true")
      .config("spark.io.compression.codec", "snappy")
      .config("spark.sql.warehouse.dir", "s3://mob-emr-test/spark-warehouse")
      .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
      .enableHiveSupport()
      .getOrCreate()
    try {

      val jsonString = if (json.endsWith("]")) {
        json.trim + "}}"
      } else {
        json.trim
      }
      val parentCondition: ParentCondition = parseParent(jsonString)

      val id = parentCondition.id
      FileSystem.get(new URI(s"s3://mob-emr-test"), spark.sparkContext.hadoopConfiguration).delete(new Path(output + "/" + id), true)
      val platform = if (parentCondition.platform.equals("ios")) {
        "0"
      } else {
        "1"
      }

      var sql = Constant.query_active_sql.replace("@dt", parentCondition.date)
        .replace("@platform", platform)
        .replace("@cnt", parentCondition.days)

      if (StringUtils.isNotBlank(parentCondition.device_type)) {
        val device_type = parentCondition.device_type
        sql = sql + s" AND device_type = '$device_type'"
      }
      if (StringUtils.isNotBlank(parentCondition.country_code)) {
        val country_code = parentCondition.country_code
        sql = sql + s" AND country_code = '$country_code'"
      }

      val df = spark.sql(sql).persist(StorageLevel.MEMORY_AND_DISK_SER)
      df.createOrReplaceTempView("active_tag")
      val tagCondition: TagCondition = parentCondition.tagCondition
      parseJson(tagCondition, spark)

      stack.pop() //  将栈中存放的结果 DataFrame 出栈
        .repartition(coalesce.toInt)
        .write
        .mode(SaveMode.Overwrite)
        .option("orc.compress", "zlib")
        .orc(output + "/" + id)

    } finally {
      if (spark != null) {
        spark.stop()
      }
    }
    0
  }

  def parseParent(jsonString: String): ParentCondition = {
    GsonUtil.fromJson(GsonUtil.String2JsonObject(jsonString), classOf[ParentCondition])
  }

  def parseTag(jsonString: String): TagCondition = {
    GsonUtil.fromJson(GsonUtil.String2JsonObject(jsonString), classOf[TagCondition])
  }

  def parseSql(jsonString: String): SqlCondition = {
    GsonUtil.fromJson(GsonUtil.String2JsonObject(jsonString), classOf[SqlCondition])
  }

  def parseJson(tag: TagCondition, spark: SparkSession) {
    val logic = tag.logic
    val tags = tag.tags
    for (i <- tags.indices) {
      var sql = ""
      println(tags(i))
      if (GsonUtil.String2JsonObject(tags(i).toString).has("logic")) {
        val tagss = parseTag(tags(i).toString)
        parseJson(tagss, spark)
      } else {
        val sqlCondition = parseSql(tags(i).toString)
        sql = sql_con(sqlCondition.tag_code)
      }
      if (StringUtils.isNotBlank(sql)) {
        val df = spark.sql(sql)
        //  df 入栈
        stack.push(df)
      }
    }
    //  创建一个 空的 DataFrame
    var resultDF = spark.emptyDataFrame
    if (stack.size >= 2) { //  判断栈中的 DataFrame 个数
      val top1 = stack.pop()
      val top2 = stack.pop()
      resultDF = if (logic) { //  判断两组 DataFrame 之间的逻辑关系，true: 交集，false: 并集
        top1.intersect(top2)
      } else {
        top1.union(top2).dropDuplicates()
      }
    } else { //  当栈中仅有一个 DataFrame 时，出栈
      val top1 = stack.pop()
      resultDF = top1
    }
    //  将计算好的结果 DataFrame 入栈
    stack.push(resultDF)
  }

  //  封装 SQL 查询语句
  def sql_con(tag_codes: Array[String]): String = {
    var param = ""
    tag_codes.foreach(code => {
      param += "'" + code + "',"
    })
    common_sql + " tag_code IN (" + param.substring(0, param.length - 1) + ")"
  }
}

object TagQuery {
  def main(args: Array[String]): Unit = {
    new TagQuery().run(args)
  }
}