package mobvista.dmp.format;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.io.compress.CompressionCodec;
import org.apache.hadoop.io.compress.GzipCodec;
import org.apache.hadoop.mapreduce.OutputCommitter;
import org.apache.hadoop.mapreduce.RecordWriter;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import org.apache.hadoop.mapreduce.TaskID;
import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.util.ReflectionUtils;

import java.io.DataOutputStream;
import java.io.IOException;
import java.text.NumberFormat;
import java.util.HashMap;
import java.util.Iterator;

public class RDDMultipleOutputFormat<K extends WritableComparable<?>, V extends Writable>
        extends FileOutputFormat<K, V> {
    private static final NumberFormat NUMBER_FORMAT = NumberFormat
            .getInstance();

    static {
        NUMBER_FORMAT.setMinimumIntegerDigits(5);
        NUMBER_FORMAT.setGroupingUsed(false);
    }

    private MultiRecordWriter writer = null;

    public RecordWriter<K, V> getRecordWriter(TaskAttemptContext job)
            throws IOException {
        if (writer == null) {
            writer = new MultiRecordWriter(job, getTaskOutputPath(job));
        }
        return writer;
    }

    private Path getTaskOutputPath(TaskAttemptContext conf) throws IOException {
        Path workPath;
        OutputCommitter committer = super.getOutputCommitter(conf);
        if (committer instanceof FileOutputCommitter) {
            workPath = ((FileOutputCommitter) committer).getWorkPath();
        } else {
            Path outputPath = super.getOutputPath(conf);
            if (outputPath == null) {
                throw new IOException("Undefined job output-path");
            }
            workPath = outputPath;
        }
        return workPath;
    }

    protected String generateFileNameForKeyValue(K key, V value,
                                                 String name) {
        return key.toString() + "/" + name;
    }

    public class MultiRecordWriter extends RecordWriter<K, V> {
        private HashMap<String, RecordWriter<NullWritable, V>> recordWriters;
        private TaskAttemptContext job;
        private Path workPath;

        public MultiRecordWriter(TaskAttemptContext job, Path workPath) {
            super();
            this.job = job;
            this.workPath = workPath;
            recordWriters = new HashMap<>();
        }

        @Override
        public void close(TaskAttemptContext context) throws IOException,
                InterruptedException {
            Iterator<RecordWriter<NullWritable, V>> values = this.recordWriters.values()
                    .iterator();
            while (values.hasNext()) {
                values.next().close(context);
            }
            this.recordWriters.clear();
        }

        @Override
        public void write(K key, V value) throws IOException,
                InterruptedException {
            TaskID taskId = job.getTaskAttemptID().getTaskID();
            int partition = taskId.getId();
            String baseName = generateFileNameForKeyValue(key, value,
                    NUMBER_FORMAT.format(partition));
            RecordWriter<NullWritable, V> rw = this.recordWriters.get(baseName);
            if (rw == null) {
                rw = getBaseRecordWriter(job, baseName);
                this.recordWriters.put(baseName, rw);
            }
            //	key = generateActualKey(key, value);
            rw.write(NullWritable.get(), value);
        }

        //	${mapred.out.dir}/_temporary/_${taskid}/${nameWithExtension}
        private RecordWriter<NullWritable, V> getBaseRecordWriter(TaskAttemptContext job,
                                                                  String baseName) throws IOException {
            Configuration conf = job.getConfiguration();
            boolean isCompressed = getCompressOutput(job);
            String keyValueSeparator = "\t";
            RecordWriter<NullWritable, V> recordWriter;
            if (isCompressed) {
                Class<? extends CompressionCodec> codecClass = getOutputCompressorClass(
                        job, GzipCodec.class);
                CompressionCodec codec = ReflectionUtils.newInstance(
                        codecClass, conf);
                Path file = new Path(workPath, baseName
                        + codec.getDefaultExtension());
                FSDataOutputStream fileOut = file.getFileSystem(conf).create(
                        file, false);
                recordWriter = new LineRecordWriter<>(new DataOutputStream(
                        codec.createOutputStream(fileOut)), keyValueSeparator);
            } else {
                Path file = new Path(workPath, baseName);
                FSDataOutputStream fileOut = file.getFileSystem(conf).create(
                        file, false);
                recordWriter = new LineRecordWriter<>(fileOut,
                        keyValueSeparator);
            }
            return recordWriter;
        }
    }
}