Commit 271ec4e7 authored by Janos Borst's avatar Janos Borst
Browse files

testing stuff

parent fd4f3eb0
Pipeline #53271 passed with stage
in 12 minutes and 3 seconds
......@@ -30,6 +30,10 @@ class KMemoryGraph(SentenceTextClassificationAbstract, TextClassificationAbstrac
self.parameter = torch.nn.Linear(self.embeddings_dim,256)
self.entailment_projection = torch.nn.Linear(3 * self.embeddings_dim, self.embeddings_dim)
self.entailment_projection2 = torch.nn.Linear(self.embeddings_dim, 1)
# self.entailment_projection = torch.nn.Linear(3 * self.embeddings_dim, self.embeddings_dim)
# self.entailment_projection2 = NormedLinear(self.embeddings_dim, 1, bias=True)
self.project = NormedLinear(self.embeddings_dim, len(self.classes), bias=False)
self._config["dropout"] = dropout
......@@ -71,7 +75,7 @@ class KMemoryGraph(SentenceTextClassificationAbstract, TextClassificationAbstrac
}
self.create_labels(self.classes)
self.vdropout = VerticalDropout(0.5)
self._classifier_weight = torch.nn.Parameter(torch.tensor([0.01]))
self.build()
def fit(self, train, valid,*args, **kwargs):
......@@ -147,7 +151,19 @@ class KMemoryGraph(SentenceTextClassificationAbstract, TextClassificationAbstrac
r = r.diag()
return r
def forward(self, x):
def _entailment(self, x, y,):
b = tuple([1]*(len(x.shape)-2))
e = self.entailment_projection(self.dropout(torch.cat([
x.unsqueeze(-2).repeat(*(b+ (1, y.shape[0], 1))),
y.unsqueeze(-3).repeat(*(b+ (x.shape[0], 1, 1))),
(x.unsqueeze(-2) - y.unsqueeze(-3)).abs()
], -1)))
r = self.entailment_projection2(e).squeeze(-1)
if self._config["target"] == "entailment":
r = r.diag()
return r
def forward(self, x,return_keywords=False):
input_embedding = self.vdropout(self.embedding(**x)[0])
label_embedding = self.dropout(self.embedding(**self.label_dict)[0])
nodes_embedding = self.embedding(**self.nodes)[0]
......@@ -160,16 +176,89 @@ class KMemoryGraph(SentenceTextClassificationAbstract, TextClassificationAbstrac
memory_embedding = {x: self._mean_pooling(memory_embedding[x], self.memory_dicts[x]["attention_mask"]) if memory_embedding[x] is not None else None for x in memory_embedding}
words, ke, tfidf= self.agg(input_embedding, memory_embedding.values(), x_mask = x["attention_mask"])
nodes_embedding = self._mean_pooling(nodes_embedding, self.nodes["attention_mask"])
input_embedding = self._mean_pooling(input_embedding, x["attention_mask"])
label_embedding = self._mean_pooling(label_embedding, self.label_dict["attention_mask"])
r1 = self._sim(ke,label_embedding).squeeze()
r2 = torch.stack([self._sim(input_embedding, x).max(-1)[0] for i,(k, x) in enumerate(memory_embedding.items())],-1)
r3 = self._sim(input_embedding,label_embedding).squeeze()
r4 = torch.mm(self._sim(input_embedding,nodes_embedding).squeeze(), (self.adjencies/self.adjencies.norm(1,dim=-1, keepdim=True)).t())
# p = self.project(input_embedding)
l = [r1, r2,r3, r4]
# l = [r2, r4]
return torch.stack(l,-1).mean(-1)#tfidf.max(1)[0].log_softmax(-1)
\ No newline at end of file
r1 = self._sim(ke,label_embedding).squeeze() # weighted-similarity
r2 = torch.stack([self._sim(input_embedding, x).max(-1)[0] for i,(k, x) in enumerate(memory_embedding.items())],-1) # keyword-similarity-max
r3 = self._sim(input_embedding,label_embedding).squeeze() # pooled-similarity
r4 = torch.mm(self._sim(input_embedding,nodes_embedding).squeeze(), (self.adjencies/self.adjencies.norm(1,dim=-1, keepdim=True)).t()) # keyword-similarity-mean
p = self._classifier_weight * self._entailment(input_embedding, label_embedding) # classifier
l = [r1, r2,r3, r4, p]
# l = [r3,r2, r4]
scores = torch.stack(l,-1).mean(-1)
if return_keywords:
return scores, tfidf.softmax(-1)
return scores#tfidf.max(1)[0].log_softmax(-1)
def transform(self, x, max_length=None, return_tokens=False) -> dict:
if max_length is None:
max_length = self._config["max_len"]
r = {k: v.to(self.device) for k, v in
self.tokenizer(x, padding=True, max_length=max_length, truncation=True,
add_special_tokens=True, return_tensors='pt').items()}
if return_tokens:
return r, [self.tokenizer.tokenize(s) for s in x]
return r
def keywords(self, x, y, n=10):
self.eval()
with torch.no_grad():
i, tokens = self.transform(x, return_tokens=True)
scores, keywords = self.forward(i, return_keywords=True)
import matplotlib.pyplot as plt
sorted_scores, prediction = scores.sort(-1)
idx = scores.argmax(-1)
label_specific_scores = torch.stack([k[i] for k, i in zip(keywords, idx)])
keywords = [list(zip(t,l[1:(1+len(t))].tolist())) for t,l in zip(tokens, label_specific_scores)]
keywords_new = []
for l in keywords:
new_list = []
new_tuple = [[], 0]
for i in range(1, 1+len(l)):
new_tuple[0].append(l[-i][0])
new_tuple[1] += l[-i][1]
if not l[-i][0].startswith("##"):
new_tuple[0] = "".join(new_tuple[0][::-1]).replace("##", "")
new_list.append(tuple(new_tuple))
new_tuple = [[], 0]
keywords_new.append(new_list)
import numpy as np
prediction = [[self.classes_rev[x] for x in y] for y in prediction.detach().cpu().tolist()]
binary = np.array([[p in c for p in pred] for c, pred in zip(y, prediction)])
import seaborn as sns
ax = sns.heatmap((sorted_scores.softmax(-1).cpu()+binary), annot=np.array(prediction), fmt="")
plt.show()
return [(p[-1], t, sorted(x, key=lambda x: -x[1])[:n]) for p, t,x in zip(prediction, y, keywords_new)]
def scores(self, x):
"""
Returns 2D tensor with length of x and number of labels as shape: (N, L)
Args:
x:
Returns:
"""
self.eval()
assert not (self._config["target"] == "single" and self._config["threshold"] != "max"), \
"You are running single target mode and predicting not in max mode."
if not hasattr(self, "classes_rev"):
self.classes_rev = {v: k for k, v in self.classes.items()}
x = self.transform(x)
with torch.no_grad():
output = self.act(self(x))
# output = 0.5*(output+1)
self.train()
return output
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
......@@ -9,9 +9,9 @@ class TFIDFAggregation(torch.nn.Module):
def forward(self, x, y, x_mask=None):
x_norm = x/x.norm(dim=-1, keepdim=True)
with torch.no_grad():
y = [w-w.mean(0)[None] for w in y]
words = [torch.einsum("ijn,ln->ilj",x_norm , te/te.norm(dim=-1, keepdim=True) ) for te in y]
words = [(w* x_mask[:,None]) for w in words]
# y = [w-w.mean(0)[None] for w in y]
words = [0.5*(1+torch.einsum("ijn,ln->ilj",x_norm , te/te.norm(dim=-1, keepdim=True) )) for te in y]
words = [(w.softmax(-1)* x_mask[:,None]) for w in words]
cidf = (1./(sum([w.sum(1)[0]/ x_mask[:,None].sum(-1) for w in words]) ))
cidf[cidf.isinf()]=0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment