Commit 0664ab7e authored by Janos Borst's avatar Janos Borst
Browse files

Paper additions

parent 3f2de6ae
Pipeline #53656 passed with stage
in 11 minutes and 4 seconds
......@@ -7,7 +7,7 @@ import torch
from .data_loaders_classification import load_eurlex, load_wiki30k, load_huffpost, load_aapd, load_rcv1, \
load_moviesummaries, load_blurbgenrecollection, load_blurbgenrecollection_de, load_20newsgroup, export, \
load_agnews, load_dbpedia, load_ohsumed, load_yahoo_answers, load_movie_reviews, load_amazonfull, load_trec6, \
load_trec50, load_yelpfull, load_amazonpolarity
load_trec50, load_yelpfull, load_amazonpolarity, load_yelppolarity, load_imdb
from .data_loaders_similarity import load_sts, load_stsb, load_sts12, load_sts13, load_sts14, load_sts16, load_sick
from .data_loaders_nli import load_mnli, load_snli
......@@ -31,13 +31,15 @@ register = {
"ohsumed": load_ohsumed,
"yahoo_answers": load_yahoo_answers,
"movie_reviews": load_movie_reviews,
"imdb": load_imdb,
"trec6": load_trec6,
"trec50":load_trec50,
"yelpfull": load_yelpfull,
"amazonfull": load_amazonfull,
"snli": load_snli,
"mnli": load_mnli,
"amazonpolarity": load_amazonpolarity
"amazonpolarity": load_amazonpolarity,
"yelppolarity": load_yelppolarity
}
......
......@@ -773,6 +773,16 @@ def load_movie_reviews():
_save_to_tmp("movie_reviews", (data, classes))
return data, classes
def load_imdb():
from datasets import load_dataset
train = load_dataset("imdb", split="train")
test = load_dataset("imdb", split="test")
data = {
"train": (train["text"], [["negative" if x==0 else "positive"] for x in train["label"]]),
"test": (test["text"], [["negative" if x==0 else "positive"] for x in test["label"]]),
}
return data, dict(negative=0, positive=1)
def load_yelpfull():
......@@ -796,6 +806,16 @@ def load_amazonpolarity():
}
return data, {"negative":0, "positive":1}
def load_yelppolarity():
from datasets import load_dataset
train = load_dataset("yelp_polarity", split="train")
test = load_dataset("yelp_polarity", split="test")
data = {
"train": (train["text"], [["negative"] if x==0 else ["positive"] for x in train["label"]]),
"test": (test["text"], [["negative"] if x==0 else ["positive"] for x in test["label"]]),
}
return data, {"negative":0, "positive":1}
import tempfile
import tarfile
import csv
......
......@@ -74,6 +74,11 @@ label_dicts = {"trec6": {"ABBR": "question about abbreviation",
"3":"neutral restaurant sentiment",
"4":"positive restaurant sentiment",
"5":"very positive restaurant sentiment",
"one":"very negative restaurant sentiment",
"two":"negative restaurant sentiment",
"three":"neutral restaurant sentiment",
"four":"positive restaurant sentiment",
"five":"very positive restaurant sentiment",
},
"amazonfull":{
"1":"very negative restaurant sentiment",
......@@ -95,10 +100,10 @@ SFORMATTER = {"agnews": lambda x: f"The topic of this is {label_dicts['agnews'].
"movies_summaries": lambda x: f"This movie is about {x}",
"movie_reviews": lambda x: f"This sounds rather {x}",
"dbpedia": lambda x: f"The topic of this is {label_dicts['dbpedia'].get(x,x)}",
"trec6": lambda x: f"This is a {label_dicts['trec6'][x]}",
"trec50": lambda x: f"This is a {label_dicts['trec50'][x]}",
"yelpfull": lambda x: f"This is a {label_dicts['yelpfull'][x]}",
"amazonfull": lambda x: f"This is a {label_dicts['amazonfull'][x]}",
"trec6": lambda x: f"This is a {label_dicts['trec6'].get(x,x)}",
"trec50": lambda x: f"This is a {label_dicts['trec50'].get(x,x)}",
"yelpfull": lambda x: f"This is a {label_dicts['yelpfull'].get(x,x)}",
"amazonfull": lambda x: f"This is a {label_dicts['amazonfull'].get(x,x)}",
}
SFORMATTER_TARS = {"agnews": lambda x: f"label topic {label_dicts['agnews'].get(x,x)}",
......@@ -106,10 +111,10 @@ SFORMATTER_TARS = {"agnews": lambda x: f"label topic {label_dicts['agnews'].get(
"rcv1": lambda x: f"label topics {x}",
"blurbgenrecollection": lambda x: f"label topics {x}",
"movies_summaries": lambda x: f"label genre {x}",
"movie_reviews": lambda x: f"{x} sentiment",
"movie_reviews": lambda x: f"sentiment {x}",
"dbpedia": lambda x: f"label topic {label_dicts['dbpedia'].get(x,x)}",
"trec6": lambda x: f"question {label_dicts['trec6'][x]}",
"trec50": lambda x: f"question {label_dicts['trec50'][x]}",
"yelpfull": lambda x: f"sentiment {label_dicts['yelpfull'][x]}",
"amazonfull": lambda x: f"sentiment {label_dicts['amazonfull'][x]}",
"trec6": lambda x: f"question {label_dicts['trec6'].get(x,x)}",
"trec50": lambda x: f"question {label_dicts['trec50'].get(x,x)}",
"yelpfull": lambda x: f"sentiment {label_dicts['yelpfull'].get(x,x)}",
"amazonfull": lambda x: f"sentiment {label_dicts['amazonfull'].get(x,x)}",
}
\ No newline at end of file
......@@ -60,7 +60,7 @@ keywordmap = {"Sports": ["sport"], "Business":["business"], "World": ["world"],
"Education & Reference":["Education", "reference"], "Computers & Internet":["computer", "internet"], "Business & Finance": ["business", "finance"],
"Entertainment & Music":["entertainment", "music"], "Family & Relationships": ["family", "relationship"], "Politics & Government":["politics", "government"],
# "1":["1", "worst", "terrible"], "2":["2","poor", "odd", "simple"], "3":["3", "neutral","ok", "fine"], "4":["4", "bold", "worth", "good", "nice"], "5":["5","amazing", "excellent", "wow"],
"1":["1"], "2":["2"], "3":["3"], "4":["4",], "5":["5"],
"1":["1"], "2":["2"], "3":["3"], "4":["4",], "5":["5"],"one":["1"], "two":["2"], "three":["3"], "four":["4",], "five":["5"],
"negative":["1", "2"], "positive":["4","5"],
"ENTY:sport": ["entity", "sport"], "ENTY:dismed": ["entity","disease", "medicine"], "LOC:city": ["location", "city"],
"DESC:reason": ["description","reason"],
......@@ -99,7 +99,7 @@ keywordmap = {"Sports": ["sport"], "Business":["business"], "World": ["world"],
'Teen & Young Adult Historical Fiction': ['teen', 'young', 'adult', 'historical', 'fiction'],
'U.S. History': ['U.S.', 'history'], 'Children’s Picture Books': ['child', 'picture', 'book'],
'Fiction Classics': ['fiction', 'classics'], 'Ancient World History': ['ancient', 'world', 'history'],
'Classics': ['classics'], 'Business': ['business'], 'Military Science Fiction': ['military', 'science fiction'],
'Classics': ['classics'], 'Military Science Fiction': ['military', 'science fiction'],
'World War I Military History': ['world war', 'i', 'military', 'history'],
'Fiction': ['fiction'], 'Paranormal Romance': ['paranormal', 'romance'], 'Women’s Fiction': ['woman', 'fiction'],
'Crime Mysteries': ['crime', 'mystery'], 'Design': ['design'], 'Personal Growth': ['personal', 'growth'],
......@@ -144,18 +144,36 @@ keywordmap = {"Sports": ["sport"], "Business":["business"], "World": ["world"],
'Weddings': ['wedding'], 'Teen & Young Adult Nonfiction': ['teen', 'young', 'adult', 'nonfiction'],
'21st Century U.S. History': ['21st', 'century', 'U.S.', 'history'], 'Gothic & Horror': ['gothic', 'horror'],
'Domestic Politics': ['domestic', 'politics'], 'Reference': ['reference'], 'Beauty': ['beauty'],
'Sports': ['sport'], 'Western Fiction': ['Western', 'fiction'],
'Western Fiction': ['Western', 'fiction'],
'Teen & Young Adult Science Fiction': ['teen', 'young', 'adult', 'science fiction'],
'Philosophy': ['philosophy'], 'Parenting': ['parent', 'raise'],
'Native American History': ['Native American', 'history'], 'Poetry': ['poetry'], 'Psychology': ['psychology'],
'Inspiration & Motivation': ['inspiration', 'motivation'], 'Step Into Reading': ['beginner', "start", 'reading'],
'Exercise': ['exercise'], 'Bibles': ['bible'], 'Travel: Middle East': ['travel', 'Middle East'],
'New Adult Romance': ['new', 'adult', 'romance'], 'Contemporary Romance': ['contemporary', 'romance'],
'Military History': ['military', 'history'], 'Cyber Punk': ["cyberpunk"], 'Film': ['film'],
'Military History': ['military', 'history'], 'Cyber Punk': ["cyberpunk"],
'Children’s Middle Grade Sports Books': ['child', 'middle', 'grade', 'sport', 'book'],
'European World History': ['European', 'world', 'history'],
'Political Figure Biographies & Memoirs': ['political', 'figure', 'biography', 'memoir'],
'Children’s Activity & Novelty Books': ['child', 'activity', 'novelty', 'book'],
'1950 – Present Military History': ['present', 'military', 'history'],
'Children’s Board Books': ['child', 'board', 'book'], 'World Politics': ['world', 'politics'],
'Food Memoir & Travel': ['food', 'memoir', 'travel'], 'Management': ['management']}
\ No newline at end of file
'Food Memoir & Travel': ['food', 'memoir', 'travel'], 'Management': ['management'],
'alt.atheism':["atheism"], 'comp.graphics': ["computer", "graphics", "computer graphics"],
'comp.os.ms-windows.misc': ['computer',"operating system", "Windows"],
'comp.sys.ibm.pc.hardware': ['computer', 'system', 'IBM' , 'hardware'],
'comp.sys.mac.hardware': ['computer', 'system', 'Apple', "Mac" , 'hardware'], 'comp.windows.x': ["computer", "Windows"],
'misc.forsale': ["for sale", "sale"],
'rec.autos': ["recreational", "auto", "car"], 'rec.motorcycles': ["recreational", "motorcycle"],
'rec.sport.baseball': ["recreational", "sport", "baseball"],
'rec.sport.hockey': ["recreational", "sport", "hockey"],
'sci.crypt': ["science", "cryptography"],
'sci.electronics': ["science", "electronics"],
'sci.med': ["science", "medicine"],
'sci.space': ["science", "outer space"],
'soc.religion.christian': ["society", "religion", "christianity"],
'talk.politics.guns': ["politic", "gun"],
'talk.politics.mideast': ["politic", "Middle East"],
'talk.politics.misc': ["politic"],
'talk.religion.misc':["religion"],
}
\ No newline at end of file
......@@ -863,4 +863,9 @@ class TextClassificationAbstract(torch.nn.Module):
"""
x = self.transform(x)
return self.forward(x, emb=True)
\ No newline at end of file
return self.forward(x, emb=True)
def log_mlflow(self):
import mlflow
mlflow.log_params({k:v for k,v in self._config.items() if k not in ["classes"]})
mlflow.log_param("model", self.__class__.__name__)
\ No newline at end of file
......@@ -24,7 +24,7 @@ class NormedLinear(torch.nn.Module):
return r * self.g
class KMemoryGraph(SentenceTextClassificationAbstract, TextClassificationAbstractZeroShot):
def __init__(self, similarity="cosine", dropout=0.5, measures=[ "keyword_similarity_max", "pooled_similarity", "keyword_similiarity_mean", "fallback_classifier", "weighted_similarity"],
def __init__(self, similarity="cosine", dropout=0.5, entropy=True, measures=[ "keyword_similarity_max", "pooled_similarity", "keyword_similiarity_mean", "fallback_classifier", "weighted_similarity"],
graph="wordnet", *args, **kwargs):
super(KMemoryGraph, self).__init__(*args, **kwargs)
self.dropout = torch.nn.Dropout(dropout)
......@@ -36,10 +36,10 @@ class KMemoryGraph(SentenceTextClassificationAbstract, TextClassificationAbstrac
self._config["dropout"] = dropout
self._config["similarity"] = similarity
self._config["entropy"] = entropy
self.agg = TFIDFAggregation()
self._config["scoring"] = measures
self.set_scoring(measures)
self._config["pos"] = ["a", "s", "n", "v"]
self._config["depth"] = 2
self._config["graph"] = graph
......@@ -52,6 +52,9 @@ class KMemoryGraph(SentenceTextClassificationAbstract, TextClassificationAbstrac
self.build()
def set_scoring(self, measures):
self._config["scoring"] = measures
def update_memory(self):
"""
Method to change the current target variables
......@@ -64,12 +67,12 @@ class KMemoryGraph(SentenceTextClassificationAbstract, TextClassificationAbstrac
graph = gget(self._config["graph"])
self.memory = {
k: [k] +self.map[k]+[x for x in sum([list(graph.neighbors(x)) for x in self.map[k]] , [])] # if graph.nodes(True)[x]["pos"] in self._config["pos"]
k: [k] +self.map[k]+[x for x in sum([list(graph.neighbors(x)) if x in graph else [x] for x in self.map[k]], [])] # if graph.nodes(True)[x]["pos"] in self._config["pos"]
for k in self.classes.keys()
}
subgraph = {
k: graph.subgraph(self.map[k] + [x for x in sum([list(graph.neighbors(x)) for x in self.map[k] ], [])])
k: graph.subgraph(self.map[k] + [x for x in sum([list(graph.neighbors(x)) if x in graph else [x] for x in self.map[k]], [])])
for k in self.classes.keys()
}
......@@ -178,5 +181,14 @@ class KMemoryGraph(SentenceTextClassificationAbstract, TextClassificationAbstrac
if "weighted_similarity" in self._config["scoring"]:
weighted_similarity = self._sim(ke, label_embedding).squeeze() # weighted-similarity
l.append(weighted_similarity)
scores = torch.stack(l,-1).mean(-1)
return scores
\ No newline at end of file
if self._config["entropy"]:
with torch.no_grad():
w = torch.stack([self._ent(x.detach()) for x in l], -1).softmax(-1).unsqueeze(-2)
scores = (torch.stack(l, -1) * w).mean(-1)
else:
scores = torch.stack(l,-1).mean(-1)
return scores
def _ent(self, x):
return (x.log_softmax(-1) * x.softmax(-1)).sum(-1)
\ No newline at end of file
......@@ -24,9 +24,13 @@ class NormedLinear(torch.nn.Module):
return r * self.g
class KMemoryGraph(SentenceTextClassificationAbstract, TextClassificationAbstractZeroShot):
def __init__(self, similarity="cosine", dropout=0.5, measures=[ "keyword_similarity_max", "pooled_similarity", "keyword_similiarity_mean", "fallback_classifier"],
def __init__(self, similarity="cosine", dropout=0.5, entropy=True,
measures=None,depth=2,
graph="wordnet", *args, **kwargs):
super(KMemoryGraph, self).__init__(*args, **kwargs)
if measures is None:
measures= ["keyword_similarity_max", "pooled_similarity", "keyword_similiarity_mean", "fallback_classifier",
"weighted_similarity"]
self.dropout = torch.nn.Dropout(dropout)
self.parameter = torch.nn.Linear(self.embeddings_dim,256)
self.entailment_projection = torch.nn.Linear(3 * self.embeddings_dim, self.embeddings_dim)
......@@ -36,12 +40,12 @@ class KMemoryGraph(SentenceTextClassificationAbstract, TextClassificationAbstrac
self._config["dropout"] = dropout
self._config["similarity"] = similarity
self._config["entropy"] = entropy
self.agg = TFIDFAggregation()
self._config["scoring"] = measures
self.set_scoring(measures)
self._config["pos"] = ["a", "s", "n", "v"]
self._config["depth"] = 2
self._config["depth"] = depth
self._config["graph"] = graph
from ....graph.helpers import keywordmap
self.map = keywordmap
......@@ -51,14 +55,9 @@ class KMemoryGraph(SentenceTextClassificationAbstract, TextClassificationAbstrac
self._classifier_weight = torch.nn.Parameter(torch.tensor([0.01]))
self.build()
def fit(self, train, valid,*args, **kwargs):
# for x, y in zip(train.x, train.y):3
# for l in y:
# self.memory[l] = list(set(self.memory.get(l, []) + [x]))
# self.update_memory()
return super().fit(train, valid, *args, **kwargs)
def set_scoring(self, measures):
self._config["scoring"] = measures
def update_memory(self):
"""
......@@ -71,31 +70,38 @@ class KMemoryGraph(SentenceTextClassificationAbstract, TextClassificationAbstrac
"""
graph = gget(self._config["graph"])
self.memory = {
k: [k] +self.map[k]+[x for x in sum([list(graph.neighbors(x)) for x in self.map[k]] , [])] # if graph.nodes(True)[x]["pos"] in self._config["pos"]
for k in self.classes.keys()
}
subgraph = {
k: graph.subgraph(self.map[k] + [x for x in sum([list(graph.neighbors(x)) for x in self.map[k] ], [])])
k: graph.subgraph(self.map[k] + [x for x in sum([list(graph.neighbors(x)) if x in graph else [x] for x in self.map[k]], [])])
for k in self.classes.keys()
}
self.g = nx.OrderedDiGraph()
self.g = nx.OrderedGraph()
for k, v in subgraph.items():
self.g = nx.compose(self.g,v)
self.g = nx.compose(self.g, v)
self.g.add_node(k)
self.g.add_edges_from([(n,k) for n in v.nodes])
self.g.add_edge(k,k)
self.g.add_edges_from([(k,n) for n in v.nodes])
self.g.add_edges_from([(n, k) for n in v.nodes])
self.g.add_edge(k, k)
if self._config["depth"] > 1:
for _ in range(self._config["depth"]-1):
self.g = nx.compose(self.g, graph.subgraph(sum([list(graph.neighbors(n)) for n in self.g if n in graph],[])))
self._node_list = sorted(list(self.g.nodes))
self.nodes = self.transform(self._node_list)
self._class_nodes = {k:self._node_list.index(k) for k in self.classes.keys()}
adj = nx.adj_matrix(self.g, self._node_list)
self.adjencies = torch.nn.Parameter(torch.cat([torch.tensor(adj[i].toarray()) for i in self._class_nodes.values()],0).float()).to(self.device)
self.adjencies = self.adjencies.detach()
self.adjencies = torch.nn.Parameter(torch.cat([torch.FloatTensor(adj[i].toarray()) for i in self._class_nodes.values()],0).float()).to(self.device)
self.adjencies= self.adjencies.detach()
adj = torch.FloatTensor(adj.toarray()).to(self.device)
adj = adj / adj.sum(-1)
for _ in range(self._config["depth"]):
adj = torch.mm(adj, adj.t())
adj = torch.stack([adj[i] for i in self._class_nodes.values()],0).float()
self.adjencies_all = torch.nn.Parameter(adj.detach()).to(self.device)
def create_labels(self, classes: dict):
super().create_labels(classes)
......@@ -143,64 +149,82 @@ class KMemoryGraph(SentenceTextClassificationAbstract, TextClassificationAbstrac
r = r.diag()
return r
def forward(self, x):
input_embedding = self.vdropout(self.embedding(**{k:x[k] for k in ['input_ids', 'token_type_ids', 'attention_mask']})[0])
def forward(self, x, kw=False):
input_embedding = self.dropout(self.embedding(**{k:x[k] for k in ['input_ids', 'token_type_ids', 'attention_mask']})[0])
label_embedding = self.dropout(self.embedding(**{k:self.label_dict[k] for k in ['input_ids', 'token_type_ids', 'attention_mask']})[0])
nodes_embedding = self.embedding(**{k:self.nodes[k] for k in ['input_ids', 'token_type_ids', 'attention_mask']})[0]
# task_embedding = self.embedding(**{k:self.transform([self._config["task"]])[k] for k in ['input_ids', 'token_type_ids', 'attention_mask']})[1]
# input_embedding = input_embedding - task_embedding[None]
# nodes_embedding = nodes_embedding - task_embedding[None]
# label_embedding = label_embedding - task_embedding[None]
nodes_embedding = self.dropout(self.embedding(**{k:self.nodes[k] for k in ['input_ids', 'token_type_ids', 'attention_mask']})[0])
if self.training:
input_embedding = input_embedding + 0.01*torch.rand_like(input_embedding)[:,0,None,0,None].round()*torch.rand_like(input_embedding) #
input_embedding = input_embedding * ((torch.rand_like(input_embedding[:,:,0])>0.05).float()*2 -1)[...,None]
# with torch.no_grad():
# x_mask = x["attention_mask"].detach()
# x_norm = input_embedding.detach() / input_embedding.norm(dim=-1, keepdim=True).detach()
# y = memory_embedding.values()
# # y = [w-w.mean(0)[None] for w in y]
# # words = [0.5 * (1 + torch.einsum("ijn,ln->ilj", x_norm, te / te.norm(dim=-1, keepdim=True))) for te in y]
# words = [torch.einsum("ijn,lkn->iljk", x_norm, te.detach() / te.detach().norm(dim=-1, keepdim=True)).relu() for te in y]
# words =[(w[:,:,1:].max(1)[0].max(-1)[0]) / ((w[:,:,1:]*e["attention_mask"][None,:,None]).sum([1,-1])/e["attention_mask"].sum() ) for e,w in zip(self.memory_dicts.values(),words)]
# w = torch.stack(words, -1)
# w[x_mask[:,1:]==0] = 0
# tfidf = torch.cat([w,torch.zeros_like(w[:,0,:])[:,None]],1)
#
#
# ke = torch.einsum("bwt,bwe->bte", tfidf.softmax(-2)*x_mask[...,None], input_embedding)
input_embedding = input_embedding + 0.01 * torch.rand_like(input_embedding)[:, 0, None, 0,
None].round() * torch.rand_like(input_embedding) #
input_embedding = input_embedding * ((torch.rand_like(input_embedding[:, :, 0]) > 0.05).float() * 2 - 1)[..., None]
nodes_embedding = self._mean_pooling(nodes_embedding, self.nodes["attention_mask"])
words, ke, tfidf= self.agg(input_embedding, [nodes_embedding[torch.where(line==1)] for line in self.adjencies], x_mask = x["attention_mask"])
tmp = self._mean_pooling(nodes_embedding, self.nodes["attention_mask"])
text_kw = torch.einsum("bse,le->bls",
input_embedding / input_embedding.norm(p=2, dim=-1, keepdim=True),
tmp / tmp.norm(p=2, dim=-1, keepdim=True)).max(1)[0]
ke = torch.einsum("bse, bs -> be", input_embedding, text_kw.softmax(-1))
nodes_embedding = self._mean_pooling(nodes_embedding, self.nodes["attention_mask"])
input_embedding = self._mean_pooling(input_embedding, x["attention_mask"])
label_embedding = self._mean_pooling(label_embedding, self.label_dict["attention_mask"])
l = []
if "keyword_similarity_max" in self._config["scoring"]:
keyword_similarity_max = torch.stack(
[self._sim(input_embedding, nodes_embedding[torch.where(line==1)]).max(-1)[0] for line in self.adjencies],
-1) # keyword-similarity-max
s = self._sim(input_embedding, nodes_embedding)
keyword_similarity_max = torch.stack([s[:, torch.where(x == 1)[0]].max(-1)[0] for x in self.adjencies],1)
l.append(keyword_similarity_max)
if "pooled_similarity" in self._config["scoring"]:
pooled_similarity = self._sim(input_embedding, label_embedding).squeeze() # pooled-similarity
l.append(pooled_similarity)
if "keyword_similiarity_mean" in self._config["scoring"]:
keyword_similiarity_mean = torch.mm(self._sim(input_embedding, nodes_embedding).squeeze(), (
self.adjencies / self.adjencies.norm(1, dim=-1, keepdim=True)).t()) # keyword-similarity-mean
keyword_similiarity = self._sim(input_embedding, nodes_embedding).squeeze() # keyword-similarity-mean
keyword_similiarity_mean = torch.mm(keyword_similiarity, (
self.adjencies / self.adjencies.norm(1, dim=-1, keepdim=True)).t())
l.append(keyword_similiarity_mean)
if "fallback_classifier" in self._config["scoring"]:
fallback_classifier = self._classifier_weight * self._entailment(input_embedding,
label_embedding) # classifier
l.append(fallback_classifier)
if "weighted_similarity" in self._config["scoring"]:
weighted_similarity = self._sim(ke, label_embedding).squeeze() # weighted-similarity
weighted_similarity = self._sim(ke, label_embedding).squeeze() # weighted_similarity
l.append(weighted_similarity)
if "keyword_similarity_weighted" in self._config["scoring"]:
keyword_similiarity_weighted = self._sim(input_embedding, nodes_embedding).squeeze() # keyword-similarity-mean
keyword_similiarity_weighted = torch.mm(keyword_similiarity_weighted, (
self.adjencies_all / self.adjencies_all.norm(1, dim=-1, keepdim=True)).t())
l.append(keyword_similiarity_weighted)
scores = torch.stack(l,-1).mean(-1)
return scores
\ No newline at end of file
if kw:
return scores, keyword_similiarity, text_kw
return scores
def keywords(self, x, n = 10):
self.eval()
with torch.no_grad():
tok = self.transform(x)
scores, graphkw, text_kw = self.forward(tok, kw=True)
gkw = [[list(self._node_list)[i] for i in x] for x in graphkw.softmax(-1).topk(n, dim=-1, )[1]]
labels = [[self.classes_rev[x.item()]] for x in scores.argmax(-1).cpu()]
keywords = [list(zip(t, l[1:(1 + len(t))].tolist())) for t, l in zip(tok["text"], text_kw)]
keywords_new = []
for l in keywords:
new_list = []
new_tuple = [[], 0]
for i in range(1, 1 + len(l)):
new_tuple[0].append(l[-i][0])
new_tuple[1] += l[-i][1]
if not l[-i][0].startswith("##"):
new_tuple[0] = "".join(new_tuple[0][::-1]).replace("##", "")
new_list.append(tuple(new_tuple))
new_tuple = [[], 0]
keywords_new.append(new_list)
tkw = [sorted(x, key=lambda x: -x[1])[:n] for x in keywords_new]
return labels, gkw, tkw
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment