Commit 2b29c0d4 authored by Janos Borst's avatar Janos Borst
Browse files

Some DOcstrings

parent 272cf3d8
......@@ -26,6 +26,7 @@ class TextClassificationAbstract(torch.nn.Module):
self.PRECISION_DIGITS = 4
def build(self):
"""internal build method"""
if isinstance(self.loss, type) and self.loss is not None:
self.loss = self.loss().to(self.device)
if isinstance(self.optimizer, type) and self.optimizer is not None:
......@@ -33,6 +34,7 @@ class TextClassificationAbstract(torch.nn.Module):
self.to(self.device)
def evaluate_classes(self, classes_subset=None, **kwargs):
"""wrapper for evaluation function if you just want to evaluate on subsets of the classes"""
if classes_subset is None:
return self.evaluate(**kwargs)
else:
......@@ -41,7 +43,7 @@ class TextClassificationAbstract(torch.nn.Module):
def evaluate(self, data, batch_size=50, return_roc=False, return_report=False, mask=None):
"""
Evaluation, return accuracy and loss
Evaluation, return accuracy and loss and some multilabel measure
"""
self.eval() # set mode to evaluation to disable dropout
p_1 = Precision(is_multilabel=True,average=True)
......@@ -133,6 +135,7 @@ class TextClassificationAbstract(torch.nn.Module):
def predict(self, x, return_scores=False, tr=0.65, method="hard"):
"""Classifiy sentence string or a list of strings."""
self.eval()
if not hasattr(self, "classes_rev"):
self.classes_rev = {v: k for k, v in self.classes.items()}
......@@ -144,7 +147,8 @@ class TextClassificationAbstract(torch.nn.Module):
return [[(self.classes_rev[i.item()], s[i].item()) for i in torch.where(p==1)[0]] for s, p in zip(output,prediction)]
return [[self.classes_rev[i.item()] for i in torch.where(p==1)[0]] for p in prediction]
def predict_dataset(self, data, batch_size=50, tr=0.65, method="hard"):
def predict_dataset(self, data, batch_size=50, tr=0.5, method="hard"):
"""Predict all labels for a dataset int the mlmc.data.MultilabelDataset format."""
train_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)
predictions = []
for b in tqdm(train_loader):
......@@ -152,6 +156,11 @@ class TextClassificationAbstract(torch.nn.Module):
return predictions
def threshold(self, x, tr=0.5, method="hard"):
"""Thresholding function for outputs of the neural network.
So far a hard threshold ( tr=0.5, method="hard") is supported and
dynamic cutting (method="mcut")
"""
if method=="hard":
return (x>tr).int()
if method=="mcut":
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment