Source code for openprompt.prompts.knowledgeable_verbalizer

import os
from openprompt.prompts.manual_template import ManualTemplate
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.utils.dummy_pt_objects import PreTrainedModel
from openprompt.data_utils import InputFeatures
import re
from openprompt.prompts.manual_verbalizer import ManualVerbalizer
from typing import *
import torch
import torch.nn as nn
import torch.nn.functional as F
from openprompt.utils.logging import logger



[docs]class KnowledgeableVerbalizer(ManualVerbalizer): r""" This is the implementation of knowledeagble verbalizer, which uses external knowledge to expand the set of label words. This class inherit the ``ManualVerbalizer`` class. Args: tokenizer (:obj:`PreTrainedTokenizer`): The tokenizer of the current pre-trained model to point out the vocabulary. classes (:obj:`classes`): The classes (or labels) of the current task. prefix (:obj:`str`, optional): The prefix string of the verbalizer. multi_token_handler (:obj:`str`, optional): The handling strategy for multiple tokens produced by the tokenizer. max_token_split (:obj:`int`, optional): verbalizer_lr (:obj:`float`, optional): The learning rate of the verbalizer optimization. candidate_frac (:obj:`float`, optional): """ def __init__(self, tokenizer: PreTrainedTokenizer = None, classes: Sequence[str] = None, prefix: Optional[str] = " ", multi_token_handler: Optional[str] = "first", max_token_split: Optional[int] = -1, verbalizer_lr: Optional[float]=5e-2, candidate_frac: Optional[float]=0.5, pred_temp: Optional[float]=1.0, **kwargs): super().__init__(classes=classes, prefix=prefix, multi_token_handler=multi_token_handler, tokenizer=tokenizer, **kwargs) self.max_token_split = max_token_split self.verbalizer_lr = verbalizer_lr self.candidate_frac = candidate_frac self.pred_temp = pred_temp
[docs] def on_label_words_set(self): self.label_words = self.delete_common_words(self.label_words) self.label_words = self.add_prefix(self.label_words, self.prefix) self.generate_parameters()
def delete_common_words(self, d): word_count = {} for d_perclass in d: for w in d_perclass: if w not in word_count: word_count[w]=1 else: word_count[w]+=1 for w in word_count: if word_count[w]>=2: for d_perclass in d: if w in d_perclass[1:]: findidx = d_perclass[1:].index(w) d_perclass.pop(findidx+1) return d
[docs] @staticmethod def add_prefix(label_words, prefix): r"""add prefix to label words. For example, if a label words is in the middle of a template, the prefix should be ' '. """ new_label_words = [] for words in label_words: new_label_words.append([prefix + word.lstrip(prefix) for word in words]) return new_label_words
[docs] def generate_parameters(self) -> List: r"""In basic manual template, the parameters are generated from label words directly. In this implementation, the label_words should not be tokenized into more one token. """ all_ids = [] label_words = [] # print([len(x) for x in self.label_words], flush=True) for words_per_label in self.label_words: ids_per_label = [] words_keep_per_label = [] for word in words_per_label: ids = self.tokenizer.encode(word, add_special_tokens=False) if self.max_token_split>0 and len(ids) > self.max_token_split: # in knowledgebale verbalizer, the labelwords may be very rare, so we may # want to remove the label words which are not recogonized by tokenizer. logger.warning("Word {} is split into {} (>{}) tokens: {}. Ignored.".format(word, \ len(ids), self.max_token_split, self.tokenizer.convert_ids_to_tokens(ids))) continue else: words_keep_per_label.append(word) ids_per_label.append(ids) label_words.append(words_keep_per_label) all_ids.append(ids_per_label) self.label_words = label_words max_len = max([max([len(ids) for ids in ids_per_label]) for ids_per_label in all_ids]) max_num_label_words = max([len(ids_per_label) for ids_per_label in all_ids]) words_ids_mask = torch.zeros(max_num_label_words, max_len) words_ids_mask = [[[1]*len(ids) + [0]*(max_len-len(ids)) for ids in ids_per_label] + [[0]*max_len]*(max_num_label_words-len(ids_per_label)) for ids_per_label in all_ids] words_ids = [[ids + [0]*(max_len-len(ids)) for ids in ids_per_label] + [[0]*max_len]*(max_num_label_words-len(ids_per_label)) for ids_per_label in all_ids] words_ids_tensor = torch.tensor(words_ids) words_ids_mask = torch.tensor(words_ids_mask) self.label_words_ids = nn.Parameter(words_ids_tensor, requires_grad=False) self.words_ids_mask = nn.Parameter(words_ids_mask, requires_grad=False) # A 3-d mask self.label_words_mask = nn.Parameter(torch.clamp(words_ids_mask.sum(dim=-1), max=1), requires_grad=False) self.label_words_weights = nn.Parameter(torch.zeros(self.num_classes, max_num_label_words), requires_grad=True) print("##Num of label words for each label: {}".format(self.label_words_mask.sum(-1).cpu().tolist()), flush=True)
# print(self.label_words_ids.data.shape, flush=True) # print(self.words_ids_mask.data.shape, flush=True) # print(self.label_words_mask.data.shape, flush=True) # print(self.label_words_weights.data.shape, flush=True) # exit() # self.verbalizer_optimizer = torch.optim.AdamW(self.parameters(), lr=self.verbalizer_lr)
[docs] def register_calibrate_logits(self, logits: torch.Tensor): r"""For Knowledgeable Verbalizer, it's nessessory to filter the words with has low prior probability. Therefore we re-compute the label words after register calibration logits. """ if logits.requires_grad: logits = logits.detach() self._calibrate_logits = logits cur_label_words_ids = self.label_words_ids.data.cpu().tolist() rm_calibrate_ids = set(torch.argsort(self._calibrate_logits)[:int(self.candidate_frac*logits.shape[-1])].cpu().tolist()) new_label_words = [] for i_label, words_ids_per_label in enumerate(cur_label_words_ids): new_label_words.append([]) for j_word, word_ids in enumerate(words_ids_per_label): if j_word >= len(self.label_words[i_label]): break if len((set(word_ids).difference(set([0]))).intersection(rm_calibrate_ids)) == 0: new_label_words[-1].append(self.label_words[i_label][j_word]) self.label_words = new_label_words self.to(self._calibrate_logits.device)
[docs] def project(self, logits: torch.Tensor, **kwargs, ) -> torch.Tensor: r"""The return value if the normalized (sum to 1) probs of label words. """ label_words_logits = logits[:, self.label_words_ids] label_words_logits = self.handle_multi_token(label_words_logits, self.words_ids_mask) label_words_logits -= 10000*(1-self.label_words_mask) return label_words_logits
[docs] def aggregate(self, label_words_logits: torch.Tensor) -> torch.Tensor: r"""Use weight to aggregate the logots of label words. Args: label_words_logits(:obj:`torch.Tensor`): The logits of the label words. Returns: :obj:`torch.Tensor`: The aggregated logits from the label words. """ if not self.training: label_words_weights = F.softmax(self.pred_temp*self.label_words_weights-10000*(1-self.label_words_mask), dim=-1) else: label_words_weights = F.softmax(self.label_words_weights-10000*(1-self.label_words_mask), dim=-1) label_words_logits = (label_words_logits * self.label_words_mask * label_words_weights).sum(-1) return label_words_logits
# def optimize(self,): # self.verbalizer_optimizer.step() # self.verbalizer_optimizer.zero_grad()