admin管理员组

文章数量:1130349

报错信息

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

方法一

网上大部分的方法是这样的,在backward()函数中添加参数retain_graph=True:

loss.backward(retain_graph=True)

但是我试过后还是会报出其他的问题。

方法二

我出现这个问题是发生在自定义循环神经网络时发生的,即将网络输出作为下一时刻的输入时出现的问题。

报错代码:

for idx, (u, x1, x2, x1_next, x2_next) in enumerate(train_loader):
	if idx == 0:
    	x1_input = x1
        x2_input = x2
	output = net(u.to(device),
	             x1_input.to(device),
	             x2_input.to(device))
	x1_input = output[:, 0]
	x2_input = output[:, 1]

解决方法是将numpy.array类型作为中间变量,解决方法如下:

for idx, (u, x1, x2, x1_next, x2_next) in enumerate(train_loader):
    if idx == 0:
        x1_input = x1.data.numpy()
        x2_input = x2.data.numpy()
    output = net(u.to(device),
                 torch.from_numpy(x1_input).float().to(device),
                 torch.from_numpy(x2_input).float().to(device))
	x1_input = output[:, 0].data.cpu().numpy()
	x2_input = output[:, 1].data.cpu().numpy()

循环之前还应该有个x1_input 、x2_input的初始化,否则会报错(numpy.zeros)

报错信息

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

方法一

网上大部分的方法是这样的,在backward()函数中添加参数retain_graph=True:

loss.backward(retain_graph=True)

但是我试过后还是会报出其他的问题。

方法二

我出现这个问题是发生在自定义循环神经网络时发生的,即将网络输出作为下一时刻的输入时出现的问题。

报错代码:

for idx, (u, x1, x2, x1_next, x2_next) in enumerate(train_loader):
	if idx == 0:
    	x1_input = x1
        x2_input = x2
	output = net(u.to(device),
	             x1_input.to(device),
	             x2_input.to(device))
	x1_input = output[:, 0]
	x2_input = output[:, 1]

解决方法是将numpy.array类型作为中间变量,解决方法如下:

for idx, (u, x1, x2, x1_next, x2_next) in enumerate(train_loader):
    if idx == 0:
        x1_input = x1.data.numpy()
        x2_input = x2.data.numpy()
    output = net(u.to(device),
                 torch.from_numpy(x1_input).float().to(device),
                 torch.from_numpy(x2_input).float().to(device))
	x1_input = output[:, 0].data.cpu().numpy()
	x2_input = output[:, 1].data.cpu().numpy()

循环之前还应该有个x1_input 、x2_input的初始化,否则会报错(numpy.zeros)

本文标签: GraphRuntimeErrorbuffersTime