Commit f9f13487 authored by Janos Borst's avatar Janos Borst
Browse files

Bert classify

parent e89805f1
......@@ -108,8 +108,8 @@ class BertAsConcept2(TextClassificationAbstract):
self.label_embedding_dim = self.labels.shape[-1]
self.input_projection2 = torch.nn.Linear(self.label_embedding_dim, self.embedding_dim)
# self.metric = Bilinear(self.embedding_dim).to(self.device)
# self.output_projection = torch.nn.Linear(in_features=self.max_len * self.n_classes, out_features=self.n_classes)
self.metric = Bilinear(self.embedding_dim).to(self.device)
self.output_projection = torch.nn.Linear(in_features=self.max_len , out_features=1)
self.build()
def forward(self, x, return_scores=False):
......@@ -118,10 +118,10 @@ class BertAsConcept2(TextClassificationAbstract):
p2 = self.input_projection2(self.labels)
output = torch.matmul(embeddings,p2.t()).sum(-2)
output = self.output_projection(torch.matmul(embeddings,p2.t()).permute(0,2,1)).squeeze()
# output = self.metric(embeddings,p2).sum(-2)
if return_scores:
return output, metric_scores
return output#, metric_scores
return output
def additional_concepts(self, x, k=5):
......
......@@ -34,7 +34,7 @@ tc = mlmc.models.BertAsConcept2(
classes=data["classes"],
label_freeze=label_freeze,
representation=representation,
optimizer=optimizer,s
optimizer=optimizer,
# optimizer_params=optimizer_params,
loss=loss,
device=device)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment