Source code for openprompt.prompts.mixed_template


import os
import string
from openprompt.utils.logging import logger

from openprompt.data_utils.utils import InputExample, InputFeatures
from typing import *

from transformers import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
from openprompt import Template

import torch
from torch import nn

[docs]class MixedTemplate(Template): r"""The Mixed Template class defined by a string of `text`. See more examples in the `tutorial <https://github.com/thunlp/OpenPrompt/blob/ca27491101df0108a8dd753e5b1e79bf591f65d3/tutorial/1.1_mixed_template.py>`_. Args: model (:obj:`PreTrainedModel`): The pre-trained language model for the current prompt-learning task. tokenizer (:obj:`PreTrainedTokenizer`): A tokenizer to appoint the vocabulary and the tokenization strategy. text (:obj:`Optional[List[str]]`, optional): manual template format. Defaults to None. """ registered_inputflag_names = ["soft_token_ids", "loss_ids", "shortenable_ids"] def __init__(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, text: Optional[str] = None, placeholder_mapping: dict = {'<text_a>':'text_a','<text_b>':'text_b'}, ): super().__init__(tokenizer = tokenizer, placeholder_mapping = placeholder_mapping) self.raw_embedding = model.get_input_embeddings() self.embedding_size = self.raw_embedding.weight.shape[-1] self.text = text
[docs] def get_default_soft_token_ids(self) -> List[int]: return self.soft_token_ids
[docs] def prepare(self): r"""get the soft token indices ( soft_token_ids ) for the template ``"soft_id"`` can be used to reference the previous soft token, which means these tokens use the same embeddings. **Note that ``"soft_id"`` should have index start from 1 but not 0** e.g. when self.text is ``'{"soft": None} {"soft": "the", "soft_id": 1} {"soft": None} {"soft": "it", "soft_id": 3} {"soft_id": 1} {"soft": "was"} {"mask"}'``, output is [1, 2, 3, 4, 2, 5, 0] """ num_soft_token = 0 text = [] soft_token_ids = [] idx_mp = {} emb_mp = {} for d in self.text: if "soft" not in d and "soft_id" not in d: text.append(d) soft_token_ids.append(0) continue old_num = num_soft_token if "soft_id" in d: if not isinstance(d["soft_id"], int) or d["soft_id"] <= 0: raise ValueError(f'soft_id should be integer greater than zero, but get {d["soft_id"]}') if d["soft_id"] in idx_mp: id_list = idx_mp[d["soft_id"]] text.extend([{"soft":None} for _ in range(len(id_list))]) soft_token_ids.extend(id_list) continue else: if "soft" not in d: d["soft"] = None if d["soft"] is None: if "duplicate" in d: if "same" in d and d["same"]: num_soft_token += 1 id_list = [num_soft_token for _ in range(len(d["duplicate"]))] else: num_soft_token += d["duplicate"] id_list = list(range(old_num+1, num_soft_token+1)) else: num_soft_token += 1 id_list = [num_soft_token] text.extend([{"soft":""} for _ in range(len(id_list))]) else: token_ids = self.tokenizer(d["add_prefix_space"] + d["soft"], add_special_tokens=False)["input_ids"] surface_forms = self.tokenizer.convert_ids_to_tokens(token_ids) assert len(token_ids) == len(surface_forms) num_soft_token += len(token_ids) id_list = list(range(old_num+1, num_soft_token+1)) for idx, soft_id in enumerate(id_list): emb_mp[soft_id] = token_ids[idx] text.extend([{"soft": surface_form} for surface_form in surface_forms]) soft_token_ids.extend(id_list) if "soft_id" in d: idx_mp[d["soft_id"]] = id_list self.num_soft_token = num_soft_token self.text = text self.soft_token_ids = soft_token_ids # Generate the embedding needed for soft tokens self.soft_embedding = nn.Embedding(1 + self.num_soft_token, self.embedding_size) for soft_id, token_id in emb_mp.items(): self.soft_embedding.weight.data[soft_id, :] = self.raw_embedding.weight.data[token_id, :].clone().detach().requires_grad_(True)
# if "post_processing" in d: # if d["post_processing"] == "mlp": # pass # TODO one mlp or more than one # else: # raise ValueError(f'post_processing of {d["post_processing"]} is not supported yet') def parse_text(self, text: str) -> List[Dict]: parsed = [] i = 0 while i < len(text): d = {"add_prefix_space": ' ' if (i > 0 and text[i-1] == ' ') else ''} while i < len(text) and text[i] == ' ': d["add_prefix_space"] = ' ' i = i + 1 if i == len(text): break if text[i] != self.mixed_token_start: j = i + 1 while j < len(text): if text[j] == self.mixed_token_start: break j = j + 1 d["text"] = text[i:j].rstrip(' ') i = j else: j = i + 1 mixed_token_cnt = 1 # { {} {} } nested support while j < len(text): if text[j] == self.mixed_token_end: mixed_token_cnt -= 1 if mixed_token_cnt == 0: break elif text[j] == self.mixed_token_start: mixed_token_cnt += 1 j = j + 1 if j == len(text): raise ValueError(f"mixed_token_start {self.mixed_token_start} at position {i} has no corresponding mixed_token_end {self.mixed_token_end}") dict_str = '{'+text[i+1:j]+'}' try: val = eval(dict_str) if isinstance(val, set): val = {k: None for k in val} d.update(val) except: import traceback print(traceback.format_exc()) print(f"syntax error in {dict_str}") exit() i = j + 1 parsed.append(d) return parsed
[docs] def on_text_set(self): """ when template text was set 1. parse text 2. generate parameter needed """ self.text = self.parse_text(self.text) self.prepare()
def incorporate_text_example(self, example: InputExample ): text = self.text.copy() for i, d in enumerate(text): if 'placeholder' in d: text[i] = d["add_prefix_space"] + d.get("post_processing", lambda x:x)(getattr(example, d['placeholder'])) elif 'meta' in d: text[i] = d["add_prefix_space"] + d.get("post_processing", lambda x:x)(example.meta[d['meta']]) elif 'soft' in d: text[i] = d['soft']; # unused elif 'mask' in d: text[i] = '<mask>' elif 'special' in d: text[i] = d['special'] elif 'text' in d: text[i] = d["add_prefix_space"] + d['text'] else: raise ValueError(f'can not parse {d}') return text
[docs] def process_batch(self, batch: Union[Dict, InputFeatures]) -> Union[Dict, InputFeatures]: """ Convert input_ids to inputs_embeds for normal tokens, use the embedding layer of PLM for soft tokens, use a new embedding layer which is initialized with their corresponding embedding of hard tokens """ raw_embeds = self.raw_embedding(batch['input_ids']) soft_embeds = self.soft_embedding(batch['soft_token_ids']) inputs_embeds = torch.where((batch['soft_token_ids'] > 0).unsqueeze(-1), soft_embeds, raw_embeds) batch['input_ids'] = None batch['inputs_embeds'] = inputs_embeds return batch