from abc import ABC
from functools import partial
from collections import defaultdict
from itertools import count

from cml.shared.parameter import HIGHEST_TIER
from cml.shared.errors import DeconstructionFailed, NoModelReconstructedError


__all__ = (
    "Deconstructor",
    "KnowledgeDatabase",
    "RelativeFinder",
    "ConceptualKnowledgeDatabase",
    "ProceduralKnowledgeDatabase",
    "RelativeFinder",
    "KnowledgeDomain"
)


def notify_inverted_index(func):
    def wrapper(self, *args, **kwargs):
        for obs in self.observer:
            getattr(obs, func.__name__)(*args, **kwargs)
        return func(*args, **kwargs)
    return wrapper


class KnowledgeDatabase(ABC):
    def __init__(self):
        self.database = [KnowledgeDomain(i) for i in range(8)]
        self.observer = []
        super().__init__()

    def generate_free_ids(self):
        for i in count(1):
            yield i

    def deserialize(self): pass

    def serialize(self): pass

    @notify_inverted_index
    def extend(self): pass

    @notify_inverted_index
    def insert(self, model):
        self.database[model.knowledge_domain].insert(model)

    @notify_inverted_index
    def remove(self): pass

    @notify_inverted_index
    def replace(self): pass


class ConceptualKnowledgeDatabase(KnowledgeDatabase):
    def __init__(self):
        super().__init__()


class ProceduralKnowledgeDatabase(KnowledgeDatabase):
    def __init__(self):
        super().__init__()


class KnowledgeDomain:
    def __init__(self, tier):
        self.tier = tier
        self.knowledge = {}

    def insert(self, model):
        self.knowledge[model] = model


class RelativeFinder:
    def __init__(self):
        self.index_t = defaultdict(list)
        self.index_z = defaultdict(list)
        self.index_sigma = defaultdict(list)

    def find_relatives(self):
        # TODO (dmt): If something is found return it, if nohtings is found
        # then return None and then raise StopIteration!
        pass

    def remove(self): pass

    def replace(self): pass

    def insert(self, model): pass

    def extend(self): pass


class Deconstructor:
    TIME_COLUMN = "T"
    SUBJECT_COLUMN = "Sigma"
    PUPORSE_COLUMN = "Z"

    def __init__(self, knowledge_database, relative_finder, settings):
        self.knowledge_database = knowledge_database
        self.relative_finder = relative_finder
        self.settings = settings
        self.source = None
        self.reconstructor = None

    def _strategies(self, block):
        yield (("T", "Z"), partial(self.deconstruct_time_zeta, block=block))
        yield (("T", "Sigma"), partial(self.deconstruct_time_sigma, block=block))
        yield (("Sigma", "Z"), self.deconstruct_sigma_zeta)
        yield (("complete", ), self.deconstruct_complete)

    def deconstruct(self, prag_model, learnblock):
        for pair, strategy in self._strategies(learnblock):
            for relative in self.relative_finder(pair):
                try:
                    strategy(prag_model, relative)
                    # successfull deconstruction
                    if self.settings.deconst_mode == "minimal": return

                except DeconstructionFailed:
                    # unsuccessfull deconstruction
                    continue

    def deconstruct_time_sigma(self, prag_model, relative_model, block):
        self.knowledge_database.insert(prag_model)
        if relative_model and prag_model.tier < HIGHEST_TIER:
            first_block = self.source.get_block(prag_model.pre_image)
            second_block = self.source.get_block(relative_model.pre_image)
            times_p = set(first_block.get_column_values(self.TIME_COLUMN))
            times_r = set(second_block.get_column_values(self.TIME_COLUMN))
            if len(times_p.union(times_r)) >= self.settings.min_learnblock:
                alpha = self.calc_reliability(relative_model, prag_model, block)
                alpha_systematic = alpha < 0
                alpha_weak_reliability = 0 <= alpha < self.settings.min_reliability
                if (self.settings.allow_weak_reliability and
                    alpha_weak_reliability) or alpha_systematic:
                    new_model = prag_model.fusion(relative_model)
                    self.knowledge_database.insert(new_model)
                    # Create a new learnblock from times_p.union(times_r)
                    # Starte eine neure Folge von Construction, Reconstruction
                    # Deconstruction from this one!

        else:
            self.knowledge_database.insert(prag_model)

    def deconstruct_time_zeta(self, prag_model, relative_model, block):
        self.knowledge_database.insert(prag_model)
        first_block = self.source.get_block(prag_model.pre_image)
        second_block = self.source.get_block(relative_model.pre_image)
        times_p = set(first_block.get_column_values(self.TIME_COLUMN))
        times_r = set(second_block.get_column_values(self.TIME_COLUMN))
        if len(times_p.union(times_r)) >= self.settings.min_learnblock:
            new_model = prag_model.fusion(relative_model)
            alpha = self.calc_reliability(relative_model, prag_model, block)
            if alpha >= self.settings.min_reliability:
                self.knowledge_database.replace(relative_model, new_model)

    def deconstruct_sigma_zeta(self, prag_model, relative_model):
        if relative_model and self.time_constraint(prag_model,
                                                   relative_model,
                                                   "SigmaZ"):
            first_block = self.source.get_block(prag_model.pre_image)
            second_block = self.source.get_block(relative_model.pre_image)
            overlapping_block = first_block.overlapping_rows(second_block)

            if overlapping_block.rows >= 2:
                new_model = prag_model.fusion(relative_model)
                try:
                    new_block = first_block.fusion(second_block)
                    which_ml_models = new_model.subject
                    recon_m = self.reconstructor.reconstruct(new_block,
                                                             which_ml_models,
                                                             new_model)
                    self.knowledge_database.replace(relative_model, recon_m)

                except NoModelReconstructedError:
                    if self.settings.deconst_mode == "conservative":
                        self.knowledge_database.remove(prag_model)
                    elif self.settings.deconst_mode == "integrative":
                        self.knowledge_database.insert(prag_model)
                    elif self.settings.deconst_mode == "oppurtunistic":
                        if first_block.rows > second_block.rows:
                            self.knowledge_database.remove(relative_model)
                            self.knowledge_database.insert(prag_model)
                        else:
                            self.knowledge_database.remove(prag_model)
                            self.knowledge_database.insert(relative_model)
        else:
            self.knowledge_database.insert(prag_model)

    def time_constraint(self, prag_model, relative_model, _type):
        if _type == "SigmaZ":
            if self.settings.deconst_max_distance_t == 0:
                return self._overlapping_time(prag_model, relative_model)

            elif 0 < self.settings.deconst_max_distance_t < 1:
                return self._enclosing_time(prag_model, relative_model)

            elif self.settings.deconst_max_distance_t == 1:
                return True
        elif _type == "complete":
            return prag_model.min_timestamp >= relative_model.max_timestamp \
                   >= prag_model.max_timestamp \
                   or prag_model.min_timestamp <= relative_model.min_timestamp \
                   and prag_model.max_timestamp >= relative_model.max_timestamp

    def _overlapping_time(self, prag_model, relative_model):
        return relative_model.min_timestamp >= prag_model.max_timestamp or \
            prag_model.min_timestamp >= relative_model.max_timestamp

    def _enclosing_time(self, prag_model, relative_model):
        m_dash_max_time = max(prag_model.max_timestamp,
                              relative_model.max_timestamp)
        m_dash_min_time = min(prag_model.min_timestamp,
                              relative_model.min_timestamp)
        relative_m_dash_time = self.settings.deconst_max_distance_t*(
                m_dash_max_time-m_dash_min_time)

        return (relative_model.min_timestamp - prag_model.max_timestamp) < \
            relative_m_dash_time or (prag_model.min_timestamp -
            relative_model.max_timestamp) < relative_m_dash_time

    def deconstruct_complete(self, prag_model, relative_model):
        if relative_model:
            if self.time_constraint(prag_model, relative_model, "complete") and \
                self._feature_intersection(prag_model, relative_model) >= 2:
                    new_model = prag_model.fusion(relative_model)
            else:
                new_block = self.source.get_block(
                    prag_model.pre_image + relative_model.pre_image)
                ts_relatives = self.source.time_simga_relatives(new_block)
                which_ml_models = prag_model.subject + relative_model.subject
                new_model = self.reconstructor.reconstruct(ts_relatives,
                                                           which_ml_models)

            new_block = self.source.get_block(new_model.pre_image)
            which_ml_models = new_block.subject

            try:
                recon_model = self.reconstructor.reconstruct(new_block,
                                                             which_ml_models,
                                                             new_model)
                self.knowledge_database.remove(relative_model)
                self.knowledge_database.insert(recon_model)
            except NoModelReconstructedError:
                # TODO (dmt): Implement the model differentiation!
                pass

        else:
            self.knowledge_database.insert(prag_model)

    def _feature_intersection(self, prag_model, relative_model):
        first_block = self.source.get_block(prag_model.pre_image)
        second_block = self.source.get_block(relative_model.pre_image)
        return len(
            set(first_block.columns).intersection(set(second_block.columns))
        )

    def calc_reliability(self, model_a, model_b, block):
        return 0