Commit 183eb319 authored by Janos Borst's avatar Janos Borst
Browse files

lr scheduling added

parent 05402770
Pipeline #45902 failed with stage
in 11 minutes and 41 seconds
......@@ -272,7 +272,7 @@ class TextClassificationAbstract(torch.nn.Module):
def fit(self, train,
valid=None, epochs=1, batch_size=16, valid_batch_size=50, patience=-1, tolerance=1e-2,
return_roc=False, return_report=False, callbacks=None, metrics=None):
return_roc=False, return_report=False, callbacks=None, metrics=None, lr_schedule=None, lr_param ={}):
"""
Training function
......@@ -310,6 +310,8 @@ class TextClassificationAbstract(torch.nn.Module):
best_loss = 10000000
last_best_loss_update = 0
if lr_schedule is not None:
scheduler = lr_schedule(self.optimizer, **lr_param)
for e in range(epochs):
self._callback_epoch_start(callbacks)
......@@ -319,6 +321,7 @@ class TextClassificationAbstract(torch.nn.Module):
with tqdm(train_loader,
postfix=[losses], desc="Epoch %i/%i" % (e + 1, epochs), ncols=100) as pbar:
loss = self._epoch(train_loader, pbar=pbar)
if lr_schedule is not None: scheduler.step()
train_history["loss"].append(loss)
# Validation if available
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment