CustomConnectionFactory.scala 7.04 KB
Newer Older
wang-jinfeng committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
package mobvista.dmp.utils.cassandra

import com.datastax.dse.driver.api.core.config.DseDriverOption
import com.datastax.oss.driver.api.core.{CqlSession, ProtocolVersion}
import com.datastax.oss.driver.api.core.config.DefaultDriverOption._
import com.datastax.oss.driver.api.core.config.{DriverConfigLoader, ProgrammaticDriverConfigLoaderBuilder => PDCLB}
import com.datastax.oss.driver.internal.core.connection.ExponentialReconnectionPolicy
import com.datastax.oss.driver.internal.core.ssl.DefaultSslEngineFactory
import com.datastax.spark.connector.cql.DefaultConnectionFactory.maybeGetLocalFile
import com.datastax.spark.connector.cql._
import mobvista.dmp.util.{MobvistaAddressTranslator, PropertyUtil}
import org.apache.commons.lang3.StringUtils
import org.apache.spark.SparkEnv
import org.slf4j.LoggerFactory

import java.time.Duration
import scala.collection.JavaConverters._

/**
 * @package: mobvista.dmp.utils.cassandra
 * @author: wangjf
 * @date: 2019-09-06
 * @time: 17:46
 * @email: jinfeng.wang@mobvista.com
 */
abstract class CustomConnectionFactory extends CassandraConnectionFactory {

  private val LOG = LoggerFactory.getLogger(classOf[CustomConnectionFactory])

  def connectorConfigBuilder(conf: CassandraConnectorConf, initBuilder: PDCLB) = {
    def basicProperties(builder: PDCLB): PDCLB = {
      //  val localCoreThreadCount = Math.max(1, Runtime.getRuntime.availableProcessors() - 1)
      try {
        builder
          .withInt(CONNECTION_POOL_LOCAL_SIZE, conf.localConnectionsPerExecutor.getOrElse(16))
          .withInt(CONNECTION_POOL_REMOTE_SIZE, conf.remoteConnectionsPerExecutor.getOrElse(12))
          .withInt(CONNECTION_INIT_QUERY_TIMEOUT, conf.connectTimeoutMillis)
          .withDuration(CONTROL_CONNECTION_TIMEOUT, Duration.ofMillis(12000))
          .withInt(REQUEST_TIMEOUT, conf.readTimeoutMillis)
          .withClass(RETRY_POLICY_CLASS, classOf[MultipleRetryPolicy])
          .withClass(RECONNECTION_POLICY_CLASS, classOf[ExponentialReconnectionPolicy])
          .withString(PROTOCOL_VERSION, ProtocolVersion.V3.toString)
          .withDuration(RECONNECTION_BASE_DELAY, Duration.ofMillis(conf.minReconnectionDelayMillis))
          .withDuration(RECONNECTION_MAX_DELAY, Duration.ofMillis(conf.maxReconnectionDelayMillis))
          .withInt(NETTY_ADMIN_SHUTDOWN_QUIET_PERIOD, conf.quietPeriodBeforeCloseMillis / 1000)
          .withInt(NETTY_ADMIN_SHUTDOWN_TIMEOUT, conf.timeoutBeforeCloseMillis / 1000)
          .withInt(NETTY_IO_SHUTDOWN_QUIET_PERIOD, conf.quietPeriodBeforeCloseMillis / 1000)
          .withInt(NETTY_IO_SHUTDOWN_TIMEOUT, conf.timeoutBeforeCloseMillis / 1000)
          .withBoolean(NETTY_DAEMON, true)
          .withBoolean(RESOLVE_CONTACT_POINTS, conf.resolveContactPoints)
          .withInt(MultipleRetryPolicy.MaxRetryCount, conf.queryRetryCount)
          .withDuration(DseDriverOption.CONTINUOUS_PAGING_TIMEOUT_FIRST_PAGE, Duration.ofMillis(conf.readTimeoutMillis))
          .withDuration(DseDriverOption.CONTINUOUS_PAGING_TIMEOUT_OTHER_PAGES, Duration.ofMillis(conf.readTimeoutMillis))

        val SYSTEM_NAME = System.getProperty("SYSTEM_NAME")

        val REGION_NAME = System.getProperty("REGION_NAME")

        LOG.info(s"SYSTEM_NAME: ${System.getProperty("SYSTEM_NAME")}, REGION_NAME: ${System.getProperty("REGION_NAME")}")

        val ipStr = PropertyUtil.getProperty("ip.properties", SYSTEM_NAME + "." + REGION_NAME + ".host_map")

        if (StringUtils.isNotBlank(ipStr)) {
          builder.withClass(ADDRESS_TRANSLATOR_CLASS, classOf[MobvistaAddressTranslator])
        }
      } catch {
        case e: ClassNotFoundException =>
          LOG.info("Exception: ", e)
      }
      builder
    }

    def compressionProperties(b: PDCLB): PDCLB =
      Option(conf.compression)
        .filter(_.toLowerCase != "none")
        .fold(b)(c => b.withString(PROTOCOL_COMPRESSION, c.toLowerCase))

    def localDCProperty(b: PDCLB): PDCLB =
      conf.localDC.map(b.withString(LOAD_BALANCING_LOCAL_DATACENTER, _)).getOrElse(b)

    // add ssl properties if ssl is enabled
    def ipBasedConnectionProperties(ipConf: IpBasedContactInfo) = (builder: PDCLB) => {
      builder
        .withStringList(CONTACT_POINTS, ipConf.hosts.map(h => s"${h.getHostString}:${h.getPort}").toList.asJava)
        .withClass(LOAD_BALANCING_POLICY_CLASS, classOf[LocalNodeFirstLoadBalancingPolicy])


      def clientAuthEnabled(value: Option[String]) =
        if (ipConf.cassandraSSLConf.clientAuthEnabled) value else None

      if (ipConf.cassandraSSLConf.enabled) {
        Seq(
          SSL_TRUSTSTORE_PATH -> ipConf.cassandraSSLConf.trustStorePath,
          SSL_TRUSTSTORE_PASSWORD -> ipConf.cassandraSSLConf.trustStorePassword,
          SSL_KEYSTORE_PATH -> clientAuthEnabled(ipConf.cassandraSSLConf.keyStorePath),
          SSL_KEYSTORE_PASSWORD -> clientAuthEnabled(ipConf.cassandraSSLConf.keyStorePassword))
          .foldLeft(builder) { case (b, (name, value)) =>
            value.map(b.withString(name, _)).getOrElse(b)
          }
          .withClass(SSL_ENGINE_FACTORY_CLASS, classOf[DefaultSslEngineFactory])
          .withStringList(SSL_CIPHER_SUITES, ipConf.cassandraSSLConf.enabledAlgorithms.toList.asJava)
          .withBoolean(SSL_HOSTNAME_VALIDATION, false)
      } else {
        builder
      }
    }

    val universalProperties: Seq[PDCLB => PDCLB] =
      Seq(basicProperties, compressionProperties, localDCProperty)

    val appliedProperties: Seq[PDCLB => PDCLB] = conf.contactInfo match {
      case ipConf: IpBasedContactInfo => universalProperties :+ ipBasedConnectionProperties(ipConf)
      case _ => universalProperties
    }

    appliedProperties.foldLeft(initBuilder) { case (builder, properties) => properties(builder) }
  }

  override def createSession(conf: CassandraConnectorConf): CqlSession = {

    val configLoaderBuilder = DriverConfigLoader.programmaticBuilder()

    val configLoader = connectorConfigBuilder(conf, configLoaderBuilder).build()

    val initialBuilder = CqlSession.builder()

    val builderWithContactInfo = conf.contactInfo match {
      case ipConf: IpBasedContactInfo =>
        ipConf.authConf.authProvider.fold(initialBuilder)(initialBuilder.withAuthProvider)
          .withConfigLoader(configLoader)
      case CloudBasedContactInfo(path, authConf) =>
        authConf.authProvider.fold(initialBuilder)(initialBuilder.withAuthProvider)
          .withCloudSecureConnectBundle(maybeGetLocalFile(path))
          .withConfigLoader(configLoader)
      case ProfileFileBasedContactInfo(path) =>
        initialBuilder.withConfigLoader(DriverConfigLoader.fromUrl(maybeGetLocalFile(path)))
    }

    val appName = Option(SparkEnv.get).map(env => env.conf.getAppId).getOrElse("NoAppID")
    builderWithContactInfo
      .withApplicationName(s"Spark-Cassandra-Connector-$appName")
      .withSchemaChangeListener(new MultiplexingSchemaListener())
      .build()
  }

  def setProperties(key: String, value: String): Unit = {
    try {
      System.setProperty(key, value)
    } catch {
      case e: Exception =>
        LOG.info("Failed to set environment variable", e)
    }
  }

  protected def getRegion(): String

  protected def getSystem(): String
}