模型
模型结构和数据相对比较简单,核心在于mask部分
mask表示当前字只和它前面的字有关系同时预测的也是下一个字
例如
‘’’
我是你爹
“是” 这个字与我字相关 “你” 字和”我是”两个字有关
‘’’
所以mask才是一个与token长度一样三角矩阵,用这个矩阵是mask掉关系矩阵
中的某些关系
其他的就是Transformer结构都一样
训练数据构建
比较简单的数据
比如原始的一句话 “我是你爹”
input 部分就是”我是你”
label 部分就是”是你的爹”
‘’’
def collate_fn(batch):
input_ids = []
label_ids = []
l = []
for i in batch:
l.append(len(i[‘input_ids’]))
l.append(len(i[‘label_ids’]))
length = max(l)
for i in batch:
input_id = i[‘input_ids’]
label_id = i[‘label_ids’]
input_id = input_id + [0] * (length - len(input_id))
label_id = label_id + [0] * (length - len(label_id))
input_ids.append(input_id)
label_ids.append(label_id)
return {‘input_ids’: torch.tensor(input_ids), ‘label_ids’: torch.tensor(label_ids)}
class TextData(Dataset):
def init(self, path, config=None):
super(TextData, self).init()
#print(path)
self.root = path
self.config = config
self.data = []
with open(path)as ff:
for ll in ff:
#print(ll)
try:
data = ll.strip().split(‘\t’)[1]
self.data.append(data)
except:
pass
self.tokener = tokenization.BertTokenizer(‘./vocab.txt’)
def getitem(self, index):
data = self.tokener.encode(self.data[index])[:(self.config.n_positions-2)]
data = [101] + data + [102]
input_ids = data[:-1]
label_ids = data[1:]
return {‘input_ids’: input_ids, ‘label_ids’: label_ids}
def __len__(self):
return len(self.data) '''
模型结构
1、TokenEmbedding+PosEmbedding
2、PaddingMask + FeatureMask
3、LayerNorm
4、Attention:QK+Mask 和V计算
5、计算输入输出残差
6、LayerNorm
7、MLP
8、计算残差
重复(345678)12次