Source code for openprompt.lm_bff_trainer

from transformers.data.processors.utils import InputExample, InputFeatures


from openprompt import PromptDataLoader, PromptForClassification
from openprompt.pipeline_base import PromptModel

from openprompt.prompts import ManualVerbalizer, ManualTemplate
from typing import List, Optional, Dict, Union
from . import Verbalizer, PromptDataLoader
import copy
import warnings
from .trainer import ClassificationRunner
from yacs.config import CfgNode
from openprompt.utils.logging import logger
from openprompt.utils.cuda import model_to_device
from openprompt.prompts import load_template_generator, load_verbalizer_generator
from openprompt.plms import load_plm_from_config




def build_dataloader(dataset, template, tokenizer,tokenizer_wrapper_class, config, split):
    dataloader = PromptDataLoader(
        dataset = dataset,
        template = template,
        tokenizer = tokenizer,
        tokenizer_wrapper_class=tokenizer_wrapper_class,
        batch_size = config[split].batch_size,
        shuffle = config[split].shuffle_data,
        teacher_forcing = config[split].teacher_forcing if hasattr(config[split],'teacher_forcing') else None,
        predict_eos_token = True if config.task == "generation" else False,
        **config.dataloader
    )
    return dataloader


[docs]class LMBFFClassificationRunner: r""" This runner implements the LM-BFF training process in paper `Making Pre-trained Language Models Better Few-shot Learners(Gao et al. 2020) <https://arxiv.org/pdf/2012.15723.pdf>`_. Args: train_dataset (:obj:`List[InputExample]`): The dataset for training valid_dataset (:obj:`List[InputExample]`): The dataset for validation test_dataset (:obj:`List[InputExample]`): The dataset for test verbalizer (:obj:`Optional[Verbalizer]`): The manually designed verbalizer for template generation. Defaults to None. template (:obj:`Optional[Verbalizer]`): The manually designed template for verbalizer generation. Defaults to None. config (:obj:`CfgNode`): A configuration object """ def __init__(self, train_dataset: List[InputExample], valid_dataset: List[InputExample], test_dataset: List[InputExample], verbalizer: Optional[Verbalizer] = None, template: Optional[str] = None, config: CfgNode = None): self.train_dataset = train_dataset self.valid_dataset = valid_dataset self.test_dataset = test_dataset self.model, self.tokenizer, self.model_config, self.tokenizer_wrapper = load_plm_from_config(config) self.auto_t = config.classification.auto_t self.auto_v = config.classification.auto_v self.verbalizer = verbalizer self.template = template self.config = config self._check_param() def _check_param(self): if self.auto_t: if self.verbalizer is None: raise ValueError("no verbalizer for template generation provided!") if self.template is not None: warnings.warn("auto_t is set True, ignore the given template") elif self.auto_v: if self.template is None: raise ValueError("no template for verbalizer generation provided, or set auto_t=True to automatically generate one") if self.verbalizer is not None: warnings.warn("auto_v is set True, ignore the given verbalizer") else: warnings.warn("auto_t and auto_v are both False, the trainer will degenerate to a simple classification trainer") def _auto_t(self): logger.info("performing auto-t...") template_generate_model, template_generate_tokenizer, template_generate_model_config, template_tokenizer_wrapper = load_plm_from_config(self.config.template_generator) model = model_to_device(template_generate_model, self.config.environment) template_generator = load_template_generator(config=self.config, model = model, tokenizer=template_generate_tokenizer, tokenizer_wrapper = template_tokenizer_wrapper, verbalizer = self.verbalizer) template_texts = template_generator.generate(self.train_dataset) # List[str] template_generator.release_memory() del template_generator, model return template_texts def _auto_v(self, template): logger.info("performing auto-v...") model = copy.deepcopy(self.model) model = model_to_device(model, self.config.environment) verbalizer_generator = load_verbalizer_generator(config=self.config, model=model, tokenizer=self.tokenizer) dataloader = PromptDataLoader(self.train_dataset, template, self.tokenizer, self.tokenizer_wrapper, batch_size=self.config.test.batch_size) for data in dataloader: data = template.process_batch(data) if self.config.environment.num_gpus > 0: data = data.to("cuda:{}".format(self.config.environment.local_rank)) verbalizer_generator.register_buffer(data) label_words_list = verbalizer_generator.generate() # List[List[str]] verbalizer_generator.release_memory() del verbalizer_generator, model return label_words_list def _get_best_template_text(self, template_texts_candidates, verbalizer): best_metrics = 0.0 best_template_text = None for template_text in template_texts_candidates: template = ManualTemplate(self.tokenizer, template_text) train_dataloader = build_dataloader(self.train_dataset, template, self.tokenizer, self.tokenizer_wrapper, self.config, 'train') valid_dataloader = build_dataloader(self.valid_dataset, template, self.tokenizer, self.tokenizer_wrapper, self.config, 'dev') score = self._train_eval(template, verbalizer, train_dataloader, valid_dataloader) if score > best_metrics: best_metrics = score best_template_text = template_text logger.info('best template:' + str(best_template_text)) return best_template_text def _get_best_label_words(self, verbalizer_labelwords_candidates, template, verbalizer): current_verbalizer = copy.deepcopy(verbalizer) best_metrics = 0.0 best_label_words = None for label_words in verbalizer_labelwords_candidates: current_verbalizer.label_words = label_words train_dataloader = build_dataloader(self.train_dataset, template, self.tokenizer, self.tokenizer_wrapper, self.config, 'train') valid_dataloader = build_dataloader(self.valid_dataset, template, self.tokenizer, self.tokenizer_wrapper, self.config, 'dev') score = self._train_eval(template, current_verbalizer, train_dataloader, valid_dataloader) if score > best_metrics: best_metrics = score best_label_words = label_words logger.info('best label words:' + str(best_label_words)) return best_label_words def _train_eval(self, template, verbalizer, train_dataloader, valid_dataloader): model = PromptForClassification(copy.deepcopy(self.model), template, verbalizer) runner = ClassificationRunner(model, config=self.config, train_dataloader=train_dataloader, valid_dataloader=valid_dataloader) runner.clean = True best_score = runner.fit() return best_score
[docs] def run(self): r""" Run LM-BFF. if both `auto_v` and `auto_v` are set to True in ``config``, automatic template generation will be performed first. """ best_template = self.template best_verbalizer = self.verbalizer if self.auto_t: template_texts = self._auto_t() best_template_text = self._get_best_template_text(template_texts, best_verbalizer) best_template = ManualTemplate(self.tokenizer, best_template_text) if self.auto_v: label_words_list = self._auto_v(best_template) best_label_words = self._get_best_label_words(label_words_list, best_template, best_verbalizer) best_verbalizer.label_words = best_label_words train_dataloader = build_dataloader(self.train_dataset, best_template, self.tokenizer, self.tokenizer_wrapper, self.config, 'train') valid_dataloader = build_dataloader(self.valid_dataset, best_template, self.tokenizer, self.tokenizer_wrapper, self.config, 'dev') test_dataloader = build_dataloader(self.test_dataset, best_template, self.tokenizer, self.tokenizer_wrapper, self.config, 'test') model = PromptForClassification(copy.deepcopy(self.model), best_template, best_verbalizer) runner = ClassificationRunner(model, config=self.config, train_dataloader=train_dataloader, valid_dataloader=valid_dataloader, test_dataloader=test_dataloader) runner.clean = False return runner.run()