diff --git a/cml/domain/data_source.py b/cml/domain/data_source.py index 460c36aca887a3f8ab423f1a3c1284f6c486e821..aa95792a94821cd6cf63cf309d8b4177520dd9af 100644 --- a/cml/domain/data_source.py +++ b/cml/domain/data_source.py @@ -1,5 +1,6 @@ from os.path import commonprefix - +from itertools import cycle +from array import array __all__ = ( "DataSource", @@ -9,28 +10,58 @@ __all__ = ( class DataSource: + _SIGNED_CHAR = 'b' + def __init__(self, source, learnblock_identifier, *, block_size): self.block_size = block_size self.__source = source self.__learnblock_identifier = learnblock_identifier + self.__source_halde_flags = None @property def learnblocks(self): for block in self: - for learnblock in self.__learnblock_identifier.identify(block): - yield learnblock - - def __getitem__(self, item): - start_index = item*self.block_size - stop_index = (item+1)*self.block_size - - if stop_index > self.__source.length: - stop_index = None - - if start_index > self.__source.length: - raise IndexError() - - return self.__source.get_block(start_index, stop_index) + learnblock = self.__learnblock_identifier.identify(block) + self._flip_source_halde_flags(learnblock.indexes) + yield learnblock + + def __next__(self): + return next(self.__next_helper) + + def _next_helper(self): + block_indexes = [] + counter = 0 + old_index = 0 + halde_runs = -1 + + for i in cycle(range(0, len(self))): + if counter == self.block_size: + old_index = i + counter = 0 + yield self.__source.get_block_via_index(block_indexes) + block_indexes.clear() + + elif old_index > i: + halde_runs += 1 + old_index = i + + counter = 0 + yield self.__source.get_block_via_index(block_indexes) + block_indexes.clear() + + if self.__source_halde_flags[i] == 0: + block_indexes.append(i) + counter += 1 + + def __iter__(self): + self.__next_helper = iter(self._next_helper()) + self.__source_halde_flags = array( + self._SIGNED_CHAR, [0 for _ in range(len(self))]) + return self + + def _flip_source_halde_flags(self, learnblock_indexes): + for i in learnblock_indexes: + self.__source_halde_flags[i] = 1 def __len__(self): return self.__source.length