from pickle import FALSE
from torch.utils.data.sampler import RandomSampler
from transformers.configuration_utils import PretrainedConfig
from transformers.generation_utils import GenerationMixin
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from typing import *
from openprompt.data_utils import InputExample, InputFeatures
from torch.utils.data._utils.collate import default_collate
from tqdm.std import tqdm
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.utils.dummy_pt_objects import PreTrainedModel
from openprompt.plms.utils import TokenizerWrapper
from openprompt.prompt_base import Template, Verbalizer
from collections import defaultdict
from openprompt.utils import round_list, signature
import numpy as np
from torch.utils.data import DataLoader
from yacs.config import CfgNode
from openprompt.utils.logging import logger
from transformers import AdamW, get_linear_schedule_with_warmup
[docs]class PromptDataLoader(object):
r"""
PromptDataLoader wraps the original dataset. The input data is firstly wrapped with the
prompt's template, and then is tokenized by a wrapperd-tokenizer.
Args:
dataset (:obj:`Dataset` or :obj:`List`): Either a DatasetObject or a list containing the input examples.
template (:obj:`Template`): A derived class of :obj:`Template`
tokenizer (:obj:`PretrainedTokenizer`): The pretrained tokenizer.
tokenizer_wrapper_class (:cls:`TokenizerWrapper`): The class of tokenizer wrapper.
max_seq_length (:obj:`int`, optional): The max sequence length of the input ids. It's used to truncate sentences.
batch_size (:obj:`int`, optional): The batch_size of data loader
teacher_forcing (:obj:`bool`, optional): Whether to fill the mask with target text. Set to true in training generation model.
decoder_max_length (:obj:`int`, optional): the decoder maximum length of an encoder-decoder model.
predict_eos_token (:obj:`bool`, optional): Whether to predict the <eos> token. Suggest to set to true in generation.
truncate_method (:obj:`bool`, optional): the truncate method to use. select from `head`, `tail`, `balanced`.
kwargs :Other kwargs that might be passed into a tokenizer wrapper.
"""
def __init__(self,
dataset: Union[Dataset, List],
template: Template,
tokenizer_wrapper: Optional[TokenizerWrapper] = None,
tokenizer: PreTrainedTokenizer = None,
tokenizer_wrapper_class = None,
verbalizer: Optional[Verbalizer] = None,
max_seq_length: Optional[str] = 512,
batch_size: Optional[int] = 1,
shuffle: Optional[bool] = False,
teacher_forcing: Optional[bool] = False,
decoder_max_length: Optional[int] = -1,
predict_eos_token: Optional[bool] = False,
truncate_method: Optional[str] = "tail",
drop_last: Optional[bool] = False,
**kwargs,
):
assert hasattr(dataset, "__iter__"), f"The dataset must have __iter__ method. dataset is {dataset}"
assert hasattr(dataset, "__len__"), f"The dataset must have __len__ method. dataset is {dataset}"
self.raw_dataset = dataset
self.wrapped_dataset = []
self.tensor_dataset = []
self.template = template
self.verbalizer = verbalizer
self.batch_size = batch_size
self.shuffle = shuffle
self.teacher_forcing = teacher_forcing
if tokenizer_wrapper is None:
if tokenizer_wrapper_class is None:
raise RuntimeError("Either wrapped_tokenizer or tokenizer_wrapper_class should be specified.")
if tokenizer is None:
raise RuntimeError("No tokenizer specified to instantiate tokenizer_wrapper.")
tokenizer_wrapper_init_keys = signature(tokenizer_wrapper_class.__init__).args
prepare_kwargs = {
"max_seq_length" : max_seq_length,
"truncate_method" : truncate_method,
"decoder_max_length" : decoder_max_length,
"predict_eos_token" : predict_eos_token,
"tokenizer" : tokenizer,
**kwargs,
}
to_pass_kwargs = {key: prepare_kwargs[key] for key in prepare_kwargs if key in tokenizer_wrapper_init_keys}
self.tokenizer_wrapper = tokenizer_wrapper_class(**to_pass_kwargs)
else:
self.tokenizer_wrapper = tokenizer_wrapper
# check the satisfiability of each component
assert hasattr(self.template, 'wrap_one_example'), "Your prompt has no function variable \
named wrap_one_example"
# process
self.wrap()
self.tokenize()
if self.shuffle:
sampler = RandomSampler(self.tensor_dataset)
else:
sampler = None
self.dataloader = DataLoader(
self.tensor_dataset,
batch_size = self.batch_size,
sampler= sampler,
collate_fn = InputFeatures.collate_fct,
drop_last = drop_last,
)
[docs] def wrap(self):
r"""A simple interface to pass the examples to prompt, and wrap the text with template.
"""
if isinstance(self.raw_dataset, Dataset) or isinstance(self.raw_dataset, List):
assert len(self.raw_dataset) > 0, 'The dataset to be wrapped is empty.'
# for idx, example in tqdm(enumerate(self.raw_dataset),desc='Wrapping'):
for idx, example in enumerate(self.raw_dataset):
if self.verbalizer is not None and hasattr(self.verbalizer, 'wrap_one_example'): # some verbalizer may also process the example.
example = self.verbalizer.wrap_one_example(example)
wrapped_example = self.template.wrap_one_example(example)
self.wrapped_dataset.append(wrapped_example)
else:
raise NotImplementedError
[docs] def tokenize(self) -> None:
r"""Pass the wrapped text into a prompt-specialized tokenizer,
the true PretrainedTokenizer inside the tokenizer is flexible, e.g. AlBert, Bert, T5,...
"""
for idx, wrapped_example in tqdm(enumerate(self.wrapped_dataset),desc='tokenizing'):
# for idx, wrapped_example in enumerate(self.wrapped_dataset):
inputfeatures = InputFeatures(**self.tokenizer_wrapper.tokenize_one_example(wrapped_example, self.teacher_forcing), **wrapped_example[1]).to_tensor()
self.tensor_dataset.append(inputfeatures)
def __len__(self):
return len(self.dataloader)
def __iter__(self,):
return self.dataloader.__iter__()
[docs]class PromptModel(nn.Module):
r'''``PromptModel`` is the encapsulation of ``Template`` and the ``pre-trained model``,
with OpenPrompt, these modules could be flexibly combined. And this class is the base class of ``PromptForClassification`` and ``PromptForGeneration``
Args:
plm (:obj:`PreTrainedModel`): The pre-trained language model for the current prompt-learning task.
template (:obj:`Template`): The ``Template`` object to warp the input data.
freeze_plm (:obj:`bool`): whether or not to freeze the pretrained language model
plm_eval_mode (:obj:`bool`): this is a stronger freezing mode than freeze_plm, i.e. the dropout of the model is turned off. No matter whether the other part is set to train.
'''
def __init__(self,
plm: PreTrainedModel,
template: Template,
freeze_plm: bool = False,
plm_eval_mode: bool=False,
):
super().__init__()
self.plm = plm
self.template = template
self.freeze_plm = freeze_plm
self.plm_eval_mode = plm_eval_mode
if freeze_plm:
for param in self.plm.parameters():
param.requires_grad = False
if plm_eval_mode:
self.plm.eval()
for param in self.plm.parameters():
param.requires_grad = False
# get model's forward function's keywords
self.forward_keys = signature(self.plm.forward).args
self._prepare_main_input_name()
def _prepare_main_input_name(self):
model = self.plm
if hasattr(model, "encoder") and hasattr(model.encoder, "main_input_name"):
if model.encoder.main_input_name != model.main_input_name:
main_input_name = model.encoder.main_input_name
else:
main_input_name = model.main_input_name
else:
main_input_name = getattr(model, "main_input_name", "input_ids")
self.main_input_name = main_input_name
[docs] def train(self, mode: bool = True):
if not isinstance(mode, bool):
raise ValueError("training mode is expected to be boolean")
self.training = mode
for name, module in self.named_children():
if not (self.plm_eval_mode and 'plm' in name and mode):
module.train(mode)
return self
[docs] def forward(self, batch: Union[Dict, InputFeatures]) -> torch.Tensor:
r"""
This is a forward method to make wrapped input data go through the model, and return the output logits.
Typically, this function aims to predict the ``<mask>`` position.
Args:
batch (:obj:`Union[Dict, InputFeatures]`): The input features of batchified data sequences.
"""
batch = self.template.process_batch(batch)
input_batch = {key: batch[key] for key in batch if key in self.forward_keys}
outputs = self.plm(**input_batch, output_hidden_states=True)
outputs = self.template.post_processing_outputs(outputs)
return outputs
[docs]class PromptForClassification(nn.Module):
r'''``PromptModel`` with a classification head on top. The classification head will map
the logits in all position of the sequence (return value of a ``PromptModel``) into the
logits of the labels, using a verbalizer.
Args:
plm (:obj:`PretrainedModel`): A pre-traiend model you decide to use for classification, e.g. BERT.
template (:obj:`Template`): A ``Template`` object you use to wrap the input text for classification, e.g. ``ManualTemplate``.
verbalizer (:obj:`Verbalizer`): A ``Verbalizer`` object you use to project the labels to label words for classification, e.g. ``ManualVerbalizer``.
freeze_plm (:obj:`bool`): whether or not to freeze the pretrained language model
plm_eval_mode (:obj:`bool`): this is a stronger freezing mode than freeze_plm, i.e. the dropout of the model is turned off. No matter whether the other part is set to train.
'''
def __init__(self,
plm: PreTrainedModel,
template: Template,
verbalizer: Verbalizer,
freeze_plm: bool = False,
plm_eval_mode: bool=False
):
super().__init__()
self.prompt_model = PromptModel(plm, template, freeze_plm, plm_eval_mode)
self.verbalizer = verbalizer
@property
def plm(self):
return self.prompt_model.plm
@property
def template(self):
return self.prompt_model.template
@property
def device(self,):
r"""Register the device parameter."""
return self.plm.device
[docs] def forward(self, batch: Union[Dict, InputFeatures]) -> torch.Tensor:
r"""
Get the logits of label words.
Args:
batch (:obj:`Union[Dict, InputFeatures]`): The original batch
Returns:
:obj:`torch.Tensor`: The logits of the label words (obtained by the current verbalizer).
"""
outputs = self.prompt_model(batch)
outputs = self.verbalizer.gather_outputs(outputs)
if isinstance(outputs, tuple):
outputs_at_mask = [self.extract_at_mask(output, batch) for output in outputs]
else:
outputs_at_mask = self.extract_at_mask(outputs, batch)
label_words_logits = self.verbalizer.process_outputs(outputs_at_mask, batch=batch)
return label_words_logits
def predict(self):
pass
def forward_without_verbalize(self, batch: Union[Dict, InputFeatures]) -> torch.Tensor:
outputs = self.prompt_model(batch)
outputs = self.verbalizer.gather_outputs(outputs)
outputs_at_mask = self.extract_at_mask(outputs, batch)
return outputs_at_mask
@property
def tokenizer(self):
r'''Utility property, to get the tokenizer more easily.
'''
return self.verbalizer.tokenizer
## We comment this code since it conflict with [OpenDelta](https://github.com/thunlp/OpenDelta)
# def state_dict(self, *args, **kwargs):
# """ Save the model using template, plm and verbalizer's save methods."""
# _state_dict = {}
# if not self.prompt_model.freeze_plm:
# _state_dict['plm'] = self.plm.state_dict(*args, **kwargs)
# _state_dict['template'] = self.template.state_dict(*args, **kwargs)
# _state_dict['verbalizer'] = self.verbalizer.state_dict(*args, **kwargs)
# return _state_dict
# def load_state_dict(self, state_dict, *args, **kwargs):
# """ Load the model using template, plm and verbalizer's load methods."""
# if 'plm' in state_dict and not self.prompt_model.freeze_plm:
# self.plm.load_state_dict(state_dict['plm'], *args, **kwargs)
# self.template.load_state_dict(state_dict['template'], *args, **kwargs)
# self.verbalizer.load_state_dict(state_dict['verbalizer'], *args, **kwargs)
[docs] def parallelize(self, device_map=None):
r"""Parallelize the model across device
"""
if hasattr(self.plm, "parallelize"):
self.plm.parallelize(device_map)
self.device_map = self.plm.device_map
self.template.cuda()
self.verbalizer.cuda()
else:
raise NotImplementedError("parallelize method was not implemented for this plm.")
[docs] def deparallelize(self):
r"""Deparallelize the model across device
"""
if hasattr(self.plm, "deparallelize"):
self.plm.deparallelize()
self.device_map = None
self.template.cpu()
self.verbalizer.cpu()
else:
raise NotImplementedError("parallelize method was not implemented for this plm.")
[docs]class PromptForGeneration(nn.Module, GenerationMixin):
r'''``PromptModel`` with generation loss calculation and generation utils integrated.
Args:
plm (:obj:`PretrainedModel`): A pre-traiend model you decide to use for generation, e.g. GPT.
template (:obj:`Template`): A ``Template`` object you use to wrap the input text for classification, e.g. ``PrefixTemplate``.
tokenizer (:obj:`Tokenizer`): A ``Tokenizer`` of the current model.
gen_config (:obj:`CfgNode`): The generation configs to pass into `GenerationMixin.generate <https://huggingface.co/transformers/_modules/transformers/generation_utils.html#GenerationMixin.generate>`_
freeze_plm (:obj:`bool`): whether or not to freeze the pretrained language model
plm_eval_mode (:obj:`bool`): this is a stronger freezing mode than freeze_plm, i.e. the dropout of the model is turned off. No matter whether the other part is set to train.
'''
def __init__(self,
plm: PreTrainedModel,
template: Template,
freeze_plm: bool = False,
plm_eval_mode: bool = False,
gen_config: Optional[CfgNode] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
):
super().__init__()
self.freeze_plm = freeze_plm
if tokenizer is None:
assert template.tokenizer is not None, "Tokenizer can't be set from input args or template"
self.tokenizer = template.tokenizer
else:
self.tokenizer = tokenizer
self.prompt_model = PromptModel(plm, template, freeze_plm, plm_eval_mode)
self.loss_fct = nn.CrossEntropyLoss(reduction='none')
self.config = plm.config
if gen_config:
for key in gen_config:
setattr(self.config, key, gen_config[key])
self.in_generation_function = False
self.main_input_name = self.prompt_model.main_input_name # for transformers 4.17.0 and higher.
@property
def plm(self):
return self.prompt_model.plm
@property
def template(self):
return self.prompt_model.template
@property
def device(self):
return self.plm.device
[docs] def shift_logits_and_labels(self,
logits,
loss_ids,
reference_ids):
r"""
Left shift the label, and make label of the positions that are
not loss position to -100, which is the ignore index in pytorch's
loss function.
Args:
logits (:obj:`torch.Tensor`):
batch (:obj:`InputFeatures`): The input features of batchified data sequences.
Returns:
shift_logits (:obj:`torch.Tensor`):
shift_input_ids (:obj:`List[int]`):
"""
shift_logits = logits[..., :-1, :].contiguous()
shift_loss_ids = loss_ids[..., 1:].contiguous()
shift_input_ids = reference_ids[..., 1:].contiguous()
shift_input_ids = torch.where(shift_loss_ids>0, shift_input_ids, -100)
return shift_logits, shift_input_ids
[docs] def forward(self, *args, **kwargs):
r"""In generation process, it will use the plm's forward function.
This is because, in the first step we will directly call the process_batch function to
generate initial input with the template, after that the all template
have been processed into the past_key_value,
then we can use the normal generation function.
In learning process, the forward is linked to ``_forward`` functions.
in which the loss will be calculated for all the positions in the same time.
"""
if self.in_generation_function:
return self.plm.forward(*args, **kwargs)
else:
return self._forward(*args, **kwargs)
def _forward(self, batch: Union[Dict, InputFeatures]) -> torch.Tensor:
r"""
This is the forward method of the training of generation in prompt-learning framework.
Args:
batch (:obj:`Union[Dict, InputFeatures]`): The input features of batchified data sequences.
Returns:
loss(:obj:torch.Tensor): The loss of the current generation procedure.
"""
if self.config.is_encoder_decoder:
reference_ids = batch['decoder_input_ids']
else:
reference_ids = batch['input_ids'] # in case in some template, these field is dropped
outputs = self.prompt_model(batch)
logits = outputs.logits
logits, labels = self.shift_logits_and_labels(logits, batch['loss_ids'], reference_ids)
batch_size, seq_len, vocab_size = logits.shape
loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
loss = loss.view(batch_size, -1).sum(dim=-1) # TODO support more objectives
loss = loss.mean()
return loss
[docs] def generate(self, batch: Union[Dict, InputFeatures], verbose: Optional[bool]=False, **generation_kwargs):
r""" This function wraps the generate() methods in parent class ``GenerationMixin``.
Forward uses the ``PretrainedModel``'s forward method.
generation_kwargs include all the parameters that are passed in to
``transformers.generation_util.GenerationMixin.generate``
Args:
batch (:obj:`Union[Dict, InputFeatures]`): The input features of batchified data sequences.
verbose (:obj:`Optional[bool]`): Set to true to verbose the generated sentence.
Returns:
output_sequences (:obj:`List[torch.Tensor]`): The raw sequences generated by the generation model.
generated_sentences (:obj:`List[torch.Tensor]`): The generated sentences that have been post-processed.
"""
input_generation_kwargs = {key: value for key,value in generation_kwargs.items() if key in signature(GenerationMixin.generate).args}
if self.config.is_encoder_decoder:
loss_ids_start = batch['loss_ids'].argmax(dim=-1)
assert loss_ids_start.min() == loss_ids_start.max(), "The generation start from different position in a batch."
batch['decoder_input_ids'] = batch['decoder_input_ids'][:, :loss_ids_start.min()+1]
input_length = batch['decoder_input_ids'].size(1)
batch_size = batch['decoder_input_ids'].size(0)
self.generate_ith_token = 0
self.in_generation_function = True
output_sequences = super().generate(**batch, **input_generation_kwargs, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id)
self.in_generation_function = False
output_sequences = output_sequences.cpu().tolist()
generated_sentences = self.post_processing(output_sequences=output_sequences, input_lengths=input_length)
else:
input_length = batch['input_ids'].size(1)
batch_size = batch['input_ids'].size(0)
# Currently huggingface transformers only support single sample generation, or padding to the left (instead of the right).
# because it will only extract the last position of the output
# generate one_by_one
if 'input_ids_len' in batch:
input_real_lens = batch['input_ids_len']
else:
input_real_lens = torch.sum((batch['input_ids'] != self.tokenizer.pad_token_id).to(torch.int), dim=-1)
output_sequences = []
for instance_id in range(batch_size):
# remove the pad token
instance = {key: batch[key][instance_id:instance_id+1][:,:input_real_lens[instance_id]] for key in batch if isinstance(batch[key], torch.Tensor) and batch[key].shape[:2]==torch.Size([batch_size, input_length])}
self.generate_ith_token = 0
self.in_generation_function = True
output_sequence = super().generate(**instance, **input_generation_kwargs, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id)
self.in_generation_function = False
output_sequences.extend(output_sequence.cpu().tolist()) # TODO: to support generate multiple sentence
generated_sentences = self.post_processing(output_sequences=output_sequences, input_lengths=input_real_lens.cpu().tolist())
if verbose:
logger.info(f"Generated:{generated_sentences}")
return output_sequences, generated_sentences
[docs] def post_processing(self, output_sequences, input_lengths):
r"""
Post-process the sequences generated by the generation model.
Args:
output_sequences (:obj:`torch.Tensor`): The raw sequences generated by the generation model.
input_lengths (:obj:`int` or `list`): The length(s) of the input sequence.
Returns:
:obj:`List`: The generated sentences that have been post-processed.
"""
generated_sentences = []
if type(input_lengths)==int:
input_lengths = [input_lengths] * len(output_sequences)
for sent_id, seq in enumerate(output_sequences):
seq = seq[input_lengths[sent_id]:]
if hasattr(self.tokenizer, "eos_token") and self.tokenizer.eos_token is not None:
text_output = self.tokenizer.decode(seq, clean_up_tokenization_spaces=True, skip_special_tokens=False)
idx = text_output.find(self.tokenizer.eos_token)
if idx >= 0:
text_output = text_output[:idx]
else:
text_output = self.tokenizer.decode(seq, clean_up_tokenization_spaces=True, skip_special_tokens=True)
text_output = text_output.strip()
generated_sentences.append(text_output)
return generated_sentences
def _update_model_kwargs_for_generation(self,
outputs, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
) -> Dict[str, Any]:
r""" The parents class's ``_update_model_kwargs_for_generation`` method will
add ``past_key_values`` to model_kwargs, and update ``token_type_ids``, and ``attention_mask_ids``.
In case some of the model_kwargs are modified in the prepare_inputs_for_generation function
and should be used as the subsequent model_kwargs, we update these kwargs before the parent class
call.
Other updates should be added here after the parent's function call.
Args:
outputs (:obj:`torch.Tensor`):
is_encoder_decoder (:obj:`bool`, defaults to False):
"""
if self.generate_ith_token == 0:
for key in self.last_model_inputs:
if key in model_kwargs:
model_kwargs[key] = self.last_model_inputs[key]
model_kwargs = super(PromptForGeneration, PromptForGeneration)._update_model_kwargs_for_generation(outputs=outputs, model_kwargs=model_kwargs, is_encoder_decoder=is_encoder_decoder)
self.generate_ith_token += 1
return model_kwargs
def _prepare_encoder_decoder_kwargs_for_generation(
self, input_ids: torch.LongTensor, model_kwargs, model_input_name: Optional[str] = None,
) -> Dict[str, Any]:
r""" This function resemble the function in GeneraionMix
Args:
input_ids (:obj:`torch.LongTensor`) The input ids for
"""
if "encoder_outputs" not in model_kwargs:
# retrieve encoder hidden states
encoder = self.plm.get_encoder()
encoder_kwargs = {
argument: value
for argument, value in model_kwargs.items()
if not (argument.startswith("decoder_") or argument.startswith("cross_attn"))
}
model_input_name = model_input_name if model_input_name is not None else self.main_input_name
batch = {model_input_name:input_ids, **encoder_kwargs}
model_inputs = self.prompt_model.prepare_model_inputs(batch) # This line differs from the orinigal code base, we should process the input
# with our template, then pass it into the model.
# some of the arguments may have been changed by the template,
# e.g. the attention mask. Here we update the model_kwargs
for key in model_kwargs:
if key in model_inputs:
model_kwargs[key] = model_inputs[key]
model_kwargs["encoder_outputs"] = encoder(return_dict=True, **model_inputs)
return model_kwargs
## We comment this code since it conflict with [OpenDelta](https://github.com/thunlp/OpenDelta)
# def state_dict(self, *args, **kwargs):
# """ Save the model using template and plm's save methods. """
# _state_dict = {}
# if not self.prompt_model.freeze_plm:
# _state_dict['plm'] = self.plm.state_dict(*args, **kwargs)
# _state_dict['template'] = self.template.state_dict(*args, **kwargs)
# return _state_dict
# def load_state_dict(self, state_dict, *args, **kwargs):
# """ Load the model using template and plm's load methods. """
# if 'plm' in state_dict and not self.prompt_model.freeze_plm:
# self.plm.load_state_dict(state_dict['plm'], *args, **kwargs)
# self.template.load_state_dict(state_dict['template'], *args, **kwargs)
def _reorder_cache(self, past, beam_idx):
r"""Use the plm's default _reorder_cache function
"""
return self.plm._reorder_cache(past, beam_idx)
[docs] def parallelize(self, device_map=None):
r"""Parallelize the model across device
"""
if hasattr(self.plm, "parallelize"):
self.plm.parallelize(device_map)
self.device_map = self.plm.device_map
else:
raise NotImplementedError("parallelize method was not implemented for this plm.")
[docs] def deparallelize(self):
r"""Deparallelize the model across device
"""
if hasattr(self.plm, "deparallelize"):
self.plm.deparallelize()
self.device_map = None
else:
raise NotImplementedError("parallelize method was not implemented for this plm.")