Pytorch debug经验之RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

有一次,我在自己写的关于“神经组合优化”的项目中遭遇报错:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [CUDABoolType [1024, 21]] is at version 139; expected version 138 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

此报错令人一头雾水,因为抛出异常的位置位于loss.backward(),显然跟真正出问题的地方相去甚远,而其他有价值的信息就只有关于问题张量的数据类型与形状,所以想要进行问题定位还是比较困难的。

根据异常的提示,我在程序一开始执行了torch.autograd.set_detect_anomaly(True),于是报错变成了:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [CUDABoolType [1024, 21]] is at version 139; expected version 138 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

可以看到除了获得了Gook luck的祝福外,有价值的信息并未增加。不过好在根据这少量的提示,我能看出问题张量为bool类型,其中1024是扩展的batch数,21是客户数量,从而确定了问题张量与mask有关,缩小了检查范围,最终成功定位并消除了问题。

下面我将通过简单的例子说明为何会遇到这个问题以及为了避免这个问题最好注意什么。

import torch
from torch import nn


class A(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.x = nn.Parameter(torch.tensor([1.0]))

    def forward(self) -> torch.Tensor:
        t = torch.tensor([2.0])
        y = self.x * t
        t[:] = 3.0
        return y


y = A()()
y.backward()

运行此程序,将会得到报错:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1]] is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

可以看出,这正是我在实际项目中遇到的错误,我们分析一下反向传播的流程:在第19行,我们要求dy/dx,根据第13行的计算可以看出,dy/dx=t,所以梯度值应该等于第12行为t赋的值2,可注意在第14行,我们对t进行了就地修改,此时程序再去求梯度的话,会看到t是3,但如果认为dy/dx就是3,那就错误了,因为实际的梯度值是2,可由于我们对t进行了就地修改,所以程序无法获取正确的梯度值,因此会报错。

如果我们把第14行改为t = 3.0,就不会发生此报错,因为此时只是让引用t指向了其他对象,而原本的那个值为2.0的张量还在内存中,被torch框架内部的计算图所引用着,它没有被就地修改,就不会发生错误。

建议

其实在实际项目中,我们最容易犯此类错误的情景是:

1,一个变量x在某个网络A的forward函数内创建,参与了前向传播,然后调用了另一个网络B的forward函数并将x传入,但在B.forward里x被就地修改了。

2,或者在B.forward中x参与了前向传播,之后从B.forward返回到A.forward后,x又在A.forward中被就地修改了。

这些都很难被注意到,我在实际项目也正是犯了2的错误。

想避免这类错误,首先是尽量不要在其他forward函数中就地修改传入参数中的张量,如果实在是需要修改传入的张量,那么就先复制一份,比如在B.forward中,首先进行x = x.clone(),这样可以避免1的错误。

这样可以将问题限制在A.forward内部,然后再去检查是否犯了类似示例程序中第14行的错误,去消除掉对应的就地修改操作,或者在使用此张量前先复制一份就好了。

知乎上有一篇文章也讨论了此问题,不过作者并没有详细说明为什么不能就地修改,但我还是把链接放这供君参考:https://zhuanlan.zhihu.com/p/115808980

留下评论