Trainer¶
Users could freely customize the training process, and can also use our Runner
API to conduct classification or generation,
while in this case you need to construct a config.yml
file, see Play with Configuration.
Classification Runner¶
- class ClassificationRunner(model: openprompt.pipeline_base.PromptForClassification, config: Optional[yacs.config.CfgNode] = None, train_dataloader: Optional[openprompt.pipeline_base.PromptDataLoader] = None, valid_dataloader: Optional[openprompt.pipeline_base.PromptDataLoader] = None, test_dataloader: Optional[openprompt.pipeline_base.PromptDataLoader] = None, loss_function: Optional[Callable] = None, id2label: Optional[Dict] = None)[source]¶
A runner for simple training without training tricks. Applying training tricks such as ensemble of template or verbalizer, or self-training can use other runner class. This class is specially implemented for classification. For generation task, though it can be integrated in this class via task option, we keep it as another class for simplicity.
- Parameters
model (
PromptForClassification
) – OnePromptForClassification
object.train_dataloader (
PromptDataloader
, optional) – The dataloader to bachify and process the training data.valid_dataloader (
PromptDataloader
, optionla) – The dataloader to bachify and process the val data.test_dataloader (
PromptDataloader
, optional) – The dataloader to bachify and process the test data.config (
CfgNode
) – A configuration object.loss_function (
Callable
, optional) – The loss function in the training process.
Generation Runner¶
- class GenerationRunner(model: openprompt.pipeline_base.PromptForGeneration, config: Optional[yacs.config.CfgNode] = None, train_dataloader: Optional[openprompt.pipeline_base.PromptDataLoader] = None, valid_dataloader: Optional[openprompt.pipeline_base.PromptDataLoader] = None, test_dataloader: Optional[openprompt.pipeline_base.PromptDataLoader] = None)[source]¶
A runner for simple training without training tricks. Applying training tricks such as ensemble of template or verbalizer, or self-training can use other runner class. This class is specially implemented for generation.
- Parameters
model (
PromptForGeneration
) – OnePromptForGeneration
object.train_dataloader (
PromptDataloader
, optional) – The dataloader to bachify and process the training data.valid_dataloader (
PromptDataloader
, optionla) – The dataloader to bachify and process the val data.test_dataloader (
PromptDataloader
, optional) – The dataloader to bachify and process the test data.config (
CfgNode
) – A configuration object.
LM-BFF Classification Runner¶
- class LMBFFClassificationRunner(train_dataset: List[transformers.data.processors.utils.InputExample], valid_dataset: List[transformers.data.processors.utils.InputExample], test_dataset: List[transformers.data.processors.utils.InputExample], verbalizer: Optional[openprompt.prompt_base.Verbalizer] = None, template: Optional[str] = None, config: Optional[yacs.config.CfgNode] = None)[source]¶
This runner implements the LM-BFF training process in paper Making Pre-trained Language Models Better Few-shot Learners(Gao et al. 2020).
- Parameters
train_dataset (
List[InputExample]
) – The dataset for trainingvalid_dataset (
List[InputExample]
) – The dataset for validationtest_dataset (
List[InputExample]
) – The dataset for testverbalizer (
Optional[Verbalizer]
) – The manually designed verbalizer for template generation. Defaults to None.template (
Optional[Verbalizer]
) – The manually designed template for verbalizer generation. Defaults to None.config (
CfgNode
) – A configuration object