Commit 3f2de6ae authored by Janos Borst's avatar Janos Borst
Browse files

testing for graph network

parent 6d3dcdbe
Pipeline #53378 passed with stage
in 13 minutes and 21 seconds
......@@ -138,7 +138,7 @@ class KMemoryGraph(SentenceTextClassificationAbstract, TextClassificationAbstrac
return r
def forward(self, x):
input_embedding = self.vdropout(self.embedding(**{k:x[k] for k in ['input_ids', 'token_type_ids', 'attention_mask']})[0])
input_embedding = self.dropout(self.embedding(**{k:x[k] for k in ['input_ids', 'token_type_ids', 'attention_mask']})[0])
label_embedding = self.dropout(self.embedding(**{k:self.label_dict[k] for k in ['input_ids', 'token_type_ids', 'attention_mask']})[0])
nodes_embedding = self.embedding(**{k:self.nodes[k] for k in ['input_ids', 'token_type_ids', 'attention_mask']})[0]
memory_embedding = {x:self.embedding(**{k:self.memory_dicts.get(x)[k] for k in ['input_ids', 'token_type_ids', 'attention_mask']})[0] if x in self.memory_dicts else None for x in self.classes.keys()}
......
......@@ -24,7 +24,7 @@ class NormedLinear(torch.nn.Module):
return r * self.g
class KMemoryGraph(SentenceTextClassificationAbstract, TextClassificationAbstractZeroShot):
def __init__(self, similarity="cosine", dropout=0.5, measures=[ "keyword_similarity_max", "pooled_similarity", "keyword_similiarity_mean", "fallback_classifier", "weighted_similarity"],
def __init__(self, similarity="cosine", dropout=0.5, measures=[ "keyword_similarity_max", "pooled_similarity", "keyword_similiarity_mean", "fallback_classifier"],
graph="wordnet", *args, **kwargs):
super(KMemoryGraph, self).__init__(*args, **kwargs)
self.dropout = torch.nn.Dropout(dropout)
......@@ -43,7 +43,7 @@ class KMemoryGraph(SentenceTextClassificationAbstract, TextClassificationAbstrac
self._config["pos"] = ["a", "s", "n", "v"]
self._config["depth"] = 2
self._config["graph"] = graph
from ..graph.helpers import keywordmap
from ....graph.helpers import keywordmap
self.map = keywordmap
self.create_labels(self.classes)
......@@ -89,9 +89,7 @@ class KMemoryGraph(SentenceTextClassificationAbstract, TextClassificationAbstrac
self.g.add_edge(k,k)
self.g.add_edges_from([(k,n) for n in v.nodes])
self.memory_dicts = {}
self.memory_dicts = {k:self.label_embed(ex) for k, ex in self.memory.items() }
self._node_list = sorted(list(self.g.nodes))
self.nodes = self.transform(self._node_list)
self._class_nodes = {k:self._node_list.index(k) for k in self.classes.keys()}
......@@ -149,7 +147,11 @@ class KMemoryGraph(SentenceTextClassificationAbstract, TextClassificationAbstrac
input_embedding = self.vdropout(self.embedding(**{k:x[k] for k in ['input_ids', 'token_type_ids', 'attention_mask']})[0])
label_embedding = self.dropout(self.embedding(**{k:self.label_dict[k] for k in ['input_ids', 'token_type_ids', 'attention_mask']})[0])
nodes_embedding = self.embedding(**{k:self.nodes[k] for k in ['input_ids', 'token_type_ids', 'attention_mask']})[0]
memory_embedding = {x:self.embedding(**{k:self.memory_dicts.get(x)[k] for k in ['input_ids', 'token_type_ids', 'attention_mask']})[0] if x in self.memory_dicts else None for x in self.classes.keys()}
# task_embedding = self.embedding(**{k:self.transform([self._config["task"]])[k] for k in ['input_ids', 'token_type_ids', 'attention_mask']})[1]
# input_embedding = input_embedding - task_embedding[None]
# nodes_embedding = nodes_embedding - task_embedding[None]
# label_embedding = label_embedding - task_embedding[None]
if self.training:
......@@ -157,9 +159,6 @@ class KMemoryGraph(SentenceTextClassificationAbstract, TextClassificationAbstrac
input_embedding = input_embedding * ((torch.rand_like(input_embedding[:,:,0])>0.05).float()*2 -1)[...,None]
# with torch.no_grad():
# x_mask = x["attention_mask"].detach()
# x_norm = input_embedding.detach() / input_embedding.norm(dim=-1, keepdim=True).detach()
......@@ -175,12 +174,10 @@ class KMemoryGraph(SentenceTextClassificationAbstract, TextClassificationAbstrac
#
# ke = torch.einsum("bwt,bwe->bte", tfidf.softmax(-2)*x_mask[...,None], input_embedding)
memory_embedding = {x: self._mean_pooling(memory_embedding[x], self.memory_dicts[x]["attention_mask"]) if memory_embedding[
x] is not None else None
for x in memory_embedding}
words, ke, tfidf= self.agg(input_embedding, memory_embedding.values(), x_mask = x["attention_mask"])
nodes_embedding = self._mean_pooling(nodes_embedding, self.nodes["attention_mask"])
words, ke, tfidf= self.agg(input_embedding, [nodes_embedding[torch.where(line==1)] for line in self.adjencies], x_mask = x["attention_mask"])
input_embedding = self._mean_pooling(input_embedding, x["attention_mask"])
label_embedding = self._mean_pooling(label_embedding, self.label_dict["attention_mask"])
......@@ -188,7 +185,7 @@ class KMemoryGraph(SentenceTextClassificationAbstract, TextClassificationAbstrac
l = []
if "keyword_similarity_max" in self._config["scoring"]:
keyword_similarity_max = torch.stack(
[self._sim(input_embedding, x).max(-1)[0] for i, (k, x) in enumerate(memory_embedding.items())],
[self._sim(input_embedding, nodes_embedding[torch.where(line==1)]).max(-1)[0] for line in self.adjencies],
-1) # keyword-similarity-max
l.append(keyword_similarity_max)
if "pooled_similarity" in self._config["scoring"]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment