representations.py 4.92 KB
Newer Older
Janos Borst's avatar
Janos Borst committed
1
import numpy as np
2
3
from transformers import *
import torch
Janos Borst's avatar
Janos Borst committed
4
5
6
7
8
9
10
11
from  pathlib import Path
from urllib import error
from urllib.request import urlopen
from io import BytesIO
from zipfile import ZipFile


EMBEDDINGCACHE = Path.home() / ".mlmc" / "embedding"
Janos Borst's avatar
Janos Borst committed
12

13
14
15
16
17
18
19
20
21
22
23
24
25
26
MODELS = {"bert": (BertModel, BertTokenizer, 'bert-large-uncased'),
          "bert_cased": (BertModel, BertTokenizer, 'bert-base-cased'),
          "albert": (AlbertModel, AlbertTokenizer, 'albert-large-v2'),
          "gpt": (OpenAIGPTModel, OpenAIGPTTokenizer, 'openai-gpt'),
          "gpt2": (GPT2Model, GPT2Tokenizer, 'gpt2'),
          "ctrl": (CTRLModel, CTRLTokenizer, 'ctrl'),
          "xlnet": (XLNetModel, XLNetTokenizer, 'xlnet-base-cased'),
          "xlm": (XLMModel, XLMTokenizer, 'xlm-mlm-enfr-1024'),
          "distilbert": (DistilBertModel, DistilBertTokenizer, 'distilbert-base-uncased'),
          "roberta": (RobertaModel, RobertaTokenizer, 'roberta-base'),
          }



Janos Borst's avatar
Janos Borst committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def load_static(embedding="glove300"):
    embeddingfiles = {"glove50": "glove.6B.50d.txt",
            "glove100": "glove.6B.100d.txt",
            "glove200": "glove.6B.200d.txt",
            "glove300": "glove.6B.300d.txt"}

    if not (EMBEDDINGCACHE / embeddingfiles[embedding]).exists():
        URL ="http://nlp.stanford.edu/data/glove.6B.zip"
        try:
            resp = urlopen(URL)
        except error.HTTPError:
            print(error.HTTPError)
            return None
        assert resp.getcode() == 200, "Download not found Error: (%i)" % (resp.getcode(),)
        print("Downloading glove vectors... This may take a while...")
        zipfile = ZipFile(BytesIO(resp.read()))
        zipfile.extractall(EMBEDDINGCACHE)
    fp = EMBEDDINGCACHE / embeddingfiles[embedding]

    glove = np.loadtxt(fp, dtype='str', comments=None)
Janos Borst's avatar
Janos Borst committed
47
48
49
50
51
52
53
54
55
56
    glove = glove[np.unique(glove[:,:1],axis=0, return_index=True)[1]]
    words = glove[:, 0]
    weights = glove[:, 1:].astype('float')
    weights = np.vstack((
                            np.array([0]* len(weights[1])), # the vector for the masking
                            weights,
                            np.mean(weights, axis=0)), # the vector for the masking)
    )
    words = words.tolist()+["<UNK_TOKEN>"]
    vocabulary = dict(zip(words,range(1,len(words)+1)))
57
    return weights, vocabulary
58
59

def map_vocab(query, vocab, maxlen):
Janos Borst's avatar
Janos Borst committed
60
61
    ind = [[vocab.get(token, vocab["<UNK_TOKEN>"]) for token in s] for s in query]
    result = torch.zeros((len(query),maxlen)).long()
62
    for i, e in enumerate(ind):
Janos Borst's avatar
Janos Borst committed
63
       result[i,:min(len(e),maxlen)] = torch.LongTensor(e[:min(len(e),maxlen)])
64
65
66
    return result


67
def get_embedding(name, **kwargs):
68
    weights, vocabulary = load_static(name)
69
70
    e = torch.nn.Embedding(weights.shape[0], weights.shape[1],)
    e = e.from_pretrained(torch.Tensor(weights).float(), **kwargs)
71
72
    def tokenizer(x, maxlen=500):
        x = [x] if isinstance(x, str) else x
Janos Borst's avatar
weekend    
Janos Borst committed
73
        x = [s.lower().split() for s in x]
74
75
76
77
78
79
80
81
        return map_vocab(x, vocabulary, maxlen).long()
    return e, tokenizer


def get_transformer(model="bert", **kwargs):
    # Transformers has a unified API
    # for 10 transformer architectures and 30 pretrained weights.
    #          Model          | Tokenizer          | Pretrained weights shortcut
82

83
84
85

    model_class, tokenizer_class, pretrained_weights = MODELS.get(model,(None,None,None))
    if model_class is None:
86
        print("Model is not a transformer...")
87
88
89
90
91
92
93
94
        return None
    else:
        # Load pretrained model/tokenizer
        tokenizer = tokenizer_class.from_pretrained(pretrained_weights)
        def list_tokenizer(x, maxlen=500):
            l = len(x.split()) if isinstance(x, str) else max([len(s.split()) for s in x])
            x = [x] if isinstance(x, str) else x
            i = torch.nn.utils.rnn.pad_sequence(
Janos Borst's avatar
Janos Borst committed
95
                [torch.tensor([tokenizer.encode(sentence, add_special_tokens=False, pad_to_max_length=True)][0]) for sentence in x], batch_first=True)
96
97
98
99
100
101
102
            i = i[:, :min(maxlen, i.shape[-1])]
            return i

        model = model_class.from_pretrained(pretrained_weights, **kwargs)
        return model, list_tokenizer


103
def get_by_arg_(static=None, transformer=None, **kwargs):
104
    assert (static is None) != (transformer is None), "Exactly one of the arguments has to be not None"
105
    if static is not None:
106
        return get_embedding(static, **kwargs)
107
    elif transformer is not None:
108
109
110
        import logging
        print("Setting transformers.tokenization_utils logger to ERROR.")
        logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR)
111
        return get_transformer(transformer, **kwargs)
112
113
114
115
116
117
118
119
120
121
122
123
124
125

def get(model, **kwargs):
    logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR)
    module = get_transformer(model, **kwargs)
    if module is None:
        module = get_embedding(model, **kwargs)
        if module is None:
            raise FileNotFoundError
        return module
    else:
        return module

def is_transformer(name):
    return name in MODELS.keys()