diff --git a/cml/controller/api.py b/cml/controller/api.py index 0e940d3c8caad314d46ec25352fc400e02ff3785..6fcd5a367dcc9beec3e706e5c1708c8696c318f6 100644 --- a/cml/controller/api.py +++ b/cml/controller/api.py @@ -28,7 +28,9 @@ from cml.shared.request import ( from cml.ports.ml_adapter import ( KernelDensityEstimator, find_relative_extrema, - Autoencoder + Autoencoder, + ReconstructionConceptualMLModel, + ReconstructionProceduralMLModel ) @@ -119,9 +121,19 @@ def feature_selection(filter_ml_model, embedded_ml_model): return feature_selection_usecase.execute(feature_selection_req) -def reconstruction(*args, **kwargs): +def reconstruction(reconstruction_type, *args, **kwargs): + # TODO (dmt): Create ReconstructionModels depending on the sklearn + # models. + if reconstruction_type == "conceptual": + ml_models = [ReconstructionConceptualMLModel(model) for model in args] + elif reconstruction_type == "procedural": + ml_models = [ReconstructionProceduralMLModel(model) for model in args] + else: + raise ValueError("reconstruction_type is wrong.") settings = specific_settings_factory("reconstruction") - reconstruction_req = ReconstructionRequest(settings) + reconstruction_req = ReconstructionRequest(settings, + ml_models, + reconstruction_type) reconstruction_usecase = ReconstructionUsecase() return reconstruction_usecase.execute(reconstruction_req)