Commit b4a03531 authored by Janos Borst's avatar Janos Borst
Browse files

Possibility for entailment pretraining

parent 97150af8
......@@ -570,5 +570,18 @@ def is_multilabel(x):
"""
return type(x) in (MultiLabelDataset, MultiOutputMultiLabelDataset)
class EntailmentDataset(Dataset):
def __init__(self, x1, x2, labels, classes):
self.x1 = x1
self.x2 = x2
self.labels = labels
self.classes = classes
def __len__(self):
return len(self.x1)
def __getitem__(self, item):
return {"x1": self.x1[item], "x2": self.x2[item], "labels": self.classes[self.labels[item]]}
## Sampler import
from .sampler import sampler, successive_sampler, class_sampler, validation_split
import torch
from tqdm import tqdm
from abc import abstractmethod
from ...data import MultiLabelDataset, SingleLabelDataset
from ...data import MultiLabelDataset, SingleLabelDataset, EntailmentDataset
from copy import deepcopy
from tqdm import tqdm
try:
from apex import amp
......@@ -11,7 +12,6 @@ except:
pass
from ...data import is_multilabel
class TextClassificationAbstractZeroShot(torch.nn.Module):
"""
Abstract class for Multilabel Models. Defines fit, evaluate, predict and threshold methods for virtually any
......@@ -203,4 +203,56 @@ class TextClassificationAbstractZeroShot(torch.nn.Module):
self.set_threshold("mcut")
self.activation = torch.sigmoid
self.loss = torch.nn.BCEWithLogitsLoss()
self.build()
\ No newline at end of file
self.build()
def _entail_forward(self, x1, x2):
self.create_labels(x2)
return self.forward(x1)
def entailment_pretrain(self, data, valid = None, epochs=10, batch_size=16):
train_history = {"loss": []}
for e in range(epochs):
# An epoch
losses = {"loss": str(0.)}
dl = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)
with tqdm(dl,
postfix=[losses], desc="Epoch %i/%i" % (e + 1, epochs), ncols=100) as pbar:
from ignite.metrics import Average
average = Average()
for b in dl:
self.zero_grad()
scores = self._entail_forward(self.transform(b["x1"]).to(self.device),
self.transform(b["x2"]).to(self.device))
l = self.loss(scores, b["labels"].to(self.device).float())
l.backward()
self.optimizer.step()
average.update(l.detach().item())
pbar.postfix[0]["loss"] = round(average.compute().item(), 8)
pbar.update()
if valid is not None:
validation_result = self.entailment_eval(valid,batch_size=batch_size*2)
pbar.postfix[0]["valid_loss"] = round(validation_result["loss"], 8)
pbar.postfix[0]["valid_accuracy"] = round(validation_result["accuracy"], 8)
pbar.update()
train_history["loss"].append(average.compute().item())
return {"train": train_history, "valid": validation_result}
def entailment_eval(self, data, batch_size=16):
self.eval()
dl = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)
from ignite.metrics import Average
average = Average()
accuracy = Average()
for b in dl:
with torch.no_grad():
scores = self._entail_forward(self.transform(b["x1"]).to(self.device),
self.transform(b["x2"]).to(self.device))
l = self.loss(scores, b["labels"].to(self.device).float())
for i in (b["labels"].to(self.device)==(scores>0.5)): accuracy.update(i.item())
average.update(l.detach().item())
return { "loss": average.compute().item(), "accuracy": accuracy.compute().item()}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment