Commit dc245dd4 authored by Janos Borst's avatar Janos Borst
Browse files

unsupervised

parent 94c11257
Pipeline #52996 passed with stage
in 11 minutes and 3 seconds
...@@ -159,9 +159,9 @@ class KeywordCoherence(): ...@@ -159,9 +159,9 @@ class KeywordCoherence():
kw_acc = sum([(x.item() in y) for x,y in zip(counts.sum(-1).argmax(-1), predictions)]) / len(predictions) kw_acc = sum([(x.item() in y) for x,y in zip(counts.sum(-1).argmax(-1), predictions)]) / len(predictions)
kw_sw = [sum([(x.item() in y) for x,y in zip(counts[:,:,(i-idx):i].sum(-1).argmax(-1), predictions)]) / len(predictions) for i in range(min(idx,ndx-1),ndx)] kw_sw = [sum([(x.item() in y) for x,y in zip(counts[:,:,(i-idx):i].sum(-1).argmax(-1), predictions)]) / len(predictions) for i in range(min(idx,ndx-1),ndx)]
import matplotlib.pyplot as plt # import matplotlib.pyplot as plt
plt.plot(kw_sw) # plt.plot(kw_sw)
plt.show() # plt.show()
return {"kw_acc": kw_acc, "kwsw_aoc":sum(kw_sw) / len(kw_sw), "kw_sw": kw_sw, } return {"kw_acc": kw_acc, "kwsw_aoc":sum(kw_sw) / len(kw_sw), "kw_sw": kw_sw, }
......
...@@ -379,7 +379,7 @@ class TextClassificationAbstract(torch.nn.Module): ...@@ -379,7 +379,7 @@ class TextClassificationAbstract(torch.nn.Module):
if log_mlflow: if log_mlflow:
import mlflow import mlflow
mlflow.log_metric(f"{valid_prefix}_loss" ,valid_loss, step=e) mlflow.log_metric(f"{valid_prefix}_loss" ,valid_loss, step=e)
result_metrics.log_mlflow(step=e, prefix=valid_prefix) result_metrics.log_mlflow(step=e, prefix=valid_prefix, model=self)
valid_loss_dict = {f"{valid_prefix}_loss": valid_loss} valid_loss_dict = {f"{valid_prefix}_loss": valid_loss}
valid_loss_dict.update(result_metrics.compute(model=self)) valid_loss_dict.update(result_metrics.compute(model=self))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment