from transformers.data.processors.utils import InputExample, InputFeatures
from openprompt import PromptDataLoader, PromptForClassification
from openprompt.pipeline_base import PromptModel
from openprompt.prompts import ManualVerbalizer, ManualTemplate
from typing import List, Optional, Dict, Union
from . import Verbalizer, PromptDataLoader
import copy
import warnings
from .trainer import ClassificationRunner
from yacs.config import CfgNode
from openprompt.utils.logging import logger
from openprompt.utils.cuda import model_to_device
from openprompt.prompts import load_template_generator, load_verbalizer_generator
from openprompt.plms import load_plm_from_config
def build_dataloader(dataset, template, tokenizer,tokenizer_wrapper_class, config, split):
dataloader = PromptDataLoader(
dataset = dataset,
template = template,
tokenizer = tokenizer,
tokenizer_wrapper_class=tokenizer_wrapper_class,
batch_size = config[split].batch_size,
shuffle = config[split].shuffle_data,
teacher_forcing = config[split].teacher_forcing if hasattr(config[split],'teacher_forcing') else None,
predict_eos_token = True if config.task == "generation" else False,
**config.dataloader
)
return dataloader
[docs]class LMBFFClassificationRunner:
r"""
This runner implements the LM-BFF training process in paper `Making Pre-trained Language Models Better Few-shot Learners(Gao et al. 2020) <https://arxiv.org/pdf/2012.15723.pdf>`_.
Args:
train_dataset (:obj:`List[InputExample]`): The dataset for training
valid_dataset (:obj:`List[InputExample]`): The dataset for validation
test_dataset (:obj:`List[InputExample]`): The dataset for test
verbalizer (:obj:`Optional[Verbalizer]`): The manually designed verbalizer for template generation. Defaults to None.
template (:obj:`Optional[Verbalizer]`): The manually designed template for verbalizer generation. Defaults to None.
config (:obj:`CfgNode`): A configuration object
"""
def __init__(self,
train_dataset: List[InputExample],
valid_dataset: List[InputExample],
test_dataset: List[InputExample],
verbalizer: Optional[Verbalizer] = None,
template: Optional[str] = None,
config: CfgNode = None):
self.train_dataset = train_dataset
self.valid_dataset = valid_dataset
self.test_dataset = test_dataset
self.model, self.tokenizer, self.model_config, self.tokenizer_wrapper = load_plm_from_config(config)
self.auto_t = config.classification.auto_t
self.auto_v = config.classification.auto_v
self.verbalizer = verbalizer
self.template = template
self.config = config
self._check_param()
def _check_param(self):
if self.auto_t:
if self.verbalizer is None:
raise ValueError("no verbalizer for template generation provided!")
if self.template is not None:
warnings.warn("auto_t is set True, ignore the given template")
elif self.auto_v:
if self.template is None:
raise ValueError("no template for verbalizer generation provided, or set auto_t=True to automatically generate one")
if self.verbalizer is not None:
warnings.warn("auto_v is set True, ignore the given verbalizer")
else:
warnings.warn("auto_t and auto_v are both False, the trainer will degenerate to a simple classification trainer")
def _auto_t(self):
logger.info("performing auto-t...")
template_generate_model, template_generate_tokenizer, template_generate_model_config, template_tokenizer_wrapper = load_plm_from_config(self.config.template_generator)
model = model_to_device(template_generate_model, self.config.environment)
template_generator = load_template_generator(config=self.config, model = model, tokenizer=template_generate_tokenizer, tokenizer_wrapper = template_tokenizer_wrapper, verbalizer = self.verbalizer)
template_texts = template_generator.generate(self.train_dataset) # List[str]
template_generator.release_memory()
del template_generator, model
return template_texts
def _auto_v(self, template):
logger.info("performing auto-v...")
model = copy.deepcopy(self.model)
model = model_to_device(model, self.config.environment)
verbalizer_generator = load_verbalizer_generator(config=self.config, model=model, tokenizer=self.tokenizer)
dataloader = PromptDataLoader(self.train_dataset, template, self.tokenizer, self.tokenizer_wrapper, batch_size=self.config.test.batch_size)
for data in dataloader:
data = template.process_batch(data)
if self.config.environment.num_gpus > 0:
data = data.to("cuda:{}".format(self.config.environment.local_rank))
verbalizer_generator.register_buffer(data)
label_words_list = verbalizer_generator.generate() # List[List[str]]
verbalizer_generator.release_memory()
del verbalizer_generator, model
return label_words_list
def _get_best_template_text(self, template_texts_candidates, verbalizer):
best_metrics = 0.0
best_template_text = None
for template_text in template_texts_candidates:
template = ManualTemplate(self.tokenizer, template_text)
train_dataloader = build_dataloader(self.train_dataset, template, self.tokenizer, self.tokenizer_wrapper, self.config, 'train')
valid_dataloader = build_dataloader(self.valid_dataset, template, self.tokenizer, self.tokenizer_wrapper, self.config, 'dev')
score = self._train_eval(template, verbalizer, train_dataloader, valid_dataloader)
if score > best_metrics:
best_metrics = score
best_template_text = template_text
logger.info('best template:' + str(best_template_text))
return best_template_text
def _get_best_label_words(self, verbalizer_labelwords_candidates, template, verbalizer):
current_verbalizer = copy.deepcopy(verbalizer)
best_metrics = 0.0
best_label_words = None
for label_words in verbalizer_labelwords_candidates:
current_verbalizer.label_words = label_words
train_dataloader = build_dataloader(self.train_dataset, template, self.tokenizer, self.tokenizer_wrapper, self.config, 'train')
valid_dataloader = build_dataloader(self.valid_dataset, template, self.tokenizer, self.tokenizer_wrapper, self.config, 'dev')
score = self._train_eval(template, current_verbalizer, train_dataloader, valid_dataloader)
if score > best_metrics:
best_metrics = score
best_label_words = label_words
logger.info('best label words:' + str(best_label_words))
return best_label_words
def _train_eval(self, template, verbalizer, train_dataloader, valid_dataloader):
model = PromptForClassification(copy.deepcopy(self.model), template, verbalizer)
runner = ClassificationRunner(model, config=self.config, train_dataloader=train_dataloader, valid_dataloader=valid_dataloader)
runner.clean = True
best_score = runner.fit()
return best_score
[docs] def run(self):
r"""
Run LM-BFF. if both `auto_v` and `auto_v` are set to True in ``config``, automatic template generation will be performed first.
"""
best_template = self.template
best_verbalizer = self.verbalizer
if self.auto_t:
template_texts = self._auto_t()
best_template_text = self._get_best_template_text(template_texts, best_verbalizer)
best_template = ManualTemplate(self.tokenizer, best_template_text)
if self.auto_v:
label_words_list = self._auto_v(best_template)
best_label_words = self._get_best_label_words(label_words_list, best_template, best_verbalizer)
best_verbalizer.label_words = best_label_words
train_dataloader = build_dataloader(self.train_dataset, best_template, self.tokenizer, self.tokenizer_wrapper, self.config, 'train')
valid_dataloader = build_dataloader(self.valid_dataset, best_template, self.tokenizer, self.tokenizer_wrapper, self.config, 'dev')
test_dataloader = build_dataloader(self.test_dataset, best_template, self.tokenizer, self.tokenizer_wrapper, self.config, 'test')
model = PromptForClassification(copy.deepcopy(self.model), best_template, best_verbalizer)
runner = ClassificationRunner(model, config=self.config, train_dataloader=train_dataloader, valid_dataloader=valid_dataloader, test_dataloader=test_dataloader)
runner.clean = False
return runner.run()