Commit b1eb42aa authored by Janos Borst's avatar Janos Borst
Browse files

minor issues

parent c75481e3
Pipeline #49954 failed with stage
in 9 minutes and 25 seconds
......@@ -25,20 +25,8 @@ except:
print("pytorch_geometric not installed.")
pass
def finetune_mixed_precision_model(model, finetune=True):
"""
Sets a model to use FP16 where appropriate to save memory and speed up training.
:param model: A model instance
:return: A model with initialized Automatic Mixed Precision
"""
try:
from apex import amp
model.use_amp=True
opt = model.optimizer.__class__(filter(lambda p: p.requires_grad, model.parameters()), **model.optimizer_params)
model, opt = amp.initialize(model, opt,opt_level="O2",
keep_batchnorm_fp32=True, loss_scale="dynamic")
model.optimizer = opt
except ModuleNotFoundError:
model.use_amp = False
return model
def get(name: str):
import mlmc.models as mm
fct = getattr(mm, name.capitalize())
return fct
\ No newline at end of file
from .abstract_textclassification import TextClassificationAbstract
\ No newline at end of file
from .abstract_textclassification import TextClassificationAbstract
from .abstract_label import LabelEmbeddingAbstract
from .abstract_encoder import EncoderAbstract
from .abstracts_zeroshot import TextClassificationAbstractZeroShot
from .abstracts_graph import TextClassificationAbstractGraph
from .abstract_sentence import SentenceTextClassificationAbstract
\ No newline at end of file
......@@ -42,7 +42,8 @@ class LabelEmbeddingAbstract(TextClassificationAbstract):
"""
self._config["classes"] = classes
self._config["n_classes"] = self.n_classes
self.classes = classes
self._config["n_classes"] = len(self._config["classes"])
if isinstance(self._config["classes"], dict):
self.classes_rev = {v: k for k, v in self._config["classes"].items()}
......
......@@ -95,10 +95,10 @@ class TextClassificationAbstract(torch.nn.Module):
if threshold is not None:
self.set_threshold(threshold)
assert not (self.loss is torch.nn.BCEWithLogitsLoss and target == "single"), \
assert not (self._config["loss"] is torch.nn.BCEWithLogitsLoss and self._config["target"] == "single"), \
"You are using BCE with a single label target. " \
"Not possible, please use torch.nn.CrossEntropy with a single label target."
assert not (self.loss is torch.nn.CrossEntropyLoss and target == "multi"), \
assert not (self._config["loss"] is torch.nn.CrossEntropyLoss and self._config["target"] == "multi"), \
"You are using CrossEntropy with a multi label target. " \
"Not possible, please use torch.nn.BCELossWithLogits with a multi label target."
......@@ -120,10 +120,9 @@ class TextClassificationAbstract(torch.nn.Module):
:param name: Name of the threshold (see mlmc.thresholds.threshold_dict.keys())
"""
self.threshold = name
self._config["threshold"] = name
if isinstance(name, str):
self._threshold_fct = thresholdget(name)
self._threshold_fct = thresholdget(self._config["threshold"])
elif callable(name):
self._threshold_fct = name
else:
......@@ -135,18 +134,23 @@ class TextClassificationAbstract(torch.nn.Module):
def set_loss(self, loss):
self._config["loss"] = loss
self.loss = loss
if isinstance(self._config["loss"], type) and self._config["loss"] is not None:
self.loss = self._config["loss"]().to(self.device)
else:
self.loss = self._config["loss"].to(self.device)
def build(self):
"""
Internal build method.
"""
if isinstance(self.loss, type) and self.loss is not None:
self.loss = self.loss().to(self.device)
if isinstance(self._config["loss"], type) and self._config["loss"] is not None:
self.loss = self._config["loss"]().to(self.device)
else:
self.loss = self._config["loss"].to(self.device)
if isinstance(self.optimizer, type) and self.optimizer is not None:
self.optimizer = self.optimizer(filter(lambda p: p.requires_grad, self.parameters()),
**self.optimizer_params)
self.optimizer = self._config["optimizer"](filter(lambda p: p.requires_grad, self.parameters()), **self.optimizer_params)
else:
self.optimizer = self._config["optimizer"]
self.to(self.device)
def _init_metrics(self, metrics=None):
......
from mlmc.models.abstracts.abstract_encoder import EncoderAbstract
from ...abstracts.abstracts_zeroshot import TextClassificationAbstractZeroShot
import torch
class SimpleEncoder(EncoderAbstract):
class SimpleEncoder(EncoderAbstract, TextClassificationAbstractZeroShot):
"""
Trainin a model by entailing text and label into an entailment task. Offers good zeroshot capacities when pretrained
on an NLI task. (you can pretrain (almost) any transformer model with model.pretrain_snli() or model.pretrain_mnli().
......
......@@ -19,4 +19,5 @@ torch>=1.5.0
beautifulsoup4~=4.9.3
requests~=2.25.1
dgl~=0.5.2
dill
\ No newline at end of file
dill
datasets
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment