diff --git a/cml/domain/data_source.py b/cml/domain/data_source.py index 966c944677fc9b543b25887db38b7ff52eba53ae..68cf58a8b02b3bfc7d3c16e3617f588ac5dad30c 100644 --- a/cml/domain/data_source.py +++ b/cml/domain/data_source.py @@ -2,6 +2,8 @@ from os.path import commonprefix from itertools import cycle from array import array +from cml.shared.parameter import PROTOCOL_LEVEL + __all__ = ( "DataSource", "Preprocessor", @@ -9,22 +11,62 @@ __all__ = ( ) +def log_learnblock_processing(func): + def wrapper(self): + for learnblock in func(self): + if self.logger and self.logger.level == PROTOCOL_LEVEL: + self.logger.protocol( + "\t".join([ + "", + str(learnblock.length), + str(learnblock.min_timestamp), + str(learnblock.max_timestamp), + str(learnblock.relatives), + ]) + ) + yield learnblock + return wrapper + + +def log_block_processing(func): + counter = 0 + + def wrapper(self): + nonlocal counter + counter += 1 + block = func(self) + if self.logger and self.logger.level == PROTOCOL_LEVEL: + self.logger.protocol( + "\t".join([ + str(counter), + str(block.length), + str(block.min_timestamp), + str(block.max_timestamp) + ]) + ) + return block + return wrapper + + class DataSource: _SIGNED_CHAR = 'b' def __init__(self, source, learnblock_identifier, *, block_size): self.block_size = block_size + self.logger = None self.__source = source self.__learnblock_identifier = learnblock_identifier self.__source_halde_flags = None @property + @log_learnblock_processing def learnblocks(self): for block in self: learnblock = self.__learnblock_identifier.identify(block) self._flip_source_halde_flags(learnblock.indexes) yield learnblock + @log_block_processing def __next__(self): return next(self.__next_helper) @@ -44,7 +86,6 @@ class DataSource: elif old_index > i: halde_runs += 1 old_index = i - counter = 0 yield self.__source.get_block_via_index(block_indexes) block_indexes.clear()