Source code for openprompt.prompt_base

from abc import abstractmethod
import json

from transformers.file_utils import ModelOutput
from openprompt.config import convert_cfg_to_dict

from transformers.utils.dummy_pt_objects import PreTrainedModel
from openprompt.utils.utils import signature

from yacs.config import CfgNode
from openprompt.data_utils import InputFeatures, InputExample
import torch
import torch.nn as nn
from typing import *
from transformers.tokenization_utils import PreTrainedTokenizer

from openprompt.utils.logging import logger
import numpy as np
import torch.nn.functional as F


[docs]class Template(nn.Module): r''' Base class for all the templates. Most of methods are abstract, with some exceptions to hold the common methods for all template, such as ``loss_ids``, ``save``, ``load``. Args: tokenizer (:obj:`PreTrainedTokenizer`): A tokenizer to appoint the vocabulary and the tokenization strategy. placeholder_mapping (:obj:`dict`): A place holder to represent the original input text. ''' registered_inputflag_names = ["loss_ids", "shortenable_ids"] def __init__(self, tokenizer: PreTrainedTokenizer, placeholder_mapping: dict = {'<text_a>':'text_a','<text_b>':'text_b'}, ): super().__init__() self.tokenizer = tokenizer self.placeholder_mapping = placeholder_mapping self._in_on_text_set = False self.mixed_token_start = "{" self.mixed_token_end = "}"
[docs] def get_default_loss_ids(self) -> List[int]: '''Get the loss indices for the template using mask. e.g. when self.text is ``'{"placeholder": "text_a"}. {"meta": "word"} is {"mask"}.'``, output is ``[0, 0, 0, 0, 1, 0]``. Returns: :obj:`List[int]`: A list of integers in the range [0, 1]: - 1 for a masked tokens. - 0 for a sequence tokens. ''' return [1 if 'mask' in d else 0 for d in self.text]
[docs] def get_default_shortenable_ids(self) -> List[int]: """Every template needs shortenable_ids, denoting which part of the template can be truncate to fit the language model's ``max_seq_length``. Default: the input text is shortenable, while the template text and other special tokens are not shortenable. e.g. when self.text is ``'{"placeholder": "text_a"} {"placeholder": "text_b", "shortenable": False} {"meta": "word"} is {"mask"}.'``, output is ``[1, 0, 0, 0, 0, 0, 0]``. Returns: :obj:`List[int]`: A list of integers in the range ``[0, 1]``: - 1 for the input tokens. - 0 for the template sequence tokens. """ idx = [] for d in self.text: if 'shortenable' in d: idx.append(1 if d['shortenable'] else 0) else: idx.append(1 if 'placeholder' in d else 0) return idx
[docs] def get_default_soft_token_ids(self) -> List[int]: r''' This function identifies which tokens are soft tokens. Sometimes tokens in the template are not from the vocabulary, but a sequence of soft tokens. In this case, you need to implement this function Raises: NotImplementedError: if needed, add ``soft_token_ids`` into ``registered_inputflag_names`` attribute of Template class and implement this method. ''' raise NotImplementedError
[docs] def incorporate_text_example(self, example: InputExample, text = None, ): if text is None: text = self.text.copy() else: text = text.copy() for i, d in enumerate(text): if 'placeholder' in d: text[i] = d["add_prefix_space"] + d.get("post_processing", lambda x:x)(getattr(example, d['placeholder'])) elif 'meta' in d: text[i] = d["add_prefix_space"] + d.get("post_processing", lambda x:x)(example.meta[d['meta']]) elif 'soft' in d: text[i] = ''; # unused elif 'mask' in d: text[i] = '<mask>' elif 'special' in d: text[i] = d['special'] elif 'text' in d: text[i] = d["add_prefix_space"] + d['text'] else: raise ValueError(f'can not parse {d}') return text
def _check_template_format(self, ): r"""check whether the template format is correct. TODO: add more """ mask_num = 0 for i, d in enumerate(self.text): if 'mask' in d: mask_num += 1 if mask_num==0: raise RuntimeError(f"'mask' position not found in the template: {self.text}. Please Check!")
[docs] def parse_text(self, text: str) -> List[Dict]: parsed = [] i = 0 while i < len(text): d = {"add_prefix_space": ' ' if (i > 0 and text[i-1] == ' ') else ''} while i < len(text) and text[i] == ' ': d["add_prefix_space"] = ' ' i = i + 1 if i == len(text): break if text[i] != self.mixed_token_start: j = i + 1 while j < len(text): if text[j] == self.mixed_token_start: break j = j + 1 d["text"] = text[i:j].rstrip(' ') i = j else: j = i + 1 mixed_token_cnt = 1 # { {} {} } nested support while j < len(text): if text[j] == self.mixed_token_end: mixed_token_cnt -= 1 if mixed_token_cnt == 0: break elif text[j] == self.mixed_token_start: mixed_token_cnt += 1 j = j + 1 if j == len(text): raise ValueError(f"mixed_token_start {self.mixed_token_start} at position {i} has no corresponding mixed_token_end {self.mixed_token_end}") dict_str = '{'+text[i+1:j]+'}' try: val = eval(dict_str) if isinstance(val, set): val = {k: None for k in val} d.update(val) except: import traceback print(traceback.format_exc()) print(f"syntax error in {dict_str}") exit() i = j + 1 parsed.append(d) return parsed
# @abstractmethod
[docs] def wrap_one_example(self, example: InputExample) -> List[Dict]: r'''Given an input example which contains input text, which can be referenced by self.template.placeholder_mapping 's value. This function process the example into a list of dict, Each dict functions as a group, which has the sample properties, such as whether it's shortenable, whether it's the masked position, whether it's soft token, etc. Since a text will be tokenized in the subsequent processing procedure, these attributes are broadcasted along the tokenized sentence. Args: example (:obj:`InputExample`): An :py:class:`~openprompt.data_utils.data_utils.InputExample` object, which should have attributes that are able to be filled in the template. Returns: :obj:`List[Dict]`: A list of dict of the same length as self.text. e.g. ``[{"loss_ids": 0, "text": "It was"}, {"loss_ids": 1, "text": "<mask>"}, ]`` ''' if self.text is None: raise ValueError("template text has not been initialized") if isinstance(example, InputExample): text = self.incorporate_text_example(example) not_empty_keys = example.keys() for placeholder_token in self.placeholder_mapping: not_empty_keys.remove(self.placeholder_mapping[placeholder_token]) # placeholder has been processed, remove not_empty_keys.remove('meta') # meta has been processed keys, values= ['text'], [text] for inputflag_name in self.registered_inputflag_names: keys.append(inputflag_name) v = None if hasattr(self, inputflag_name) and getattr(self, inputflag_name) is not None: v = getattr(self, inputflag_name) elif hasattr(self, "get_default_"+inputflag_name): v = getattr(self, "get_default_"+inputflag_name)() setattr(self, inputflag_name, v) # cache else: raise ValueError(""" Template's inputflag '{}' is registered but not initialize. Try using template.{} = [...] to initialize or create an method get_default_{}(self) in your template. """.format(inputflag_name, inputflag_name, inputflag_name)) if len(v) != len(text): raise ValueError("Template: len({})={} doesn't match len(text)={}."\ .format(inputflag_name, len(v), len(text))) values.append(v) wrapped_parts_to_tokenize = [] for piece in list(zip(*values)): wrapped_parts_to_tokenize.append(dict(zip(keys, piece))) wrapped_parts_not_tokenize = {key: getattr(example, key) for key in not_empty_keys} return [wrapped_parts_to_tokenize, wrapped_parts_not_tokenize] else: raise TypeError("InputExample")
[docs] @abstractmethod def process_batch(self, batch): r"""Template should rewrite this method if you need to process the batch input such as substituting embeddings. """ return batch # not being processed
[docs] def post_processing_outputs(self, outputs): r"""Post processing the outputs of language models according to the need of template. Most templates don't need post processing, The template like SoftTemplate, which appends soft template as a module (rather than a sequence of input tokens) to the input, should remove the outputs on these positions to keep the seq_len the same """ return outputs
[docs] def save(self, path: str, **kwargs) -> None: r''' A save method API. Args: path (str): A path to save your template. ''' raise NotImplementedError
@property def text(self): return self._text @text.setter def text(self, text): self._text = text if text is None: return if not self._in_on_text_set: self.safe_on_text_set() self._check_template_format() # else: # logger.warning("Reset text in on_text_set function. Is this intended?")
[docs] def safe_on_text_set(self) -> None: r"""With this wrapper function, setting text inside ``on_text_set()`` will not trigger ``on_text_set()`` again to prevent endless recursion. """ self._in_on_text_set = True self.on_text_set() self._in_on_text_set = False
[docs] @abstractmethod def on_text_set(self): r""" A hook to do something when template text was set. The designer of the template should explicitly know what should be down when the template text is set. """ raise NotImplementedError
[docs] def from_file(self, path: str, choice: int = 0, ): r''' Read the template from a local file. Args: path (:obj:`str`): The path of the local template file. choice (:obj:`int`): The id-th line of the file. ''' with open(path, 'r') as fin: text = fin.readlines()[choice].rstrip() logger.info(f"using template: {text}") self.text = text return self
[docs] @classmethod def from_config(cls, config: CfgNode, **kwargs): r"""load a template from template's configuration node. Args: config (:obj:`CfgNode`): the sub-configuration of template, i.e. config[config.template] if config is a global config node. kwargs: Other kwargs that might be used in initialize the verbalizer. The actual value should match the arguments of __init__ functions. """ init_args = signature(cls.__init__).args _init_dict = {**convert_cfg_to_dict(config), **kwargs} init_dict = {key: _init_dict[key] for key in _init_dict if key in init_args} template = cls(**init_dict) if hasattr(template, "from_file"): if not hasattr(config, "file_path"): pass else: if (not hasattr(config, "text") or config.text is None) and config.file_path is not None: if config.choice is None: config.choice = 0 template.from_file(config.file_path, config.choice) elif (hasattr(config, "text") and config.text is not None) and config.file_path is not None: raise RuntimeError("The text can't be both set from `text` and `file_path`.") return template
[docs]class Verbalizer(nn.Module): r''' Base class for all the verbalizers. Args: tokenizer (:obj:`PreTrainedTokenizer`): A tokenizer to appoint the vocabulary and the tokenization strategy. classes (:obj:`Sequence[str]`): A sequence of classes that need to be projected. ''' def __init__(self, tokenizer: Optional[PreTrainedTokenizer] = None, classes: Optional[Sequence[str]] = None, num_classes: Optional[int] = None, ): super().__init__() self.tokenizer = tokenizer self.classes = classes if classes is not None and num_classes is not None: assert len(classes) == num_classes, "len(classes) != num_classes, Check you config." self.num_classes = num_classes elif num_classes is not None: self.num_classes = num_classes elif classes is not None: self.num_classes = len(classes) else: self.num_classes = None # raise AttributeError("No able to configure num_classes") self._in_on_label_words_set = False @property def label_words(self,): r''' Label words means the words in the vocabulary projected by the labels. E.g. if we want to establish a projection in sentiment classification: positive :math:`\rightarrow` {`wonderful`, `good`}, in this case, `wonderful` and `good` are label words. ''' if not hasattr(self, "_label_words"): raise RuntimeError("label words haven't been set.") return self._label_words @label_words.setter def label_words(self, label_words): if label_words is None: return self._label_words = self._match_label_words_to_label_ids(label_words) if not self._in_on_label_words_set: self.safe_on_label_words_set() # else: # logger.warning("Reset label words in on_label_words_set function. Is this intended?") def _match_label_words_to_label_ids(self, label_words): # TODO newly add function after docs written # TODO rename this function """ sort label words dict of verbalizer to match the label order of the classes """ if isinstance(label_words, dict): if self.classes is None: raise ValueError(""" classes attribute of the Verbalizer should be set since your given label words is a dict. Since we will match the label word with respect to class A, to A's index in classes """) if set(label_words.keys()) != set(self.classes): raise ValueError("name of classes in verbalizer are different from those of dataset") label_words = [ # sort the dict to match dataset label_words[c] for c in self.classes ] # length: label_size of the whole task elif isinstance(label_words, list) or isinstance(label_words, tuple): pass # logger.info(""" # Your given label words is a list, by default, the ith label word in the list will match class i of the dataset. # Please make sure that they have the same order. # Or you can pass label words as a dict, mapping from class names to label words. # """) else: raise ValueError("Verbalizer label words must be list, tuple or dict") return label_words
[docs] def safe_on_label_words_set(self,): self._in_on_label_words_set = True self.on_label_words_set() self._in_on_label_words_set = False
[docs] def on_label_words_set(self,): r"""A hook to do something when textual label words were set. """ pass
@property def vocab(self,) -> Dict: if not hasattr(self, '_vocab'): self._vocab = self.tokenizer.convert_ids_to_tokens(np.arange(self.vocab_size).tolist()) return self._vocab @property def vocab_size(self,) -> int: return self.tokenizer.vocab_size
[docs] @abstractmethod def generate_parameters(self, **kwargs) -> List: r""" The verbalizer can be seen as an extra layer on top of the original pre-trained models. In manual verbalizer, it is a fixed one-hot vector of dimension ``vocab_size``, with the position of the label word being 1 and 0 everywhere else. In other situation, the parameters may be a continuous vector over the vocab, with each dimension representing a weight of that token. Moreover, the parameters may be set to trainable to allow label words selection. Therefore, this function serves as an abstract methods for generating the parameters of the verbalizer, and must be instantiated in any derived class. Note that the parameters need to be registered as a part of pytorch's module to It can be achieved by wrapping a tensor using ``nn.Parameter()``. """ raise NotImplementedError
[docs] def register_calibrate_logits(self, logits: torch.Tensor): r""" This function aims to register logits that need to be calibrated, and detach the original logits from the current graph. """ if logits.requires_grad: logits = logits.detach() self._calibrate_logits = logits
[docs] def process_outputs(self, outputs: torch.Tensor, batch: Union[Dict, InputFeatures], **kwargs): r"""By default, the verbalizer will process the logits of the PLM's output. Args: logits (:obj:`torch.Tensor`): The current logits generated by pre-trained language models. batch (:obj:`Union[Dict, InputFeatures]`): The input features of the data. """ return self.process_logits(outputs, batch=batch, **kwargs)
[docs] def gather_outputs(self, outputs: ModelOutput): r""" retrieve useful output for the verbalizer from the whole model output By default, it will only retrieve the logits Args: outputs (:obj:`ModelOutput`) The output from the pretrained language model. Return: :obj:`torch.Tensor` The gathered output, should be of shape (``batch_size``, ``seq_len``, ``any``) """ return outputs.logits
[docs] @staticmethod def aggregate(label_words_logits: torch.Tensor) -> torch.Tensor: r""" To aggregate logits on multiple label words into the label's logits Basic aggregator: mean of each label words' logits to a label's logits Can be re-implemented in advanced verbaliezer. Args: label_words_logits (:obj:`torch.Tensor`): The logits of the label words only. Return: :obj:`torch.Tensor`: The final logits calculated by the label words. """ if label_words_logits.dim()>2: return label_words_logits.mean(dim=-1) else: return label_words_logits
[docs] def normalize(self, logits: torch.Tensor) -> torch.Tensor: r""" Given logits regarding the entire vocab, calculate the probs over the label words set by softmax. Args: logits(:obj:`Tensor`): The logits of the entire vocab. Returns: :obj:`Tensor`: The probability distribution over the label words set. """ batch_size = logits.shape[0] return F.softmax(logits.reshape(batch_size, -1), dim=-1).reshape(*logits.shape)
[docs] @abstractmethod def project(self, logits: torch.Tensor, **kwargs) -> torch.Tensor: r"""This method receives input logits of shape ``[batch_size, vocab_size]``, and use the parameters of this verbalizer to project the logits over entire vocab into the logits of labels words. Args: logits (:obj:`Tensor`): The logits over entire vocab generated by the pre-trained language model with shape [``batch_size``, ``max_seq_length``, ``vocab_size``] Returns: :obj:`Tensor`: The normalized probs (sum to 1) of each label . """ raise NotImplementedError
[docs] def handle_multi_token(self, label_words_logits, mask): r""" Support multiple methods to handle the multi tokens produced by the tokenizer. We suggest using 'first' or 'max' if the some parts of the tokenization is not meaningful. Can broadcast to 3-d tensor. Args: label_words_logits (:obj:`torch.Tensor`): Returns: :obj:`torch.Tensor` """ if self.multi_token_handler == "first": label_words_logits = label_words_logits.select(dim=-1, index=0) elif self.multi_token_handler == "max": label_words_logits = label_words_logits - 1000*(1-mask.unsqueeze(0)) label_words_logits = label_words_logits.max(dim=-1).values elif self.multi_token_handler == "mean": label_words_logits = (label_words_logits*mask.unsqueeze(0)).sum(dim=-1)/(mask.unsqueeze(0).sum(dim=-1)+1e-15) else: raise ValueError("multi_token_handler {} not configured".format(self.multi_token_handler)) return label_words_logits
[docs] @classmethod def from_config(cls, config: CfgNode, **kwargs): r"""load a verbalizer from verbalizer's configuration node. Args: config (:obj:`CfgNode`): the sub-configuration of verbalizer, i.e. ``config[config.verbalizer]`` if config is a global config node. kwargs: Other kwargs that might be used in initialize the verbalizer. The actual value should match the arguments of ``__init__`` functions. """ init_args = signature(cls.__init__).args _init_dict = {**convert_cfg_to_dict(config), **kwargs} if config is not None else kwargs init_dict = {key: _init_dict[key] for key in _init_dict if key in init_args} verbalizer = cls(**init_dict) if hasattr(verbalizer, "from_file"): if not hasattr(config, "file_path"): pass else: if (not hasattr(config, "label_words") or config.label_words is None) and config.file_path is not None: if config.choice is None: config.choice = 0 verbalizer.from_file(config.file_path, config.choice) elif (hasattr(config, "label_words") and config.label_words is not None) and config.file_path is not None: raise RuntimeError("The text can't be both set from `text` and `file_path`.") return verbalizer
[docs] def from_file(self, path: str, choice: Optional[int] = 0 ): r"""Load the predefined label words from verbalizer file. Currently support three types of file format: 1. a .jsonl or .json file, in which is a single verbalizer in dict format. 2. a .jsonal or .json file, in which is a list of verbalizers in dict format 3. a .txt or a .csv file, in which is the label words of a class are listed in line, separated by commas. Begin a new verbalizer by an empty line. This format is recommended when you don't know the name of each class. The details of verbalizer format can be seen in :ref:`How_to_write_a_verbalizer`. Args: path (:obj:`str`): The path of the local template file. choice (:obj:`int`): The choice of verbalizer in a file containing multiple verbalizers. Returns: Template : `self` object """ if path.endswith(".txt") or path.endswith(".csv"): with open(path, 'r') as f: lines = f.readlines() label_words_all = [] label_words_single_group = [] for line in lines: line = line.strip().strip(" ") if line == "": if len(label_words_single_group)>0: label_words_all.append(label_words_single_group) label_words_single_group = [] else: label_words_single_group.append(line) if len(label_words_single_group) > 0: # if no empty line in the last label_words_all.append(label_words_single_group) if choice >= len(label_words_all): raise RuntimeError("choice {} exceed the number of verbalizers {}" .format(choice, len(label_words_all))) label_words = label_words_all[choice] label_words = [label_words_per_label.strip().split(",") \ for label_words_per_label in label_words] elif path.endswith(".jsonl") or path.endswith(".json"): with open(path, "r") as f: label_words_all = json.load(f) # if it is a file containing multiple verbalizers if isinstance(label_words_all, list): if choice >= len(label_words_all): raise RuntimeError("choice {} exceed the number of verbalizers {}" .format(choice, len(label_words_all))) label_words = label_words_all[choice] elif isinstance(label_words_all, dict): label_words = label_words_all if choice>0: logger.warning("Choice of verbalizer is 1, but the file \ only contains one verbalizer.") self.label_words = label_words if self.num_classes is not None: num_classes = len(self.label_words) assert num_classes==self.num_classes, 'number of classes in the verbalizer file\ does not match the predefined num_classes.' return self