CustomConnectionFactory.scala 7.04 KB
package mobvista.dmp.utils.cassandra

import com.datastax.dse.driver.api.core.config.DseDriverOption
import com.datastax.oss.driver.api.core.{CqlSession, ProtocolVersion}
import com.datastax.oss.driver.api.core.config.DefaultDriverOption._
import com.datastax.oss.driver.api.core.config.{DriverConfigLoader, ProgrammaticDriverConfigLoaderBuilder => PDCLB}
import com.datastax.oss.driver.internal.core.connection.ExponentialReconnectionPolicy
import com.datastax.oss.driver.internal.core.ssl.DefaultSslEngineFactory
import com.datastax.spark.connector.cql.DefaultConnectionFactory.maybeGetLocalFile
import com.datastax.spark.connector.cql._
import mobvista.dmp.util.{MobvistaAddressTranslator, PropertyUtil}
import org.apache.commons.lang3.StringUtils
import org.apache.spark.SparkEnv
import org.slf4j.LoggerFactory

import java.time.Duration
import scala.collection.JavaConverters._

/**
 * @package: mobvista.dmp.utils.cassandra
 * @author: wangjf
 * @date: 2019-09-06
 * @time: 17:46
 * @email: jinfeng.wang@mobvista.com
 */
abstract class CustomConnectionFactory extends CassandraConnectionFactory {

  private val LOG = LoggerFactory.getLogger(classOf[CustomConnectionFactory])

  def connectorConfigBuilder(conf: CassandraConnectorConf, initBuilder: PDCLB) = {
    def basicProperties(builder: PDCLB): PDCLB = {
      //  val localCoreThreadCount = Math.max(1, Runtime.getRuntime.availableProcessors() - 1)
      try {
        builder
          .withInt(CONNECTION_POOL_LOCAL_SIZE, conf.localConnectionsPerExecutor.getOrElse(16))
          .withInt(CONNECTION_POOL_REMOTE_SIZE, conf.remoteConnectionsPerExecutor.getOrElse(12))
          .withInt(CONNECTION_INIT_QUERY_TIMEOUT, conf.connectTimeoutMillis)
          .withDuration(CONTROL_CONNECTION_TIMEOUT, Duration.ofMillis(12000))
          .withInt(REQUEST_TIMEOUT, conf.readTimeoutMillis)
          .withClass(RETRY_POLICY_CLASS, classOf[MultipleRetryPolicy])
          .withClass(RECONNECTION_POLICY_CLASS, classOf[ExponentialReconnectionPolicy])
          .withString(PROTOCOL_VERSION, ProtocolVersion.V3.toString)
          .withDuration(RECONNECTION_BASE_DELAY, Duration.ofMillis(conf.minReconnectionDelayMillis))
          .withDuration(RECONNECTION_MAX_DELAY, Duration.ofMillis(conf.maxReconnectionDelayMillis))
          .withInt(NETTY_ADMIN_SHUTDOWN_QUIET_PERIOD, conf.quietPeriodBeforeCloseMillis / 1000)
          .withInt(NETTY_ADMIN_SHUTDOWN_TIMEOUT, conf.timeoutBeforeCloseMillis / 1000)
          .withInt(NETTY_IO_SHUTDOWN_QUIET_PERIOD, conf.quietPeriodBeforeCloseMillis / 1000)
          .withInt(NETTY_IO_SHUTDOWN_TIMEOUT, conf.timeoutBeforeCloseMillis / 1000)
          .withBoolean(NETTY_DAEMON, true)
          .withBoolean(RESOLVE_CONTACT_POINTS, conf.resolveContactPoints)
          .withInt(MultipleRetryPolicy.MaxRetryCount, conf.queryRetryCount)
          .withDuration(DseDriverOption.CONTINUOUS_PAGING_TIMEOUT_FIRST_PAGE, Duration.ofMillis(conf.readTimeoutMillis))
          .withDuration(DseDriverOption.CONTINUOUS_PAGING_TIMEOUT_OTHER_PAGES, Duration.ofMillis(conf.readTimeoutMillis))

        val SYSTEM_NAME = System.getProperty("SYSTEM_NAME")

        val REGION_NAME = System.getProperty("REGION_NAME")

        LOG.info(s"SYSTEM_NAME: ${System.getProperty("SYSTEM_NAME")}, REGION_NAME: ${System.getProperty("REGION_NAME")}")

        val ipStr = PropertyUtil.getProperty("ip.properties", SYSTEM_NAME + "." + REGION_NAME + ".host_map")

        if (StringUtils.isNotBlank(ipStr)) {
          builder.withClass(ADDRESS_TRANSLATOR_CLASS, classOf[MobvistaAddressTranslator])
        }
      } catch {
        case e: ClassNotFoundException =>
          LOG.info("Exception: ", e)
      }
      builder
    }

    def compressionProperties(b: PDCLB): PDCLB =
      Option(conf.compression)
        .filter(_.toLowerCase != "none")
        .fold(b)(c => b.withString(PROTOCOL_COMPRESSION, c.toLowerCase))

    def localDCProperty(b: PDCLB): PDCLB =
      conf.localDC.map(b.withString(LOAD_BALANCING_LOCAL_DATACENTER, _)).getOrElse(b)

    // add ssl properties if ssl is enabled
    def ipBasedConnectionProperties(ipConf: IpBasedContactInfo) = (builder: PDCLB) => {
      builder
        .withStringList(CONTACT_POINTS, ipConf.hosts.map(h => s"${h.getHostString}:${h.getPort}").toList.asJava)
        .withClass(LOAD_BALANCING_POLICY_CLASS, classOf[LocalNodeFirstLoadBalancingPolicy])


      def clientAuthEnabled(value: Option[String]) =
        if (ipConf.cassandraSSLConf.clientAuthEnabled) value else None

      if (ipConf.cassandraSSLConf.enabled) {
        Seq(
          SSL_TRUSTSTORE_PATH -> ipConf.cassandraSSLConf.trustStorePath,
          SSL_TRUSTSTORE_PASSWORD -> ipConf.cassandraSSLConf.trustStorePassword,
          SSL_KEYSTORE_PATH -> clientAuthEnabled(ipConf.cassandraSSLConf.keyStorePath),
          SSL_KEYSTORE_PASSWORD -> clientAuthEnabled(ipConf.cassandraSSLConf.keyStorePassword))
          .foldLeft(builder) { case (b, (name, value)) =>
            value.map(b.withString(name, _)).getOrElse(b)
          }
          .withClass(SSL_ENGINE_FACTORY_CLASS, classOf[DefaultSslEngineFactory])
          .withStringList(SSL_CIPHER_SUITES, ipConf.cassandraSSLConf.enabledAlgorithms.toList.asJava)
          .withBoolean(SSL_HOSTNAME_VALIDATION, false)
      } else {
        builder
      }
    }

    val universalProperties: Seq[PDCLB => PDCLB] =
      Seq(basicProperties, compressionProperties, localDCProperty)

    val appliedProperties: Seq[PDCLB => PDCLB] = conf.contactInfo match {
      case ipConf: IpBasedContactInfo => universalProperties :+ ipBasedConnectionProperties(ipConf)
      case _ => universalProperties
    }

    appliedProperties.foldLeft(initBuilder) { case (builder, properties) => properties(builder) }
  }

  override def createSession(conf: CassandraConnectorConf): CqlSession = {

    val configLoaderBuilder = DriverConfigLoader.programmaticBuilder()

    val configLoader = connectorConfigBuilder(conf, configLoaderBuilder).build()

    val initialBuilder = CqlSession.builder()

    val builderWithContactInfo = conf.contactInfo match {
      case ipConf: IpBasedContactInfo =>
        ipConf.authConf.authProvider.fold(initialBuilder)(initialBuilder.withAuthProvider)
          .withConfigLoader(configLoader)
      case CloudBasedContactInfo(path, authConf) =>
        authConf.authProvider.fold(initialBuilder)(initialBuilder.withAuthProvider)
          .withCloudSecureConnectBundle(maybeGetLocalFile(path))
          .withConfigLoader(configLoader)
      case ProfileFileBasedContactInfo(path) =>
        initialBuilder.withConfigLoader(DriverConfigLoader.fromUrl(maybeGetLocalFile(path)))
    }

    val appName = Option(SparkEnv.get).map(env => env.conf.getAppId).getOrElse("NoAppID")
    builderWithContactInfo
      .withApplicationName(s"Spark-Cassandra-Connector-$appName")
      .withSchemaChangeListener(new MultiplexingSchemaListener())
      .build()
  }

  def setProperties(key: String, value: String): Unit = {
    try {
      System.setProperty(key, value)
    } catch {
      case e: Exception =>
        LOG.info("Failed to set environment variable", e)
    }
  }

  protected def getRegion(): String

  protected def getSystem(): String
}