import sys
from queue import Empty
from copy import deepcopy
from functools import partial
from multiprocessing import Process, Queue

from cml.shared.parameter import PROTOCOL_LEVEL


__all__ = (
    "Constructor",
)


def log_construction(func):
    def wrapper(self, learnblock):
        self.logger.protocol("{:*^100}".format("Construction"))
        for complete_learnblock in func(self, learnblock):
            if self.logger and self.logger.level == PROTOCOL_LEVEL:
                self.logger.protocol("{:<20} ==>".format("Result"))
                self.logger.protocol(
                    "\t".join([
                        "",
                        "[{:^5}{:^5}]".format(learnblock.rows, learnblock.cols),
                        str(complete_learnblock.min_timestamp),
                        str(complete_learnblock.max_timestamp),
                        "{:<20}".format(str(learnblock.origin)),
                        "{:<20}".format(str(learnblock.relatives)),
                        "{:<20}".format(str(complete_learnblock.purpose)),
                    ]))

            yield complete_learnblock
    return wrapper


def update_construction(func):
    def wrapper(self, value):
        func(self, value)
        self.construction_type = self.construction_type
    return wrapper


class BackgroundConstructor(Process):
    def __init__(self, models, logger, settings, input_queue, output_queue):
        super(BackgroundConstructor, self).__init__(name="BackgroundConstructor")
        self.models = models
        self.logger = logger
        self.settings = settings
        self.input_queue = input_queue
        self.output_queue = output_queue

    def run(self):
        while 1:
            try:
                learnblock = self.input_queue.get(timeout=30)
                for args in self.prepare_args(learnblock):
                    learnblock, model, min_category_size = args
                    if learnblock.labeled:
                        self.output_queue.put(learnblock)
                    else:
                        trained_model = model.train(learnblock)
                        for cluster, size in trained_model.cluster_sizes.items():
                            if size < min_category_size:
                                break
                        else:
                            labels = trained_model.get_labels()
                            labeled_learnblock = learnblock.set_labels(labels)
                            labeled_learnblock.n_cluster = model.cluster
                            purpose = "{}{:02}".format(model.abbreviation,
                                                       model.cluster)
                            labeled_learnblock.purpose = purpose
                            self.output_queue.put(labeled_learnblock)

            except Empty:
                sys.exit()

    def prepare_args(self, learnblock):
        for ml_model in self.models:
            for complexity in range(2, self.settings.max_categories+1):
                ml_model.cluster = complexity
                yield deepcopy(learnblock), deepcopy(ml_model), complexity


class Constructor:
    def __init__(self, mode, ml_models, settings):
        self.mode = mode
        self.settings = settings
        self.ml_models = ml_models
        self._construction = self._init_construction(mode)
        self.logger = None
        self.__process = None
        self.__intput_queue = None
        self.__output_queue = None

    @log_construction
    def construct(self, learnblock):
        for block in self._construction(learnblock):
            block.origin = learnblock.origin
            if block.origin is None: input()
            yield block

    @property
    def max_target_error(self):
        return self.settings.max_target_error

    @max_target_error.setter
    @update_construction
    def max_target_error(self, value):
        self.settings.max_target_error = value

    @property
    def max_model_targets(self):
        return self.settings.max_model_targets

    @max_model_targets.setter
    @update_construction
    def max_model_targets(self, value):
        self.settings.max_model_targets = value

    @property
    def min_category_size(self):
        return self.settings.min_category_size

    @min_category_size.setter
    @update_construction
    def min_category_size(self, value):
        self.settings.min_category_size = value

    @property
    def max_categories(self):
        return self.settings.max_categories

    @max_categories.setter
    def max_categories(self, value):
        self.settings.max_categories = value

    def prepare_args(self, learnblock):
        for ml_model in self.ml_models:
            for complexity in range(2, self.settings.max_categories+1):
                ml_model.cluster = complexity
                yield deepcopy(learnblock), deepcopy(ml_model), complexity

    def prepare_background_process(self):
        if not self.__process or not self.__process.is_alive():
            self.__intput_queue = Queue()
            self.__output_queue = Queue()
            background = BackgroundConstructor(self.ml_models,
                                               self.logger,
                                               self.settings,
                                               self.__intput_queue,
                                               self.__output_queue)
            background.name = "BackgroundConstructor"
            background.daemon = True
            background.start()
            self.__process = background

    def construct_parallel(self, learnblock):
        self.prepare_background_process()
        self.__intput_queue.put(learnblock)
        yield self.__output_queue.get()
        while not self.__output_queue.empty():
            yield self.__output_queue.get()

    def _init_construction(self, mode):
        if mode == "conceptual":
            return partial(self._construct_conceptual_knowledge,
                           categorial_complexity=self.settings.max_categories,
                           min_category_size=self.settings.min_category_size)
        else:
            return partial(self._construct_procedural_knowledge,
                           procedural_complexity=self.settings.max_model_targets,
                           max_target_error=self.settings.max_target_error)

    #def construct_parallel(self, learnblock):
    #    pool = Pool(cpu_count()-3)
    #    results = pool.map(construct_conceptual_knowledge, self.prepare_args(learnblock))
    #    pool.close()
    #    pool.join()
    #    for block in results:
    #       if block:
    #           block.origin = learnblock.origin
    #           yield block

    def _construct_conceptual_knowledge(self,
                                        learnblock,
                                        categorial_complexity=None,
                                        min_category_size=None):
        if learnblock.labeled:
            yield learnblock
        else:
            for ml_model in self.ml_models:
                for cluster_number in range(2, categorial_complexity+1):
                    ###########################################################
                    self.logger.protocol("{:<20}".format(ml_model.subject))
                    self.logger.protocol(
                        "{:<20}: {}".format("# Cluster", str(cluster_number)))
                    self.logger.protocol(
                        "{:<20}: {}".format("MinCategorysize",
                                            self.settings.min_category_size))
                    ###########################################################

                    ml_model.cluster = cluster_number
                    trained_model = ml_model.train(learnblock)
                    for cluster, size in trained_model.cluster_sizes.items():

                        ########################################################
                        self.logger.protocol(
                            "{:<20}: {}".
                                format("Cluster Nr. {}".format(cluster), size))
                        ########################################################
                        if size < min_category_size:
                            ###################################################
                            self.logger.protocol("{:-^100}".format("Discarded"))
                            ###################################################
                            break
                    else:
                        labels = trained_model.get_labels()
                        labeled_learnblock = learnblock.set_labels(labels)
                        labeled_learnblock.n_cluster = cluster_number
                        purpose = "{}{:02}".format(ml_model.abbreviation,
                                                   cluster_number)
                        labeled_learnblock.purpose = purpose
                        yield labeled_learnblock
                        self.logger.protocol("{:*^100}".format("Construction"))

    def _construct_procedural_knowledge(self,
                                        learnblock,
                                        procedural_complexity=None,
                                        max_target_error=None):
        if learnblock.labeled:
            yield learnblock
        else:
            for ml_model in self.ml_models:
                for target_number in range(2, procedural_complexity+1):
                    model = ml_model()
                    model.io_shape = learnblock.learn_cols
                    model.target_number = target_number
                    trained_model = model.train(learnblock.as_numpy_array())
                    if trained_model.target_error < max_target_error:
                        for labels in trained_model.targets:
                            labeled_learnblock = learnblock.set_labels(list(labels))
                            labeled_learnblock.n_cluster = target_number
                            yield labeled_learnblock


# def construct_conceptual_knowledge(*args):
#     learnblock, model, min_category_size = args[0]
#     if learnblock.labeled:
#         return learnblock
#
#     else:
#         trained_model = model.train(learnblock)
#         for cluster, size in trained_model.cluster_sizes.items():
#             if size < min_category_size:
#                 return
#         else:
#             labels = trained_model.get_labels()
#             labeled_learnblock = learnblock.set_labels(labels)
#             labeled_learnblock.n_cluster = model.cluster
#             purpose = "{}{:02}".format(model.abbreviation, model.cluster)
#             labeled_learnblock.purpose = purpose
#             return labeled_learnblock