from typing import *
from abc import abstractmethod
from openprompt.data_utils.utils import InputExample
[docs]class DataProcessor:
"""
labels of the dataset is optional
here's the examples of loading the labels:
:obj:`I`: ``DataProcessor(labels = ['positive', 'negative'])``
:obj:`II`: ``DataProcessor(labels_path = 'datasets/labels.txt')``
labels file should have label names separated by any blank characters, such as
.. code-block::
positive neutral
negative
Args:
labels (:obj:`Sequence[Any]`, optional): class labels of the dataset. Defaults to None.
labels_path (:obj:`str`, optional): Defaults to None. If set and :obj:`labels` is None, load labels from :obj:`labels_path`.
"""
def __init__(self,
labels: Optional[Sequence[Any]] = None,
labels_path: Optional[str] = None
):
if labels is not None:
self.labels = labels
elif labels_path is not None:
with open(labels_path, "r") as f:
self.labels = ' '.join(f.readlines()).split()
@property
def labels(self) -> List[Any]:
if not hasattr(self, "_labels"):
raise ValueError("DataProcessor doesn't set labels or label_mapping yet")
return self._labels
@labels.setter
def labels(self, labels: Sequence[Any]):
if labels is not None:
self._labels = labels
self._label_mapping = {k: i for (i, k) in enumerate(labels)}
@property
def label_mapping(self) -> Dict[Any, int]:
if not hasattr(self, "_labels"):
raise ValueError("DataProcessor doesn't set labels or label_mapping yet")
return self._label_mapping
@label_mapping.setter
def label_mapping(self, label_mapping: Mapping[Any, int]):
self._labels = [item[0] for item in sorted(label_mapping.items(), key=lambda item: item[1])]
self._label_mapping = label_mapping
@property
def id2label(self) -> Dict[int, Any]:
if not hasattr(self, "_labels"):
raise ValueError("DataProcessor doesn't set labels or label_mapping yet")
return {i: k for (i, k) in enumerate(self._labels)}
[docs] def get_label_id(self, label: Any) -> int:
"""get label id of the corresponding label
Args:
label: label in dataset
Returns:
int: the index of label
"""
return self.label_mapping[label] if label is not None else None
[docs] def get_labels(self) -> List[Any]:
"""get labels of the dataset
Returns:
List[Any]: labels of the dataset
"""
return self.labels
[docs] def get_num_labels(self):
"""get the number of labels in the dataset
Returns:
int: number of labels in the dataset
"""
return len(self.labels)
[docs] def get_train_examples(self, data_dir: Optional[str] = None) -> InputExample:
"""
get train examples from the training file under :obj:`data_dir`
call ``get_examples(data_dir, "train")``, see :py:meth:`~openprompt.data_utils.data_processor.DataProcessor.get_examples`
"""
return self.get_examples(data_dir, "train")
[docs] def get_dev_examples(self, data_dir: Optional[str] = None) -> List[InputExample]:
"""
get dev examples from the development file under :obj:`data_dir`
call ``get_examples(data_dir, "dev")``, see :py:meth:`~openprompt.data_utils.data_processor.DataProcessor.get_examples`
"""
return self.get_examples(data_dir, "dev")
[docs] def get_test_examples(self, data_dir: Optional[str] = None) -> List[InputExample]:
"""
get test examples from the test file under :obj:`data_dir`
call ``get_examples(data_dir, "test")``, see :py:meth:`~openprompt.data_utils.data_processor.DataProcessor.get_examples`
"""
return self.get_examples(data_dir, "test")
[docs] def get_unlabeled_examples(self, data_dir: Optional[str] = None) -> List[InputExample]:
"""
get unlabeled examples from the unlabeled file under :obj:`data_dir`
call ``get_examples(data_dir, "unlabeled")``, see :py:meth:`~openprompt.data_utils.data_processor.DataProcessor.get_examples`
"""
return self.get_examples(data_dir, "unlabeled")
[docs] @abstractmethod
def get_examples(self, data_dir: Optional[str] = None, split: Optional[str] = None) -> List[InputExample]:
"""get the :obj:`split` of dataset under :obj:`data_dir`
:obj:`data_dir` is the base path of the dataset, for example:
training file could be located in ``data_dir/train.txt``
Args:
data_dir (str): the base path of the dataset
split (str): ``train`` / ``dev`` / ``test`` / ``unlabeled``
Returns:
List[InputExample]: return a list of :py:class:`~openprompt.data_utils.data_utils.InputExample`
"""
raise NotImplementedError