torchtext预处理自定义文本数据集!🌞

torchtext预处理自定义文本数据集

1.设置

导入必要的包:

1
2
3
4
import torchdata.datapipes as dp
import torchtext.transforms as T
import spacy
from torchtext.vocab import build_vocab_from_iterator

下载分词的模型:

加载模型:

1
2
eng = spacy.load("en_core_web_md-3.7.1")  # Load the English model to tokenize English text
zh = spacy.load("zh_core_web_md-3.7.0")

加载数据集:

1
2
3
4
FILE_PATH = 'cmn.txt'
data_pipe = dp.iter.IterableWrapper([FILE_PATH])
data_pipe = dp.iter.FileOpener(data_pipe, mode='rb')
data_pipe = data_pipe.parse_csv(skip_lines=0, delimiter='\t', as_tuple=True)
  1. 在第 2 行,我们正在创建文件名的可迭代对象
  2. 在第 3 行,我们将可迭代对象传递给 FileOpener,然后 FileOpener 以读取模式打开文件
  3. 在第 4 行,我们调用一个函数来解析文件,该函数再次返回一个元组的可迭代对象,表示制表符分隔文件的每一行

DataPipes 可以被认为是一个数据集对象,我们可以在它上面执行各种操作。

1
2
3
4
5
def removeAttribution(row):
return row[:2]


data_pipe = data_pipe.map(removeAttribution)

map 函数可用于对data_pipe的每个元素应用一些函数。

分词:

1
2
3
4
5
6
7
def engTokenize(text):
engTokenList = [token.text for token in eng.tokenizer(text)]
return engTokenList


def zhTokenize(text):
return [token.text for token in zh.tokenizer(text)]

2.建立词汇表

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def getTokens(data_iter, place):
for english, chinese in data_iter:
if place == 0:
yield engTokenize(english)
else:
yield zhTokenize(chinese)



source_vocab = build_vocab_from_iterator(
getTokens(data_pipe, 0),
min_freq=2,
specials=['<pad>', '<sos>', '<eos>', '<unk>'],
special_first=True
)
source_vocab.set_default_index(source_vocab['<unk>'])

target_vocab = build_vocab_from_iterator(
getTokens(data_pipe, 1),
min_freq=2,
specials=['<pad>', '<sos>', '<eos>', '<unk>'],
special_first=True
)
target_vocab.set_default_index(target_vocab['<unk>'])

print(source_vocab.get_itos()[:9])

source_vocab.get_itos() 返回一个列表,其中包含基于词汇的索引中的标记。

3.使用词汇对句子进行数字化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def getTransform(vocab):
text_tranform = T.Sequential(
T.VocabTransform(vocab=vocab),
T.AddToken(1, begin=True),
T.AddToken(2, begin=False)
)
return text_tranform


def applyTransform(sequence_pair):
return (
getTransform(source_vocab)(engTokenize(sequence_pair[0])),
getTransform(target_vocab)(zhTokenize(sequence_pair[1]))
)


data_pipe = data_pipe.map(applyTransform)

4.制作批次

通常,我们分批训练模型。在处理序列到序列模型时,建议保持批次中序列的长度相似。为此,我们将使用 data_pipe 的 bucketbatch 函数。

1
2
3
4
5
6
7
8
def sortBucket(bucket):
return sorted(bucket, key=lambda x: (len(x[0]), len(x[1])))


data_pipe = data_pipe.bucketbatch(
batch_size=4, batch_num=5, bucket_num=1,
use_in_batch_shuffle=False, sort_key=sortBucket
)

data_pipe中的一批是 [(X_1,y_1)、(X_2,y_2)、(X_3,y_3)、(X_4,y_4)]

因此,我们现在将它们转换为以下形式:((X_1,X_2,X_3,X_4), (y_1,y_2,y_3,y_4))。为此,我们将编写一个小函数:

1
2
3
4
5
6
def separateSourceTarget(sequence_pairs):
sources, targets = zip(*sequence_pairs)
return sources, targets


data_pipe = data_pipe.map(separateSourceTarget)

5.填充

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def applyPadding(pair_of_sequences):
return (T.ToTensor(0)(list(pair_of_sequences[0])), T.ToTensor(0)(list(pair_of_sequences[1])))


data_pipe = data_pipe.map(applyPadding)

source_index_to_string = source_vocab.get_itos()
target_index_to_string = target_vocab.get_itos()


def showSomeTransformedSentences(data_pipe):
"""
Function to show how the sentences look like after applying all transforms.
Here we try to print actual words instead of corresponding index
"""
for sources, targets in data_pipe:
if sources[0][-1] != 0:
continue # Just to visualize padding of shorter sentences
for i in range(4):
source = ""
for token in sources[i]:
source += " " + source_index_to_string[token]
target = ""
for token in targets[i]:
target += " " + target_index_to_string[token]
print(f"Source: {source}")
print(f"Traget: {target}")
break


showSomeTransformedSentences(data_pipe)

返回结果:

1
2
3
4
5
6
7
8
Source:  <sos> <unk> ! <eos> <pad>
Traget: <sos> 完美 ! <eos>
Source: <sos> Hold on . <eos>
Traget: <sos> 坚持 。 <eos>
Source: <sos> See you . <eos>
Traget: <sos> 再见 ! <eos>
Source: <sos> Shut up ! <eos>
Traget: <sos> <unk> ! <eos>

torchtext预处理自定义文本数据集!🌞
https://yangchuanzhi20.github.io/2024/03/12/人工智能/Pytorch/项目实战/torchtext预处理自定义文本数据集/
作者
白色很哇塞
发布于
2024年3月12日
许可协议