from os import mkdir
from os.path import expanduser, join
from datetime import datetime
from logging import getLogger, FileHandler, StreamHandler

from cml.shared.errors import (
    NoModelReconstructedError,
    NotReducableLearnblockWarning,
    ModiError
)
from cml.shared.parameter import PROTOCOL_LEVEL
from cml.usecases.usecase import Usecase
from cml.domain.knowledge import TS_QUEUE


class KnowledgeSearchUsecase(Usecase):

    def __init__(self):
        super().__init__()

    def process(self, request):
        if not self.valid_modi(request): raise ModiError()
        source, constructor, selector, reconstructor, deconstructor = self.init_loggers(request)
        self.init_components(reconstructor, deconstructor)
        self.run(source, constructor, selector, reconstructor, deconstructor,
                 request.parallel)
        return deconstructor.knowledge_database

    def run(self, source, constructor, selector, reconstructor, deconstructor,
            parallel: bool):
        construct = constructor.construct_parallel if parallel else constructor.construct
        reconstruct = reconstructor.reconstruct_parallel if parallel else reconstructor.reconstruct
        tier = 1

        deconstructor.logger.protocol("{:#^100}".format("TIER 1"))
        for learnblock in source.learnblocks:
            for complete_learnblock in construct(learnblock):
                try:
                    reduced = selector.select(complete_learnblock)
                    pragmatic = reconstruct(tier, reduced)
                    deconstructor.deconstruct(tier, pragmatic, reduced)

                except NotReducableLearnblockWarning as error:
                    continue

                except NoModelReconstructedError as error:
                    continue

        while 1:
            try:
                new_tier, learnblock = TS_QUEUE.pop()
                if new_tier > tier:
                    banner = "TIER {}".format(new_tier)
                    deconstructor.logger.protocol("{:#^100}".format(banner))
                    tier = new_tier

                for complete_learnblock in construct(learnblock):
                    try:
                        pragmatic = reconstruct(tier, complete_learnblock)
                        deconstructor.deconstruct(tier, pragmatic, complete_learnblock)

                    except NotReducableLearnblockWarning as error:
                        continue

                    except NoModelReconstructedError as error:
                        continue

            except IndexError:  # deque is empty
                break

    def valid_modi(self, request):
        return request.constructor.mode == request.reconstructor.mode == \
               request.deconstructor.mode

    def init_loggers(self, request):
        source = request.deconstructor.source
        log_dir = create_log_dir()
        iteration_log_file = join(log_dir, "iteration.log")
        iteration_handler = FileHandler(iteration_log_file)
        iteration_handler.level = PROTOCOL_LEVEL
        iteration_logger = getLogger("iterationLogger")
        iteration_logger.addHandler(iteration_handler)

        source.logger = iteration_logger
        request.constructor.logger = iteration_logger
        request.feature_selector.logger = iteration_logger
        request.reconstructor.logger = iteration_logger
        request.deconstructor.logger = iteration_logger
        request.deconstructor.knowledge_database.inject_logger(iteration_logger)

        if request.stdout:
            iteration_stdout_handler = StreamHandler()
            iteration_stdout_handler.level = PROTOCOL_LEVEL
            iteration_logger.addHandler(iteration_stdout_handler)

        return (source,
                request.constructor,
                request.feature_selector,
                request.reconstructor,
                request.deconstructor)

    def init_components(self, reconstructor, deconstructor):
        reconstructor.next_model_counter = deconstructor.next_model_counter
        deconstructor.reconstructor = reconstructor


def create_log_dir():
    log_dir = join(expanduser('~'),
                   ".cml",
                   datetime.now().strftime("%d%m%Y%H%M%S"))
    mkdir(log_dir)
    return log_dir