paddle hub 支撑ERNIE进行文本分类
收藏
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import pickle
import numpy as np
import pandas as pd
from paddlehub.dataset import InputExample, HubDataset
def random(data):
random_order = list(range(len(data)))
np.random.shuffle(random_order)
train_data = [data[j] for i, j in enumerate(random_order) if i % 3 == 0]
valid_data = [data[j] for i, j in enumerate(random_order) if i % 3 == 1]
test_data = [data[j] for i, j in enumerate(random_order) if i % 3 == 2]
return train_data, valid_data, test_data
def read_message():
if not os.path.exists("sets/all_data.pkl"):
x_items = []
user_message = pd.read_csv("/home/aistudio/data/data10296/table1_user",
sep="\t")
jd_message = pd.read_csv("/home/aistudio/data/data10296/table2_jd",
sep="\t")
match_message = pd.read_csv("/home/aistudio/data/data10296/table3_action",
sep="\t")
user_message_index = {}
for i in user_message.values.tolist():
user_message_str = ''
for message in i[1:]:
user_message_str += str(message)
user_message_index[i[0]] = user_message_str
jd_message_index = {}
for i in jd_message.values.tolist():
user_message_str = ''
for message in i[1:]:
user_message_str += str(message)
jd_message_index[i[0]] = user_message_str
for i in match_message.values.tolist():
if i[0] in user_message_index.keys():
x_item = str(user_message_index[i[0]])
else:
continue
if i[1] in jd_message_index.keys():
x_item += str(jd_message_index[i[1]])
else:
continue
y_label = str(i[2]) + str(i[3]) + str(i[4])
c = [x_item, y_label]
x_items.append(c)
with open('sets/all_data.pkl', 'wb') as f:
pickle.dump(x_items, f)
else:
with open('sets/all_data.pkl', 'rb') as f:
x_items = pickle.load(f)
train_data, valid_data, test_data = random(x_items)
return train_data, valid_data, test_data
def _read_tsv(input_file):
"""Reads a tab separated value file."""
examples = []
seq_id = 0
for line in input_file:
example = InputExample(
guid=seq_id, label=line[1], text_a=line[0])
seq_id += 1
examples.append(example)
return examples
class DemoDataset(HubDataset):
"""DemoDataset"""
def __init__(self):
self.dataset_dir = "path/to/dataset"
self.train_data, self.valid_data, self.test_data = read_message()
self._load_train_examples()
self._load_test_examples()
self._load_dev_examples()
def _load_train_examples(self):
self.train_examples = _read_tsv(self.train_data)
def _load_dev_examples(self):
self.dev_examples = _read_tsv(self.valid_data)
def _load_test_examples(self):
self.test_examples = _read_tsv(self.test_data)
def get_train_examples(self):
return self.train_examples
def get_dev_examples(self):
return self.dev_examples
def get_test_examples(self):
return self.test_examples
def get_labels(self):
"""define it according the real dataset"""
return ["000", "100", "110", "111"]
@property
def num_labels(self):
"""
Return the number of labels in the dataset.
"""
return len(self.get_labels())
import paddlehub as hub
module = hub.Module(name="ernie", version="1.0.2")
dataset = DemoDataset()
reader = hub.reader.ClassifyReader(
dataset=dataset,
vocab_path=module.get_vocab_path(),
max_seq_len=128)
strategy = hub.AdamWeightDecayStrategy(
weight_decay=0.01,
warmup_proportion=0.1,
learning_rate=1e-5,
lr_scheduler="linear_decay",
optimizer_name="adam")
config = hub.RunConfig(
use_cuda=True,
num_epoch=50,
checkpoint_dir="ernie_turtorial_demo",
batch_size=64,
log_interval=10,
eval_interval=500,
strategy=strategy)
inputs, outputs, program = module.context(
trainable=True, max_seq_len=128)
#对整个句子中的分类任务使用“pooled_output”。 Use "pooled_output" for classification tasks on an entire sentence.
pooled_output = outputs["pooled_output"]
feed_list = [
inputs["input_ids"].name,
inputs["position_ids"].name,
inputs["segment_ids"].name,
inputs["input_mask"].name,
]
cls_task = hub.TextClassifierTask(
data_reader=reader,
feature=pooled_output,
feed_list=feed_list,
num_classes=dataset.num_labels,
config=config)
cls_task.finetune_and_eval()
0
收藏
请登录后评论
啥问题这是?
没问题啊就是技术分享啊
学习了
好歹介绍一下主题
能否方便交流一下使用ernie进行英文文本分类的大致流程么
学习一下