Commit e89805f1 authored by Janos Borst's avatar Janos Borst
Browse files

BertConcepts

parent 17272b86
......@@ -83,40 +83,59 @@ class BertAsConcept(TextClassificationAbstract):
print("Labels:\t", label)
print("Concepts:\t", concepts)
# def fit(self, train, valid = None, epochs=1, batch_size=16, valid_batch_size=50, classes_subset=None):
# validation=[]
# train_history = {"loss": []}
# for e in range(epochs):
# losses = {"loss": str(0.)}
# average = Average()
# train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)
#
# with tqdm(train_loader,
# postfix=[losses], desc="Epoch %i/%i" %(e+1,epochs)) as pbar:
# for i, b in enumerate(train_loader):
# self.optimizer.zero_grad()
# y = b["labels"].to(self.device)
# y[y!=0] = 1
# x = self.transform(b["text"]).to(self.device)
# output, scores = self(x)
# if hasattr(self, "regularize"):
# l = self.loss(output, torch._cast_Float(y)) + self.regularize() + 0.3*self.loss(scores,self.label_concept_onehot[torch.where(y==1)[1]][:,None,:].repeat([1,self.max_len,1]).to(self.device))
# else:
# l = self.loss(output, torch._cast_Float(y)) + 0.3* self.loss(scores,self.label_concept_onehot[torch.where(y==1)[1]][:,None,:].repeat([1,self.max_len,1]).to(self.device))
# l.backward()
# self.optimizer.step()
# average.update(l.item())
# pbar.postfix[0]["loss"] = round(average.compute().item(),2*self.PRECISION_DIGITS)
# pbar.update()
# # torch.cuda.empty_cache()
# if valid is not None:
# validation.append(self.evaluate_classes(classes_subset=classes_subset,
# data=valid,
# batch_size=valid_batch_size,
# return_report=False,
# return_roc=False))
# pbar.postfix[0].update(validation[-1])
# pbar.update()
# # torch.cuda.empty_cache()
# train_history["loss"].append(average.compute().item())
# return{"train":train_history, "valid": validation }
\ No newline at end of file
class BertAsConcept2(TextClassificationAbstract):
"""
https://raw.githubusercontent.com/EMNLP2019LSAN/LSAN/master/attention/model.py
"""
def __init__(self, classes, representation="roberta", label_freeze=True, max_len=300, **kwargs):
super(BertAsConcept2, self).__init__(**kwargs)
# My Stuff
assert is_transformer(representation), "This model only works with transformers"
self.classes = classes
self.max_len = max_len
self.n_layers = 2
self.representation = representation
self._init_input_representations()
# Original
self.n_classes = len(classes)
self.label_freeze = label_freeze
self.labels = torch.nn.Parameter(self.embedding(self.transform(self.classes.keys()))[1])
self.labels.requires_grad = False
self.label_embedding_dim = self.labels.shape[-1]
self.input_projection2 = torch.nn.Linear(self.label_embedding_dim, self.embedding_dim)
# self.metric = Bilinear(self.embedding_dim).to(self.device)
# self.output_projection = torch.nn.Linear(in_features=self.max_len * self.n_classes, out_features=self.n_classes)
self.build()
def forward(self, x, return_scores=False):
with torch.no_grad():
embeddings = torch.cat(self.embedding(x)[2][(-1 - self.n_layers):-1], -1)
p2 = self.input_projection2(self.labels)
output = torch.matmul(embeddings,p2.t()).sum(-2)
# output = self.metric(embeddings,p2).sum(-2)
if return_scores:
return output, metric_scores
return output
def additional_concepts(self, x, k=5):
self.eval()
if not hasattr(self, "classes_rev"):
self.classes_rev = {v: k for k, v in self.classes.items()}
label_vocabulary_rev = {v: k for k, v in self.label_vocabulary.items()}
prediction, scores = self(self.transform(x).to(self.device), return_scores=True)
label = [self.classes_rev[x.item()] for x in torch.where(self.threshold(prediction, 0.5, "hard") == 1)[1]]
tk = (scores.sum(-2)).topk(k)[1][0]
concepts = [label_vocabulary_rev[x.item()] for x in tk]
print("Labels:\t", label)
print("Concepts:\t", concepts)
......@@ -21,5 +21,5 @@ except:
from .ConceptScores import ConceptScores, ConceptScoresCNN,ConceptScoresCNNAttention,KimCNN2Branch,ConceptProjection,ConceptScoresAttention, ConceptScoresRelevance, ConceptScoresRelevanceWithImportanceWeights
from .ConceptLSAN import ConceptLSAN
from .GloveAsConcept import GloveConcepts
from .BertAsConcept import BertAsConcept
from .BertAsConcept import BertAsConcept, BertAsConcept2
......@@ -6,14 +6,14 @@ import numpy as np
epochs = 20
batch_size = 32
batch_size = 50
mode = "transformer"
representation = "roberta"
optimizer = torch.optim.Adam
optimizer_params = {"lr": 1e-4}#, "betas": (0.9, 0.99)}
optimizer_params = {"lr": 1e-5}#, "betas": (0.9, 0.99)}
loss = torch.nn.BCEWithLogitsLoss
dataset = "blurbgenrecollection"
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
concept_graph = "random"
layers = 1
label_freeze = True
......@@ -30,11 +30,11 @@ data = mlmc.data.get_dataset(dataset,
tc = mlmc.models.BertAsConcept(
tc = mlmc.models.BertAsConcept2(
classes=data["classes"],
label_freeze=label_freeze,
representation=representation,
optimizer=optimizer,
optimizer=optimizer,s
# optimizer_params=optimizer_params,
loss=loss,
device=device)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment