Utils Functions

Contents

Calibration

calibrate(prompt_model: openprompt.pipeline_base.PromptForClassification, dataloader: openprompt.pipeline_base.PromptDataLoader) torch.Tensor[source]

Calibrate. See Paper

Parameters
  • prompt_model (PromptForClassification) – the PromptForClassification model.

  • dataloader (List) – the dataloader to conduct the calibrate, could be a virtual one, i.e. contain an only-template example.

Returns

(torch.Tensor) A tensor of shape (vocabsize) or (mask_num, vocabsize), the logits calculated for each word in the vocabulary