Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
asv-ml
mlmc
Commits
6c25917a
Commit
6c25917a
authored
Jan 27, 2021
by
Janos Borst
Browse files
Auto stash before rebase of "origin/dev"
parent
787624e6
Pipeline
#45363
failed with stage
in 15 minutes and 19 seconds
Changes
6
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
mlmc/__init__.py
View file @
6c25917a
...
...
@@ -15,6 +15,7 @@ import mlmc.models
import
mlmc.graph
import
mlmc.metrics
import
mlmc.representation
import
mlmc.modules
# Save and load models for inference
from
.save_and_load
import
save
,
load
...
...
mlmc/callbacks/__init__.py
View file @
6c25917a
...
...
@@ -3,5 +3,9 @@ class Callback:
def
__init__
(
self
):
self
.
name
=
"Callback"
def
_on_epoch_end
(
self
,
*
args
,
**
kwargs
):
def
_on_epoch_end
(
self
,
model
):
pass
def
_on_train_end
(
self
,
model
):
pass
def
_on_epoch_start
(
self
,
model
):
pass
mlmc/models/ZAGCNNLM.py
View file @
6c25917a
...
...
@@ -108,3 +108,13 @@ class ZAGCNNLM(TextClassificationAbstractGraph, TextClassificationAbstractZeroSh
self
.
label_dict
=
self
.
create_label_dict
()
self
.
label_embeddings
=
torch
.
stack
([
self
.
label_dict
[
cls
]
for
cls
in
classes
.
keys
()])
self
.
label_embeddings
=
self
.
label_embeddings
.
to
(
self
.
device
)
if
not
hasattr
(
self
,
"_trained_classes"
):
self
.
_trained_classes
=
[]
#Auxiliary values
l
=
list
(
classes
.
items
())
l
.
sort
(
key
=
lambda
x
:
x
[
1
])
self
.
_config
[
"zeroshot_ind"
]
=
torch
.
LongTensor
([
1
if
x
[
0
]
in
self
.
_trained_classes
else
0
for
x
in
l
])
self
.
_config
[
"mixed_shot"
]
=
not
(
self
.
_config
[
"zeroshot_ind"
].
sum
()
==
0
or
self
.
_config
[
"zeroshot_ind"
].
sum
()
==
self
.
_config
[
"zeroshot_ind"
].
shape
[
0
]).
item
()
# maybe obsolete?
mlmc/models/abstracts/abstracts.py
View file @
6c25917a
...
...
@@ -257,6 +257,10 @@ class TextClassificationAbstract(torch.nn.Module):
for
cb
in
callbacks
:
if
hasattr
(
cb
,
"on_epoch_end"
):
cb
.
on_epoch_end
(
self
)
def
_callback_train_end
(
self
,
callbacks
):
for
cb
in
callbacks
:
if
hasattr
(
cb
,
"on_train_end"
):
cb
.
on_epoch_end
(
self
)
def
_callback_epoch_start
(
self
,
callbacks
):
# TODO: Documentation
for
cb
in
callbacks
:
...
...
mlmc/models/abstracts/abstracts_zeroshot.py
View file @
6c25917a
...
...
@@ -9,6 +9,8 @@ try:
from
apex
import
amp
except
:
pass
from
...data
import
is_multilabel
class
TextClassificationAbstractZeroShot
(
torch
.
nn
.
Module
):
"""
...
...
@@ -62,20 +64,32 @@ class TextClassificationAbstractZeroShot(torch.nn.Module):
}
return
printable
def
_zeroshot_fit
(
self
,
*
args
,
**
kwargs
):
# TODO: Documentation
return
self
.
zeroshot_fit_sacred
(
_run
=
None
,
*
args
,
**
kwargs
)
def
zeroshot_fit_sacred
(
self
,
data
,
epochs
=
10
,
batch_size
=
16
,
_run
=
None
,
metrics
=
None
,
callbacks
=
None
):
def
zeroshot_fit_sacred
(
self
,
data
,
epochs
=
10
,
batch_size
=
16
,
_run
=
None
,
metrics
=
None
,
callbacks
=
None
,
log
=
False
):
histories
=
{
"train"
:
[],
"gzsl"
:
[],
"zsl"
:
[],
"nsl"
:
[]}
if
"trained_classes"
not
in
self
.
_config
:
self
.
_config
[
"trained_classes"
]
=
[]
self
.
_config
[
"trained_classes"
].
extend
(
list
(
data
[
"train"
].
classes
.
keys
()))
self
.
_config
[
"trained_classes"
]
=
list
(
set
(
self
.
_config
[
"trained_classes"
]))
for
i
in
range
(
epochs
):
self
.
create_labels
(
data
[
"train"
].
classes
)
if
is_multilabel
(
data
[
"train"
]):
self
.
multi
()
else
:
self
.
single
()
history
=
self
.
fit
(
data
[
"train"
],
batch_size
=
batch_size
,
epochs
=
1
,
metrics
=
metrics
,
callbacks
=
callbacks
)
if
_run
is
not
None
:
_run
.
log_scalar
(
"train_loss"
,
history
[
"train"
][
"loss"
][
0
],
i
)
self
.
create_labels
(
data
[
"valid_gzsl"
].
classes
)
if
is_multilabel
(
data
[
"valid_gzsl"
]):
self
.
multi
()
else
:
self
.
single
()
gzsl_loss
,
GZSL
=
self
.
evaluate
(
data
[
"valid_gzsl"
],
batch_size
=
batch_size
,
metrics
=
metrics
,
_fit
=
True
)
if
_run
is
not
None
:
GZSL
.
log_sacred
(
_run
,
i
,
"gzsl"
)
GZSL_comp
=
GZSL
.
compute
()
...
...
@@ -83,6 +97,10 @@ class TextClassificationAbstractZeroShot(torch.nn.Module):
histories
[
"gzsl"
][
-
1
].
update
(
GZSL_comp
)
self
.
create_labels
(
data
[
"valid_zsl"
].
classes
)
if
is_multilabel
(
data
[
"valid_zsl"
]):
self
.
multi
()
else
:
self
.
single
()
zsl_loss
,
ZSL
=
self
.
evaluate
(
data
[
"valid_zsl"
],
batch_size
=
batch_size
,
metrics
=
metrics
,
_fit
=
True
)
if
_run
is
not
None
:
ZSL
.
log_sacred
(
_run
,
i
,
"zsl"
)
ZSL_comp
=
ZSL
.
compute
()
...
...
@@ -90,6 +108,10 @@ class TextClassificationAbstractZeroShot(torch.nn.Module):
histories
[
"zsl"
][
-
1
].
update
(
ZSL_comp
)
self
.
create_labels
(
data
[
"valid_nsl"
].
classes
)
if
is_multilabel
(
data
[
"valid_nsl"
]):
self
.
multi
()
else
:
self
.
single
()
nsl_loss
,
NSL
=
self
.
evaluate
(
data
[
"valid_nsl"
],
batch_size
=
batch_size
,
metrics
=
metrics
,
_fit
=
True
)
if
_run
is
not
None
:
NSL
.
log_sacred
(
_run
,
i
,
"nsl"
)
NSL_comp
=
NSL
.
compute
()
...
...
@@ -104,14 +126,26 @@ class TextClassificationAbstractZeroShot(torch.nn.Module):
print
(
"========================================================================================
\n
"
)
self
.
create_labels
(
data
[
"test_gzsl"
].
classes
)
if
is_multilabel
(
data
[
"test_gzsl"
]):
self
.
multi
()
else
:
self
.
single
()
gzsl_loss
,
GZSL
=
self
.
evaluate
(
data
[
"test_gzsl"
],
batch_size
=
batch_size
,
_fit
=
True
)
if
_run
is
not
None
:
GZSL
.
log_sacred
(
_run
,
epochs
,
"gzsl"
)
self
.
create_labels
(
data
[
"test_zsl"
].
classes
)
if
is_multilabel
(
data
[
"test_zsl"
]):
self
.
multi
()
else
:
self
.
single
()
zsl_loss
,
ZSL
=
self
.
evaluate
(
data
[
"test_zsl"
],
batch_size
=
batch_size
,
_fit
=
True
)
if
_run
is
not
None
:
ZSL
.
log_sacred
(
_run
,
epochs
,
"zsl"
)
self
.
create_labels
(
data
[
"test_nsl"
].
classes
)
if
is_multilabel
(
data
[
"test_nsl"
]):
self
.
multi
()
else
:
self
.
single
()
nsl_loss
,
NSL
=
self
.
evaluate
(
data
[
"test_nsl"
],
batch_size
=
batch_size
,
_fit
=
True
)
if
_run
is
not
None
:
NSL
.
log_sacred
(
_run
,
epochs
,
"nsl"
)
...
...
@@ -152,6 +186,21 @@ class TextClassificationAbstractZeroShot(torch.nn.Module):
l
.
sort
(
key
=
lambda
x
:
x
[
1
])
#Auxiliary values
self
.
_zeroshot_ind
=
torch
.
LongTensor
([
1
if
x
[
0
]
in
self
.
_trained_classes
else
0
for
x
in
l
])
self
.
_mixed_shot
=
not
(
self
.
_zeroshot_ind
.
sum
()
==
0
or
self
.
_zeroshot_ind
.
sum
()
==
self
.
_zeroshot_ind
.
shape
[
self
.
_
config
[
"
zeroshot_ind
"
]
=
torch
.
LongTensor
([
1
if
x
[
0
]
in
self
.
_trained_classes
else
0
for
x
in
l
])
self
.
_
config
[
"
mixed_shot
"
]
=
not
(
self
.
_
config
[
"
zeroshot_ind
"
]
.
sum
()
==
0
or
self
.
_
config
[
"
zeroshot_ind
"
]
.
sum
()
==
self
.
_
config
[
"
zeroshot_ind
"
]
.
shape
[
0
]).
item
()
# maybe obsolete?
def
single
(
self
):
self
.
_config
[
"target"
]
=
"single"
self
.
target
=
"single"
self
.
set_threshold
(
"max"
)
self
.
activation
=
torch
.
softmax
self
.
loss
=
torch
.
nn
.
CrossEntropyLoss
()
self
.
build
()
def
multi
(
self
):
self
.
_config
[
"target"
]
=
"multi"
self
.
target
=
"multi"
self
.
set_threshold
(
"mcut"
)
self
.
activation
=
torch
.
sigmoid
self
.
loss
=
torch
.
nn
.
BCEWithLogitsLoss
()
self
.
build
()
\ No newline at end of file
mlmc_lab_debug.py
0 → 100644
View file @
6c25917a
import
mlmc
import
torch
from
mlmc_lab.mlmc_experimental.models
import
GR_ranking
from
mlmc_lab.mlmc_experimental.loss.LabelwiseRankingLoss
import
LabelRankingLoss
import
mlmc_lab
run
=
None
percentage
=
0.0
dataset
=
""
data
=
None
graph
=
"random"
graph_n
=
1000
graph_dim
=
300
graph_density
=
0.2
epochs
=
15
batch_size
=
50
representation
=
"google/bert_uncased_L-2_H-768_A-12"
# "distilroberta-base"# #"distilroberta-base"# "google/bert_uncased_L-2_H-768_A-12"#"google/bert_uncased_L-2_H-128_A-2"#"google/bert_uncased_L-4_H-256_A-4"
finetune
=
True
device
=
"cuda:1"
optimizer
=
torch
.
optim
.
Adam
optimizer_params
=
{
"lr"
:
1e-5
}
decision_noise
=
0.015
zsdata
=
mlmc
.
data
.
get
(
"rcv1"
)
gr
=
GR_ranking
(
classes
=
zsdata
[
"train"
].
classes
,
graph_n
=
graph_n
,
graph_dim
=
graph_dim
,
graph_density
=
graph_density
,
loss
=
LabelRankingLoss
(
logits
=
True
,
add_categorical
=
2.0
,
threshold
=
"mcut"
),
#loss,#torch.nn.BCEWithLogitsLoss if mlmc.data.is_multilabel(zsdata["train"]) else torch.nn.CrossEntropyLoss,
target
=
"multi"
if
mlmc
.
data
.
is_multilabel
(
zsdata
[
"train"
])
else
"single"
,
representation
=
representation
,
finetune
=
finetune
,
device
=
device
,
optimizer
=
optimizer
,
optimizer_params
=
optimizer_params
,
decision_noise
=
decision_noise
)
zsdata
[
"valid"
]
=
mlmc
.
data
.
sampler
(
zsdata
[
"test"
],
absolute
=
10000
)
d
=
mlmc_lab
.
mlmc_experimental
.
data
.
ZeroshotDataset
(
zsdata
,
zeroshot_classes
=
mlmc_lab
.
constants
.
ZEROSHOT_10
[
"rcv1"
])
data
=
{
"GZSL"
:
d
[
"valid_gzsl"
],
"ZSL"
:
d
[
"valid_zsl"
],
"NSL"
:
d
[
"valid_nsl"
]}
gr
.
_zeroshot_fit
(
d
)
gr
.
plot_weights
(
data
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment