Commit 6c25917a authored by Janos Borst's avatar Janos Borst
Browse files

Auto stash before rebase of "origin/dev"

parent 787624e6
Pipeline #45363 failed with stage
in 15 minutes and 19 seconds
......@@ -15,6 +15,7 @@ import mlmc.models
import mlmc.graph
import mlmc.metrics
import mlmc.representation
import mlmc.modules
# Save and load models for inference
from .save_and_load import save, load
......
......@@ -3,5 +3,9 @@ class Callback:
def __init__(self):
self.name = "Callback"
def _on_epoch_end(self, *args, **kwargs):
def _on_epoch_end(self, model):
pass
def _on_train_end(self, model):
pass
def _on_epoch_start(self, model):
pass
......@@ -108,3 +108,13 @@ class ZAGCNNLM(TextClassificationAbstractGraph, TextClassificationAbstractZeroSh
self.label_dict = self.create_label_dict()
self.label_embeddings = torch.stack([self.label_dict[cls] for cls in classes.keys()])
self.label_embeddings = self.label_embeddings.to(self.device)
if not hasattr(self, "_trained_classes"):
self._trained_classes = []
#Auxiliary values
l = list(classes.items())
l.sort(key=lambda x: x[1])
self._config["zeroshot_ind"] = torch.LongTensor([1 if x[0] in self._trained_classes else 0 for x in l])
self._config["mixed_shot"] = not (self._config["zeroshot_ind"].sum() == 0 or self._config["zeroshot_ind"].sum() == self._config["zeroshot_ind"].shape[
0]).item() # maybe obsolete?
......@@ -257,6 +257,10 @@ class TextClassificationAbstract(torch.nn.Module):
for cb in callbacks:
if hasattr(cb, "on_epoch_end"):
cb.on_epoch_end(self)
def _callback_train_end(self, callbacks):
for cb in callbacks:
if hasattr(cb, "on_train_end"):
cb.on_epoch_end(self)
def _callback_epoch_start(self, callbacks):
# TODO: Documentation
for cb in callbacks:
......
......@@ -9,6 +9,8 @@ try:
from apex import amp
except:
pass
from ...data import is_multilabel
class TextClassificationAbstractZeroShot(torch.nn.Module):
"""
......@@ -62,20 +64,32 @@ class TextClassificationAbstractZeroShot(torch.nn.Module):
}
return printable
def _zeroshot_fit(self,*args, **kwargs):
# TODO: Documentation
return self.zeroshot_fit_sacred(_run=None, *args,**kwargs)
def zeroshot_fit_sacred(self, data, epochs=10, batch_size=16, _run=None, metrics=None, callbacks=None):
def zeroshot_fit_sacred(self, data, epochs=10, batch_size=16, _run=None, metrics=None, callbacks=None, log=False):
histories = {"train": [], "gzsl": [], "zsl": [], "nsl": []}
if "trained_classes" not in self._config:
self._config["trained_classes"] = []
self._config["trained_classes"].extend(list(data["train"].classes.keys()))
self._config["trained_classes"] = list(set(self._config["trained_classes"]))
for i in range(epochs):
self.create_labels(data["train"].classes)
if is_multilabel(data["train"]):
self.multi()
else:
self.single()
history = self.fit(data["train"],
batch_size=batch_size, epochs=1, metrics=metrics, callbacks=callbacks)
if _run is not None: _run.log_scalar("train_loss", history["train"]["loss"][0], i)
self.create_labels(data["valid_gzsl"].classes)
if is_multilabel(data["valid_gzsl"]):
self.multi()
else:
self.single()
gzsl_loss, GZSL = self.evaluate(data["valid_gzsl"], batch_size=batch_size, metrics=metrics,_fit=True)
if _run is not None: GZSL.log_sacred(_run, i, "gzsl")
GZSL_comp = GZSL.compute()
......@@ -83,6 +97,10 @@ class TextClassificationAbstractZeroShot(torch.nn.Module):
histories["gzsl"][-1].update(GZSL_comp)
self.create_labels(data["valid_zsl"].classes)
if is_multilabel(data["valid_zsl"]):
self.multi()
else:
self.single()
zsl_loss, ZSL = self.evaluate(data["valid_zsl"], batch_size=batch_size, metrics=metrics,_fit=True)
if _run is not None: ZSL.log_sacred(_run, i, "zsl")
ZSL_comp = ZSL.compute()
......@@ -90,6 +108,10 @@ class TextClassificationAbstractZeroShot(torch.nn.Module):
histories["zsl"][-1].update(ZSL_comp)
self.create_labels(data["valid_nsl"].classes)
if is_multilabel(data["valid_nsl"]):
self.multi()
else:
self.single()
nsl_loss, NSL = self.evaluate(data["valid_nsl"], batch_size=batch_size, metrics=metrics,_fit=True)
if _run is not None: NSL.log_sacred(_run, i, "nsl")
NSL_comp = NSL.compute()
......@@ -104,14 +126,26 @@ class TextClassificationAbstractZeroShot(torch.nn.Module):
print("========================================================================================\n")
self.create_labels(data["test_gzsl"].classes)
if is_multilabel(data["test_gzsl"]):
self.multi()
else:
self.single()
gzsl_loss, GZSL = self.evaluate(data["test_gzsl"], batch_size=batch_size,_fit=True)
if _run is not None: GZSL.log_sacred(_run, epochs, "gzsl")
self.create_labels(data["test_zsl"].classes)
if is_multilabel(data["test_zsl"]):
self.multi()
else:
self.single()
zsl_loss, ZSL = self.evaluate(data["test_zsl"], batch_size=batch_size,_fit=True)
if _run is not None: ZSL.log_sacred(_run, epochs, "zsl")
self.create_labels(data["test_nsl"].classes)
if is_multilabel(data["test_nsl"]):
self.multi()
else:
self.single()
nsl_loss, NSL = self.evaluate(data["test_nsl"], batch_size=batch_size,_fit=True)
if _run is not None: NSL.log_sacred(_run, epochs, "nsl")
......@@ -152,6 +186,21 @@ class TextClassificationAbstractZeroShot(torch.nn.Module):
l.sort(key=lambda x: x[1])
#Auxiliary values
self._zeroshot_ind = torch.LongTensor([1 if x[0] in self._trained_classes else 0 for x in l])
self._mixed_shot = not (self._zeroshot_ind.sum() == 0 or self._zeroshot_ind.sum() == self._zeroshot_ind.shape[
self._config["zeroshot_ind"] = torch.LongTensor([1 if x[0] in self._trained_classes else 0 for x in l])
self._config["mixed_shot"] = not (self._config["zeroshot_ind"].sum() == 0 or self._config["zeroshot_ind"].sum() == self._config["zeroshot_ind"].shape[
0]).item() # maybe obsolete?
def single(self):
self._config["target"] = "single"
self.target = "single"
self.set_threshold("max")
self.activation = torch.softmax
self.loss = torch.nn.CrossEntropyLoss()
self.build()
def multi(self):
self._config["target"] = "multi"
self.target = "multi"
self.set_threshold("mcut")
self.activation = torch.sigmoid
self.loss = torch.nn.BCEWithLogitsLoss()
self.build()
\ No newline at end of file
import mlmc
import torch
from mlmc_lab.mlmc_experimental.models import GR_ranking
from mlmc_lab.mlmc_experimental.loss.LabelwiseRankingLoss import LabelRankingLoss
import mlmc_lab
run=None
percentage=0.0
dataset=""
data = None
graph = "random"
graph_n = 1000
graph_dim = 300
graph_density = 0.2
epochs = 15
batch_size = 50
representation = "google/bert_uncased_L-2_H-768_A-12" # "distilroberta-base"# #"distilroberta-base"# "google/bert_uncased_L-2_H-768_A-12"#"google/bert_uncased_L-2_H-128_A-2"#"google/bert_uncased_L-4_H-256_A-4"
finetune = True
device = "cuda:1"
optimizer = torch.optim.Adam
optimizer_params = {"lr": 1e-5}
decision_noise = 0.015
zsdata = mlmc.data.get("rcv1")
gr = GR_ranking(classes=zsdata["train"].classes,
graph_n = graph_n,
graph_dim = graph_dim,
graph_density = graph_density,
loss=LabelRankingLoss(logits=True, add_categorical=2.0, threshold="mcut"),#loss,#torch.nn.BCEWithLogitsLoss if mlmc.data.is_multilabel(zsdata["train"]) else torch.nn.CrossEntropyLoss,
target="multi" if mlmc.data.is_multilabel(zsdata["train"]) else "single",
representation = representation,
finetune=finetune,
device=device,
optimizer=optimizer,
optimizer_params=optimizer_params,
decision_noise=decision_noise)
zsdata["valid"] = mlmc.data.sampler(zsdata["test"], absolute=10000)
d = mlmc_lab.mlmc_experimental.data.ZeroshotDataset(zsdata, zeroshot_classes=mlmc_lab.constants.ZEROSHOT_10["rcv1"])
data = {"GZSL": d["valid_gzsl"],
"ZSL": d["valid_zsl"],
"NSL": d["valid_nsl"]}
gr._zeroshot_fit(d)
gr.plot_weights(data)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment