From 0ac64a7719e5d03aa20473900e7255b89926515c Mon Sep 17 00:00:00 2001 From: dmt <> Date: Mon, 28 Oct 2019 21:37:18 +0100 Subject: [PATCH] Suppress warnings called by sklearn. --- cml/ports/ml_adapter.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/cml/ports/ml_adapter.py b/cml/ports/ml_adapter.py index fcd5212..2f9f4dc 100644 --- a/cml/ports/ml_adapter.py +++ b/cml/ports/ml_adapter.py @@ -1,5 +1,6 @@ from collections import Counter from abc import abstractmethod +import warnings from numpy import array, linspace, less, greater, std, argsort from scipy.signal import argrelextrema @@ -21,8 +22,14 @@ SCIKIT_CLUSTERING_TABLE = { } -class MachineLearningModel: +def log_warning(func): + def wrapper(*args, **kwargs): + warnings.filterwarnings("ignore") + return func(*args, **kwargs) + return wrapper + +class MachineLearningModel: @abstractmethod def train(self, data, *args, **kwargs): pass @@ -32,6 +39,7 @@ class FilterMethod(MachineLearningModel): def __init__(self, model): self.__model = model + @log_warning def train(self, data, *args, **kwargs): self.__model = self.__model.fit(data) return self @@ -48,6 +56,7 @@ class EmbeddedMethod(MachineLearningModel): def __init__(self, model): self.__model = model + @log_warning def train(self, data, *args, **kwargs): labels = data.get_column_values("Z") self.__model = self.__model.fit(data, labels) @@ -91,6 +100,7 @@ class ConstructionClusteringMLModel(MachineLearningModel): ) return Counter(labels) + @log_warning def train(self, data, *args, **kwargs): self.__model.fit(data) return self @@ -103,6 +113,7 @@ class ReconstructionConceptualMLModel(MachineLearningModel): self.accuracy = None self.subject = model.__class__.__name__ + @log_warning def train(self, data, *args, **kwargs): # TODO (dmt): Improve signature of this function! labels = args[0] @@ -121,6 +132,7 @@ class ReconstructionProceduralMLModel(MachineLearningModel): self.max_error = None self.subject = model.__class__.__name__ + @log_warning def train(self, data, *args, **kwargs): # TODO (dmt): Provide a better way dealing with # zero values as max_abs_label! @@ -149,6 +161,7 @@ class KernelDensityEstimator(MachineLearningModel): self.bandwidth = bandwidth self.gridsize = gridsize + @log_warning def train(self, data, *args, **kwargs): reshaped_data = array(data).reshape(-1, 1) if not self.__model: @@ -173,6 +186,7 @@ class Autoencoder(MachineLearningModel): self.__model = None self.__hidden_outputter = None + @log_warning def train(self, data, *args, **kwargs): inputer = Input(shape=(self.io_shape, )) hidden = Dense(units=self.target_number, -- GitLab