Commit 22dc9d68 authored by Janos Borst's avatar Janos Borst
Browse files

home

parent 3983a65a
Pipeline #53218 passed with stage
in 11 minutes and 28 seconds
......@@ -4,6 +4,22 @@ import torch
from ..modules.dropout import VerticalDropout
from ..graph import get as gget
from ..modules.module_tfidf import TFIDFAggregation
class NormedLinear(torch.nn.Module):
def __init__(self, input_dim, output_dim, bias=True):
super(NormedLinear, self).__init__()
self.weight = torch.nn.Parameter(torch.randn((input_dim, output_dim)))
self.use_bias = bias
if bias:
self.bias = torch.nn.Parameter(torch.randn((1, output_dim,)))
self.g = torch.nn.Parameter(torch.tensor([0.001]))
def forward(self, x):
r = torch.mm(x, self.weight/self.weight.norm(p=2, dim=0, keepdim=True))
if self.use_bias:
r = r + self.bias
return r * self.g
class KMemoryGraph(SentenceTextClassificationAbstract, TextClassificationAbstractZeroShot):
def __init__(self, similarity="cosine", dropout=0.5, *args, **kwargs):
super(KMemoryGraph, self).__init__(*args, **kwargs)
......@@ -11,7 +27,8 @@ class KMemoryGraph(SentenceTextClassificationAbstract, TextClassificationAbstrac
self.parameter = torch.nn.Linear(self.embeddings_dim,256)
self.entailment_projection = torch.nn.Linear(3 * self.embeddings_dim, self.embeddings_dim)
self.entailment_projection2 = torch.nn.Linear(self.embeddings_dim, 1)
self.project = torch.nn.Linear(self.embeddings_dim, len(self.classes))
self.project = NormedLinear(self.embeddings_dim, len(self.classes), bias=False)
self._config["dropout"] = dropout
self.create_labels(self.classes)
self._config["similarity"] = similarity
......@@ -109,4 +126,5 @@ class KMemoryGraph(SentenceTextClassificationAbstract, TextClassificationAbstrac
r = self._sim(ke,label_embedding).squeeze()
r3 = self._sim(input_embedding,label_embedding).squeeze()
# r = torch.einsum("bte,te->bt", ke, label_embedding)
return r+r2+r3 + self.project(input_embedding)#tfidf.max(1)[0].log_softmax(-1)
\ No newline at end of file
p = self.project(input_embedding)
return torch.stack([r,r2,r3, p],-1).mean(-1)#tfidf.max(1)[0].log_softmax(-1)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment