首页 PaddleHub 帖子详情
paddle hub 支撑ERNIE进行文本分类
收藏
快速回复
PaddleHub 问答预训练模型 3200 6
paddle hub 支撑ERNIE进行文本分类
收藏
快速回复
PaddleHub 问答预训练模型 3200 6
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
 
import os
import pickle
 
import numpy as np
import pandas as pd
from paddlehub.dataset import InputExample, HubDataset
 
 
def random(data):
    random_order = list(range(len(data)))
    np.random.shuffle(random_order)
    train_data = [data[j] for i, j in enumerate(random_order) if i % 3 == 0]
    valid_data = [data[j] for i, j in enumerate(random_order) if i % 3 == 1]
    test_data = [data[j] for i, j in enumerate(random_order) if i % 3 == 2]
    return train_data, valid_data, test_data
 
 
def read_message():
    if not os.path.exists("sets/all_data.pkl"):
        x_items = []
        user_message = pd.read_csv("/home/aistudio/data/data10296/table1_user",
                                   sep="\t")
        jd_message = pd.read_csv("/home/aistudio/data/data10296/table2_jd",
                                 sep="\t")
        match_message = pd.read_csv("/home/aistudio/data/data10296/table3_action",
                                    sep="\t")
        user_message_index = {}
        for i in user_message.values.tolist():
            user_message_str = ''
            for message in i[1:]:
                user_message_str += str(message)
            user_message_index[i[0]] = user_message_str
        jd_message_index = {}
        for i in jd_message.values.tolist():
            user_message_str = ''
            for message in i[1:]:
                user_message_str += str(message)
            jd_message_index[i[0]] = user_message_str
        for i in match_message.values.tolist():
            if i[0] in user_message_index.keys():
                x_item = str(user_message_index[i[0]])
            else:
                continue
            if i[1] in jd_message_index.keys():
                x_item += str(jd_message_index[i[1]])
            else:
                continue
            y_label = str(i[2]) + str(i[3]) + str(i[4])
            c = [x_item, y_label]
            x_items.append(c)
        with open('sets/all_data.pkl', 'wb') as f:
            pickle.dump(x_items, f)
    else:
        with open('sets/all_data.pkl', 'rb') as f:
            x_items = pickle.load(f)
    train_data, valid_data, test_data = random(x_items)
    return train_data, valid_data, test_data
 
 
def _read_tsv(input_file):
    """Reads a tab separated value file."""
    examples = []
    seq_id = 0
    for line in input_file:
        example = InputExample(
            guid=seq_id, label=line[1], text_a=line[0])
        seq_id += 1
        examples.append(example)
    return examples
 
 
class DemoDataset(HubDataset):
    """DemoDataset"""
 
    def __init__(self):
        self.dataset_dir = "path/to/dataset"
        self.train_data, self.valid_data, self.test_data = read_message()
        self._load_train_examples()
        self._load_test_examples()
        self._load_dev_examples()
 
    def _load_train_examples(self):
        self.train_examples = _read_tsv(self.train_data)
 
    def _load_dev_examples(self):
        self.dev_examples = _read_tsv(self.valid_data)
 
    def _load_test_examples(self):
        self.test_examples = _read_tsv(self.test_data)
 
    def get_train_examples(self):
        return self.train_examples
 
    def get_dev_examples(self):
        return self.dev_examples
 
    def get_test_examples(self):
        return self.test_examples
 
    def get_labels(self):
        """define it according the real dataset"""
        return ["000", "100", "110", "111"]
 
    @property
    def num_labels(self):
        """
        Return the number of labels in the dataset.
        """
        return len(self.get_labels())
 
 
 
import paddlehub as hub
module = hub.Module(name="ernie", version="1.0.2")
dataset = DemoDataset()
 
 
reader = hub.reader.ClassifyReader(
    dataset=dataset,
    vocab_path=module.get_vocab_path(),
    max_seq_len=128)
 
strategy = hub.AdamWeightDecayStrategy(
    weight_decay=0.01,
    warmup_proportion=0.1,
    learning_rate=1e-5,
    lr_scheduler="linear_decay",
    optimizer_name="adam")
 
config = hub.RunConfig(
    use_cuda=True,
    num_epoch=50,
    checkpoint_dir="ernie_turtorial_demo",
    batch_size=64,
    log_interval=10,
    eval_interval=500,
    strategy=strategy)
inputs, outputs, program = module.context(
    trainable=True, max_seq_len=128)
 
#对整个句子中的分类任务使用“pooled_output”。 Use "pooled_output" for classification tasks on an entire sentence.
pooled_output = outputs["pooled_output"]
 
feed_list = [
    inputs["input_ids"].name,
    inputs["position_ids"].name,
    inputs["segment_ids"].name,
    inputs["input_mask"].name,
]
 
cls_task = hub.TextClassifierTask(
    data_reader=reader,
    feature=pooled_output,
    feed_list=feed_list,
    num_classes=dataset.num_labels,
    config=config)
cls_task.finetune_and_eval()
0
收藏
回复
全部评论(6)
时间顺序
AIStudio810261
#2 回复于2019-08

啥问题这是?

0
回复
ygq
#3 回复于2019-08

没问题啊就是技术分享啊

 

0
回复
宋老实
#4 回复于2019-08

学习了

0
回复
w
wangwei8638
#5 回复于2019-08

好歹介绍一下主题

0
回复
院长灿爷
#6 回复于2021-04

能否方便交流一下使用ernie进行英文文本分类的大致流程么

0
回复
jsdbzcm
#7 回复于2021-05

学习一下

 

0
回复
需求/bug反馈?一键提issue告诉我们
发现bug?如果您知道修复办法,欢迎提pr直接参与建设飞桨~
在@后输入用户全名并按空格结束,可艾特全站任一用户