Commit 8bc5989b authored by Janos Borst's avatar Janos Borst
Browse files

Merge branch 'dev'

parents f0c00eda 170a82e9
Pipeline #38146 passed with stage
in 51 seconds
......@@ -6,7 +6,7 @@ from torch.utils.data import Dataset
import torch
from .data_loaders import load_eurlex, load_wiki30k, load_huffpost, load_aapd, load_rcv1, \
load_moviesummaries,load_blurbgenrecollection, load_blurbgenrecollection_de, load_20newsgroup,export,\
load_agnews
load_agnews, load_dbpedia, load_ohsumed
# String Mappings
register = {
......@@ -19,7 +19,9 @@ register = {
"blurbgenrecollection": load_blurbgenrecollection,
"blurbgenrecollection_de": load_blurbgenrecollection_de,
"20newsgroup": load_20newsgroup,
"agnews": load_agnews
"agnews": load_agnews,
"dbpedia": load_dbpedia,
"ohsumed": load_ohsumed
}
......@@ -303,4 +305,4 @@ def get_singlelabel_dataset(name, target_dtype=torch._cast_Float):
## Sampler import
from .sampler import sampler, successive_sampler, class_sampler, validation_split
from .data_loaders_text import RawTextDatasetTokenizer, RawTextDataset,RawTextDatasetTensor
\ No newline at end of file
from .data_loaders_text import RawTextDatasetTokenizer, RawTextDataset,RawTextDatasetTensor
......@@ -336,6 +336,7 @@ def load_blurbgenrecollection_de():
for i in soup.findAll("book"):
text.append(i.find("body").text)
labels.append([x.text for x in i.find("categories").findAll("topic")])
labels = [list(set(x)) for x in labels]
if purpose == "dev":
data["valid"] = (text, labels)
else:
......@@ -457,6 +458,135 @@ def load_agnews():
_save_to_tmp("agnews", (data, classes))
return data, classes
def load_dbpedia():
url = "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz"
data = _load_from_tmp("dbpedia")
if data is not None:
return data
else:
with tempfile.TemporaryDirectory() as tmpdir:
resp = urlopen(url)
tf = tarfile.open(fileobj=resp, mode="r|gz")
tf.extractall(Path(tmpdir))
testdir = Path(tmpdir) / 'dbpedia_csv/test.csv'
traindir = Path(tmpdir) / 'dbpedia_csv/train.csv'
classesdir = Path(tmpdir) / 'dbpedia_csv/classes.txt'
#
with open(testdir, "r", encoding="iso-8859-1") as f:
testdata = [x.replace("\n", "").split(',') for x in f.readlines()]
testlabel = [int(x[0]) for x in testdata]
testtitle = [x[1] for x in testdata]
testdescription = [x[2] for x in testdata]
testtext = [" \n ".join([t, d]) for t, d in zip(testtitle, testdescription)]
with open(traindir, "r", encoding="iso-8859-1") as f:
traindata = [x.replace("\n", "").split(',') for x in f.readlines()]
trainlabel = [int(x[0]) for x in traindata]
traintitle = [x[1] for x in traindata]
traindescription = [x[2] for x in traindata]
traintext = [" \n ".join([t, d]) for t, d in zip(traintitle, traindescription)]
with open(classesdir,"r") as f:
classes = [x.replace("\n","") for x in f.readlines()]
classes = dict(zip(classes, range(len(classes))))
rev_classes = {v: k for k, v in classes.items()}
data = {
"train": (traintext, [[rev_classes[x-1]] for x in trainlabel]),
"test": (testtext, [[rev_classes[x-1]] for x in testlabel]),
"test_title": testtitle,
"test_description": testdescription,
"train_title": traintitle,
"train_description": traindescription
}
_save_to_tmp("dbpedia", (data, classes))
return data, classes
def load_ohsumed():
url = "http://disi.unitn.eu/moschitti/corpora/ohsumed-first-20000-docs.tar.gz"
url_classes = "http://disi.unitn.eu/moschitti/corpora/First-Level-Categories-of-Cardiovascular-Disease.txt"
data = _load_from_tmp("ohsumed")
if data is not None:
return data
else:
with tempfile.TemporaryDirectory() as tmpdir:
resp = urlopen(url)
tf = tarfile.open(fileobj=resp, mode="r|gz")
tf.extractall(Path(tmpdir))
testdir = Path(tmpdir) / 'ohsumed-first-20000-docs/test'
traindir = Path(tmpdir) / 'ohsumed-first-20000-docs/training'
testdict = {}
for catg in testdir.iterdir():
catg_name = int(catg.name[1:].replace("0", ""))
for file in catg.iterdir():
if file.name not in testdict:
testdict[file.name] = [catg_name]
else:
if int(catg_name) not in testdict[file.name]:
testdict[file.name].append(catg_name)
traindict = {}
for catg in traindir.iterdir():
catg_name = int(catg.name[1:].replace("0", ""))
for file in catg.iterdir():
if file.name not in traindict:
traindict[file.name] = [catg_name]
else:
if int(catg_name) not in traindict[file.name]:
traindict[file.name].append(catg_name)
testdata, testlabel, testlist = [], [], []
for catg in testdir.iterdir():
for file in catg.iterdir():
if file.name not in testlist:
testlist.append(file.name)
with open(file, 'r') as f:
testdata.append(f.read().split(("\n"), 1))
testlabel.append(testdict.get(file.name))
testtitle = [x[0] for x in testdata]
testdescription = [x[1].replace("\n", "").strip() for x in testdata]
testtext = [" \n ".join([t, d]) for t, d in zip(testtitle, testdescription)]
traindata, trainlabel, trainlist = [], [], []
for catg in traindir.iterdir():
for file in catg.iterdir():
if file.name not in trainlist:
trainlist.append(file.name)
with open(file, 'r') as f:
traindata.append(f.read().split(("\n"), 1))
trainlabel.append(traindict.get(file.name))
traintitle = [x[0] for x in traindata]
traindescription = [x[1].replace("\n", "").strip() for x in traindata]
traintext = [" \n ".join([t, d]) for t, d in zip(traintitle, traindescription)]
classes_file = urlopen(url_classes).read().decode("utf-8")
classes_file = classes_file.split("\n")
classes_list = [x[:-3].strip() for x in classes_file]
classes_list.pop()
classes = dict(zip(classes_list, range(len(classes_list))))
rev_classes = {v: k for k, v in classes.items()}
testlabellist, trainlabellist, tmp = [], [], []
for x in testlabel:
for v in x:
tmp.append(rev_classes[v-1])
testlabellist.append(tmp)
tmp = []
for x in trainlabel:
for v in x:
tmp.append(rev_classes[v-1])
trainlabellist.append(tmp)
tmp = []
data = {
"train": (traintext, trainlabellist),
"test": (testtext, testlabellist),
"test_title": testtitle,
"test_description": testdescription,
"train_title": traintitle,
"train_description": traindescription
}
_save_to_tmp("ohsumed", (data, classes))
return data, classes
def export(data, classes, path=Path("./export")):
path = Path(path)
if not path.exists():
......
......@@ -65,7 +65,7 @@ def subgraphs(classes, graph, depth=1, model="glove50", topk=10, allow_non_alig
import re
e = Embedder(model, device=device, return_device=device)
classes_tokens = [" ".join(re.split("[/ _-]", x.lower())) for x in classes.keys()]
classes_tokens = [" ".join(re.split("[/ _.-]", x.lower())) for x in classes.keys()]
class_embeddings = torch.stack([x.mean(-2) for x in e.embed(classes_tokens, None)],0)
......
......@@ -12,11 +12,12 @@ class PrecisionK(Precision):
super(PrecisionK, self).update((transformed, output[1]))
class AccuracyTreshold(Accuracy):
def __init__(self, trf, args_dict={}, *args, **kwargs):
class AccuracyTreshold():
def __init__(self, trf, args_dict={}):
self.trf = trf
self.args_dict = args_dict
super(AccuracyTreshold, self).__init__(*args, **kwargs)
self.l = []
def update(self, output):
super(AccuracyTreshold, self).update((self.trf(x=output[0], **self.args_dict), output[1]))
self.l.extend((self.trf(output[0], **self.args_dict) == output[1]).all(-1).tolist())
def compute(self):
return sum(self.l)/len(self.l)
\ No newline at end of file
......@@ -100,8 +100,8 @@ class TextClassificationAbstract(torch.nn.Module):
"p@1": PrecisionK(k=1, is_multilabel=True, average=True),
"p@3": PrecisionK(k=3, is_multilabel=True, average=True),
"p@5": PrecisionK(k=5, is_multilabel=True, average=True),
"tr@0.5": AccuracyTreshold(trf=threshold_hard, args_dict={"tr": 0.5}, is_multilabel=True),
"mcut": AccuracyTreshold(trf=threshold_mcut, is_multilabel=True),
"tr@0.5": AccuracyTreshold(trf=threshold_hard, args_dict={"tr": 0.5}),
"mcut": AccuracyTreshold(trf=threshold_mcut),
"auc_roc": AUC_ROC(len(self.classes), return_roc=return_roc),
}
if return_report:
......@@ -114,7 +114,7 @@ class TextClassificationAbstract(torch.nn.Module):
del multilabel_metrics["p@3"]
singlelabel_metrics = {
"accuracy": AccuracyTreshold(threshold_max, is_multilabel=False)
"accuracy": AccuracyTreshold(threshold_max)
}
metrics = multilabel_metrics
......@@ -273,6 +273,7 @@ class TextClassificationAbstract(torch.nn.Module):
"""
self.eval()
if not hasattr(self, "classes_rev"):
self.classes_rev = {v: k for k, v in self.classes.items()}
x = self.transform(x).to(self.device)
......
from mlmc.metrics import AccuracyTreshold
from mlmc.representation import threshold_max, threshold_hard, threshold_mcut
import torch
def test_AccuracyTresholdId():
d = AccuracyTreshold(lambda x: x)
example1 = (torch.tensor([[0,1,0], [0,1,0]]), torch.tensor([[0,1,0], [0,0,1]]) )
example2 = (torch.tensor([[0,1,0], [0,1,0]]), torch.tensor([[0,1,0], [0,1,0]]) )
d.update(example1)
assert d.compute() == 0.5
d.update(example2)
assert d.compute() == 0.75
def test_AccuracyTreshold():
d = AccuracyTreshold(threshold_max)
example1 = (torch.tensor([[0,0.5,0], [0.3,0.7,0.1]]), torch.tensor([[0,1,0], [0,0,1]]) )
example2 = (torch.tensor([[0,0.51,0.5], [-0.7,-0.5,-0.9]]), torch.tensor([[0,1,0], [0,1,0]]) )
d.update(example1)
assert d.compute() == 0.5
d.update(example2)
assert d.compute() == 0.75
d = AccuracyTreshold(threshold_hard, args_dict={"tr":0.5})
example1 = (torch.tensor([[0,0.51,0], [0.3,0.7,0.1]]), torch.tensor([[0,1,0], [0,0,1]]) )
example2 = (torch.tensor([[0,0.51,0.5], [-0.7,-0.5,-0.9]]), torch.tensor([[0,1,0], [0,1,0]]) )
d.update(example1)
assert d.compute() == 0.5
d.update(example2)
assert d.compute() == 0.5
d = AccuracyTreshold(threshold_mcut)
example1 = (torch.tensor([[0,0.7,0.001], [0.3,0.7,0.1]]), torch.tensor([[0,1,0], [0,0,1]]) )
example2 = (torch.tensor([[0,0.51,0.5], [-0.7,-0.5,-0.9]]), torch.tensor([[0,1,0], [0,1,0]]) )
d.update(example1)
assert d.compute() == 0.5
d.update(example2)
assert d.compute() == 0.5
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment