"""
This file contains the logic for loading data for all RelationClassification tasks.
"""
import os
import json, csv
from abc import ABC, abstractmethod
from collections import defaultdict, Counter
from typing import *
from openprompt.utils.logging import logger
from openprompt.data_utils.utils import InputExample
from openprompt.data_utils.data_processor import DataProcessor
[docs]class TACREDProcessor(DataProcessor):
"""
`TAC Relation Extraction Dataset (TACRED) <https://nlp.stanford.edu/projects/tacred/>`_ is one of the largest and most widely used datasets for relation classification.
It was released together with the paper `Position-aware Attention and Supervised Data Improve Slot Filling (Zhang et al. 2017) <https://nlp.stanford.edu/pubs/zhang2017tacred.pdf>`_
This processor is also inherited by :py:class:`TACREVProcessor` and :py:class:`ReTACREDProcessor`.
Examples:
.. code-block:: python
from openprompt.data_utils.relation_classification_dataset import PROCESSORS
base_path = "datasets/RelationClassification"
dataset_name = "TACRED"
dataset_path = os.path.join(base_path, dataset_name)
processor = PROCESSORS[dataset_name.lower()]()
train_dataset = processor.get_train_examples(dataset_path)
dev_dataset = processor.get_dev_examples(dataset_path)
test_dataset = processor.get_test_examples(dataset_path)
assert processor.get_num_labels() == 42
assert processor.get_labels() == ["no_relation", "org:founded", "org:subsidiaries", "per:date_of_birth", "per:cause_of_death", "per:age", "per:stateorprovince_of_birth", "per:countries_of_residence", "per:country_of_birth", "per:stateorprovinces_of_residence", "org:website", "per:cities_of_residence", "per:parents", "per:employee_of", "per:city_of_birth", "org:parents", "org:political/religious_affiliation", "per:schools_attended", "per:country_of_death", "per:children", "org:top_members/employees", "per:date_of_death", "org:members", "org:alternate_names", "per:religion", "org:member_of", "org:city_of_headquarters", "per:origin", "org:shareholders", "per:charges", "per:title", "org:number_of_employees/members", "org:dissolved", "org:country_of_headquarters", "per:alternate_names", "per:siblings", "org:stateorprovince_of_headquarters", "per:spouse", "per:other_family", "per:city_of_death", "per:stateorprovince_of_death", "org:founded_by"]
assert len(train_dataset) == 68124
assert len(dev_dataset) == 22631
assert len(test_dataset) == 15509
assert train_dataset[0].text_a == 'Tom Thabane resigned in October last year to form the All Basotho Convention -LRB- ABC -RRB- , crossing the floor with 17 members of parliament , causing constitutional monarch King Letsie III to dissolve parliament and call the snap election .'
assert train_dataset[0].meta["head"] == "All Basotho Convention"
assert train_dataset[0].meta["tail"] == "Tom Thabane"
assert train_dataset[0].label == 41
"""
def __init__(self):
super().__init__(labels = [
"no_relation", "org:founded", "org:subsidiaries", "per:date_of_birth", "per:cause_of_death", "per:age", "per:stateorprovince_of_birth", "per:countries_of_residence", "per:country_of_birth", "per:stateorprovinces_of_residence", "org:website", "per:cities_of_residence", "per:parents", "per:employee_of", "per:city_of_birth", "org:parents", "org:political/religious_affiliation", "per:schools_attended", "per:country_of_death", "per:children", "org:top_members/employees", "per:date_of_death", "org:members", "org:alternate_names", "per:religion", "org:member_of", "org:city_of_headquarters", "per:origin", "org:shareholders", "per:charges", "per:title", "org:number_of_employees/members", "org:dissolved", "org:country_of_headquarters", "per:alternate_names", "per:siblings", "org:stateorprovince_of_headquarters", "per:spouse", "per:other_family", "per:city_of_death", "per:stateorprovince_of_death", "org:founded_by"
])
def get_examples(self, data_dir, split):
path = os.path.join(data_dir, "{}.json".format(split))
examples = []
with open(path, encoding='utf8') as f:
example_jsons = json.load(f)
for example_json in example_jsons:
guid = example_json["id"]
label = self.get_label_id(example_json["relation"])
tokens = example_json["token"]
text_a = " ".join(tokens)
meta = {
"head": " ".join(tokens[example_json["subj_start"]: example_json["subj_end"]+1]),
"tail": " ".join(tokens[example_json["obj_start"]: example_json["obj_end"]+1]),
}
example = InputExample(guid=guid, text_a=text_a, meta=meta, label=label)
examples.append(example)
return examples
[docs]class TACREVProcessor(TACREDProcessor):
"""
`TACRED Revisted (TACREV) <https://github.com/DFKI-NLP/tacrev>`_ is a variant of the TACRED dataset
It was proposed by the paper `TACRED Revisited: A Thorough Evaluation of the TACRED Relation Extraction Task (Alt et al. 2020) <https://aclanthology.org/2020.acl-main.142.pdf>`_
This processor inherit :py:class:`TACREDProcessor` and can be used similarly
Examples:
.. code-block:: python
from openprompt.data_utils.relation_classification_dataset import PROCESSORS
base_path = "datasets/RelationClassification"
dataset_name = "TACREV"
dataset_path = os.path.join(base_path, dataset_name)
processor = PROCESSORS[dataset_name.lower()]()
train_dataset = processor.get_train_examples(dataset_path)
dev_dataset = processor.get_dev_examples(dataset_path)
test_dataset = processor.get_test_examples(dataset_path)
assert processor.get_num_labels() == 42
assert processor.get_labels() == ["no_relation", "org:founded", "org:subsidiaries", "per:date_of_birth", "per:cause_of_death", "per:age", "per:stateorprovince_of_birth", "per:countries_of_residence", "per:country_of_birth", "per:stateorprovinces_of_residence", "org:website", "per:cities_of_residence", "per:parents", "per:employee_of", "per:city_of_birth", "org:parents", "org:political/religious_affiliation", "per:schools_attended", "per:country_of_death", "per:children", "org:top_members/employees", "per:date_of_death", "org:members", "org:alternate_names", "per:religion", "org:member_of", "org:city_of_headquarters", "per:origin", "org:shareholders", "per:charges", "per:title", "org:number_of_employees/members", "org:dissolved", "org:country_of_headquarters", "per:alternate_names", "per:siblings", "org:stateorprovince_of_headquarters", "per:spouse", "per:other_family", "per:city_of_death", "per:stateorprovince_of_death", "org:founded_by"]
assert len(train_dataset) == 68124
assert len(dev_dataset) == 22631
assert len(test_dataset) == 15509
"""
def __init__(self):
super().__init__()
self.labels = [
"no_relation", "org:founded", "org:subsidiaries", "per:date_of_birth", "per:cause_of_death", "per:age", "per:stateorprovince_of_birth", "per:countries_of_residence", "per:country_of_birth", "per:stateorprovinces_of_residence", "org:website", "per:cities_of_residence", "per:parents", "per:employee_of", "per:city_of_birth", "org:parents", "org:political/religious_affiliation", "per:schools_attended", "per:country_of_death", "per:children", "org:top_members/employees", "per:date_of_death", "org:members", "org:alternate_names", "per:religion", "org:member_of", "org:city_of_headquarters", "per:origin", "org:shareholders", "per:charges", "per:title", "org:number_of_employees/members", "org:dissolved", "org:country_of_headquarters", "per:alternate_names", "per:siblings", "org:stateorprovince_of_headquarters", "per:spouse", "per:other_family", "per:city_of_death", "per:stateorprovince_of_death", "org:founded_by"
]
[docs]class ReTACREDProcessor(TACREDProcessor):
"""
`Re-TACRED <https://github.com/gstoica27/Re-TACRED>`_ is a variant of the TACRED dataset
It was proposed by the paper `Re-TACRED: Addressing Shortcomings of the TACRED Dataset (Stoica et al. 2021) <https://arxiv.org/pdf/2104.08398.pdf>`_
This processor inherit :py:class:`TACREDProcessor` and can be used similarly
Examples:
.. code-block:: python
from openprompt.data_utils.relation_classification_dataset import PROCESSORS
base_path = "datasets/RelationClassification"
dataset_name = "ReTACRED"
dataset_path = os.path.join(base_path, dataset_name)
processor = PROCESSORS[dataset_name.lower()]()
train_dataset = processor.get_train_examples(dataset_path)
dev_dataset = processor.get_dev_examples(dataset_path)
test_dataset = processor.get_test_examples(dataset_path)
assert processor.get_num_labels() == 40
assert processor.get_labels() == ["no_relation", "org:members", "per:siblings", "per:spouse", "org:country_of_branch", "per:country_of_death", "per:parents", "per:stateorprovinces_of_residence", "org:top_members/employees", "org:dissolved", "org:number_of_employees/members", "per:stateorprovince_of_death", "per:origin", "per:children", "org:political/religious_affiliation", "per:city_of_birth", "per:title", "org:shareholders", "per:employee_of", "org:member_of", "org:founded_by", "per:countries_of_residence", "per:other_family", "per:religion", "per:identity", "per:date_of_birth", "org:city_of_branch", "org:alternate_names", "org:website", "per:cause_of_death", "org:stateorprovince_of_branch", "per:schools_attended", "per:country_of_birth", "per:date_of_death", "per:city_of_death", "org:founded", "per:cities_of_residence", "per:age", "per:charges", "per:stateorprovince_of_birth"]
assert len(train_dataset) == 58465
assert len(dev_dataset) == 19584
assert len(test_dataset) == 13418
"""
def __init__(self):
super().__init__()
self.labels = [
"no_relation", "org:members", "per:siblings", "per:spouse", "org:country_of_branch", "per:country_of_death", "per:parents", "per:stateorprovinces_of_residence", "org:top_members/employees", "org:dissolved", "org:number_of_employees/members", "per:stateorprovince_of_death", "per:origin", "per:children", "org:political/religious_affiliation", "per:city_of_birth", "per:title", "org:shareholders", "per:employee_of", "org:member_of", "org:founded_by", "per:countries_of_residence", "per:other_family", "per:religion", "per:identity", "per:date_of_birth", "org:city_of_branch", "org:alternate_names", "org:website", "per:cause_of_death", "org:stateorprovince_of_branch", "per:schools_attended", "per:country_of_birth", "per:date_of_death", "per:city_of_death", "org:founded", "per:cities_of_residence", "per:age", "per:charges", "per:stateorprovince_of_birth"
]
[docs]class SemEvalProcessor(DataProcessor):
"""
`SemEval-2010 Task 8 <https://aclanthology.org/S10-1006.pdf>`_ is a a traditional dataset in relation classification.
It was released together with the paper `SemEval-2010 Task 8: Multi-Way Classification of Semantic Relations Between Pairs of Nominals (Hendrickx et al. 2010) <https://aclanthology.org/S10-1006.pdf>`_
Examples:
.. code-block:: python
from openprompt.data_utils.relation_classification_dataset import PROCESSORS
base_path = "datasets/RelationClassification"
dataset_name = "SemEval"
dataset_path = os.path.join(base_path, dataset_name)
processor = PROCESSORS[dataset_name.lower()]()
train_dataset = processor.get_train_examples(dataset_path)
dev_dataset = processor.get_dev_examples(dataset_path)
test_dataset = processor.get_test_examples(dataset_path)
assert processor.get_num_labels() == 19
assert processor.get_labels() == ["Other", "Member-Collection(e1,e2)", "Entity-Destination(e1,e2)", "Content-Container(e1,e2)", "Message-Topic(e1,e2)", "Entity-Origin(e1,e2)", "Cause-Effect(e1,e2)", "Product-Producer(e1,e2)", "Instrument-Agency(e1,e2)", "Component-Whole(e1,e2)", "Member-Collection(e2,e1)", "Entity-Destination(e2,e1)", "Content-Container(e2,e1)", "Message-Topic(e2,e1)", "Entity-Origin(e2,e1)", "Cause-Effect(e2,e1)", "Product-Producer(e2,e1)", "Instrument-Agency(e2,e1)", "Component-Whole(e2,e1)"]
assert len(train_dataset) == 6507
assert len(dev_dataset) == 1493
assert len(test_dataset) == 2717
assert dev_dataset[0].text_a == 'the system as described above has its greatest application in an arrayed configuration of antenna elements .'
assert dev_dataset[0].meta["head"] == "configuration"
assert dev_dataset[0].meta["tail"] == "elements"
assert dev_dataset[0].label == 18
"""
def __init__(self):
super().__init__()
self.labels = ["Other", "Member-Collection(e1,e2)", "Entity-Destination(e1,e2)", "Content-Container(e1,e2)", "Message-Topic(e1,e2)", "Entity-Origin(e1,e2)", "Cause-Effect(e1,e2)", "Product-Producer(e1,e2)", "Instrument-Agency(e1,e2)", "Component-Whole(e1,e2)", "Member-Collection(e2,e1)", "Entity-Destination(e2,e1)", "Content-Container(e2,e1)", "Message-Topic(e2,e1)", "Entity-Origin(e2,e1)", "Cause-Effect(e2,e1)", "Product-Producer(e2,e1)", "Instrument-Agency(e2,e1)", "Component-Whole(e2,e1)"]
def get_examples(self, data_dir, split):
path = os.path.join(data_dir, "{}.jsonl".format(split))
examples = []
with open(path, encoding='utf8') as f:
for choicex, line in enumerate(f):
example_json = json.loads(line)
label = self.get_label_id(example_json["relation"])
text_a = " ".join(example_json["token"])
meta = {
"head": example_json["h"]["name"],
"tail": example_json["t"]["name"],
}
example = InputExample(guid=str(choicex), text_a=text_a, meta=meta, label=label)
examples.append(example)
return examples
PROCESSORS = {
"tacred": TACREDProcessor,
"tacrev": TACREVProcessor,
"retacred": ReTACREDProcessor,
"semeval": SemEvalProcessor,
}