利用艾尼做自定义数据集文本分类 2019年8月
收藏
快速回复
AI Studio平台使用 问答数据集 1729 10
利用艾尼做自定义数据集文本分类 2019年8月
收藏
快速回复
AI Studio平台使用 问答数据集 1729 10
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import pickle

import numpy as np
import pandas as pd
from paddlehub.dataset import InputExample, HubDataset


def random(data):
    random_order = list(range(len(data)))
    np.random.shuffle(random_order)
    train_data = [data[j] for i, j in enumerate(random_order) if i % 3 == 0]
    valid_data = [data[j] for i, j in enumerate(random_order) if i % 3 == 1]
    test_data = [data[j] for i, j in enumerate(random_order) if i % 3 == 2]
    return train_data, valid_data, test_data


def read_message():
    if not os.path.exists("sets/all_data.pkl"):
        x_items = []
        user_message = pd.read_csv("/home/aistudio/data/data10296/table1_user",
                                   sep="\t")
        jd_message = pd.read_csv("/home/aistudio/data/data10296/table2_jd",
                                 sep="\t")
        match_message = pd.read_csv("/home/aistudio/data/data10296/table3_action",
                                    sep="\t")
        user_message_index = {}
        for i in user_message.values.tolist():
            user_message_str = ''
            for message in i[1:]:
                user_message_str += str(message)
            user_message_index[i[0]] = user_message_str
        jd_message_index = {}
        for i in jd_message.values.tolist():
            user_message_str = ''
            for message in i[1:]:
                user_message_str += str(message)
            jd_message_index[i[0]] = user_message_str
        for i in match_message.values.tolist():
            if i[0] in user_message_index.keys():
                x_item = str(user_message_index[i[0]])
            else:
                continue
            if i[1] in jd_message_index.keys():
                x_item += str(jd_message_index[i[1]])
            else:
                continue
            y_label = str(i[2]) + str(i[3]) + str(i[4])
            c = [x_item, y_label]
            x_items.append(c)
        with open('sets/all_data.pkl', 'wb') as f:
            pickle.dump(x_items, f)
    else:
        with open('sets/all_data.pkl', 'rb') as f:
            x_items = pickle.load(f)
    train_data, valid_data, test_data = random(x_items)
    return train_data, valid_data, test_data


def _read_tsv(input_file):
    """Reads a tab separated value file."""
    examples = []
    seq_id = 0
    for line in input_file:
        example = InputExample(
            guid=seq_id, label=line[1], text_a=line[0])
        seq_id += 1
        examples.append(example)
    return examples


class DemoDataset(HubDataset):
    """DemoDataset"""

    def __init__(self):
        self.dataset_dir = "path/to/dataset"
        self.train_data, self.valid_data, self.test_data = read_message()
        self._load_train_examples()
        self._load_test_examples()
        self._load_dev_examples()

    def _load_train_examples(self):
        self.train_examples = _read_tsv(self.train_data)

    def _load_dev_examples(self):
        self.dev_examples = _read_tsv(self.valid_data)

    def _load_test_examples(self):
        self.test_examples = _read_tsv(self.test_data)

    def get_train_examples(self):
        return self.train_examples

    def get_dev_examples(self):
        return self.dev_examples

    def get_test_examples(self):
        return self.test_examples

    def get_labels(self):
        """define it according the real dataset"""
        return ["000", "100", "110", "111"]

    @property
    def num_labels(self):
        """
        Return the number of labels in the dataset.
        """
        return len(self.get_labels())



import paddlehub as hub
module = hub.Module(name="ernie", version="1.0.2")
dataset = DemoDataset()


reader = hub.reader.ClassifyReader(
    dataset=dataset,
    vocab_path=module.get_vocab_path(),
    max_seq_len=128)

strategy = hub.AdamWeightDecayStrategy(
    weight_decay=0.01,
    warmup_proportion=0.1,
    learning_rate=1e-5,
    lr_scheduler="linear_decay",
    optimizer_name="adam")

config = hub.RunConfig(
    use_cuda=True,
    num_epoch=50,
    checkpoint_dir="ernie_turtorial_demo",
    batch_size=64,
    log_interval=10,
    eval_interval=500,
    strategy=strategy)
inputs, outputs, program = module.context(
    trainable=True, max_seq_len=128)

#对整个句子中的分类任务使用“pooled_output”。 Use "pooled_output" for classification tasks on an entire sentence.
pooled_output = outputs["pooled_output"]

feed_list = [
    inputs["input_ids"].name,
    inputs["position_ids"].name,
    inputs["segment_ids"].name,
    inputs["input_mask"].name,
]

cls_task = hub.TextClassifierTask(
    data_reader=reader,
    feature=pooled_output,
    feed_list=feed_list,
    num_classes=dataset.num_labels,
    config=config)
cls_task.finetune_and_eval()
3
收藏
回复
全部评论(10)
时间顺序
AIStudio810261
#2 回复于2019-08

呀, 很不错呀~~~

0
回复
ygq
#3 回复于2019-08

我相信这个产品会改变中国的格局,足够简单足够快速,足够精准

0
回复
宋老实
#4 回复于2019-08

很不错,在学习中

0
回复
w
wangwei8638
#5 回复于2020-03

请教一下:碰到如下错误,

ImportError: cannot import name 'HubDataset' from 'paddlehub.dataset' (/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddlehub/dataset/__init__.py)

0
回复
水水水的老师
#6 回复于2020-03

不错不错

0
回复
AIStudio810258
#7 回复于2020-03

老帖子挖出来了。paddle出新版本了,不知道对hub有没影响。

0
回复
AIStudio810258
#8 回复于2020-04

这篇文写的太有用了。正看怎么用hub读自定义数据中。。。

0
回复
AIStudio810258
#9 回复于2020-04
这篇文写的太有用了。正看怎么用hub读自定义数据中。。。

对我有用,正解决问题。

0
回复
AIStudio810258
#10 回复于2020-04
不错不错

感谢顶起

0
回复
micahvista
#11 回复于2020-04

支持,参考下,最近也在弄这个课题

0
回复
在@后输入用户全名并按空格结束,可艾特全站任一用户