from abc import ABC
from functools import partial
from collections import defaultdict
from itertools import count
from collections import deque
from typing import Tuple, Optional

import krippendorff

from cml.domain.data_source import DataSource
from cml.domain.reconstruction import PragmaticMachineLearningModel
from cml.shared.settings import DeconstructionSettings
from cml.shared.errors import (
    DeconstructionFailed,
    NoModelReconstructedError,
    NotEnoughFeaturesWarning
)


__all__ = (
    "Deconstructor",
    "KnowledgeDatabase",
    "RelativeFinder",
    "ConceptualKnowledgeDatabase",
    "ProceduralKnowledgeDatabase",
    "RelativeFinder",
    "KnowledgeDomain"
)

TS_QUEUE = deque()


def notify_inverted_index(func):
    def wrapper(self, *args, **kwargs):
        for obs in self.observer:
            getattr(obs, func.__name__)(*args, **kwargs)
        return func(self, *args, **kwargs)
    return wrapper


def notify_deque(func):
    def wrapper(self, *args, **kwargs):
        # replace
        if len(args) == 2:
            model, *_ = args
        # remove
        else: model, = args
        deque_copy = TS_QUEUE.copy()
        for tier, learnblock in deque_copy:
            if model.uid in learnblock.origin:
                TS_QUEUE.remove((tier, learnblock))
        return func(self, *args, **kwargs)
    return wrapper


class KnowledgeDatabase(ABC):
    def __init__(self, highest_tier):
        self.database = [KnowledgeDomain(i) for i in range(highest_tier)]
        self.observer = []
        self.n_models = 0
        super().__init__()

    def __contains__(self, item: PragmaticMachineLearningModel):
        if not isinstance(item, PragmaticMachineLearningModel):
            raise TypeError()
        try:
            self.get(item.uid)
            return True
        except KeyError:
            return False

    @notify_inverted_index
    def insert(self, model: PragmaticMachineLearningModel):
        if model not in self:
            self.database[model.tier].insert(model)
            self.n_models += 1

            ###################################################################
            self.logger.protocol("{:<20}: {}".format("Inserted", model))
            ###################################################################

    @notify_deque
    @notify_inverted_index
    def remove(self, model: PragmaticMachineLearningModel):
        self.database[model.tier].remove(model)
        self.n_models -= 1

        ###################################################################
        self.logger.protocol("{:<20}: {}".format("Removed", model))
        ###################################################################

    @notify_deque
    @notify_inverted_index
    def replace(self,
                replaced: PragmaticMachineLearningModel,
                replacer: PragmaticMachineLearningModel):
        self.database[replaced.tier].replace(replaced, replacer)
        ###################################################################
        self.logger.protocol("{:<20}: {} {:<20}: {}".format(
            "Replaced", str(replaced), "with", str(replacer)))
        ###################################################################

    @notify_inverted_index
    def extend(self): pass

    def get(self, uid: str):
        _, tier, _ = uid.split(".")
        return self.database[int(tier)].get(uid)

    def deserialize(self): pass

    def serialize(self): pass

    def model_counter(self):
        def counts(tier):
            return self.database[tier].biggest_id + 1
        return counts

    def remove_dependent_models(self,
                                relative_model: PragmaticMachineLearningModel):
        for domain in self.database:
            for model in domain.knowledge.values():
                if model.origin == relative_model.origin:
                    self.remove(model)

    def inject_logger(self, logger):
        self.logger = logger


class ConceptualKnowledgeDatabase(KnowledgeDatabase):
    def __init__(self, highest_tier):
        super().__init__(highest_tier)


class ProceduralKnowledgeDatabase(KnowledgeDatabase):
    def __init__(self, highest_tier):
        super().__init__(highest_tier)

class KnowledgeDomain:
    def __init__(self, tier: int):
        self.tier = tier
        self.knowledge = {}
        self.biggest_id = 0

    def get(self, uid: str):
        return self.knowledge[uid]

    def insert(self, model: PragmaticMachineLearningModel):
        self.knowledge[model] = model
        self.update_biggest_id(model.counter)

    def remove(self, model: PragmaticMachineLearningModel):
        del self.knowledge[model]

    def replace(self,
                replaced: PragmaticMachineLearningModel,
                replacer: PragmaticMachineLearningModel):
        del self.knowledge[replaced]
        self.knowledge[replacer] = replacer
        self.update_biggest_id(replacer.counter)

    def update_biggest_id(self, new_id):
        self.biggest_id = max(self.biggest_id, new_id)


class RelativeFinder:
    def __init__(self):
        self.index_t = defaultdict(set)
        self.index_z = defaultdict(set)
        self.index_sigma = defaultdict(set)

    def find(self,
             pair: Tuple[str, Optional[str]],
             model: PragmaticMachineLearningModel):

        if pair == ("T", "Z"):
            set_t = self.index_t[(model.min_timestamp, model.max_timestamp)]
            set_z = self.index_z[model.purpose]
            relatives = set_t.intersection(set_z)
        elif pair == ("T", "Sigma"):
            set_t = self.index_t[(model.min_timestamp, model.max_timestamp)]
            set_s = self.index_sigma[model.subject]
            relatives = set_t.intersection(set_s)
        elif pair == ("Sigma", "Z"):
            set_s = self.index_sigma[model.subject]
            set_z = self.index_z[model.purpose]
            relatives = set_s.intersection(set_z)
        elif pair == ("complete", ):
            set_t = self.index_t[(model.min_timestamp, model.max_timestamp)]
            set_s = self.index_sigma[model.subject]
            set_z = self.index_z[model.purpose]
            relatives = set_t.intersection(set_s).intersection(set_z)
        else: raise ValueError

        if model in relatives:
            relatives.remove(model)
        for relative in relatives:
            yield relative
        if not relatives:
            yield None

    def remove(self, model: PragmaticMachineLearningModel):
        t_list, s_list, z_list = self.get_index_lists(model,
                                                      time=True,
                                                      subject=True,
                                                      purpose=True)
        t_list.remove(model)
        s_list.remove(model)
        z_list.remove(model)

    def replace(self,
                replaced: PragmaticMachineLearningModel,
                replacer: PragmaticMachineLearningModel):
        t_list, s_list, z_list = self.get_index_lists(replaced,
                                                      time=True,
                                                      subject=True,
                                                      purpose=True)
        t_list.remove(replaced)
        s_list.remove(replaced)
        z_list.remove(replaced)
        self.index_t[(replacer.min_timestamp, replacer.max_timestamp)].add(
            replacer)
        self.index_z[replacer.purpose].add(replacer)
        self.index_sigma[replacer.subject].add(replacer)

    def insert(self, model: PragmaticMachineLearningModel):
        self.index_t[(model.min_timestamp, model.max_timestamp)].add(model)
        self.index_z[model.purpose].add(model)
        self.index_sigma[model.subject].add(model)

    def get_index_lists(self,
                        model: PragmaticMachineLearningModel,
                        time: bool,
                        subject: bool,
                        purpose: bool):
        t_list = s_list = z_list = None
        if time:
            t_list = self.index_t[(model.min_timestamp, model.max_timestamp)]
        if subject:
            s_list = self.index_sigma[model.subject]
        if purpose:
            z_list = self.index_z[model.purpose]
        return t_list, s_list, z_list

    def extend(self): pass


def log(func):
    def wrapper(self, tier, pragmatic, learnblock):
        ###################################################################
        self.logger.protocol("{:*^100}".format("Deconstruction"))
        ###################################################################
        return func(self, tier, pragmatic, learnblock)
    return wrapper


class Deconstructor:
    TIME_COLUMN = "T"
    SUBJECT_COLUMN = "Sigma"
    PURPOSE_COLUMN = "Z"
    NEXT_MODEL_COUNTER = None

    def __init__(self,
                 knowledge_database: KnowledgeDatabase,
                 relative_finder: RelativeFinder,
                 source: DataSource,
                 settings: DeconstructionSettings):

        self.knowledge_database = knowledge_database
        self.relative_finder = relative_finder
        self.settings = settings
        self.source = source
        self.logger = None
        self.reconstructor = None

    def _strategies(self, block):
        yield (("complete", ), partial(self.deconstruct_complete, block=block))
        yield (("Sigma", "Z"), partial(self.deconstruct_sigma_zeta, block=block))
        yield (("T", "Sigma"), partial(self.deconstruct_time_sigma, block=block))
        # yield (("T", "Z"), partial(self.deconstruct_time_zeta, block=block))

    @log
    def deconstruct(self,
                    tier: int,
                    prag_model: PragmaticMachineLearningModel,
                    learnblock) -> None:

        success = False
        ###################################################################
        self.logger.protocol(self.settings)
        self.logger.protocol("{:<20}: {}".format("Pragmatic", str(prag_model)))
        ###################################################################
        for pair, strategy in self._strategies(learnblock):

            ###################################################################
            self.logger.protocol("{:<20}".format(str(pair)))
            ###################################################################

            for relative in self.relative_finder.find(pair, prag_model):

                ###############################################################
                self.logger.protocol(
                    "{:<20}: {}".format("Relative", str(relative)))
                ###############################################################

                try:
                    strategy(tier, prag_model, relative)
                    success = True
                    if self.settings.deconst_mode == "minimal": return

                except NoModelReconstructedError:
                    continue

                except DeconstructionFailed:
                    continue

                except Exception as error:
                    print(error.with_traceback())

        if not success:
            # All deconstructions failed, so save the pragmatic model
            self.knowledge_database.insert(prag_model)

    def deconstruct_time_sigma(self,
                               tier: int,
                               p_model: PragmaticMachineLearningModel,
                               r_model: PragmaticMachineLearningModel,
                               block):
        success = False
        if r_model and p_model.tier < self.settings.highest_tier-1:

            second_block = r_model.trained_with(self.source)
            overlapping = second_block.new_block_from(block.get_column_values("T"))

            if overlapping.rows >= self.settings.learn_block_minimum:
                alpha = self.calculate_reliability(p_model.pre_image_labels,
                                                   r_model.pre_image_labels)
                alpha_systematic = alpha < 0
                alpha_weak_reliability = 0 <= alpha < self.settings.min_reliability
                if (self.settings.allow_weak_reliability and
                   alpha_weak_reliability) or alpha_systematic:
                    overlapping_b = block.new_block_from(overlapping.get_column_values("T"))
                    overblock = self.source.new_learnblock(
                        values=list(zip(
                            overlapping.get_column_values("Z"),
                            overlapping_b.get_column_values("Z"),
                            overlapping.get_column_values("T"),
                            ("\"\"" for _ in range(overlapping.rows)),
                            ("\"\"" for _ in range(overlapping.rows)))),
                        columns=(p_model.uid, r_model.uid, "T", "Sigma", "Z"),
                        index=[i for i in range(overlapping.rows)],
                        origin=[p_model.uid, r_model.uid]
                    )

                    # samples
                    # data = list(zip(over_block.get_column_values("Z"),
                    #                 over_block.get_column_values("T"),
                    #                 ["\"\"" for _ in range(over_block.rows)],
                    #                 ["\"\"" for _ in range(over_block.rows)]))
                    # feature = ".".join(["0", str(tier+1), "1"])
                    # columns = [feature, "T", "Sigma", "Z"]
                    # source = self.source.new_learnblock(
                    #     values=data, columns=columns, index=over_block.indexes,
                    #     origin=[p_model.uid, r_model.uid])
                    TS_QUEUE.append((tier+1, overblock))
                    success = True

        if not success:
            raise DeconstructionFailed()

    def deconstruct_time_zeta(self,
                              tier: int,
                              prag_model: PragmaticMachineLearningModel,
                              relative_model: PragmaticMachineLearningModel,
                              block):
        success = False
        self.knowledge_database.insert(prag_model)

        if relative_model:
            # Get learnblock that trained relative model
            second_block = relative_model.trained_with(self.source)

            # Get samples that have overlapping timestamp
            over_block = block.overlapping_rows(second_block, subset=["T"])
            if over_block.rows >= self.settings.learn_block_minimum:

                # Create new metadata for a pragmatic model
                new_model = prag_model.fusion(
                    relative_model, self.NEXT_MODEL_COUNTER(tier))

                # Which models should be used for the reconstruction
                which_ml_models = new_model.sigma

                # Get learningblock
                train_block = block.fusion(second_block)

                # Start the reconstruction
                try:
                    recon_model = self.reconstructor.reconstruct(
                        tier, train_block, which_ml_models, new_model)
                    self.knowledge_database.replace(relative_model, recon_model)
                    success = True
                except (NoModelReconstructedError, NotEnoughFeaturesWarning):
                    pass
                except ValueError as error:
                    print(error.with_traceback())

                # alpha = self.calc_reliability(relative_model, prag_model, block)
                # if alpha >= self.settings.min_reliability:

        if not success:
            raise DeconstructionFailed()

    def deconstruct_sigma_zeta(self,
                               tier: int,
                               p_model: PragmaticMachineLearningModel,
                               r_model: PragmaticMachineLearningModel,
                               block):
        success = False

        if r_model and self.time_constraint(p_model, r_model, "SigmaZ"):

            # Get learnblock that trained relative model
            second_block = r_model.trained_with(self.source)

            overlapping_block = block.same_features_fusion(second_block)

            # Check constraint
            if overlapping_block.n_features >= 2:
                # Model fusion
                new_model = p_model.fusion(r_model, self.NEXT_MODEL_COUNTER(tier))
                which_ml_models = new_model.sigma
                try:
                    # Reconstruct model
                    recon_m = self.reconstructor.reconstruct(
                        tier, overlapping_block, which_ml_models, new_model)

                    self.knowledge_database.replace(r_model, recon_m)
                    success = True

                except (NoModelReconstructedError, NotEnoughFeaturesWarning):
                    if self.settings.deconst_mode == "conservative":
                        self.knowledge_database.remove(p_model)
                    elif self.settings.deconst_mode == "integrative":
                        self.knowledge_database.insert(p_model)
                    elif self.settings.deconst_mode == "oppurtunistic":
                        if block.rows > second_block.rows:
                            self.knowledge_database.remove(r_model)
                            self.knowledge_database.insert(p_model)
                        else:
                            self.knowledge_database.remove(p_model)
                            self.knowledge_database.insert(r_model)

        if not success:
            raise DeconstructionFailed()

    def time_constraint(self,
                        prag_model: PragmaticMachineLearningModel,
                        relative_model: PragmaticMachineLearningModel,
                        _type: str):

        if _type == "SigmaZ":
            if self.settings.deconst_max_distance_t == 0:
                return self._overlapping_time(prag_model, relative_model)

            elif 0 < self.settings.deconst_max_distance_t < 1:
                return self._enclosing_time(prag_model, relative_model)

            elif self.settings.deconst_max_distance_t == 1:
                return True
        elif _type == "complete":
            return prag_model.min_timestamp >= relative_model.max_timestamp \
                   >= prag_model.max_timestamp \
                   or prag_model.min_timestamp <= relative_model.min_timestamp \
                   and prag_model.max_timestamp >= relative_model.max_timestamp

    def _overlapping_time(self,
                          prag_model: PragmaticMachineLearningModel,
                          relative_model: PragmaticMachineLearningModel):

        return relative_model.min_timestamp >= prag_model.max_timestamp or \
            prag_model.min_timestamp >= relative_model.max_timestamp

    def _enclosing_time(self, prag_model, relative_model):
        m_dash_max_time = max(prag_model.max_timestamp,
                              relative_model.max_timestamp)
        m_dash_min_time = min(prag_model.min_timestamp,
                              relative_model.min_timestamp)
        relative_m_dash_time = self.settings.deconst_max_distance_t*(
                m_dash_max_time-m_dash_min_time)

        return (relative_model.min_timestamp - prag_model.max_timestamp) < \
            relative_m_dash_time or (prag_model.min_timestamp -
            relative_model.max_timestamp) < relative_m_dash_time

    def deconstruct_complete(self,
                             tier: int,
                             p_model: PragmaticMachineLearningModel,
                             r_model: PragmaticMachineLearningModel,
                             block):
        success = False

        try:
            # Check feature intersection constraint
            if r_model and self._feature_intersection(p_model, r_model) >= 2:
                new_model = p_model.fusion(
                    r_model, self.NEXT_MODEL_COUNTER(tier))

            # Check time contraint
            elif r_model and self.time_constraint(p_model, r_model, "complete"):

                # Create submodel from TSgima relative samples
                second_block = r_model.trained_with(self.source)
                new_block = block.same_features_fusion(second_block)
                ts_relatives = self.source.time_sigma_relatives(new_block)
                which_ml_models = p_model.subject + r_model.subject
                self.reconstructor.reconstruct(
                    tier, ts_relatives, which_ml_models)
                new_model = p_model.fusion(r_model, self.NEXT_MODEL_COUNTER(tier))
            else: return

            # Create learnblock
            first_block = p_model.trained_with(self.source)
            second_block = r_model.trained_with(self.source)
            new_block = first_block.same_features_fusion(second_block)
            which_ml_models = new_model.sigma

            try:
                # Reconstruct model
                recon_model = self.reconstructor.reconstruct(
                    tier, new_block, which_ml_models, new_model)
                self.knowledge_database.remove(r_model)
                self.knowledge_database.insert(recon_model)
                success = True

            except (NoModelReconstructedError, NotEnoughFeaturesWarning):
                success = self.model_differentiation(tier, new_block, r_model)

        except (NoModelReconstructedError, NotEnoughFeaturesWarning):
            self.knowledge_database.remove(r_model)
            self.knowledge_database.insert(p_model)

        finally:
            if not success:
                raise DeconstructionFailed()

    def model_differentiation(self,
                              tier: int,
                              block,
                              relative_model: PragmaticMachineLearningModel):
        success = False

        time_column = block.get_column_values("T")
        density = self.source.estimate_density(time_column)
        self.source.remove_time_dense_relatives(block, density)
        clusters = self.source.cluster(block, density)
        for time_values in clusters:
            learnblock = block.new_block_from(time_values)
            try:
                reconstructed_model = self.reconstructor.reconstruct(
                    tier, learnblock)
                self.knowledge_database.insert(reconstructed_model)
                success = True
            except NoModelReconstructedError:
                self.knowledge_database.remove_dependent_models(relative_model)

        if not success:
            self.knowledge_database.remove(relative_model)

        return success

    def _feature_intersection(self,
                              prag_model: PragmaticMachineLearningModel,
                              relative_model: PragmaticMachineLearningModel):
        first_block = prag_model.trained_with(self.source)
        second_block = relative_model.trained_with(self.source)
        return len(
            set(first_block.columns()).intersection(set(second_block.columns()))
        )

    def calculate_reliability(self, predicts_a, predicts_b):
        predictions = [predicts_a, predicts_b]
        if self.reconstructor.category == "conceptual":
            return krippendorff.alpha(predictions, level_of_measurement="nomimal")
        elif self.reconstructor.category:
            return krippendorff.alpha(predictions, level_of_measurement="ration")
    #
    # def calc_reliability(self,
    #                      model_a: PragmaticMachineLearningModel,
    #                      model_b: PragmaticMachineLearningModel,
    #                      block):
    #     y_one = model_a.model.predict(block.as_numpy_array())
    #     y_two = model_b.model.predict(block.as_numpy_array())
    #     reliability_data = [y_one, y_two]
    #     if self.reconstructor.category == "conceptual":
    #         return krippendorff.alpha(reliability_data,
    #                                   level_of_measurement="nominal")
    #     elif self.reconstructor.category == "procedural":
    #         return krippendorff.alpha(reliability_data,
    #                                   level_of_measurement="ratio")
    #     else:
    #         raise ValueError()