Commit de551148 authored by Janos Borst's avatar Janos Borst
Browse files

workday over

parent 93f50e26
Pipeline #36532 passed with stages
in 2 minutes and 19 seconds
......@@ -27,21 +27,25 @@ class MultiLabelDataset(Dataset):
It also inherits torch.utils.data.Dataset so to be able to lates use the Dataloader and iterate
"""
def __init__(self, x, y, classes, purpose="train", target_dtype=torch.LongTensor, **kwargs):
def __init__(self, x, y, classes, purpose="train", target_dtype=torch._cast_Float, one_hot=True, **kwargs):
self.__dict__.update(kwargs)
self.classes = classes
self.purpose = purpose
self.x = x
self.y = y
self.one_hot = one_hot
self.target_dtype = target_dtype
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
labels = [self.classes[tag] for tag in self.y[idx]]
labels = torch.nn.functional.one_hot(torch.LongTensor(labels), len(self.classes)).sum(0)
return {'text': self.x[idx], 'labels': self.target_dtype(labels)}
if self.one_hot:
labels = [self.classes[tag] for tag in self.y[idx]]
labels = torch.nn.functional.one_hot(torch.LongTensor(labels), len(self.classes)).sum(0)
return {'text': self.x[idx], 'labels': self.target_dtype(labels)}
else:
return {'text': self.x[idx], 'labels': self.classes[self.y[2][0]]}
def transform(self, fct):
self.x = [fct(sen) for sen in self.x]
......
......@@ -487,8 +487,8 @@ def load_agnews():
classes = dict(zip(classes, range(len(classes))))
rev_classes = {v: k for k, v in classes.items()}
data = {
"train": (traindata,[[rev_classes[x-1]] for x in trainlabel]),
"test": (testdata, [[rev_classes[x-1]] for x in testlabel]),
"train": (traintext,[[rev_classes[x-1]] for x in trainlabel]),
"test": (testtext, [[rev_classes[x-1]] for x in testlabel]),
"test_title": testtitle,
"test_description": testdescription,
"train_title": traintitle,
......
......@@ -61,9 +61,9 @@ class TextClassificationAbstract(torch.nn.Module):
x = self.transform(b["text"])
output = self(x.to(self.device)).cpu()
if hasattr(self, "regularize"):
l = self.loss(output, torch._cast_Float(y)) + self.regularize()
l = self.loss(output, y) + self.regularize()
else:
l = self.loss(output, torch._cast_Float(y))
l = self.loss(output, y)
output = torch.sigmoid(output)
# Subset evaluation if ...
......@@ -109,9 +109,9 @@ class TextClassificationAbstract(torch.nn.Module):
x = self.transform(b["text"]).to(self.device)
output = self(x)
if hasattr(self, "regularize"):
l = self.loss(output, torch._cast_Float(y)) + self.regularize()
l = self.loss(output, y) + self.regularize()
else:
l = self.loss(output, torch._cast_Float(y))
l = self.loss(output, y[:,0])
l.backward()
self.optimizer.step()
......
......@@ -136,4 +136,4 @@ class LanguageModelAbstract(torch.nn.Module):
def representations(self, s):
if not isinstance(s, list):
s = [s]
return self(self.transform([e.tokens for e in self.tokenizer.encode_batch(s)]).to(self.device), representations=True)
return self(self.transform(s).to(self.device), representations=True)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment