Commit 87096fd1 authored by Shahbaz Syed's avatar Shahbaz Syed
Browse files

add longformer2roberta

parent 547d2ee4
from transformers import pipeline
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, EncoderDecoderModel, LongformerTokenizer
class NeuralSummarizer(object):
......@@ -10,29 +10,44 @@ class NeuralSummarizer(object):
For our demo, we use the following models denoted as {'model name': 'model code'}
{
'T5': 't5-base',
'BART': 'facebook/bart-large-cnn',
'DistilBART': 'sshleifer/distilbart-cnn-12-6',
'Pegasus': 'google/pegasus-cnn_dailymail'
'Pegasus': 'google/pegasus-cnn_dailymail',
'Longformer2Roberta': 'patrickvonplaten/longformer2roberta-cnn_dailymail-fp16'
}
Args:
model (str, optional): [summarization model]. Defaults to 't5-base'.
"""
self.model = model
self.summarization_pipeline = pipeline(
'summarization', model=self.model)
self.tokenizer = None
self.encoder_decoder = None
self.pipeline = None
if self.model != 'patrickvonplaten/longformer2roberta-cnn_dailymail-fp16':
self.pipeline = pipeline('summarization', model=self.model)
self.tokenizer = AutoTokenizer.from_pretrained(self.model)
else:
self.encoder_decoder = EncoderDecoderModel.from_pretrained("patrickvonplaten/longformer2roberta-cnn_dailymail-fp16")
self.tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")
def summarize(self, text: str = None, ratio: float = 0.2):
"""Currently used models cannot process sequences longer than 1024 tokens. Thus, truncate the text to first 800 words. This will be improved with the integration of longformer model in the near future
"""Currently used models cannot process sequences longer than 1024 tokens. Thus, truncate the text to appropriate number of tokens.
"""
text_length = len(text.split())
long_text = ""
if text_length > 800:
long_text = " ".join(text.split()[:800])
else:
long_text = text
long_text_length = len(long_text.split())
min_summary_length = round(long_text_length * ratio)
summarizer = self.summarization_pipeline(
text, min_length=min_summary_length)
return summarizer[0]['summary_text']
tokens = self.tokenizer(text, return_tensors="pt", truncation=True).input_ids
max_model_length = tokens.size()[1]
truncated_tokens = tokens[0][:max_model_length-3]
truncated_text = self.tokenizer.decode(truncated_tokens, clean_up_tokenization_spaces=True)
min_summary_length = round(len(truncated_text.split()) * ratio)
if self.pipeline:
summarization = self.pipeline(truncated_text, min_length=min_summary_length, clean_up_tokenization_spaces=True)
summary_text = summarization[0]['summary_text']
return summary_text
if self.encoder_decoder:
output_ids = self.encoder_decoder.generate(tokens, min_length=min_summary_length)
summary_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
return summary_text
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment