Skip to content
Snippets Groups Projects
Commit 0ac64a77 authored by dmt's avatar dmt
Browse files

Suppress warnings called by sklearn.

parent f3d6c66a
No related branches found
No related tags found
No related merge requests found
from collections import Counter from collections import Counter
from abc import abstractmethod from abc import abstractmethod
import warnings
from numpy import array, linspace, less, greater, std, argsort from numpy import array, linspace, less, greater, std, argsort
from scipy.signal import argrelextrema from scipy.signal import argrelextrema
...@@ -21,8 +22,14 @@ SCIKIT_CLUSTERING_TABLE = { ...@@ -21,8 +22,14 @@ SCIKIT_CLUSTERING_TABLE = {
} }
class MachineLearningModel: def log_warning(func):
def wrapper(*args, **kwargs):
warnings.filterwarnings("ignore")
return func(*args, **kwargs)
return wrapper
class MachineLearningModel:
@abstractmethod @abstractmethod
def train(self, data, *args, **kwargs): def train(self, data, *args, **kwargs):
pass pass
...@@ -32,6 +39,7 @@ class FilterMethod(MachineLearningModel): ...@@ -32,6 +39,7 @@ class FilterMethod(MachineLearningModel):
def __init__(self, model): def __init__(self, model):
self.__model = model self.__model = model
@log_warning
def train(self, data, *args, **kwargs): def train(self, data, *args, **kwargs):
self.__model = self.__model.fit(data) self.__model = self.__model.fit(data)
return self return self
...@@ -48,6 +56,7 @@ class EmbeddedMethod(MachineLearningModel): ...@@ -48,6 +56,7 @@ class EmbeddedMethod(MachineLearningModel):
def __init__(self, model): def __init__(self, model):
self.__model = model self.__model = model
@log_warning
def train(self, data, *args, **kwargs): def train(self, data, *args, **kwargs):
labels = data.get_column_values("Z") labels = data.get_column_values("Z")
self.__model = self.__model.fit(data, labels) self.__model = self.__model.fit(data, labels)
...@@ -91,6 +100,7 @@ class ConstructionClusteringMLModel(MachineLearningModel): ...@@ -91,6 +100,7 @@ class ConstructionClusteringMLModel(MachineLearningModel):
) )
return Counter(labels) return Counter(labels)
@log_warning
def train(self, data, *args, **kwargs): def train(self, data, *args, **kwargs):
self.__model.fit(data) self.__model.fit(data)
return self return self
...@@ -103,6 +113,7 @@ class ReconstructionConceptualMLModel(MachineLearningModel): ...@@ -103,6 +113,7 @@ class ReconstructionConceptualMLModel(MachineLearningModel):
self.accuracy = None self.accuracy = None
self.subject = model.__class__.__name__ self.subject = model.__class__.__name__
@log_warning
def train(self, data, *args, **kwargs): def train(self, data, *args, **kwargs):
# TODO (dmt): Improve signature of this function! # TODO (dmt): Improve signature of this function!
labels = args[0] labels = args[0]
...@@ -121,6 +132,7 @@ class ReconstructionProceduralMLModel(MachineLearningModel): ...@@ -121,6 +132,7 @@ class ReconstructionProceduralMLModel(MachineLearningModel):
self.max_error = None self.max_error = None
self.subject = model.__class__.__name__ self.subject = model.__class__.__name__
@log_warning
def train(self, data, *args, **kwargs): def train(self, data, *args, **kwargs):
# TODO (dmt): Provide a better way dealing with # TODO (dmt): Provide a better way dealing with
# zero values as max_abs_label! # zero values as max_abs_label!
...@@ -149,6 +161,7 @@ class KernelDensityEstimator(MachineLearningModel): ...@@ -149,6 +161,7 @@ class KernelDensityEstimator(MachineLearningModel):
self.bandwidth = bandwidth self.bandwidth = bandwidth
self.gridsize = gridsize self.gridsize = gridsize
@log_warning
def train(self, data, *args, **kwargs): def train(self, data, *args, **kwargs):
reshaped_data = array(data).reshape(-1, 1) reshaped_data = array(data).reshape(-1, 1)
if not self.__model: if not self.__model:
...@@ -173,6 +186,7 @@ class Autoencoder(MachineLearningModel): ...@@ -173,6 +186,7 @@ class Autoencoder(MachineLearningModel):
self.__model = None self.__model = None
self.__hidden_outputter = None self.__hidden_outputter = None
@log_warning
def train(self, data, *args, **kwargs): def train(self, data, *args, **kwargs):
inputer = Input(shape=(self.io_shape, )) inputer = Input(shape=(self.io_shape, ))
hidden = Dense(units=self.target_number, hidden = Dense(units=self.target_number,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment