Skip to content
Snippets Groups Projects
Commit 0ac64a77 authored by dmt's avatar dmt
Browse files

Suppress warnings called by sklearn.

parent f3d6c66a
No related branches found
No related tags found
No related merge requests found
from collections import Counter
from abc import abstractmethod
import warnings
from numpy import array, linspace, less, greater, std, argsort
from scipy.signal import argrelextrema
......@@ -21,8 +22,14 @@ SCIKIT_CLUSTERING_TABLE = {
}
class MachineLearningModel:
def log_warning(func):
def wrapper(*args, **kwargs):
warnings.filterwarnings("ignore")
return func(*args, **kwargs)
return wrapper
class MachineLearningModel:
@abstractmethod
def train(self, data, *args, **kwargs):
pass
......@@ -32,6 +39,7 @@ class FilterMethod(MachineLearningModel):
def __init__(self, model):
self.__model = model
@log_warning
def train(self, data, *args, **kwargs):
self.__model = self.__model.fit(data)
return self
......@@ -48,6 +56,7 @@ class EmbeddedMethod(MachineLearningModel):
def __init__(self, model):
self.__model = model
@log_warning
def train(self, data, *args, **kwargs):
labels = data.get_column_values("Z")
self.__model = self.__model.fit(data, labels)
......@@ -91,6 +100,7 @@ class ConstructionClusteringMLModel(MachineLearningModel):
)
return Counter(labels)
@log_warning
def train(self, data, *args, **kwargs):
self.__model.fit(data)
return self
......@@ -103,6 +113,7 @@ class ReconstructionConceptualMLModel(MachineLearningModel):
self.accuracy = None
self.subject = model.__class__.__name__
@log_warning
def train(self, data, *args, **kwargs):
# TODO (dmt): Improve signature of this function!
labels = args[0]
......@@ -121,6 +132,7 @@ class ReconstructionProceduralMLModel(MachineLearningModel):
self.max_error = None
self.subject = model.__class__.__name__
@log_warning
def train(self, data, *args, **kwargs):
# TODO (dmt): Provide a better way dealing with
# zero values as max_abs_label!
......@@ -149,6 +161,7 @@ class KernelDensityEstimator(MachineLearningModel):
self.bandwidth = bandwidth
self.gridsize = gridsize
@log_warning
def train(self, data, *args, **kwargs):
reshaped_data = array(data).reshape(-1, 1)
if not self.__model:
......@@ -173,6 +186,7 @@ class Autoencoder(MachineLearningModel):
self.__model = None
self.__hidden_outputter = None
@log_warning
def train(self, data, *args, **kwargs):
inputer = Input(shape=(self.io_shape, ))
hidden = Dense(units=self.target_number,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment