package mobvista.dmp.util;

import com.datastax.oss.driver.api.core.addresstranslation.AddressTranslator;
import com.datastax.oss.driver.api.core.context.DriverContext;
import com.datastax.oss.driver.internal.core.util.Loggers;
import edu.umd.cs.findbugs.annotations.NonNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.naming.NamingEnumeration;
import javax.naming.NamingException;
import javax.naming.directory.Attribute;
import javax.naming.directory.Attributes;
import javax.naming.directory.DirContext;
import javax.naming.directory.InitialDirContext;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.util.Enumeration;
import java.util.Hashtable;

/**
 * @package: mobvista.dmp.util
 * @author: wangjf
 * @date: 2019-09-06
 * @time: 18:31
 * @email: jinfeng.wang@mobvista.com
 * @phone: 152-1062-7698
 */
public class MobvistaAddressTranslator implements AddressTranslator {

    private static final Logger LOG = LoggerFactory.getLogger(MobvistaAddressTranslator.class);

    private final DirContext ctx;

    private static final String SYSTEM_NAME = System.getProperty("SYSTEM_NAME");

    private static final String REGION_NAME = System.getProperty("REGION_NAME");

    private static final String ipStr = PropertyUtil.getProperty("ip.properties", SYSTEM_NAME + "." + REGION_NAME + ".host_map");

    public MobvistaAddressTranslator(
            @SuppressWarnings("unused") @NonNull DriverContext context) {
        try {
            context.getSessionName();
            Hashtable<Object, Object> privatePublicAddressMap = new Hashtable<>();
            String[] ips = ipStr.split(",");
            for (String ip : ips) {
                String[] ipArr = ip.split(":");
                privatePublicAddressMap.put(ipArr[0], ipArr[1]);
            }
            ctx = new InitialDirContext(privatePublicAddressMap);
        } catch (NamingException e) {
            throw new RuntimeException("Could not create translator", e);
        }
    }

    @NonNull
    @Override
    public InetSocketAddress translate(@NonNull InetSocketAddress inetSocketAddress) {
        InetAddress address = inetSocketAddress.getAddress();
        try {
            String publicAddress = lookupPtrRecord(address.getHostAddress());
            if (publicAddress == null) {
                return inetSocketAddress;
            }
            return new InetSocketAddress(publicAddress, inetSocketAddress.getPort());
        } catch (Exception e) {
            return inetSocketAddress;
        }
    }

    private String lookupPtrRecord(String reversedDomain) throws Exception {
        Attributes attrs = ctx.getAttributes(reversedDomain, new String[]{"PTR"});
        for (NamingEnumeration ae = attrs.getAll(); ae.hasMoreElements(); ) {
            Attribute attr = (Attribute) ae.next();
            Enumeration<?> vals = attr.getAll();
            if (vals.hasMoreElements()) {
                return vals.nextElement().toString();
            }
        }
        return null;
    }

    @Override
    public void close() {
        try {
            ctx.close();
        } catch (NamingException e) {
            Loggers.warnWithException(LOG, "Error closing translator", e);
        }
    }
}