Skip to content
Snippets Groups Projects
zeroshot_uqa.py 4.61 KiB
Newer Older
Fabian Ziegner's avatar
Fabian Ziegner committed
from string import ascii_uppercase

import mlmc
import torch
from sacred import Experiment, SETTINGS
from sacred.observers import MongoObserver
from sacred.utils import apply_backspaces_and_linefeeds
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoTokenizer, T5ForConditionalGeneration

name = ""
user = ""
host = ""
database = ""
auth = ""
pw = ""

SETTINGS.CAPTURE_MODE = "sys"
ex = Experiment(name)
ex.observers.append(MongoObserver(url="localhost:27017", db_name=database))
ex.captured_out_filter = apply_backspaces_and_linefeeds


class ZeroshotClassification:
    def __init__(self, representation, classes, target, format_):
        self.model = T5ForConditionalGeneration.from_pretrained(representation)
        self.tokenizer = AutoTokenizer.from_pretrained(representation)
        self.config = {"representation": representation,
                       "classes": classes,
                       "target": target,
                       "format": format_, }

    def init_metrics(self, metrics="default_singlelabel"):
        """
        Initializes metrics to be used. If no metrics are specified then depending on the target the default metrics
        for this target will be used. (see mlmc.metrics.metrics_config.items())

        :param metrics: Name of the metrics (see mlmc.metrics.metrics_dict.keys() and mlmc.metrics.metrics_config.keys())
        :return: A dictionary containing the initialized metrics
        """
        metrics = mlmc.metrics.MetricsDict(metrics)
        metrics.init(self.config)
        metrics.reset()
        return metrics

    def run_model(self, input_string, **generator_args):
        input_ids = self.tokenizer.encode(input_string, return_tensors="pt")
        res = self.model.generate(input_ids, **generator_args)
        return self.tokenizer.batch_decode(res, skip_special_tokens=True)

@ex.config
def ex_config():
    device = 0
    batch_size = 1
    representation = "tals/albert-base-mnli"
    dataset = "agnews"
    target = "single"
    threshold = "max"
    if target == "multi":
        threshold = "mcut"
    formatted = True
    cut_sample = False
    if target == "multi":
        cut_sample = True
    method = "huggingface"
    whole_dataset = True
    if dataset == "rcv1":
        whole_dataset = False
        dataset_size = 10000

@ex.automain
def run(_run, dataset, formatted):
    data = mlmc.data.get(dataset)

    if formatted:
        if dataset in ["trec6", "trec50", "dbpedia", "agnews", "yelpfull", "amazonfull"]:
            formatted_classes = {}
            for i, c in enumerate(data["classes"]):
                formatted_class = mlmc.data.dataset_formatter.label_dicts[dataset].get(c, c)
                formatted_classes[formatted_class] = i
            data["classes"] = formatted_classes

    classes = data["classes"]

    if dataset == "rcv1":
        data["test"] = mlmc.data.sampler(data["test"], absolute=10000)

    test_dataloader = DataLoader(data["test"], batch_size=1, shuffle=False)
    zc = ZeroshotClassification("allenai/unifiedqa-t5-small", classes=classes, target="single", format_=formatted)
    initialized_metrics = zc.init_metrics()
    threshold_ = mlmc.thresholds.get("max")

    question = "What is this question about?"
    choices = ""

    """
    class_counter = 0
    for char1 in ascii_uppercase:
        for char2, class_ in zip(ascii_uppercase, classes.keys()):
            if class_counter < len(classes.keys()):
                choices += "("+char1+char2+") " + class_ + " "
            class_counter += 1
    """

    for char1, class_ in zip(ascii_uppercase, classes.keys()):
        choices += "("+char1+") " + class_ + " "

    for sample in tqdm(test_dataloader):
        truth_l, pred_l = [], []
        text = " ".join(sample["text"][0].replace("\n", "").split())
        encoded_input = question + " \\n " + choices + "\\n " + text
        # num_return_sequences = 5
        # output = zc.run_model(encoded_input, num_beams=20, num_return_sequences=num_return_sequences, do_sample=True)
        output = zc.run_model(encoded_input)
        for class_ in output:
            if class_ in classes:
                predicted_class = class_
                break
        scores_list = [1 if predicted_class == class_ else 0 for class_ in classes]
        scores = torch.tensor([scores_list])
        truth_l.append(torch.squeeze(sample["labels"]))
        pred_l.append(torch.squeeze(threshold_(scores)))
        initialized_metrics.update_metrics((scores, torch.stack(truth_l), torch.stack(pred_l)))

    initialized_metrics.compute()
    initialized_metrics.log_sacred(_run, 1, "test")
    metrics = initialized_metrics.print()
    print(metrics)