from os.path import commonprefix
from itertools import cycle
from array import array

from cml.shared.parameter import PROTOCOL_LEVEL

__all__ = (
    "DataSource",
    "Preprocessor",
    "LearnblockIdentifier"
)


def log_learnblock_processing(func):
    def wrapper(self):

        for learnblock in func(self):
            self.logger.protocol("{:*^100}".format("Learnblock generation"))
            if self.logger and self.logger.level == PROTOCOL_LEVEL:
                self.logger.protocol(
                    "\t".join([
                        "",
                        "[{:^5}{:^5}]".format(learnblock.rows, learnblock.cols),
                        str(learnblock.min_timestamp),
                        str(learnblock.max_timestamp),
                        "{:<20}".format(str(learnblock.origin)),
                        "{:<20}".format(str(learnblock.relatives))
                    ])
                )
            yield learnblock
    return wrapper


def log_block_processing(func):
    counter = 0

    def wrapper(self):
        nonlocal counter
        counter += 1
        block = func(self)
        if self.logger and self.logger.level == PROTOCOL_LEVEL:
            self.logger.protocol("{:*^100}".format("Blockprocessing"))
            self.logger.protocol(
                "\t".join([
                    str(counter),
                    "[{:^5}{:^5}]".format(block.rows, block.cols),
                    str(block.min_timestamp),
                    str(block.max_timestamp)
                ])
            )
        return block
    return wrapper


class DataSource:
    _SIGNED_CHAR = 'b'

    def __init__(self, source, learnblock_identifier, settings):
        self.settings = settings
        self.logger = None
        self.__source = source
        self.__learnblock_identifier = learnblock_identifier
        self.__source_halde_flags = None

    @property
    @log_learnblock_processing
    def learnblocks(self):
        if self.settings.block_size > len(self):
            raise ValueError("Block size cannot be larger then the size"
                             "the data source.")

        for block in self:
            learnblock = self.__learnblock_identifier.identify(block)
            if learnblock:
                learnblock.origin = "source"
                self._flip_source_halde_flags(learnblock.indexes)
                yield learnblock

    @log_block_processing
    def __next__(self):
        return next(self.__next_helper)

    def _next_helper(self):
        block_indexes = []
        counter = 0
        old_index = 0
        halde_runs = -1

        for i in cycle(range(0, len(self))):
            if halde_runs >= self.settings.stack_iterations:
                # manually stop generator
                return

            if counter == self.settings.block_size:
                old_index = i
                counter = 0
                yield self.__source.get_block_via_index(block_indexes)
                block_indexes.clear()

            elif old_index > i:
                halde_runs += 1
                old_index = i
                counter = 0
                yield self.__source.get_block_via_index(block_indexes)
                block_indexes.clear()

            if self.__source_halde_flags[i] == 0:
                block_indexes.append(i)
                counter += 1

    def __iter__(self):
        self.__next_helper = iter(self._next_helper())
        self.__source_halde_flags = array(
            self._SIGNED_CHAR, [0 for _ in range(len(self))])
        return self

    def _flip_source_halde_flags(self, learnblock_indexes):
        for i in learnblock_indexes:
            self.__source_halde_flags[i] = 1

    def __len__(self):
        return self.__source.length

    def get_block(self, indices=None, columns=None):
        return self.__source.get_block_via_index(indices, columns=columns)

    def time_sigma_relatives(self, block):
        return next(iter(self.__learnblock_identifier._identify_relatives(
            block, "T", "Sigma")))

    def estimate_density(self, data):
        kernel_density_estimator = self.__learnblock_identifier.\
            density_estimator.train(data)
        return kernel_density_estimator.density()

    def remove_time_dense_relatives(self, block, density):
        self.__learnblock_identifier._remove_time_dense_relatives(
            block, density)

    def cluster(self, block, density):
        return self.__learnblock_identifier._cluster_sigma_zeta_relatives(
            block, density
        )

    def new_learnblock(self, values, columns, index, origin):
        return self.__source.new_block(values, columns, index, origin)

    def get_time_values(self, indices):
        return self.__source.get_block_via_index(indices, columns="T")\
            .as_numpy_array()


class Preprocessor:
    TIME_COLUMN = "T"
    TARGET_COLUMN = "Z"
    SIGMA_COLUMN = "Sigma"

    def __init__(self, settings):
        self.settings = settings

    def clean(self, table):
        self._drop_irrelevant_columns(table)

        if self.settings.set_targets:
            self._overwrite_target_column(table)

        if self.settings.sort_time_stamp:
            self._sort_according_time_stamp(table)

        if self.settings.cut_time_stamp:
            self._remove_common_time_stamp_prefix(table)

        return table

    def _drop_irrelevant_columns(self, table):
        # TODO (dmt): Features get dropped and not kept! Fix settings!
        features_to_be_removed = [table.get_column_name_by_index(i)
                                  for i in self.settings.set_features]
        for column in features_to_be_removed:
            table.drop_column_by_name(column)

    def _overwrite_target_column(self, table):
        table.set_column_value(self.TARGET_COLUMN, self.settings.set_targets)

    def _sort_according_time_stamp(self, table):
        table.sort(self.TIME_COLUMN)

    def _remove_common_time_stamp_prefix(self, table):
        # TODO (dmt): Check if timestamp column is of type string!
        time_column = table.get_column_values_as_list(self.TIME_COLUMN)
        common_prefix = commonprefix(time_column)
        cleaned_time_column = [s.lstrip(common_prefix) for s in time_column]
        table.set_column_values(cleaned_time_column)


class LearnblockIdentifier:
    def __init__(self, settings, density_estimator, relative_extrema):
        self.settings = settings
        self.density_estimator = density_estimator
        self._relative_extrema = relative_extrema

    @classmethod
    def _column_pairs(cls):
        yield ("Sigma", "Z")
        yield ("T", "Sigma")
        yield ("T", "Z")

    def identify(self, block):
        biggest_learn_block = None
        biggest_block_size = 0

        for pair in LearnblockIdentifier._column_pairs():
            for possible_learnblock in self._identify_relatives(block, *pair):
                if self._is_learn_block(possible_learnblock.length):
                    if possible_learnblock.length > biggest_block_size:
                        biggest_learn_block = possible_learnblock
                        biggest_learn_block.relatives = pair

        return biggest_learn_block

    def _is_learn_block(self, block_length):
        return block_length >= self.settings.learn_block_minimum

    def _identify_relatives(self, block, *args):
        already_seen = set()
        for value_pair in block.get_duplicated_pairs(args[0], args[1]):
            if value_pair not in already_seen:
                already_seen.add(value_pair)
                kw = {args[0]: value_pair[0], args[1]: value_pair[1]}
                if args[0] == "Sigma" and args[1] == "Z":
                    try:
                        for block in self._get_sigma_zeta_relatives(block,
                                                                    **kw):
                            yield block
                    except ValueError as e:
                        # TODO (dmt): Provide mechanism for logging errors!
                        continue

                else:
                    yield block.get_values(**kw)

    def _get_sigma_zeta_relatives(self, block, **kw):
        relatives = block.get_values(**kw)
        time_column = relatives.get_column_values("T")
        density = self.density_estimator.train(time_column).density()
        self._remove_time_dense_relatives(relatives, density)
        clusters = self._cluster_sigma_zeta_relatives(relatives, density)
        for time_values in clusters:
            yield relatives.new_block_from(time_values)

    def _remove_time_dense_relatives(self, block, density):
        max_dens = max(density)
        for index, dens in enumerate(density):
            if dens > max_dens*(self.settings.sigma_zeta_cutoff/100):
                block.drop_row(index)

    def _cluster_sigma_zeta_relatives(self, cutted_block, density):
        # TOOD (dmt): Don't rely on data series from pandas, 'cause ckmeans
        # needs primitives data types.

        time_column = list(cutted_block.get_column_values("T"))
        _, maxs = self._relative_extrema(density)
        return ckmeans(time_column, len(maxs))


# TODO (dmt): Refactor Ckmeans algorithm!
# Resource: https://journal.r-project.org/archive/2011-2/RJournal_2011-2_Wang+Song.pdf
def ssq(j, i, sum_x, sum_x_sq):
    if j > 0:
        muji = (sum_x[i] - sum_x[j-1]) / (i - j + 1)
        sji = sum_x_sq[i] - sum_x_sq[j-1] - (i - j + 1) * muji ** 2
    else:
        sji = sum_x_sq[i] - sum_x[i] ** 2 / (i+1)

    return 0 if sji < 0 else sji


def fill_row_k(imin, imax, k, S, J, sum_x, sum_x_sq, N):
    if imin > imax: return

    i = (imin+imax) // 2
    S[k][i] = S[k-1][i-1]
    J[k][i] = i

    jlow = k

    if imin > k:
        jlow = int(max(jlow, J[k][imin-1]))
    jlow = int(max(jlow, J[k-1][i]))

    jhigh = i-1
    if imax < N-1:
        jhigh = int(min(jhigh, J[k][imax+1]))

    for j in range(jhigh, jlow-1, -1):
        sji = ssq(j, i, sum_x, sum_x_sq)

        if sji + S[k-1][jlow-1] >= S[k][i]: break

        # Examine the lower bound of the cluster border
        # compute s(jlow, i)
        sjlowi = ssq(jlow, i, sum_x, sum_x_sq)

        SSQ_jlow = sjlowi + S[k-1][jlow-1]

        if SSQ_jlow < S[k][i]:
            S[k][i] = SSQ_jlow
            J[k][i] = jlow

        jlow += 1

        SSQ_j = sji + S[k-1][j-1]
        if SSQ_j < S[k][i]:
            S[k][i] = SSQ_j
            J[k][i] = j

    fill_row_k(imin, i-1, k, S, J, sum_x, sum_x_sq, N)
    fill_row_k(i+1, imax, k, S, J, sum_x, sum_x_sq, N)


def fill_dp_matrix(data, S, J, K, N):
    import numpy as np
    sum_x = np.zeros(N, dtype=np.float_)
    sum_x_sq = np.zeros(N, dtype=np.float_)

    # median. used to shift the values of x to improve numerical stability
    shift = data[N//2]

    for i in range(N):
        if i == 0:
            sum_x[0] = data[0] - shift
            sum_x_sq[0] = (data[0] - shift) ** 2
        else:
            sum_x[i] = sum_x[i-1] + data[i] - shift
            sum_x_sq[i] = sum_x_sq[i-1] + (data[i] - shift) ** 2

        S[0][i] = ssq(0, i, sum_x, sum_x_sq)
        J[0][i] = 0

    for k in range(1, K):
        if k < K-1:
            imin = max(1, k)
        else:
            imin = N-1

        fill_row_k(imin, N-1, k, S, J, sum_x, sum_x_sq, N)


def ckmeans(data, n_clusters):
    import numpy as np
    if n_clusters <= 0:
        raise ValueError("Cannot classify into 0 or less clusters")
    if n_clusters > len(data):
        raise ValueError("Cannot generate more classes than there are data values")

    # if there's only one value, return it; there's no sensible way to split
    # it. This means that len(ckmeans([data], 2)) may not == 2. Is that OK?
    unique = len(set(data))
    if unique == 1:
        return [data]

    data.sort()
    n = len(data)

    S = np.zeros((n_clusters, n), dtype=np.float_)

    J = np.zeros((n_clusters, n), dtype=np.uint64)

    fill_dp_matrix(data, S, J, n_clusters, n)

    clusters = []
    cluster_right = n-1

    for cluster in range(n_clusters-1, -1, -1):
        cluster_left = int(J[cluster][cluster_right])
        clusters.append(data[cluster_left:cluster_right+1])

        if cluster > 0:
            cluster_right = cluster_left - 1

    return list(reversed(clusters))