Commit be73f345 authored by Janos Borst's avatar Janos Borst
Browse files

implementing ensembles

parent e2919047
Pipeline #51717 passed with stage
in 9 minutes and 54 seconds
from .descision_criteria import ConfidenceDecision, MajorityDecision, EntropyDecision
from .ensemble import Ensemble
from .ensemble_classwise import BinaryEnsemble
\ No newline at end of file
......@@ -93,27 +93,39 @@ class Ensemble:
[m.train() for m in self.m]
return [x[i] for x,i in zip(zip(*[s[0] for s in scores]),idx)], s, p
def predict_ensemble_batch(self, *args, **kwargs):
[m.eval() for m in self.m] # set mode to evaluation to disable dropout
kwargs["return_scores"]=True
scores=[m.predict_batch(*args,**kwargs) for m in self.m]
idx = self.vote([s[1] for s in scores])
s = torch.stack([x[i] for x,i in zip(zip(*[s[1] for s in scores]), idx.tolist())],0)
p = torch.stack([x[i] for x,i in zip(zip(*[s[2] for s in scores]), idx.tolist())],0)
[m.train() for m in self.m]
return [x[i] for x,i in zip(zip(*[s[0] for s in scores]),idx)], s, p
def single(self, *args, **kwargs):
[m.single(*args,**kwargs) for m in self.m]
def multi(self, *args, **kwargs):
[m.multi(*args,**kwargs) for m in self.m]
def entailment(self, *args, **kwargs):
[m.entailment(*args,**kwargs) for m in self.m]
r = "google/bert_uncased_L-4_H-256_A-4"
from mlmc_lab import mlmc_experimental as mlmce
m = [mlmc.models.EmbeddingBasedWeighted(mode="vanilla", representation=r, sformatter=mlmc.data.SFORMATTER["agnews"], finetune=True, classes=d["classes"], target="single", loss=mlmce.loss.EncourageLoss(0.75), device=device),
mlmc.models.EmbeddingBasedWeighted(mode="max", representation=r, sformatter=mlmc.data.SFORMATTER["agnews"], finetune=True, classes=d["classes"], target="single",loss=mlmce.loss.EncourageLoss(0.75), device=device),
mlmc.models.EmbeddingBasedWeighted(mode="mean", representation=r, sformatter=mlmc.data.SFORMATTER["agnews"], finetune=True, classes=d["classes"], target="single",loss=mlmce.loss.EncourageLoss(0.75),device=device),
mlmc.models.EmbeddingBasedWeighted(mode="max_mean",representation=r, sformatter=mlmc.data.SFORMATTER["agnews"], finetune=True, classes=d["classes"], target="single", loss=mlmce.loss.EncourageLoss(0.75),device=device),
mlmc.models.EmbeddingBasedWeighted(mode="attention_max_mean", representation=r, sformatter=mlmc.data.SFORMATTER["agnews"], finetune=True, classes=d["classes"], target="single", loss=mlmce.loss.EncourageLoss(0.75),device=device)]
m = m+[mlmc.models.SimpleEncoder(representation="roberta-large-mnli", sformatter=mlmc.data.SFORMATTER["agnews"], finetune=True, classes=d["classes"], target="single", device=device)]
e = Ensemble(m)
e.t[-1] = False
e.fit(mlmc.data.sampler(d["train"], absolute=100), epochs=50)
test=mlmc.data.sampler(d["test"], absolute=1000)
print(e.evaluate(test))
e.evaluate_ensemble(test)
#
# r = "google/bert_uncased_L-4_H-256_A-4"
# from mlmc_lab import mlmc_experimental as mlmce
# m = [mlmc.models.EmbeddingBasedWeighted(mode="vanilla", representation=r, sformatter=mlmc.data.SFORMATTER["agnews"], finetune=True, classes=d["classes"], target="single", loss=mlmce.loss.EncourageLoss(0.75), device=device),
# mlmc.models.EmbeddingBasedWeighted(mode="max", representation=r, sformatter=mlmc.data.SFORMATTER["agnews"], finetune=True, classes=d["classes"], target="single",loss=mlmce.loss.EncourageLoss(0.75), device=device),
# mlmc.models.EmbeddingBasedWeighted(mode="mean", representation=r, sformatter=mlmc.data.SFORMATTER["agnews"], finetune=True, classes=d["classes"], target="single",loss=mlmce.loss.EncourageLoss(0.75),device=device),
# mlmc.models.EmbeddingBasedWeighted(mode="max_mean",representation=r, sformatter=mlmc.data.SFORMATTER["agnews"], finetune=True, classes=d["classes"], target="single", loss=mlmce.loss.EncourageLoss(0.75),device=device),
# mlmc.models.EmbeddingBasedWeighted(mode="attention_max_mean", representation=r, sformatter=mlmc.data.SFORMATTER["agnews"], finetune=True, classes=d["classes"], target="single", loss=mlmce.loss.EncourageLoss(0.75),device=device)]
# # m = m+[mlmc.models.SimpleEncoder(representation="roberta-large-mnli", sformatter=mlmc.data.SFORMATTER["agnews"], finetune=True, classes=d["classes"], target="single", device=device)]
#
# e = Ensemble(m)
# # e.t[-1] = False
# # e.fit(mlmc.data.sampler(d["train"], absolute=100), epochs=50)
# test=mlmc.data.sampler(d["test"], absolute=1000)
#
# print(e.evaluate(test))
# e.predict_ensemble_batch(test)
#
......@@ -181,18 +181,18 @@ class BinaryEnsemble:
[m.multi(*args,**kwargs) for m in self.m]
def entailment(self, *args, **kwargs):
[m.entailment(*args,**kwargs) for m in self.m]
def create_model():
return mlmc.models.EmbeddingBasedWeighted(mode="max",
sformatter=mlmc.data.SFORMATTER["agnews"],
finetune=True, classes={}, target="multi", device=device)
zeromodel=[mlmc.models.SimpleEncoder(representation="roberta-large-mnli", sformatter=mlmc.data.SFORMATTER["agnews"],
finetune=True, classes=d["classes"], target="single", device=device)]
from mlmc_lab import mlmc_experimental as mlmce
e = BinaryEnsemble(create_model, classes=d["classes"], loss=mlmce.loss.EncourageLoss(0.75), zero=None)
e.fit(mlmc.data.sampler(d["train"], absolute=100), epochs=50)
e.evaluate(mlmc.data.sampler(d["test"], absolute=1000))
#
# def create_model():
# return mlmc.models.EmbeddingBasedWeighted(mode="max",
# sformatter=mlmc.data.SFORMATTER["agnews"],
# finetune=True, classes={}, target="multi", device=device)
#
#
# zeromodel=[mlmc.models.SimpleEncoder(representation="roberta-large-mnli", sformatter=mlmc.data.SFORMATTER["agnews"],
# finetune=True, classes=d["classes"], target="single", device=device)]
# from mlmc_lab import mlmc_experimental as mlmce
# e = BinaryEnsemble(create_model, classes=d["classes"], loss=mlmce.loss.EncourageLoss(0.75), zero=None)
# e.fit(mlmc.data.sampler(d["train"], absolute=100), epochs=50)
# e.evaluate(mlmc.data.sampler(d["test"], absolute=1000))
#
#
......@@ -313,7 +313,7 @@ class TextClassificationAbstract(torch.nn.Module):
def fit(self, train,
valid=None, epochs=1, batch_size=16, valid_batch_size=50, patience=-1, tolerance=1e-2,
return_roc=False, return_report=False, callbacks=None, metrics=None, lr_schedule=None, lr_param={}, log_mlflow=False):
return_roc=False, return_report=False, callbacks=None, metrics=None, lr_schedule=None, lr_param={}, log_mlflow=False, valid_prefix="valid"):
"""
Training function
......@@ -378,14 +378,14 @@ class TextClassificationAbstract(torch.nn.Module):
if log_mlflow:
import mlflow
mlflow.log_metric("valid_loss" ,valid_loss, step=e)
result_metrics.log_mlflow(step=e, prefix="valid")
mlflow.log_metric(f"{valid_prefix}_loss" ,valid_loss, step=e)
result_metrics.log_mlflow(step=e, prefix=valid_prefix)
valid_loss_dict = {"valid_loss": valid_loss}
valid_loss_dict = {f"{valid_prefix}_loss": valid_loss}
valid_loss_dict.update(result_metrics.compute())
self.validation.append(valid_loss_dict)
printables = {"valid_loss": valid_loss}
printables = {f"{valid_prefix}_loss": valid_loss}
printables.update(result_metrics.print())
pbar.postfix[0].update(printables)
pbar.update()
......@@ -515,12 +515,15 @@ class TextClassificationAbstract(torch.nn.Module):
self.classes_rev = {v: k for k, v in self.classes.items()}
for b in tqdm(train_loader, ncols=100):
predictions.extend(self.predict(b["text"], return_scores=return_scores))
labels = sum([predictions[x] for x in list(range(0, len(predictions), 3))],[])
scores = torch.cat([predictions[x] for x in list(range(1, len(predictions) + 1, 3))], dim=0)
bools = torch.cat([predictions[x] for x in list(range(2, len(predictions), 3))], dim=0)
del self.classes_rev
return labels, scores, bools
if return_scores:
labels = sum([predictions[x] for x in list(range(0, len(predictions), 3))],[])
scores = torch.cat([predictions[x] for x in list(range(1, len(predictions) + 1, 3))], dim=0)
bools = torch.cat([predictions[x] for x in list(range(2, len(predictions), 3))], dim=0)
return labels, scores, bools
else:
labels = sum([predictions[x] for x in list(range(0, len(predictions)))], [])
return labels
def run(self, x):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment