From acbc5fb12e244c52bae63d6d7abf2638d9db5885 Mon Sep 17 00:00:00 2001
From: dmt <>
Date: Thu, 10 Oct 2019 17:40:56 +0200
Subject: [PATCH] Handle Sigma Z relatives in learn block creation.

---
 cml/domain/data_source.py | 19 ++++++++++++++++---
 1 file changed, 16 insertions(+), 3 deletions(-)

diff --git a/cml/domain/data_source.py b/cml/domain/data_source.py
index aa95792..966c944 100644
--- a/cml/domain/data_source.py
+++ b/cml/domain/data_source.py
@@ -110,15 +110,24 @@ class Preprocessor:
 
 
 class LearnblockIdentifier:
-    def __init__(self, settings):
+    def __init__(self, settings, density_estimator, relative_extrema):
         self.settings = settings
         self.column_pairs = (("T", "Z"), ("T", "Sigma"), ("Sigma", "Z"))
+        self.density_estimator = density_estimator
+        self._relative_extrema = relative_extrema
 
     def identify(self, block):
+        biggest_learn_block = None
+        biggest_block_size = 0
+
         for pair in self.column_pairs:
             for possible_learnblock in self._identify_relatives(block, *pair):
                 if self._is_learn_block(possible_learnblock.length):
-                    yield possible_learnblock
+                    if possible_learnblock.length > biggest_block_size:
+                        biggest_learn_block = possible_learnblock
+                        biggest_learn_block.relatives = pair
+
+        return biggest_learn_block
 
     def _is_learn_block(self, block_length):
         return block_length > self.settings.learn_block_minimum
@@ -130,7 +139,11 @@ class LearnblockIdentifier:
             if value_pair not in already_seen:
                 already_seen.add(value_pair)
                 kw = {args[0]: value_pair[0], args[1]: value_pair[1]}
-                yield block.get_values(**kw)
+                if args[0] == "Sigma" and args[1] == "Z":
+                    for block in self._get_sigma_zeta_relatives(block, **kw):
+                        yield block
+                else:
+                    yield block.get_values(**kw)
 
     def _get_sigma_zeta_relatives(self, block, **kw):
         relatives = block.get_values(**kw)
-- 
GitLab