Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
1.2k views
in Technique[技术] by (71.8m points)

pytorch - Unable to backpropagate through torch.rfft

I have two losses: one the usual L1 loss and second one involving torch.rfft()

def dft_amp(img):
    fft_im = torch.rfft( img, signal_ndim=2, onesided=False )
    fft_amp = fft_im[:,:,:,:,0]**2 + fft_im[:,:,:,:,1]**2
    return torch.sqrt(fft_amp)

l1_loss = torch.nn.L1Loss()

loss  = l1_loss(pred,gt) + l1_loss(dft_amp(pred),dft(gt_amp))

loss.backward()

This runs for the 1st iteration with both losses not bein being nan but the loss from 2nd iteration onwards becomes nan.

If however only the simple L1 loss is kept and l1_loss(dft_amp(pred),dft(gt_amp)) is omitted, the training proceeds normally.

Does torch.rfft() supports backpropagation? I am using pytorch 1.4.0

Any suggestions would be appreciated

question from:https://stackoverflow.com/questions/65932593/unable-to-backpropagate-through-torch-rfft

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

There seems to be some issue with torch.sqrt() and not with the torch.rfft(). torch.sqrt() cannot handle very small values. Thus in the above code replace

return torch.sqrt(fft_amp)

with

return torch.sqrt(fft_amp + 1e-10)

and the NaN values will vanish away.

Though this solution works, it would be nice if someone digs in why square roots are problematic for digital computations.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...