Source code for openprompt.prompts.prompt_generator

from abc import abstractmethod
from builtins import ValueError
from typing import List, Optional, Dict, Union
from tokenizers import Tokenizer
import json
import torch
import torch.nn.functional as F
from yacs.config import CfgNode
from openprompt.data_utils.utils import InputExample, InputFeatures
from openprompt.pipeline_base import PromptDataLoader, PromptModel

from openprompt.prompt_base import Template, Verbalizer
from openprompt.prompts import ManualTemplate, ManualVerbalizer
from ..utils import logger
from transformers import T5Tokenizer, T5ForConditionalGeneration, BertForMaskedLM, RobertaForMaskedLM, RobertaTokenizer, PreTrainedModel, PreTrainedTokenizer
from tqdm import tqdm
from typing import List, Optional, Dict
import itertools
import numpy as np
from ..utils import signature
from ..config import convert_cfg_to_dict
from torch.nn.parallel import DataParallel

class LMBFFTemplateGenerationTemplate(ManualTemplate):
    """
    This is a special template used only for search of template in LM-BFF. For example, a template could be ``{"placeholder": "text_a"}{"mask"}{"meta":"labelword"}{"mask"}``, where ``{"meta":"labelword"}`` is replaced by label_words in verbalizer in `wrap_one_example` method, and ``{"mask"}`` is replaced by special tokens used for generation, for T5, it is ``<extra_id_0>, <extra_id_1>, ...``.

    Args:
        tokenizer (:obj:`PreTrainedTokenizer`): A tokenizer to appoint the vocabulary and the tokenization strategy.
        verbalizer (:obj:`ManualVerbalizer`): A verbalizer to provide label_words.
        text (:obj:`Optional[List[str]]`, optional): manual template format. Defaults to None.
        placeholder_mapping (:obj:`dict`): A place holder to represent the original input text. Default to ``{'<text_a>': 'text_a', '<text_b>': 'text_b'}``
    """
    def __init__(self,
                 tokenizer: T5Tokenizer,
                 verbalizer: ManualVerbalizer,
                 text: Optional[List[str]] = None,
                 placeholder_mapping: dict = {'<text_a>':'text_a','<text_b>':'text_b'},
                ):
        super().__init__(tokenizer=tokenizer,
                         text = text,
                         placeholder_mapping=placeholder_mapping)
        self.verbalizer = verbalizer

    def wrap_one_example(self,
                         example: InputExample) -> List[Dict]:
        example.meta['labelword'] = self.verbalizer.label_words[example.label][0].strip()
        wrapped_example = super().wrap_one_example(example)
        return wrapped_example

[docs]class TemplateGenerator: r""" This is the automatic template search implementation for `LM-BFF <https://arxiv.org/pdf/2012.15723.pdf>`_. It uses a generation model to generate multi-part text to fill in the template. By jointly considering all samples in the dataset, it uses beam search decoding method to generate a designated number of templates with the highest probability. The generated template may be uniformly used for all samples in the dataset. Args: model (:obj:`PretrainedModel`): A pretrained model for generation. tokenizer (:obj:`PretrainedTokenizer`): A corresponding type tokenizer. tokenizer_wrapper (:obj:`TokenizerWrapper`): A corresponding type tokenizer wrapper class. max_length (:obj:`Optional[int]`): The maximum length of total generated template. Defaults to 20. target_number (:obj:`Optional[int]`): The number of separate parts to generate, e.g. in T5, every <extra_id_{}> token stands for one part. Defaults to 2. beam_width (:obj:`Optional[int]`): The beam search width. Defaults to 100. length_limit (:obj:`Optional[List[int]]`): The length limit for each part of content, if None, there is no limit. If not None, the list should have a length equal to `target_number`. Defaults to None. forbidden_word_ids (:obj:`Optional[List[int]]`): Any tokenizer-specific token_id you want to prevent from generating. Defaults to `[]`, i.e. all tokens in the vocabulary are allowed in the generated template. """ def __init__(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, tokenizer_wrapper: Tokenizer, verbalizer: Verbalizer, max_length: Optional[int] = 20, target_number: Optional[int] = 2, beam_width: Optional[int] = 100, length_limit: Optional[List[int]] = None, forbidden_word_ids: Optional[List[int]] = [], config: CfgNode = None): self.model = model self.tokenizer = tokenizer self.tokenizer_wrapper = tokenizer_wrapper self.verbalizer= verbalizer self.target_number = target_number # number of parts to generate in one sample self.beam_width = beam_width self.max_length = max_length self.length_limit = length_limit self.probs_buffer, self.labels_buffer = None, None # Forbid single space token, "....", and "..........", and some other tokens based on vocab self.forbidden_word_ids = forbidden_word_ids self.sent_end_id = self.tokenizer.convert_tokens_to_ids('.') self.input_ids_buffer, self.attention_mask_buffer, self.labels_buffer = None, None, None self.config = config @property def device(self): r""" return the device of the model """ if isinstance(self.model, DataParallel): return self.model.module.device else: return self.model.device def _register_buffer(self, data): if self.input_ids_buffer is None : self.input_ids_buffer = data.input_ids.detach() self.attention_mask_buffer = data.attention_mask.detach() self.labels_buffer = data.label.detach() else: self.input_ids_buffer = torch.vstack([self.input_ids_buffer, data.input_ids.detach()]) self.attention_mask_buffer = torch.vstack([self.attention_mask_buffer, data.attention_mask.detach()]) self.labels_buffer = torch.hstack([self.labels_buffer, data.label.detach()])
[docs] @abstractmethod def get_part_token_id(self, part_id: int) -> int: r""" Get the start token id for the current part. It should be specified according to the specific model type. For T5 model, for example, the start token for `part_id=0` is `<extra_id_0>`, this method should return the corresponding token_id. Args: part_id (:obj:`int`): The current part id (starts with 0). Returns: token_id (:obj:`int`): The corresponding start token_id. """ raise NotImplementedError
[docs] def convert_template(self, generated_template: List[str], original_template: List[Dict]) -> str: r""" Given original template used for template generation,convert the generated template into a standard template for downstream prompt model, return a ``str`` Example: generated_template: ['<extra_id_0>', 'it', 'is', '<extra_id_1>', 'one', '</s>'] original_template: [{'add_prefix_space': '', 'placeholder': 'text_a'}, {'add_prefix_space': ' ', 'mask': None}, {'add_prefix_space': ' ', 'meta': 'labelword'}, {'add_prefix_space': ' ', 'mask': None}, {'add_prefix_space': '', 'text': '.'}] return: "{'placeholder':'text_a'} it is {"mask"} one." """ i = 0 part_id = 0 while generated_template[i] != self.tokenizer.additional_special_tokens[part_id] and i < len(generated_template) - 1: i += 1 assert generated_template[i] == self.tokenizer.additional_special_tokens[part_id], print('invalid generated_template {}, missing token {}'.format(generated_template, self.tokenizer.additional_special_tokens[part_id])) i += 1 output = [] for d in original_template: if 'mask' in d: j = i + 1 part_id += 1 while generated_template[j] != self.tokenizer.additional_special_tokens[part_id] and j < len(generated_template) - 1: j += 1 output.append(d.get('add_prefix_space', '') + self.tokenizer.convert_tokens_to_string(generated_template[i:j])) i = j + 1 elif 'meta' in d and d['meta'] == 'labelword': output.append(d.get('add_prefix_space', '') + '{"mask"}') elif 'text' in d: output.append(d.get('add_prefix_space', '') + d['text']) else: prefix = d.get('add_prefix_space', '') if 'add_prefix_space' in d: d.pop('add_prefix_space') output.append(prefix + json.dumps(d)) return ''.join(output)
def _get_templates(self): inner_model = self.model.module if isinstance(self.model, DataParallel) else self.model input_ids = self.input_ids_buffer attention_mask = self.attention_mask_buffer ori_decoder_input_ids = torch.zeros((input_ids.size(0), self.max_length)).long() ori_decoder_input_ids[..., 0] = inner_model.config.decoder_start_token_id # decoder_input_ids: decoder inputs for next regressive generation # ll: log likelihood # output_id: which part of generated contents we are at # output: generated content so far # last_length (deprecated): how long we have generated for this part current_output = [{'decoder_input_ids': ori_decoder_input_ids, 'll': 0, 'output_id': 1, 'output': [], 'last_length': -1}] for i in tqdm(range(self.max_length - 2)): new_current_output = [] for item in current_output: if item['output_id'] > self.target_number: # Enough contents new_current_output.append(item) continue decoder_input_ids = item['decoder_input_ids'] # Forward batch_size = 32 turn = input_ids.size(0) // batch_size if input_ids.size(0) % batch_size != 0: turn += 1 aggr_output = [] for t in range(turn): start = t * batch_size end = min((t + 1) * batch_size, input_ids.size(0)) with torch.no_grad(): aggr_output.append(self.model(input_ids[start:end], attention_mask=attention_mask[start:end], decoder_input_ids=decoder_input_ids.to(input_ids.device)[start:end])[0]) aggr_output = torch.cat(aggr_output, 0) # Gather results across all input sentences, and sort generated tokens by log likelihood aggr_output = aggr_output.mean(0) log_denominator = torch.logsumexp(aggr_output[i], -1).item() ids = list(range(inner_model.config.vocab_size)) ids.sort(key=lambda x: aggr_output[i][x].item(), reverse=True) ids = ids[:self.beam_width+3] for word_id in ids: output_id = item['output_id'] if word_id == self.get_part_token_id(output_id) or word_id == self.tokenizer.eos_token_id: # Finish one part if self.length_limit is not None and item['last_length'] < self.length_limit[output_id - 1]: check = False else: check = True output_id += 1 last_length = 0 else: last_length = item['last_length'] + 1 check = True output_text = item['output'] + [word_id] ll = item['ll'] + aggr_output[i][word_id] - log_denominator new_decoder_input_ids = decoder_input_ids.new_zeros(decoder_input_ids.size()) new_decoder_input_ids[:] = decoder_input_ids new_decoder_input_ids[..., i + 1] = word_id if word_id in self.forbidden_word_ids: check = False # Forbid continuous "." if len(output_text) > 1 and output_text[-2] == self.sent_end_id and output_text[-1] == self.sent_end_id: check = False if check: # Add new results to beam search pool new_item = {'decoder_input_ids': new_decoder_input_ids, 'll': ll, 'output_id': output_id, 'output': output_text, 'last_length': last_length} new_current_output.append(new_item) if len(new_current_output) == 0: break new_current_output.sort(key=lambda x: x['ll'], reverse=True) new_current_output = new_current_output[:self.beam_width] current_output = new_current_output return [self.tokenizer.convert_ids_to_tokens(item['output']) for item in current_output] def _show_template(self): logger.info("Templates are \n{}".format('\n'.join(self.templates_text)))
[docs] @classmethod def from_config(cls, config: CfgNode, **kwargs,): r""" Returns: template_generator (:obj:`TemplateGenerator`) """ init_args = signature(cls.__init__).args _init_dict = {**convert_cfg_to_dict(config), **kwargs} init_dict = {key: _init_dict[key] for key in _init_dict if key in init_args} init_dict['config'] = config template_generator = cls(**init_dict) return template_generator
def release_memory(self): self.model = self.model.cpu()
[docs] def generate(self, dataset: List[InputExample]): r""" Args: dataset (:obj:`List[InputExample]`): The dataset based on which template it to be generated. Returns: template_text (:obj:`List[str]`): The generated template text """ template_for_auto_t = LMBFFTemplateGenerationTemplate.from_config(config=self.config.template, tokenizer=self.tokenizer, verbalizer = self.verbalizer) dataloader = PromptDataLoader(dataset, template_for_auto_t, tokenizer=self.tokenizer, tokenizer_wrapper_class=self.tokenizer_wrapper, batch_size=len(dataset), decoder_max_length=128) # register all data at once for data in dataloader: data = data.to(self.device) self._register_buffer(data) self.model.eval() with torch.no_grad(): self.templates_text = self._get_templates() # List[str] original_template = template_for_auto_t.text self.templates_text = [self.convert_template(template_text, original_template) for template_text in self.templates_text] self._show_template() return self.templates_text
[docs]class T5TemplateGenerator(TemplateGenerator): r""" Automatic template search using T5 model. This class inherits from ``TemplateGenerator``. """ def __init__(self, model: T5ForConditionalGeneration, tokenizer: T5Tokenizer, tokenizer_wrapper: Tokenizer, verbalizer: Verbalizer, max_length: Optional[int] = 20, target_number: Optional[int] = 2, beam_width: Optional[int] = 100, length_limit: Optional[List[int]] = None, forbidden_word_ids: Optional[List[int]] = [3, 19794, 22354], config: CfgNode = None): super().__init__(model = model, tokenizer = tokenizer, tokenizer_wrapper=tokenizer_wrapper, verbalizer = verbalizer, max_length = max_length, target_number= target_number, beam_width = beam_width, length_limit = length_limit, forbidden_word_ids = forbidden_word_ids, config=config)
[docs] def get_part_token_id(self, part_id): return self.tokenizer.additional_special_tokens_ids[part_id]
# def convert_template(self, generate_text_list): # # original_template = self.template_for_auto_t.text # text_list = self.tokenizer.convert_tokens_to_string(generate_text_list).replace('<extra_id_0>', '{"placeholder":"text_a"}').replace('<extra_id_1>', ' {"mask"}').replace('<extra_id_2>', ' {"placeholder":"text_b"}').replace('</s>', '').replace(' ', ' ').split(' ') # # in case no <extra_id_1> (generation stop by maximum length) # if '{"mask"}' not in text_list: # text_list.append('{"mask"}') # if '{"placeholder":"text_b"}' not in text_list: # text_list.append('{"placeholder":"text_b"}') # return text_list
[docs]class VerbalizerGenerator: r""" This is the automatic label word search implementation in `LM-BFF <https://arxiv.org/pdf/2012.15723.pdf>`_. Args: model (:obj:`PretrainedModel`): A pre-trained model for label word generation. tokenizer (:obj:`PretrainedTokenizer`): The corresponding tokenize. candidate_num (:obj:`Optional[int]`): The number of label word combinations to generate. Validation will then be performed on each combination. Defaults to 100. label_word_num_per_class (:obj:`Optional[int]`): The number of candidate label words per class. Defaults to 100. """ def __init__(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, candidate_num: Optional[int] = 100, label_word_num_per_class: Optional[int] = 100): self.model = model self.tokenizer = tokenizer self.candidate_num = candidate_num self.label_word_num_per_class = label_word_num_per_class self.probs_buffer, self.labels_buffer = None, None def register_buffer(self, data): self.model.eval() with torch.no_grad(): inner_model = self.model.module if isinstance(self.model, DataParallel) else self.model forward_keys = signature(inner_model.forward).args input_batch = {key: data[key] for key in data if key in forward_keys} logits = self.model.forward(**input_batch).logits[data['loss_ids']==1] logits = F.softmax(logits.detach(),dim=-1) if self.probs_buffer is None: self.probs_buffer = logits self.labels_buffer = data.label.detach() else: self.probs_buffer = torch.vstack([self.probs_buffer, logits]) self.labels_buffer = torch.hstack([self.labels_buffer, data.label.detach()])
[docs] @abstractmethod def post_process(self, word: str): r""" Post-processing for generated labrl word. Args: word (:obj:`str`): The original word token. Returns: processed_word (:obj:`str`): The post-processed token. """ inner_model = self.model.module if isinstance(self.model, DataParallel) else self.model if isinstance(inner_model, RobertaForMaskedLM): return word.lstrip('Ġ') elif isinstance(inner_model, BertForMaskedLM): return word else: raise RuntimeError("{} is not supported yet".format(type(inner_model))) # TODO add more model
[docs] @abstractmethod def invalid_label_word(self, word: str): r""" Decide whether the generated token is a valid label word. Heuristic strategy can be implemented here, e.g. requiring that a label word must be the start token of a word. Args: word (:obj:`str`): The token. Returns: is_invalid (:obj:`bool`): `True` if it cannot be a label word. """ inner_model = self.model.module if isinstance(self.model, DataParallel) else self.model if isinstance(inner_model, RobertaForMaskedLM): return (not word.startswith('Ġ')) elif isinstance(inner_model, BertForMaskedLM): return False else: raise RuntimeError("{} is not supported yet".format(type(inner_model))) # TODO
def _show_verbalizer(self): logger.info("Verbalizer is {}".format(self.label_words)) def _find_verbalizer(self): logger.info("Finding verbalizer ...") label_words = self._get_top_words() label_words = self._get_top_group(candidates=label_words) return label_words def _eval_group(self, group): label_logits = self.probs_buffer[:,torch.tensor(group)] preds = torch.argmax(label_logits, axis=-1) correct = torch.sum(preds == self.labels_buffer) return (correct / len(self.labels_buffer)).item() def _get_top_group(self, candidates: List[List[int]]): groups = list(itertools.product(*candidates)) group_scores = list(map(self._eval_group, groups)) # Take top-n. best_idx = np.argsort(-np.array(group_scores))[:self.candidate_num] best_groups = [groups[i] for i in best_idx] return best_groups def _get_top_words(self): label_words_ids = [] for label_id in torch.unique(self.labels_buffer): scores = self.probs_buffer[self.labels_buffer==label_id].mean(axis=0).cpu().numpy() kept = [] for i in np.argsort(-scores): word = self.tokenizer.convert_ids_to_tokens([i])[0] if self.invalid_label_word(word): continue kept.append(i) label_words_ids.append(kept[:self.label_word_num_per_class]) return label_words_ids
[docs] @classmethod def from_config(cls, config: CfgNode, **kwargs,): r""" Returns: verbalizer_generator (:obj:`VerbalizerGenerator`) """ init_args = signature(cls.__init__).args _init_dict = {**convert_cfg_to_dict(config), **kwargs} init_dict = {key: _init_dict[key] for key in _init_dict if key in init_args} verbalizer_generator = cls(**init_dict) return verbalizer_generator
def release_memory(self): self.model = self.model.cpu()
[docs] def generate(self): r""" Generate label words. Returns: label_words (:obj:`List[List[str]]`): A list of generated label word. """ self.label_words_ids = self._find_verbalizer() self.label_words = [[self.post_process(word) for word in self.tokenizer.convert_ids_to_tokens(i)] for i in self.label_words_ids] self._show_verbalizer() return self.label_words
[docs]class RobertaVerbalizerGenerator(VerbalizerGenerator): def __init__(self, model: RobertaForMaskedLM, tokenizer: RobertaTokenizer, candidate_num: Optional[int] = 100, label_word_num_per_class: Optional[int] = 100): super().__init__( model = model, tokenizer = tokenizer, candidate_num = candidate_num, label_word_num_per_class = label_word_num_per_class)
[docs] def invalid_label_word(self, word: str): return (not word.startswith('Ġ'))
[docs] def post_process(self, word: str): return word.lstrip('Ġ')