from transformers.tokenization_utils import PreTrainedTokenizer
from openprompt.data_utils import InputFeatures
from openprompt import Verbalizer
from typing import List, Optional, Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
from openprompt.utils.logging import logger
[docs]class AutomaticVerbalizer(Verbalizer):
r"""
This implementation is slightly different from the original code in that
1). we allow re-selecting the verbalizer after a fixed training steps.
The original implementation only performs one step selection after getting
the initial logits on the training data. To adopt their implementation,
please only do ``optimize()`` after the first pass of training data.
2). We strictly follows the probility calculation in Equation (3) in the
paper, which take softmax over the logits.
3). We do not implements the ``combine_patterns'' if-branch. Since it's
not a pure verbalizer type, and doesn't yield much improvement. However,
it can be achieve by using EnsembleTrainer to pass text wrapped by
multiple templates together with this verbalizer.
We use a probs_buffer to store the probability :math:`q_{P,t}(1|\mathbf{x})` that to be used in later verbalizer selection,
and a label_buffer to store the label :math:`y` that to be used in later verbalizer selection.
Args:
num_candidates (:obj:`int`, optional): the number of candidates for further selection based on Section 4.1
label_word_num_per_class (:obj:`int`, optional): set to be greater than 1 to support Multi-Verbalizers in Section 4.2
num_searches (:obj:`int`, optional): Maximnum number of label_words search. After reaching this number, the verbalizer will use the same label_words as the previous iterations.
search_id (:obj:`int`, optional): the id of current search, used to determine when to stop label words searching.
score_fct (:obj:`str`, optional): the scoring function of label words selection. ``llr`` means log likelihood ratio, corresponding to Equation (7); ``ce`` means cross entropy, corresponding to Equation (6). As the paper points out, ``llr'' is significantly better than 'ce', we only keep it to match the original code.
balance (:obj:`book`, optional): whether to perform normalization of unbalanced training dataset, as Equation (5).
"""
def __init__(self,
tokenizer: PreTrainedTokenizer = None,
num_candidates: Optional[int]= 1000,
label_word_num_per_class: Optional[int] = 1,
num_searches: Optional[int] = 1,
score_fct: Optional[str] = 'llr',
balance: Optional[bool] = True,
num_classes: Optional[bool] = None,
classes: Optional[List[str]] = None,
init_using_split: Optional[str] = "train",
**kwargs):
super().__init__(num_classes=num_classes, tokenizer = tokenizer, classes=classes)
self.num_candidates = num_candidates
self.label_word_num_per_class = label_word_num_per_class
self.probs_buffer, self.labels_buffer = None, None
assert num_searches > 0, "You requires the verbalizer to perform {} searches. Invalid.".format(num_searches)
self.num_searches = num_searches
self.search_id = 0
self.accumulate_step = 0 # currently not used, to support not epoch-level optimize.
self.accumulate = True # A flag to indicate whether to
# accumulate examples for optimization.
# set to False after finish optimization.
self.score_fct = score_fct
self.balance = balance
self.init_using_split = init_using_split
[docs] def register_buffer(self, logits, labels):
r'''
Args:
logits (:obj:`torch.Tensor`):
labels (:obj:`List`):
'''
logits = F.softmax(logits.detach(),dim=-1)
labels = labels.detach()
if self.probs_buffer is None :
self.probs_buffer = logits
self.labels_buffer = labels
else:
self.probs_buffer = torch.vstack([self.probs_buffer, logits])
self.labels_buffer = torch.hstack([self.labels_buffer, labels])
def process_logits(self, logits: torch.Tensor, **kwargs):
if self.accumulate: # inherit from nn.Module, only store buffer in training mode.
self.accumulate_step+=1
self.register_buffer(logits, kwargs['batch']['label'])
if hasattr(self, "label_words_ids"): # TODO the content in this "if" is same as super()
# project
label_words_logits = self.project(logits, **kwargs) #Output: (batch_size, num_classes) or (batch_size, num_classes, num_label_words_per_label)
# normalize
label_words_probs = self.normalize(label_words_logits)
# calibrate
if hasattr(self, "_calibrate_logits") and self._calibrate_logits is not None:
label_words_probs = self.calibrate(label_words_probs=label_words_probs)
# convert to logits
label_words_logits = torch.log(label_words_probs+1e-15)
# aggregate
if label_words_logits.dim()>2:
label_logits = self.aggregate(label_words_logits)
else:
label_logits = label_words_logits
return label_logits
else:
return torch.randn((logits.size(0), self.num_classes), requires_grad=True).to(logits.device)
[docs] def project(self,
logits: torch.Tensor,
**kwargs, # TODO
) -> torch.Tensor:
r"""When this verbalizer hasn't perform optimize(), it has no
``label_words_ids``, thus will give random predictions, and should
have no connection to the model to give (miss-leading) grads.
Args:
logits (:obj:`torch.Tensor`): The original logits over the vocabulary.
Returns:
:obj:`torch.Tensor`: The projected logits of label words.
"""
label_words_logits = logits[:, self.label_words_ids]
return label_words_logits
def optimize(self):
pass
[docs] def optimize_to_initialize(self):
r"""This is an epoch-level optimize. If used in batch-level like an ordinary
gradient descend optimizer, the result may not be very satisfying since the accumated
examples (i.e., the probs_buffer and the labels_buffer) are not enough if the batchsize
is small.
"""
if self.search_id < self.num_searches:
self.label_words_ids = self._find_verbalizer(words_per_label=self.label_word_num_per_class,
num_candidates=self.num_candidates,
score_fct=self.score_fct,
balance=self.balance)
self.probs_buffer, self.labels_buffer = None, None
self.search_id += 1
if self.search_id == self.num_searches: # finish optimization
self.accumulate = False
else:
logger.info("Verbalizer's max num_searches reached, use the previous label words.")
self._show_verbalizer()
def _show_verbalizer(self):
tokens = [self.tokenizer.convert_ids_to_tokens(i) for i in self.label_words_ids]
logger.info("Verbalizer is {}".format(tokens))
def _find_verbalizer(self, words_per_label: int = 1, num_candidates: int = 1000, balance: bool = True,
score_fct: str = 'llr'):
# if score_fct == 'random':
# return {label: random.sample(self.word2idx.keys(), words_per_label) for label in self.labels}
logger.info("Finding verbalizer ...")
probs = self.probs_buffer
labels = self.labels_buffer
candidates = self._get_candidates(num_candidates=num_candidates, probs=probs, labels=labels)
label_words = self._get_top_words(probs=probs, candidates=candidates, balance=balance, words_per_label=words_per_label,
score_fct=score_fct)
return label_words
def _get_candidates(self,
num_candidates: int,
probs: torch.Tensor,
labels: torch.Tensor,
) -> Dict[str, List[str]]:
if num_candidates <= 0:
return [torch.arange(self.vocab_size) for label_id in range(self.num_classes)]
log_probs = torch.log(probs+1e-15)
candidate_ids = []
for label_id in range(self.num_classes):
label_mask = (labels==label_id).to(torch.float).unsqueeze(-1)
score = torch.sum(log_probs * label_mask, dim=0)
candidate_id = torch.argsort(score, descending=True)[:num_candidates]
candidate_ids.append(candidate_id)
return candidate_ids
def _get_top_words(self,
probs: torch.Tensor,
candidates: List[torch.Tensor],
balance: bool = True,
words_per_label: int = 10,
score_fct: Optional[str] = 'llr'):
label_words_ids = []
for label_id in range(self.num_classes):
label_mask = (self.labels_buffer==label_id).to(torch.float)
probs_per_label = probs[:, candidates[label_id]]
if score_fct == 'llr':
s = self._log_likelihood_ratio(probs_per_label, label_mask, balance)
elif score_fct == 'ce':
s = self._cross_entropy(probs_per_label, label_mask, balance)
else:
raise ValueError(f"Score function '{score_fct}' not implemented")
sorted_ids = torch.argsort(s, descending=True)[:words_per_label]
selected_ids = candidates[label_id][sorted_ids]
label_words_ids.append(selected_ids)
label_words_ids = torch.vstack(label_words_ids)
return label_words_ids
def _log_likelihood_ratio(self, probs, label_mask, balance):
if balance:
scale_factor = torch.sum(label_mask) / torch.sum(1 - label_mask) \
* (1-label_mask).unsqueeze(-1)
else:
scale_factor = (1-label_mask).unsqueeze(-1)
label_mask = label_mask.unsqueeze(-1)
pos_score = torch.sum(torch.log(probs+1e-15) * label_mask, dim=0) - torch.sum(torch.log(1 - probs + 1e-15) * label_mask, dim=0)
neg_score = torch.sum(torch.log(1 - probs +1e-15) * scale_factor, dim=0) - torch.sum(torch.log(probs+1e-15) * scale_factor, dim=0)
return pos_score + neg_score
def _cross_entropy(self, probs, label_mask, balance):
if balance:
scale_factor = torch.sum(label_mask) / torch.sum(1 - label_mask) \
* (1-label_mask).unsqueeze(-1)
else:
scale_factor = (1-label_mask).unsqueeze(-1)
label_mask = label_mask.unsqueeze(-1)
pos_score = torch.sum(torch.log(probs+1e-15) * label_mask, dim=0)
neg_score = torch.sum(torch.log(1 - probs +1e-15) * scale_factor, dim=0)
return pos_score + neg_score
[docs] def from_file(self,
path: str,
choice: Optional[int] = 0 ):
raise NotImplementedError("This verbalizer is learned and can't be set from file.")