From acbc5fb12e244c52bae63d6d7abf2638d9db5885 Mon Sep 17 00:00:00 2001 From: dmt <> Date: Thu, 10 Oct 2019 17:40:56 +0200 Subject: [PATCH] Handle Sigma Z relatives in learn block creation. --- cml/domain/data_source.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/cml/domain/data_source.py b/cml/domain/data_source.py index aa95792..966c944 100644 --- a/cml/domain/data_source.py +++ b/cml/domain/data_source.py @@ -110,15 +110,24 @@ class Preprocessor: class LearnblockIdentifier: - def __init__(self, settings): + def __init__(self, settings, density_estimator, relative_extrema): self.settings = settings self.column_pairs = (("T", "Z"), ("T", "Sigma"), ("Sigma", "Z")) + self.density_estimator = density_estimator + self._relative_extrema = relative_extrema def identify(self, block): + biggest_learn_block = None + biggest_block_size = 0 + for pair in self.column_pairs: for possible_learnblock in self._identify_relatives(block, *pair): if self._is_learn_block(possible_learnblock.length): - yield possible_learnblock + if possible_learnblock.length > biggest_block_size: + biggest_learn_block = possible_learnblock + biggest_learn_block.relatives = pair + + return biggest_learn_block def _is_learn_block(self, block_length): return block_length > self.settings.learn_block_minimum @@ -130,7 +139,11 @@ class LearnblockIdentifier: if value_pair not in already_seen: already_seen.add(value_pair) kw = {args[0]: value_pair[0], args[1]: value_pair[1]} - yield block.get_values(**kw) + if args[0] == "Sigma" and args[1] == "Z": + for block in self._get_sigma_zeta_relatives(block, **kw): + yield block + else: + yield block.get_values(**kw) def _get_sigma_zeta_relatives(self, block, **kw): relatives = block.get_values(**kw) -- GitLab