时效性:2023-07-26。transformers
库版本:4.29.0
~ 4.31.0
GitHub 关于此问题有讨论。
transformers.models.t5.modeling_t5.T5Attention
有方法prune_heads(heads_to_prune)
,接收一个
List heads_to_prune
,内容为需要剪枝的 head 的
index。使用后,这些 heads 将被直接移除,而不是 mask 成
0。因此该方法能确实地缩小模型。
剪枝后,进行
inference,有时会产生报错:IndexError: index 2 is out of bounds for dimension 0 with size 2
。其中数字
2 可能为任意数字。
报错发生在modeling_t5.py:554
。我在代码中加入了try-except
以打印附近的问题:
if self.pruned_heads:
mask = torch.ones(position_bias.shape[1])
try: # 新添加
mask[list(self.pruned_heads)] = 0 # 报错行
except BaseException: # 新添加
print(mask) # 新添加
print(mask.shape) # 新添加
print(self.pruned_heads) # 新添加
raise BaseException("here") # 新添加
position_bias_masked = position_bias[:, mask.bool()]
else:
position_bias_masked = position_bias
打印结果如下:
...
tensor([0., 0.])
torch.Size([2])
{0, 1, 2, 3, 5, 7}
...
Traceback (most recent call last):
...
可见,mask
的长度仅有 2,也即 index 只能为 0 或
1,因此读到pruned_heads
的2
后报错。但,该
attention 层应当有 8 个
heads。回看代码,mask
建立时长度与position_bias
相同。因此,是positional_bias
侧不能正确地处理包含剪枝的情况。
原
GitHub Issue的提出者发起了一个Pull
Request以处理此问题。解决方法为:在这段代码上方一点的地方,即modeling_t5.py:536
,有下列代码:
position_bias = torch.zeros(
(1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
)
torch.zeros()
的第一个元组中的第二个参数只计算了self.n_heads
,即
head
数目。如果有剪枝,该数据无法对应上。将其增加self.pruned_heads
的数目即可防止
out of bound 错误。修改如下:
position_bias = torch.zeros(
(1, self.n_heads + len(self.pruned_heads), real_seq_length, key_length), device=scores.device, dtype=scores.dtype
)
说实话我也不是很确定这是不是修好了,不过至少 inference 没有出问题。