pytorch报错:one of the variables needed for gradient computation has been modified by。。。
今天犯了一个低级错误,记录在此:
pytorch中,对于tensor的操作,如果需要梯度传播,那就不能是inplace的。要注意的是:+=,*=之类的运算符是inplace的操作!!
也就是说
b += a
会报错,不过如果改成
b = b+a
就不会报错,因为此时b所引用的内存位置已经改变了,即id(b)发生了改变
如下:
>>> import torch
>>> a = torch.randn(3,4)
>>> a.shape
torch.Size([3, 4])
>>> b = torch.randn(3,4)
>>> id(a)
2230921339008
>>> a = a+b
>>> id(a)
2230921340032
>>> a += b
>>> id(a)
2230921340032
>>>
另外,还有一个我以前没注意过的点,那就是python中对于符号的重载,还挺复杂的,需要辨析__add__和__radd__ 和 __iadd__的区别