Source code for openprompt.data_utils.text_classification_dataset


# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This file contains the logic for loading data for all TextClassification tasks.
"""

import os
import json, csv
from abc import ABC, abstractmethod
from collections import defaultdict, Counter
from typing import List, Dict, Callable

from openprompt.utils.logging import logger

from openprompt.data_utils.utils import InputExample
from openprompt.data_utils.data_processor import DataProcessor


class MnliProcessor(DataProcessor):
    # TODO Test needed
    def __init__(self):
        super().__init__()
        self.labels = ["contradiction", "entailment", "neutral"]

    def get_examples(self, data_dir, split):
        path = os.path.join(data_dir, "{}.csv".format(split))
        examples = []
        with open(path, encoding='utf8') as f:
            reader = csv.reader(f, delimiter=',')
            for idx, row in enumerate(reader):

                label, headline, body = row
                text_a = headline.replace('\\', ' ')
                text_b = body.replace('\\', ' ')
                example = InputExample(
                    guid=str(idx), text_a=text_a, text_b=text_b, label=int(label)-1)
                examples.append(example)

        return examples



[docs]class AgnewsProcessor(DataProcessor): """ `AG News <https://arxiv.org/pdf/1509.01626.pdf>`_ is a News Topic classification dataset we use dataset provided by `LOTClass <https://github.com/yumeng5/LOTClass>`_ Examples: .. code-block:: python from openprompt.data_utils.text_classification_dataset import PROCESSORS base_path = "datasets/TextClassification" dataset_name = "agnews" dataset_path = os.path.join(base_path, dataset_name) processor = PROCESSORS[dataset_name.lower()]() trainvalid_dataset = processor.get_train_examples(dataset_path) test_dataset = processor.get_test_examples(dataset_path) assert processor.get_num_labels() == 4 assert processor.get_labels() == ["World", "Sports", "Business", "Tech"] assert len(trainvalid_dataset) == 120000 assert len(test_dataset) == 7600 assert test_dataset[0].text_a == "Fears for T N pension after talks" assert test_dataset[0].text_b == "Unions representing workers at Turner Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul." assert test_dataset[0].label == 2 """ def __init__(self): super().__init__() self.labels = ["World", "Sports", "Business", "Tech"]
[docs] def get_examples(self, data_dir, split): path = os.path.join(data_dir, "{}.csv".format(split)) examples = [] with open(path, encoding='utf8') as f: reader = csv.reader(f, delimiter=',') for idx, row in enumerate(reader): label, headline, body = row text_a = headline.replace('\\', ' ') text_b = body.replace('\\', ' ') example = InputExample(guid=str(idx), text_a=text_a, text_b=text_b, label=int(label)-1) examples.append(example) return examples
[docs]class DBpediaProcessor(DataProcessor): """ `Dbpedia <https://aclanthology.org/L16-1532.pdf>`_ is a Wikipedia Topic Classification dataset. we use dataset provided by `LOTClass <https://github.com/yumeng5/LOTClass>`_ Examples: .. code-block:: python from openprompt.data_utils.text_classification_dataset import PROCESSORS base_path = "datasets/TextClassification" dataset_name = "dbpedia" dataset_path = os.path.join(base_path, dataset_name) processor = PROCESSORS[dataset_name.lower()]() trainvalid_dataset = processor.get_train_examples(dataset_path) test_dataset = processor.get_test_examples(dataset_path) assert processor.get_num_labels() == 14 assert len(trainvalid_dataset) == 560000 assert len(test_dataset) == 70000 """ def __init__(self): super().__init__() self.labels = ["company", "school", "artist", "athlete", "politics", "transportation", "building", "river", "village", "animal", "plant", "album", "film", "book",]
[docs] def get_examples(self, data_dir, split): examples = [] label_file = open(os.path.join(data_dir,"{}_labels.txt".format(split)),'r') labels = [int(x.strip()) for x in label_file.readlines()] with open(os.path.join(data_dir,'{}.txt'.format(split)),'r') as fin: for idx, line in enumerate(fin): splited = line.strip().split(". ") text_a, text_b = splited[0], splited[1:] text_a = text_a+"." text_b = ". ".join(text_b) example = InputExample(guid=str(idx), text_a=text_a, text_b=text_b, label=int(labels[idx])) examples.append(example) return examples
[docs]class ImdbProcessor(DataProcessor): """ `IMDB <https://ai.stanford.edu/~ang/papers/acl11-WordVectorsSentimentAnalysis.pdf>`_ is a Movie Review Sentiment Classification dataset. we use dataset provided by `LOTClass <https://github.com/yumeng5/LOTClass>`_ Examples: .. code-block:: python from openprompt.data_utils.text_classification_dataset import PROCESSORS base_path = "datasets/TextClassification" dataset_name = "imdb" dataset_path = os.path.join(base_path, dataset_name) processor = PROCESSORS[dataset_name.lower()]() trainvalid_dataset = processor.get_train_examples(dataset_path) test_dataset = processor.get_test_examples(dataset_path) assert processor.get_num_labels() == 2 assert len(trainvalid_dataset) == 25000 assert len(test_dataset) == 25000 """ def __init__(self): super().__init__() self.labels = ["negative", "positive"]
[docs] def get_examples(self, data_dir, split): examples = [] label_file = open(os.path.join(data_dir, "{}_labels.txt".format(split)), 'r') labels = [int(x.strip()) for x in label_file.readlines()] with open(os.path.join(data_dir, '{}.txt'.format(split)),'r') as fin: for idx, line in enumerate(fin): text_a = line.strip() example = InputExample(guid=str(idx), text_a=text_a, label=int(labels[idx])) examples.append(example) return examples
@staticmethod def get_test_labels_only(data_dir, dirname): label_file = open(os.path.join(data_dir,dirname,"{}_labels.txt".format('test')),'r') labels = [int(x.strip()) for x in label_file.readlines()] return labels
# class AmazonProcessor(DataProcessor): # """ # `Amazon <https://cs.stanford.edu/people/jure/pubs/reviews-recsys13.pdf>`_ is a Product Review Sentiment Classification dataset. # we use dataset provided by `LOTClass <https://github.com/yumeng5/LOTClass>`_ # Examples: # TODO implement this # """ # def __init__(self): # # raise NotImplementedError # super().__init__() # self.labels = ["bad", "good"] # def get_examples(self, data_dir, split): # examples = [] # label_file = open(os.path.join(data_dir, "{}_labels.txt".format(split)), 'r') # labels = [int(x.strip()) for x in label_file.readlines()] # if split == "test": # logger.info("Sample a mid-size test set for effeciecy, use sampled_test_idx.txt") # with open(os.path.join(self.args.data_dir,self.dirname,"sampled_test_idx.txt"),'r') as sampleidxfile: # sampled_idx = sampleidxfile.readline() # sampled_idx = sampled_idx.split() # sampled_idx = set([int(x) for x in sampled_idx]) # with open(os.path.join(data_dir,'{}.txt'.format(split)),'r') as fin: # for idx, line in enumerate(fin): # if split=='test': # if idx not in sampled_idx: # continue # text_a = line.strip() # example = InputExample(guid=str(idx), text_a=text_a, label=int(labels[idx])) # examples.append(example) # return examples class AmazonProcessor(DataProcessor): """ `Amazon <https://cs.stanford.edu/people/jure/pubs/reviews-recsys13.pdf>`_ is a Product Review Sentiment Classification dataset. we use dataset provided by `LOTClass <https://github.com/yumeng5/LOTClass>`_ Examples: # TODO implement this """ def __init__(self): super().__init__() self.labels = ["bad", "good"] def get_examples(self, data_dir, split): examples = [] label_file = open(os.path.join(data_dir, "{}_labels.txt".format(split)), 'r') labels = [int(x.strip()) for x in label_file.readlines()] with open(os.path.join(data_dir,'{}.txt'.format(split)),'r') as fin: for idx, line in enumerate(fin): text_a = line.strip() example = InputExample(guid=str(idx), text_a=text_a, label=int(labels[idx])) examples.append(example) return examples class YahooProcessor(DataProcessor): """ Yahoo! Answers Topic Classification Dataset Examples: .. code-block:: python from openprompt.data_utils.text_classification_dataset import PROCESSORS base_path = "datasets/TextClassification" """ def __init__(self): super().__init__() self.labels = ["Society & Culture", "Science & Mathematics", "Health", "Education & Reference", "Computers & Internet", "Sports", "Business & Finance", "Entertainment & Music" ,"Family & Relationships", "Politics & Government"] def get_examples(self, data_dir, split): path = os.path.join(data_dir, "{}.csv".format(split)) examples = [] with open(path, encoding='utf8') as f: reader = csv.reader(f, delimiter=',') for idx, row in enumerate(reader): label, question_title, question_body, answer = row text_a = ' '.join([question_title.replace('\\n', ' ').replace('\\', ' '), question_body.replace('\\n', ' ').replace('\\', ' ')]) text_b = answer.replace('\\n', ' ').replace('\\', ' ') example = InputExample(guid=str(idx), text_a=text_a, text_b=text_b, label=int(label)-1) examples.append(example) return examples # class SST2Processor(DataProcessor): # """ # #TODO test needed # """ # def __init__(self): # raise NotImplementedError # super().__init__() # self.labels = ["negative", "positive"] # def get_examples(self, data_dir, split): # examples = [] # path = os.path.join(data_dir,"{}.tsv".format(split)) # with open(path, 'r') as f: # reader = csv.DictReader(f, delimiter='\t') # for idx, example_json in enumerate(reader): # text_a = example_json['sentence'].strip() # example = InputExample(guid=str(idx), text_a=text_a, label=int(example_json['label'])) # examples.append(example) # return examples
[docs]class SST2Processor(DataProcessor): """ `SST-2 <https://nlp.stanford.edu/sentiment/index.html>`_ dataset is a dataset for sentiment analysis. It is a modified version containing only binary labels (negative or somewhat negative vs somewhat positive or positive with neutral sentences discarded) on top of the original 5-labeled dataset released first in `Recursive Deep Models for Semantic Compositionality Over a Sentiment Treebank <https://aclanthology.org/D13-1170.pdf>`_ We use the data released in `Making Pre-trained Language Models Better Few-shot Learners (Gao et al. 2020) <https://arxiv.org/pdf/2012.15723.pdf>`_ Examples: .. code-block:: python from openprompt.data_utils.lmbff_dataset import PROCESSORS base_path = "datasets/TextClassification" dataset_name = "SST-2" dataset_path = os.path.join(base_path, dataset_name) processor = PROCESSORS[dataset_name.lower()]() train_dataset = processor.get_train_examples(dataset_path) dev_dataset = processor.get_dev_examples(dataset_path) test_dataset = processor.get_test_examples(dataset_path) assert processor.get_num_labels() == 2 assert processor.get_labels() == ['0','1'] assert len(train_dataset) == 6920 assert len(dev_dataset) == 872 assert len(test_dataset) == 1821 assert train_dataset[0].text_a == 'a stirring , funny and finally transporting re-imagining of beauty and the beast and 1930s horror films' assert train_dataset[0].label == 1 """ def __init__(self): super().__init__() self.labels = ['0', '1']
[docs] def get_examples(self, data_dir, split): path = os.path.join(data_dir, f"{split}.tsv") examples = [] with open(path, encoding='utf-8')as f: lines = f.readlines() for idx, line in enumerate(lines[1:]): linelist = line.strip().split('\t') text_a = linelist[0] label = linelist[1] guid = "%s-%s" % (split, idx) example = InputExample(guid=guid, text_a=text_a, label=self.get_label_id(label)) examples.append(example) return examples
PROCESSORS = { "agnews": AgnewsProcessor, "dbpedia": DBpediaProcessor, "amazon" : AmazonProcessor, "imdb": ImdbProcessor, "sst-2": SST2Processor, "mnli": MnliProcessor, "yahoo": YahooProcessor, }