Source code for openprompt.prompts.generation_verbalizer

from functools import partial
import json
from openprompt.data_utils.utils import InputExample
from transformers.tokenization_utils import PreTrainedTokenizer
from yacs.config import CfgNode
from openprompt.data_utils import InputFeatures
import re
from openprompt import Verbalizer
from typing import *
import torch
import torch.nn as nn
import torch.nn.functional as F
from openprompt.utils.logging import logger
from openprompt.utils.crossfit_metrics import *
import re


[docs]class GenerationVerbalizer(Verbalizer): r""" This verbalizer is useful when the label prediction is better defined by a piece of input. For example, in correference resolution, the tgt_text is a proper noun mentioned in the text. There is no fixed mapping between a class label and its label words. This verbalizer can be used as verbalizer of ``COPA`` and ``WSC`` datasets in SuperGlue. This verbalizer is especially powerful when combined with `All NLP Tasks Are Generation Tasks <https://arxiv.org/abs/2103.10360>`_ Paradigm (Also see `Crossfit <https://arxiv.org/abs/2104.08835>`_). It can make any piece of text the tgt_text. The tgt_text will then be filled in the `{"mask"}`. For example, when label word is ``"good"``, the tgt_text is ``"good"``; when label word is ``{"text":"good"}``, the tgt_text is also ``"good"``; when label word is ``{"meta":"choice1"}``, the tgt_text is the ``"meta['choice1']"`` field of the ``InputExample``; when label word is ``{"meta":"choice1"} {"placeholder", "text_a"} .``, the tgt_text is the ``"meta['choice1']"`` field of the ``InputExample``, followed by ``text_a`` field of the ``InputExample``, and then a ``'.'``; A use case can be seen in `Tutorial 4.1 <https://github.com/thunlp/OpenPrompt/blob/main/tutorial/4.1_all_tasks_are_generation.py>`_ Args: tokenizer (:obj:`PreTrainedTokenizer`): The tokenizer of the current pre-trained model to point out the vocabulary. classes (:obj:`List[Any]`): The classes (or labels) of the current task. prefix (:obj:`str`, optional): The prefix string of the verbalizer (used in PLMs like RoBERTa, which is sensitive to prefix space) is_rule (:obj:`bool`, optional): When the verbalizer use the rule syntax of MixTemplate. label_words (:obj:`dict`, optional): The label words of the generation verbalizer Example: To use this verbalizer to train the T5 model to predict answer and explanation using two masks. When the template (Defined by :obj:`MixedTemplate`) is: >>> input_example = InputExample(text_a = "Can fish run?", meta={"answer":"no", "explanation": "The fish have no legs"}, label=0) >>> template = "{'placeholder':'text_a'} answer: {'mask'} explanation: {'mask'}" The verbalizer can be: >>> label_words = {0:["no", "{'meta':'explanation'}"], 1:["yes", "{'meta':'explanation'}"]} >>> verbalizer = GenerationVerbalizer(tokenizer, classes=None, is_rule=True, label_words=label_words) """ def __init__(self, tokenizer: PreTrainedTokenizer, classes: Optional[List[str]] = None, num_classes: Optional[int] = None, is_rule: Optional[bool] = False, label_words: Optional[dict] = None, ): if classes is None and label_words is not None: classes = list(label_words.keys()) super().__init__(tokenizer=tokenizer, num_classes=num_classes, classes=classes) self.prefix = '' self.is_rule = is_rule self.mixed_token_start = "{" self.mixed_token_end = "}" if label_words is not None: # use label words as an initialization self.label_words = label_words
[docs] def wrap_one_example(self, example: InputExample) -> List[Dict]: r"""Take an InputExample, and fill the tgt_text with label words """ if not isinstance(self.label_words[example.label], list): label_word = [self.label_words[example.label]] else: label_word = self.label_words[example.label] if example.tgt_text is not None: logger.warning(f"The example already has tgt_text {example.tgt_text}, and will be filled with new label words, is this intended?") else: if not self.is_rule: instance_label_word = label_word else: instance_label_word = [i(example) for i in label_word] #(example) if len(instance_label_word) == 1: example.tgt_text = instance_label_word[0] else: example.tgt_text = instance_label_word return example
[docs] def on_label_words_set(self): r""" Process the text into the label words (sometimes a function) according to the syntax of MixedTemplate """ if isinstance(self.label_words[0], list): self.label_words = [x[0] for x in self.label_words] if self.is_rule: for id, label_word in enumerate(self.label_words): try: d = self.parse_text(label_word) except: raise RuntimeError(f"is_rule={self.is_rule} but label_word: {label_word} can't be converted to object.") self.label_words[id] = partial(lambda x, text: self.incorporate_text_example(text, x), text=d)
def parse_text(self, text: str) -> List[Dict]: parsed = [] i = 0 while i < len(text): d = {"add_prefix_space": ' ' if (i > 0 and text[i-1] == ' ') else ''} while i < len(text) and text[i] == ' ': d["add_prefix_space"] = '' i = i + 1 if i == len(text): break if text[i] != self.mixed_token_start: j = i + 1 while j < len(text): if text[j] == self.mixed_token_start: break j = j + 1 d["text"] = text[i:j].rstrip(' ') i = j else: j = i + 1 while j < len(text): if text[j] == self.mixed_token_end: break j = j + 1 if j == len(text): raise ValueError(f"mixed_token_start {self.mixed_token_start} at position {i} has no corresponding mixed_token_end {self.mixed_token_end}") dict_str = '{'+text[i+1:j]+'}' try: val = eval(dict_str) if isinstance(val, set): val = {k: None for k in val} d.update(val) except: import traceback print(traceback.format_exc()) print(f"syntax error in {dict_str}") exit() i = j + 1 parsed.append(d) return parsed def incorporate_text_example(self, text, example: InputExample ): text = text.copy() for i, d in enumerate(text): if 'placeholder' in d: text[i] = d["add_prefix_space"] + d.get("post_processing", lambda x:x)(getattr(example, d['placeholder'])) elif 'meta' in d: text[i] = d["add_prefix_space"] + d.get("post_processing", lambda x:x)(example.meta[d['meta']]) elif 'soft' in d: raise RuntimeError("soft token not supported in verbalizer") # unused elif 'mask' in d: raise RuntimeError("mask token not supported in verbalizer") elif 'special' in d: raise RuntimeError("special token not supported in verbalizer") elif 'text' in d: text[i] = d["add_prefix_space"] + d['text'] else: raise ValueError(f'can not parse {d}') text = " ".join(text) return text