Source code for openprompt.data_utils.conditional_generation_dataset

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This file contains the logic for loading data for all Conditional Generation tasks.
"""

from openprompt.data_utils.utils import InputExample
import os
import json, csv
from abc import ABC, abstractmethod
from collections import defaultdict, Counter
from typing import List, Dict, Callable

from openprompt.utils.logging import logger
from openprompt.data_utils.data_processor import DataProcessor

[docs]class WebNLGProcessor(DataProcessor): """ # TODO citation Examples: .. code-block:: python from openprompt.data_utils.conditional_generation_dataset import PROCESSORS base_path = "datasets/CondGen" dataset_name = "webnlg_2017" dataset_path = os.path.join(base_path, dataset_name) processor = PROCESSORS[dataset_name.lower()]() train_dataset = processor.get_train_examples(dataset_path) valid_dataset = processor.get_train_examples(dataset_path) test_dataset = processor.get_test_examples(dataset_path) assert len(train_dataset) == 18025 assert len(valid_dataset) == 18025 assert len(test_dataset) == 4928 assert test_dataset[0].text_a == " | Abilene_Regional_Airport : cityServed : Abilene,_Texas" assert test_dataset[0].text_b == "" assert test_dataset[0].tgt_text == "Abilene, Texas is served by the Abilene regional airport." """ def __init__(self): super().__init__() self.labels = None
[docs] def get_examples(self, data_dir: str, split: str) -> List[InputExample]: examples = [] path = os.path.join(data_dir, "{}.json".format(split)) with open(path) as f: lines_dict = json.load(f) full_rela_lst = [] full_src_lst = [] full_tgt_lst = [] guid_lst = [] for i, example in enumerate(lines_dict['entries']): sents = example[str(i + 1)]['lexicalisations'] triples = example[str(i + 1)]['modifiedtripleset'] rela_lst = [] temp_triples = '' for j, tripleset in enumerate(triples): subj, rela, obj = tripleset['subject'], tripleset['property'], tripleset['object'] rela_lst.append(rela) temp_triples += ' | ' temp_triples += '{} : {} : {}'.format(subj, rela, obj) if split.lower() == "train": for sent in sents: if sent["comment"] == 'good': full_tgt_lst.append(sent["lex"]) full_src_lst.append(temp_triples) full_rela_lst.append(rela_lst) else: full_src_lst.append(temp_triples) full_rela_lst.append(rela_lst) temp = [] for sent in sents: if sent["comment"] == 'good': temp.append(sent["lex"]) full_tgt_lst.append("\n".join(temp)) assert len(full_rela_lst) == len(full_src_lst) assert len(full_rela_lst) == len(full_tgt_lst) if split.lower() == "train": for i, (src, tgt) in enumerate(zip(full_src_lst, full_tgt_lst)): example = InputExample(guid=str(i), text_a=src, tgt_text=tgt) examples.append(example) else: for i, (src, tgt) in enumerate(zip(full_src_lst, full_tgt_lst)): example = InputExample(guid=str(i), text_a=src, tgt_text=tgt) examples.append(example) return examples
def get_src_tgt_len_ratio(self,): pass
PROCESSORS = { "webnlg_2017": WebNLGProcessor, "webnlg": WebNLGProcessor, # "e2e": E2eProcessor, # "dart" : DartProcessor, }