Commit dc245dd4 authored by Janos Borst's avatar Janos Borst
Browse files


parent 94c11257
Pipeline #52996 passed with stage
in 11 minutes and 3 seconds
......@@ -159,9 +159,9 @@ class KeywordCoherence():
kw_acc = sum([(x.item() in y) for x,y in zip(counts.sum(-1).argmax(-1), predictions)]) / len(predictions)
kw_sw = [sum([(x.item() in y) for x,y in zip(counts[:,:,(i-idx):i].sum(-1).argmax(-1), predictions)]) / len(predictions) for i in range(min(idx,ndx-1),ndx)]
import matplotlib.pyplot as plt
# import matplotlib.pyplot as plt
# plt.plot(kw_sw)
return {"kw_acc": kw_acc, "kwsw_aoc":sum(kw_sw) / len(kw_sw), "kw_sw": kw_sw, }
......@@ -379,7 +379,7 @@ class TextClassificationAbstract(torch.nn.Module):
if log_mlflow:
import mlflow
mlflow.log_metric(f"{valid_prefix}_loss" ,valid_loss, step=e)
result_metrics.log_mlflow(step=e, prefix=valid_prefix)
result_metrics.log_mlflow(step=e, prefix=valid_prefix, model=self)
valid_loss_dict = {f"{valid_prefix}_loss": valid_loss}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment