Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from string import ascii_uppercase
import mlmc
import torch
from sacred import Experiment, SETTINGS
from sacred.observers import MongoObserver
from sacred.utils import apply_backspaces_and_linefeeds
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoTokenizer, T5ForConditionalGeneration
name = ""
user = ""
host = ""
database = ""
auth = ""
pw = ""
SETTINGS.CAPTURE_MODE = "sys"
ex = Experiment(name)
ex.observers.append(MongoObserver(url="localhost:27017", db_name=database))
ex.captured_out_filter = apply_backspaces_and_linefeeds
class ZeroshotClassification:
def __init__(self, representation, classes, target, format_):
self.model = T5ForConditionalGeneration.from_pretrained(representation)
self.tokenizer = AutoTokenizer.from_pretrained(representation)
self.config = {"representation": representation,
"classes": classes,
"target": target,
"format": format_, }
def init_metrics(self, metrics="default_singlelabel"):
"""
Initializes metrics to be used. If no metrics are specified then depending on the target the default metrics
for this target will be used. (see mlmc.metrics.metrics_config.items())
:param metrics: Name of the metrics (see mlmc.metrics.metrics_dict.keys() and mlmc.metrics.metrics_config.keys())
:return: A dictionary containing the initialized metrics
"""
metrics = mlmc.metrics.MetricsDict(metrics)
metrics.init(self.config)
metrics.reset()
return metrics
def run_model(self, input_string, **generator_args):
input_ids = self.tokenizer.encode(input_string, return_tensors="pt")
res = self.model.generate(input_ids, **generator_args)
return self.tokenizer.batch_decode(res, skip_special_tokens=True)
@ex.config
def ex_config():
device = 0
batch_size = 1
representation = "tals/albert-base-mnli"
dataset = "agnews"
target = "single"
threshold = "max"
if target == "multi":
threshold = "mcut"
formatted = True
cut_sample = False
if target == "multi":
cut_sample = True
method = "huggingface"
whole_dataset = True
if dataset == "rcv1":
whole_dataset = False
dataset_size = 10000
@ex.automain
def run(_run, dataset, formatted):
data = mlmc.data.get(dataset)
if formatted:
if dataset in ["trec6", "trec50", "dbpedia", "agnews", "yelpfull", "amazonfull"]:
formatted_classes = {}
for i, c in enumerate(data["classes"]):
formatted_class = mlmc.data.dataset_formatter.label_dicts[dataset].get(c, c)
formatted_classes[formatted_class] = i
data["classes"] = formatted_classes
classes = data["classes"]
if dataset == "rcv1":
data["test"] = mlmc.data.sampler(data["test"], absolute=10000)
test_dataloader = DataLoader(data["test"], batch_size=1, shuffle=False)
zc = ZeroshotClassification("allenai/unifiedqa-t5-small", classes=classes, target="single", format_=formatted)
initialized_metrics = zc.init_metrics()
threshold_ = mlmc.thresholds.get("max")
question = "What is this question about?"
choices = ""
"""
class_counter = 0
for char1 in ascii_uppercase:
for char2, class_ in zip(ascii_uppercase, classes.keys()):
if class_counter < len(classes.keys()):
choices += "("+char1+char2+") " + class_ + " "
class_counter += 1
"""
for char1, class_ in zip(ascii_uppercase, classes.keys()):
choices += "("+char1+") " + class_ + " "
for sample in tqdm(test_dataloader):
truth_l, pred_l = [], []
text = " ".join(sample["text"][0].replace("\n", "").split())
encoded_input = question + " \\n " + choices + "\\n " + text
# num_return_sequences = 5
# output = zc.run_model(encoded_input, num_beams=20, num_return_sequences=num_return_sequences, do_sample=True)
output = zc.run_model(encoded_input)
for class_ in output:
if class_ in classes:
predicted_class = class_
break
scores_list = [1 if predicted_class == class_ else 0 for class_ in classes]
scores = torch.tensor([scores_list])
truth_l.append(torch.squeeze(sample["labels"]))
pred_l.append(torch.squeeze(threshold_(scores)))
initialized_metrics.update_metrics((scores, torch.stack(truth_l), torch.stack(pred_l)))
initialized_metrics.compute()
initialized_metrics.log_sacred(_run, 1, "test")
metrics = initialized_metrics.print()
print(metrics)