时效性:2023-07-26。transformers库版本:4.29.0 ~ 4.31.0

GitHub 关于此问题有讨论

transformers.models.t5.modeling_t5.T5Attention有方法prune_heads(heads_to_prune),接收一个 List heads_to_prune,内容为需要剪枝的 head 的 index。使用后,这些 heads 将被直接移除,而不是 mask 成 0。因此该方法能确实地缩小模型。

剪枝后,进行 inference,有时会产生报错:IndexError: index 2 is out of bounds for dimension 0 with size 2。其中数字 2 可能为任意数字。

报错发生在modeling_t5.py:554。我在代码中加入了try-except以打印附近的问题:

if self.pruned_heads:
	mask = torch.ones(position_bias.shape[1])
	try:   # 新添加
		mask[list(self.pruned_heads)] = 0    # 报错行
	except BaseException:   # 新添加
		print(mask)   # 新添加
		print(mask.shape)   # 新添加
		print(self.pruned_heads)   # 新添加
		raise BaseException("here")   # 新添加
	position_bias_masked = position_bias[:, mask.bool()]
else:
	position_bias_masked = position_bias

打印结果如下:

...
tensor([0., 0.])
torch.Size([2])
{0, 1, 2, 3, 5, 7}
...
Traceback (most recent call last):
...

可见,mask的长度仅有 2,也即 index 只能为 0 或 1,因此读到pruned_heads2后报错。但,该 attention 层应当有 8 个 heads。回看代码,mask建立时长度与position_bias相同。因此,是positional_bias侧不能正确地处理包含剪枝的情况。

原 GitHub Issue的提出者发起了一个Pull Request以处理此问题。解决方法为:在这段代码上方一点的地方,即modeling_t5.py:536,有下列代码:

position_bias = torch.zeros(
	(1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
)

torch.zeros()的第一个元组中的第二个参数只计算了self.n_heads,即 head 数目。如果有剪枝,该数据无法对应上。将其增加self.pruned_heads的数目即可防止 out of bound 错误。修改如下:

position_bias = torch.zeros(
	(1, self.n_heads + len(self.pruned_heads), real_seq_length, key_length), device=scores.device, dtype=scores.dtype
)

说实话我也不是很确定这是不是修好了,不过至少 inference 没有出问题。