from cml.shared.parameter import PROTOCOL_LEVEL
from cml.shared.errors import NotReducableLearnblockWarning


__all__ = (
    "FeatureSelector",
)


def log_selection(func):
    def wrapper(self, learnblock):
        reduced_learnblock = func(self, learnblock)
        if self.logger and self.logger.level == PROTOCOL_LEVEL:
            self.logger.protocol(
                "\t".join([
                    "",
                    "[{:^5}{:^5}]".format(learnblock.rows, learnblock.cols),
                    str(learnblock.min_timestamp),
                    str(learnblock.max_timestamp),
                    "{:<20}".format(str(learnblock.origin)),
                    "{:<20}".format(str(learnblock.relatives)),
                    "{:<20}".format(str(learnblock.purpose))
                ])
            )
        return reduced_learnblock
    return wrapper


class FeatureSelector:

    def __init__(self, filter_method, embedded_method, settings):
        self.filter_method = filter_method
        self.embedded_method = embedded_method
        self.settings = settings
        self.logger = None

    @log_selection
    def select(self, learnblock):
        self.logger.protocol("{:*^100}".format("Feature Selection"))
        self.logger.protocol(str(self.settings))
        while 1:
            if self._to_many_features(learnblock.learn_cols) or \
                    self.settings.max_model_reduction:
                past_feature_number = learnblock.learn_cols

                if self._filter_method_criteria(learnblock.learn_cols,
                                                learnblock.rows):
                    remove_features = self.filtering(learnblock)
                    learnblock.drop_columns_by_index(remove_features)
                    ###########################################################
                    self.logger.protocol("{:<20}".format("Filter method"))
                    ###########################################################

                else:
                    ###########################################################
                    self.logger.protocol("{:<20}".format("Embedded method"))
                    ###########################################################
                    remove_features = self.embedding(learnblock)
                    learnblock.drop_columns_by_index(remove_features)

                ###########################################################
                self.logger.protocol(
                    "{:<20}: {}".format("Removed", len(remove_features)))
                ###########################################################

                if past_feature_number == learnblock.learn_cols and not \
                        self._to_many_features(learnblock.learn_cols):
                    return learnblock

                elif past_feature_number == learnblock.learn_cols \
                        and self._to_many_features(learnblock.learn_cols):
                    raise NotReducableLearnblockWarning()
                else:
                    continue

            else:
                return learnblock

    def filtering(self, learnblock):
        trained_model = self.filter_method.train(learnblock)
        return trained_model.reduce(learnblock)

    def embedding(self, learnblock):
        trained_model = self.embedded_method.train(learnblock)
        return trained_model.reduce()

    def _to_many_features(self, learn_cols):
        return learn_cols > self.settings.max_features

    def _filter_method_criteria(self, learn_cols, samples_count):
        return learn_cols > self.settings.max_filter_x or \
                samples_count > self.settings.max_filter_y