Source code for openprompt.data_utils.relation_classification_dataset

"""
This file contains the logic for loading data for all RelationClassification tasks.
"""

import os
import json, csv
from abc import ABC, abstractmethod
from collections import defaultdict, Counter
from typing import *

from openprompt.utils.logging import logger

from openprompt.data_utils.utils import InputExample
from openprompt.data_utils.data_processor import DataProcessor




[docs]class TACREDProcessor(DataProcessor): """ `TAC Relation Extraction Dataset (TACRED) <https://nlp.stanford.edu/projects/tacred/>`_ is one of the largest and most widely used datasets for relation classification. It was released together with the paper `Position-aware Attention and Supervised Data Improve Slot Filling (Zhang et al. 2017) <https://nlp.stanford.edu/pubs/zhang2017tacred.pdf>`_ This processor is also inherited by :py:class:`TACREVProcessor` and :py:class:`ReTACREDProcessor`. Examples: .. code-block:: python from openprompt.data_utils.relation_classification_dataset import PROCESSORS base_path = "datasets/RelationClassification" dataset_name = "TACRED" dataset_path = os.path.join(base_path, dataset_name) processor = PROCESSORS[dataset_name.lower()]() train_dataset = processor.get_train_examples(dataset_path) dev_dataset = processor.get_dev_examples(dataset_path) test_dataset = processor.get_test_examples(dataset_path) assert processor.get_num_labels() == 42 assert processor.get_labels() == ["no_relation", "org:founded", "org:subsidiaries", "per:date_of_birth", "per:cause_of_death", "per:age", "per:stateorprovince_of_birth", "per:countries_of_residence", "per:country_of_birth", "per:stateorprovinces_of_residence", "org:website", "per:cities_of_residence", "per:parents", "per:employee_of", "per:city_of_birth", "org:parents", "org:political/religious_affiliation", "per:schools_attended", "per:country_of_death", "per:children", "org:top_members/employees", "per:date_of_death", "org:members", "org:alternate_names", "per:religion", "org:member_of", "org:city_of_headquarters", "per:origin", "org:shareholders", "per:charges", "per:title", "org:number_of_employees/members", "org:dissolved", "org:country_of_headquarters", "per:alternate_names", "per:siblings", "org:stateorprovince_of_headquarters", "per:spouse", "per:other_family", "per:city_of_death", "per:stateorprovince_of_death", "org:founded_by"] assert len(train_dataset) == 68124 assert len(dev_dataset) == 22631 assert len(test_dataset) == 15509 assert train_dataset[0].text_a == 'Tom Thabane resigned in October last year to form the All Basotho Convention -LRB- ABC -RRB- , crossing the floor with 17 members of parliament , causing constitutional monarch King Letsie III to dissolve parliament and call the snap election .' assert train_dataset[0].meta["head"] == "All Basotho Convention" assert train_dataset[0].meta["tail"] == "Tom Thabane" assert train_dataset[0].label == 41 """ def __init__(self): super().__init__(labels = [ "no_relation", "org:founded", "org:subsidiaries", "per:date_of_birth", "per:cause_of_death", "per:age", "per:stateorprovince_of_birth", "per:countries_of_residence", "per:country_of_birth", "per:stateorprovinces_of_residence", "org:website", "per:cities_of_residence", "per:parents", "per:employee_of", "per:city_of_birth", "org:parents", "org:political/religious_affiliation", "per:schools_attended", "per:country_of_death", "per:children", "org:top_members/employees", "per:date_of_death", "org:members", "org:alternate_names", "per:religion", "org:member_of", "org:city_of_headquarters", "per:origin", "org:shareholders", "per:charges", "per:title", "org:number_of_employees/members", "org:dissolved", "org:country_of_headquarters", "per:alternate_names", "per:siblings", "org:stateorprovince_of_headquarters", "per:spouse", "per:other_family", "per:city_of_death", "per:stateorprovince_of_death", "org:founded_by" ]) def get_examples(self, data_dir, split): path = os.path.join(data_dir, "{}.json".format(split)) examples = [] with open(path, encoding='utf8') as f: example_jsons = json.load(f) for example_json in example_jsons: guid = example_json["id"] label = self.get_label_id(example_json["relation"]) tokens = example_json["token"] text_a = " ".join(tokens) meta = { "head": " ".join(tokens[example_json["subj_start"]: example_json["subj_end"]+1]), "tail": " ".join(tokens[example_json["obj_start"]: example_json["obj_end"]+1]), } example = InputExample(guid=guid, text_a=text_a, meta=meta, label=label) examples.append(example) return examples
[docs]class TACREVProcessor(TACREDProcessor): """ `TACRED Revisted (TACREV) <https://github.com/DFKI-NLP/tacrev>`_ is a variant of the TACRED dataset It was proposed by the paper `TACRED Revisited: A Thorough Evaluation of the TACRED Relation Extraction Task (Alt et al. 2020) <https://aclanthology.org/2020.acl-main.142.pdf>`_ This processor inherit :py:class:`TACREDProcessor` and can be used similarly Examples: .. code-block:: python from openprompt.data_utils.relation_classification_dataset import PROCESSORS base_path = "datasets/RelationClassification" dataset_name = "TACREV" dataset_path = os.path.join(base_path, dataset_name) processor = PROCESSORS[dataset_name.lower()]() train_dataset = processor.get_train_examples(dataset_path) dev_dataset = processor.get_dev_examples(dataset_path) test_dataset = processor.get_test_examples(dataset_path) assert processor.get_num_labels() == 42 assert processor.get_labels() == ["no_relation", "org:founded", "org:subsidiaries", "per:date_of_birth", "per:cause_of_death", "per:age", "per:stateorprovince_of_birth", "per:countries_of_residence", "per:country_of_birth", "per:stateorprovinces_of_residence", "org:website", "per:cities_of_residence", "per:parents", "per:employee_of", "per:city_of_birth", "org:parents", "org:political/religious_affiliation", "per:schools_attended", "per:country_of_death", "per:children", "org:top_members/employees", "per:date_of_death", "org:members", "org:alternate_names", "per:religion", "org:member_of", "org:city_of_headquarters", "per:origin", "org:shareholders", "per:charges", "per:title", "org:number_of_employees/members", "org:dissolved", "org:country_of_headquarters", "per:alternate_names", "per:siblings", "org:stateorprovince_of_headquarters", "per:spouse", "per:other_family", "per:city_of_death", "per:stateorprovince_of_death", "org:founded_by"] assert len(train_dataset) == 68124 assert len(dev_dataset) == 22631 assert len(test_dataset) == 15509 """ def __init__(self): super().__init__() self.labels = [ "no_relation", "org:founded", "org:subsidiaries", "per:date_of_birth", "per:cause_of_death", "per:age", "per:stateorprovince_of_birth", "per:countries_of_residence", "per:country_of_birth", "per:stateorprovinces_of_residence", "org:website", "per:cities_of_residence", "per:parents", "per:employee_of", "per:city_of_birth", "org:parents", "org:political/religious_affiliation", "per:schools_attended", "per:country_of_death", "per:children", "org:top_members/employees", "per:date_of_death", "org:members", "org:alternate_names", "per:religion", "org:member_of", "org:city_of_headquarters", "per:origin", "org:shareholders", "per:charges", "per:title", "org:number_of_employees/members", "org:dissolved", "org:country_of_headquarters", "per:alternate_names", "per:siblings", "org:stateorprovince_of_headquarters", "per:spouse", "per:other_family", "per:city_of_death", "per:stateorprovince_of_death", "org:founded_by" ]
[docs]class ReTACREDProcessor(TACREDProcessor): """ `Re-TACRED <https://github.com/gstoica27/Re-TACRED>`_ is a variant of the TACRED dataset It was proposed by the paper `Re-TACRED: Addressing Shortcomings of the TACRED Dataset (Stoica et al. 2021) <https://arxiv.org/pdf/2104.08398.pdf>`_ This processor inherit :py:class:`TACREDProcessor` and can be used similarly Examples: .. code-block:: python from openprompt.data_utils.relation_classification_dataset import PROCESSORS base_path = "datasets/RelationClassification" dataset_name = "ReTACRED" dataset_path = os.path.join(base_path, dataset_name) processor = PROCESSORS[dataset_name.lower()]() train_dataset = processor.get_train_examples(dataset_path) dev_dataset = processor.get_dev_examples(dataset_path) test_dataset = processor.get_test_examples(dataset_path) assert processor.get_num_labels() == 40 assert processor.get_labels() == ["no_relation", "org:members", "per:siblings", "per:spouse", "org:country_of_branch", "per:country_of_death", "per:parents", "per:stateorprovinces_of_residence", "org:top_members/employees", "org:dissolved", "org:number_of_employees/members", "per:stateorprovince_of_death", "per:origin", "per:children", "org:political/religious_affiliation", "per:city_of_birth", "per:title", "org:shareholders", "per:employee_of", "org:member_of", "org:founded_by", "per:countries_of_residence", "per:other_family", "per:religion", "per:identity", "per:date_of_birth", "org:city_of_branch", "org:alternate_names", "org:website", "per:cause_of_death", "org:stateorprovince_of_branch", "per:schools_attended", "per:country_of_birth", "per:date_of_death", "per:city_of_death", "org:founded", "per:cities_of_residence", "per:age", "per:charges", "per:stateorprovince_of_birth"] assert len(train_dataset) == 58465 assert len(dev_dataset) == 19584 assert len(test_dataset) == 13418 """ def __init__(self): super().__init__() self.labels = [ "no_relation", "org:members", "per:siblings", "per:spouse", "org:country_of_branch", "per:country_of_death", "per:parents", "per:stateorprovinces_of_residence", "org:top_members/employees", "org:dissolved", "org:number_of_employees/members", "per:stateorprovince_of_death", "per:origin", "per:children", "org:political/religious_affiliation", "per:city_of_birth", "per:title", "org:shareholders", "per:employee_of", "org:member_of", "org:founded_by", "per:countries_of_residence", "per:other_family", "per:religion", "per:identity", "per:date_of_birth", "org:city_of_branch", "org:alternate_names", "org:website", "per:cause_of_death", "org:stateorprovince_of_branch", "per:schools_attended", "per:country_of_birth", "per:date_of_death", "per:city_of_death", "org:founded", "per:cities_of_residence", "per:age", "per:charges", "per:stateorprovince_of_birth" ]
[docs]class SemEvalProcessor(DataProcessor): """ `SemEval-2010 Task 8 <https://aclanthology.org/S10-1006.pdf>`_ is a a traditional dataset in relation classification. It was released together with the paper `SemEval-2010 Task 8: Multi-Way Classification of Semantic Relations Between Pairs of Nominals (Hendrickx et al. 2010) <https://aclanthology.org/S10-1006.pdf>`_ Examples: .. code-block:: python from openprompt.data_utils.relation_classification_dataset import PROCESSORS base_path = "datasets/RelationClassification" dataset_name = "SemEval" dataset_path = os.path.join(base_path, dataset_name) processor = PROCESSORS[dataset_name.lower()]() train_dataset = processor.get_train_examples(dataset_path) dev_dataset = processor.get_dev_examples(dataset_path) test_dataset = processor.get_test_examples(dataset_path) assert processor.get_num_labels() == 19 assert processor.get_labels() == ["Other", "Member-Collection(e1,e2)", "Entity-Destination(e1,e2)", "Content-Container(e1,e2)", "Message-Topic(e1,e2)", "Entity-Origin(e1,e2)", "Cause-Effect(e1,e2)", "Product-Producer(e1,e2)", "Instrument-Agency(e1,e2)", "Component-Whole(e1,e2)", "Member-Collection(e2,e1)", "Entity-Destination(e2,e1)", "Content-Container(e2,e1)", "Message-Topic(e2,e1)", "Entity-Origin(e2,e1)", "Cause-Effect(e2,e1)", "Product-Producer(e2,e1)", "Instrument-Agency(e2,e1)", "Component-Whole(e2,e1)"] assert len(train_dataset) == 6507 assert len(dev_dataset) == 1493 assert len(test_dataset) == 2717 assert dev_dataset[0].text_a == 'the system as described above has its greatest application in an arrayed configuration of antenna elements .' assert dev_dataset[0].meta["head"] == "configuration" assert dev_dataset[0].meta["tail"] == "elements" assert dev_dataset[0].label == 18 """ def __init__(self): super().__init__() self.labels = ["Other", "Member-Collection(e1,e2)", "Entity-Destination(e1,e2)", "Content-Container(e1,e2)", "Message-Topic(e1,e2)", "Entity-Origin(e1,e2)", "Cause-Effect(e1,e2)", "Product-Producer(e1,e2)", "Instrument-Agency(e1,e2)", "Component-Whole(e1,e2)", "Member-Collection(e2,e1)", "Entity-Destination(e2,e1)", "Content-Container(e2,e1)", "Message-Topic(e2,e1)", "Entity-Origin(e2,e1)", "Cause-Effect(e2,e1)", "Product-Producer(e2,e1)", "Instrument-Agency(e2,e1)", "Component-Whole(e2,e1)"] def get_examples(self, data_dir, split): path = os.path.join(data_dir, "{}.jsonl".format(split)) examples = [] with open(path, encoding='utf8') as f: for choicex, line in enumerate(f): example_json = json.loads(line) label = self.get_label_id(example_json["relation"]) text_a = " ".join(example_json["token"]) meta = { "head": example_json["h"]["name"], "tail": example_json["t"]["name"], } example = InputExample(guid=str(choicex), text_a=text_a, meta=meta, label=label) examples.append(example) return examples
PROCESSORS = { "tacred": TACREDProcessor, "tacrev": TACREVProcessor, "retacred": ReTACREDProcessor, "semeval": SemEvalProcessor, }