Commit 97150af8 authored by Janos Borst's avatar Janos Borst
Browse files

Added provisoin for label length argument

parent 6c25917a
......@@ -22,7 +22,7 @@ class TextClassificationAbstract(torch.nn.Module):
"""
def __init__(self, classes, target="multi", representation="google/bert_uncased_L-2_H-128_A-2",
activation=None, loss=None, optimizer=torch.optim.Adam, max_len=200,
activation=None, loss=None, optimizer=torch.optim.Adam, max_len=200, label_len=20,
optimizer_params=None, device="cpu", finetune=False, threshold="mcut", n_layers=1, **kwargs):
"""
Abstract initializer of a Text Classification network.
......@@ -30,6 +30,7 @@ class TextClassificationAbstract(torch.nn.Module):
classes: A dictionary of classes and ther corresponding index. This argument is mandatory.
representation: The string of the input representation. (Supporting the full transformers list, and glove50, glove100, glove200, glove300)
max_len: The maximum number of tokens for the input.
label_len: The maximum number of tokens for labels.
target: single label oder multilabel mode. defined by keystrings: ("single", "multi").
Sets some basic options, like loss function, activation and
metrics to sensible defaults.
......@@ -91,7 +92,8 @@ class TextClassificationAbstract(torch.nn.Module):
"activation":self.activation, "loss": self.loss,
"optimizer": self.optimizer, "max_len": self.max_len,
"optimizer_params": self.optimizer_params, "device": self.device,
"finetune": finetune, "threshold": threshold, "n_layers": self.n_layers
"finetune": finetune, "threshold": threshold, "n_layers": self.n_layers,
"label_len":label_len,
}
self._config.update(kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment