from random import sample
from collections import defaultdict
from dataclasses import dataclass
from functools import partial

import krippendorff

from cml.shared.errors import NoModelReconstructedError


__all__ = (
    "Reconstructor",
)


@dataclass
class Metadata:
    knowledge_domain: str
    knowledge_tier: int
    identifier: int
    pre_image: list
    t_min: int
    t_max: int
    sigma: list
    zeta: list

    def __str__(self):
        return f"Knowledge domain: <{self.knowledge_domain}> " \
               f"Knowledge tier: <{self.knowledge_tier}> " \
               f"Identifier: <{self.identifier}> " \
               f"Pre image: <{self.pre_image}> " \
               f"T min: <{self.t_min}> " \
               f"T max: <{self.t_max}> " \
               f"Subjects: <{self.sigma}> " \
               f"Puposes: <{self.zeta}>"


class PragmaticMachineLearningModel:
    def __init__(self, meta, model, learnblock):
        self.meta = meta
        self.model = model
        self.domain_size = learnblock.n_features
        self.domain = learnblock.indexes

    def __hash__(self):
        return hash(self.uid)

    def __eq__(self, other):
        if isinstance(other, PragmaticMachineLearningModel):
            return hash(self) == hash(other)
        raise NotImplementedError()

    @property
    def tier(self):
        return self.meta.knowledge_tier

    @property
    def min_timestamp(self):
        return self.meta.t_min

    @property
    def max_timestamp(self):
        return self.meta.t_max

    @property
    def pre_image(self):
        return self.meta.pre_image

    @property
    def subject(self):
        return self.meta.sigma

    @property
    def purpose(self):
        return self.meta.zeta

    @property
    def uid(self):
        return ".".join([self.meta.knowledge_domain,
                         str(self.meta.knowledge_tier),
                         str(self.meta.identifier)])

    @property
    def sample_times(self):
        pass

    def fusion(self, prag_model):
        pass


class Reconstructor:
    def __init__(self, settings, ml_models, knowlege_domain):
        self.logger = None
        self.settings = settings
        self.ml_models = ml_models
        self.knowledge_domain = knowlege_domain
        self._category = None
        self._free_id = None
        self.__reconstruction = None

    def reconstruct(self, learnblock, which_models=None, meta=None):
        if not which_models:
            which_models = [m.abbreviation for m in self.ml_models]

        reliabilities_to_model = self.__reconstruction(learnblock,
                                                       which_models,
                                                       meta)
        if reliabilities_to_model.keys():
            return determine_winner(reliabilities_to_model)
        raise NoModelReconstructedError()

    @property
    def category(self):
        return self._category

    @category.setter
    def category(self, value):
        if value == "conceptual":
            self.__reconstruction = partial(self._reconstruct_conceptual,
                                            krippen="nominal")
        elif value == "procedural":
            self.__reconstruction = partial(self._reconstruct_procedural,
                                            krippen="ratio")
        else:
            raise ValueError()

    @property
    def free_id(self):
        return self._free_id

    @free_id.setter
    def free_id(self, value):
        self._free_id = iter(value)

    def _reconstruct_conceptual(self,
                                learnblock,
                                which_models,
                                krippen=None,
                                meta=None):
        reliability_to_model = defaultdict(list)
        for model in self.ml_models:
            if model.abbreviation not in which_models: continue

            # train model
            train_block, eval_block = self.split(learnblock)
            trained_model = model.train(
                train_block.as_numpy_array(),
                [i for i in train_block.get_column_values("Z")])

            # check constraints
            if self._valid_reconstructed(trained_model, "conceptual"):
                reliability = self.calc_reliability(trained_model,
                                                    learnblock,
                                                    krippen)
                if reliability >= self.settings.min_reliability:
                    # TODO (dmt): Fix the knowledge tier after first iteration!
                    prag_meta_data = Metadata(
                        "C",
                        1,
                        next(self.free_id),
                        learnblock.indexes,
                        learnblock.min_timestamp,
                        learnblock.max_timestamp,
                        [model.subject],
                        [".".join(["C", '1', learnblock.purpose])]
                    )

                    reliability_to_model[reliability].append(
                        PragmaticMachineLearningModel(prag_meta_data,
                                                      trained_model,
                                                      learnblock))
        return reliability_to_model

    def _reconstruct_procedural(self, learnblock, krippen=None, meta=None):
        reliability_to_model = defaultdict(list)
        for model in self.ml_models:

            # train model
            train_block, eval_block = self.split(learnblock)
            trained_model = model.train(
                train_block.as_numpy_array(),
                [i for i in train_block.get_column_values("Z")])

            # check contraints
            if self._valid_reconstructed(trained_model, "procedural"):
                reliability = self.calc_reliability(trained_model,
                                                    learnblock,
                                                    krippen)
                if reliability >= self.settings.min_reliability:
                    reliability_to_model[reliability].append(
                        PragmaticMachineLearningModel(trained_model,
                                                      learnblock))

        return reliability_to_model

    def split(self, learnblock):
        indices = learnblock.indexes
        eval_size = int(learnblock.length * self.settings.reliability_sample)
        eval_idx = sample(indices, eval_size)
        train_idx = list(set(indices).difference(set(eval_idx)))
        return learnblock.new_block_from_rows_index(train_idx), \
            learnblock.new_block_from_rows_index(eval_idx)

    def calc_reliability(self, trained_model, eval_block, metric):
        y_pre = trained_model.predict(eval_block.as_numpy_array())
        y_true = [i for i in eval_block.get_column_values("Z")]
        reliability_data = [y_pre, y_true]
        return krippendorff.alpha(reliability_data,
                                  level_of_measurement=metric)

    def _valid_reconstructed(self, model, knowledge_domain):
        if knowledge_domain == "conceptual":
            return model.accuracy >= self.settings.min_test_accuracy
        else:
            return model.mean_error <= self.settings.max_test_error_avg and \
                model.max_error <= self.settings.max_test_error_max


def determine_winner(reliability_to_model):
    sorted_reliabilities = sorted(reliability_to_model.keys(), reverse=True)
    biggest_reliabilities = reliability_to_model[sorted_reliabilities.pop()]

    winner = None
    min_domain = float("inf")
    for model in biggest_reliabilities:
        if model.domain_size < min_domain:
            min_domain = model.domain_size
            winner = model

    return winner