# -*- coding: utf-8 -*-

from io import BytesIO, StringIO
from multiprocessing import Pool
from pathlib import Path
import tarfile
import time

import pandas as pd

# Data load functions


def read_sequence(fileobj, index_col=None):
    return pd.read_csv(fileobj, index_col=index_col, engine='c')


def read_data_archive(path, index_col=None):
    ext = Path(path).suffix.strip().lower()
    fmt = ''
    if ext != '.tar':
        fmt = ext[1:]
    with tarfile.open(str(path), mode='r:{}'.format(fmt)) as tar:
        for member_info in tar:
            with tar.extractfile(member_info) as member:
                with StringIO(member.read().decode('utf-8')) as buffer:
                    yield read_sequence(buffer, index_col=index_col)


def read_data_directory(data_dir, features=None, index_col=None):
    features = features or slice(None)
    data_path = Path(data_dir)
    examples = []
    for fmt_ext in ('', '.gz', '.bz2', '.xz'):
        fmt_paths = data_path.glob('**/*.tar{}'.format(fmt_ext))
        for archive in filter(Path.is_file, fmt_paths):
            for sequence in read_data_archive(archive, index_col=index_col):
                feats = features or sequence.columns
                examples.append(sequence[feats].copy())
    for csv in filter(Path.is_file, data_path.glob('**/*.csv')):
        with csv.open('r') as f:
            sequence = read_sequence(f, index_col=index_col)
            feats = features or sequence.columns
            examples.append(sequence[feats].copy())
    return examples


def _tar_example(example_id, example_data):
    buffer_str = StringIO()
    buffer_bytes = BytesIO()
    example_data.to_csv(buffer_str)
    size = buffer_bytes.write(buffer_str.getvalue().encode('utf-8'))
    buffer_bytes.seek(0)
    info = tarfile.TarInfo(name='example_{:08d}.csv'.format(example_id))
    info.size = size
    info.mtime = int(time.time())
    info.mode = 0o644
    return info, buffer_bytes


def save_data_archive(data, directory, archive_name, fmt=None):
    ext = '.tar'
    fmt = (fmt or '').strip().lower()
    if fmt:
        ext += '.' + fmt
    file_name = '{}{}'.format(archive_name, ext)
    file_path = Path(directory, file_name)
    # if file_path.exists():
    #     raise ValueError('The file "{}" already exists'.format(file_path))
    if not directory.exists():
        raise ValueError('The directory "{}" does not exist'.format(directory))
    if not directory.is_dir():
        raise ValueError('Invalid directory "{}"'.format(directory))
    with file_path.open('wb') as f:
        with tarfile.open(mode='w|{}'.format(fmt), fileobj=f) as tar:
            with Pool() as pool:
                compressed_examples = pool.starmap(_tar_example,
                                                   enumerate(data))
            for info, buffer_bytes in compressed_examples:
                tar.addfile(tarinfo=info, fileobj=buffer_bytes)
