今天犯了一个低级错误,记录在此:

pytorch中,对于tensor的操作,如果需要梯度传播,那就不能是inplace的。要注意的是:+=,*=之类的运算符是inplace的操作!!

也就是说

 b += a

会报错,不过如果改成

b = b+a

就不会报错,因为此时b所引用的内存位置已经改变了,即id(b)发生了改变

如下:

>>> import torch
>>> a = torch.randn(3,4)
>>> a.shape
torch.Size([3, 4])
>>> b = torch.randn(3,4)
>>> id(a)
2230921339008
>>> a = a+b
>>> id(a)
2230921340032
>>> a += b
>>> id(a)
2230921340032
>>>

另外,还有一个我以前没注意过的点,那就是python中对于符号的重载,还挺复杂的,需要辨析__add__和__radd__ 和 __iadd__的区别

可部分参考Python:表达式 i += x 与 i = i + x 等价吗?

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注