利用艾尼做自定义数据集文本分类 2019年8月
收藏
from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import pickle import numpy as np import pandas as pd from paddlehub.dataset import InputExample, HubDataset def random(data): random_order = list(range(len(data))) np.random.shuffle(random_order) train_data = [data[j] for i, j in enumerate(random_order) if i % 3 == 0] valid_data = [data[j] for i, j in enumerate(random_order) if i % 3 == 1] test_data = [data[j] for i, j in enumerate(random_order) if i % 3 == 2] return train_data, valid_data, test_data def read_message(): if not os.path.exists("sets/all_data.pkl"): x_items = [] user_message = pd.read_csv("/home/aistudio/data/data10296/table1_user", sep="\t") jd_message = pd.read_csv("/home/aistudio/data/data10296/table2_jd", sep="\t") match_message = pd.read_csv("/home/aistudio/data/data10296/table3_action", sep="\t") user_message_index = {} for i in user_message.values.tolist(): user_message_str = '' for message in i[1:]: user_message_str += str(message) user_message_index[i[0]] = user_message_str jd_message_index = {} for i in jd_message.values.tolist(): user_message_str = '' for message in i[1:]: user_message_str += str(message) jd_message_index[i[0]] = user_message_str for i in match_message.values.tolist(): if i[0] in user_message_index.keys(): x_item = str(user_message_index[i[0]]) else: continue if i[1] in jd_message_index.keys(): x_item += str(jd_message_index[i[1]]) else: continue y_label = str(i[2]) + str(i[3]) + str(i[4]) c = [x_item, y_label] x_items.append(c) with open('sets/all_data.pkl', 'wb') as f: pickle.dump(x_items, f) else: with open('sets/all_data.pkl', 'rb') as f: x_items = pickle.load(f) train_data, valid_data, test_data = random(x_items) return train_data, valid_data, test_data def _read_tsv(input_file): """Reads a tab separated value file.""" examples = [] seq_id = 0 for line in input_file: example = InputExample( guid=seq_id, label=line[1], text_a=line[0]) seq_id += 1 examples.append(example) return examples class DemoDataset(HubDataset): """DemoDataset""" def __init__(self): self.dataset_dir = "path/to/dataset" self.train_data, self.valid_data, self.test_data = read_message() self._load_train_examples() self._load_test_examples() self._load_dev_examples() def _load_train_examples(self): self.train_examples = _read_tsv(self.train_data) def _load_dev_examples(self): self.dev_examples = _read_tsv(self.valid_data) def _load_test_examples(self): self.test_examples = _read_tsv(self.test_data) def get_train_examples(self): return self.train_examples def get_dev_examples(self): return self.dev_examples def get_test_examples(self): return self.test_examples def get_labels(self): """define it according the real dataset""" return ["000", "100", "110", "111"] @property def num_labels(self): """ Return the number of labels in the dataset. """ return len(self.get_labels()) import paddlehub as hub module = hub.Module(name="ernie", version="1.0.2") dataset = DemoDataset() reader = hub.reader.ClassifyReader( dataset=dataset, vocab_path=module.get_vocab_path(), max_seq_len=128) strategy = hub.AdamWeightDecayStrategy( weight_decay=0.01, warmup_proportion=0.1, learning_rate=1e-5, lr_scheduler="linear_decay", optimizer_name="adam") config = hub.RunConfig( use_cuda=True, num_epoch=50, checkpoint_dir="ernie_turtorial_demo", batch_size=64, log_interval=10, eval_interval=500, strategy=strategy) inputs, outputs, program = module.context( trainable=True, max_seq_len=128) #对整个句子中的分类任务使用“pooled_output”。 Use "pooled_output" for classification tasks on an entire sentence. pooled_output = outputs["pooled_output"] feed_list = [ inputs["input_ids"].name, inputs["position_ids"].name, inputs["segment_ids"].name, inputs["input_mask"].name, ] cls_task = hub.TextClassifierTask( data_reader=reader, feature=pooled_output, feed_list=feed_list, num_classes=dataset.num_labels, config=config) cls_task.finetune_and_eval()
3
收藏
请登录后评论
呀, 很不错呀~~~
我相信这个产品会改变中国的格局,足够简单足够快速,足够精准
很不错,在学习中
请教一下:碰到如下错误,
ImportError: cannot import name 'HubDataset' from 'paddlehub.dataset' (/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddlehub/dataset/__init__.py)
不错不错
老帖子挖出来了。paddle出新版本了,不知道对hub有没影响。
这篇文写的太有用了。正看怎么用hub读自定义数据中。。。
对我有用,正解决问题。
感谢顶起
支持,参考下,最近也在弄这个课题