πŸ‘¨‍🏫

[PyTorch] RuntimeError - dtype, grad_fn

geum 2023. 6. 27. 07:40

RuntimeError: Expected floating point type for target with class probabilities, got Long

ν•™μŠ΅ κ³Όμ •μ—μ„œ loss = loss_fn(pred, label) μ½”λ“œλ₯Ό μ‚¬μš©ν–ˆλŠ”λ°, pred와 label이 float일 쀄 μ•Œμ•˜λŠ”λ° long을 λ°›μ•˜λ‹€λŠ” μ—λŸ¬ 문ꡬ닀. μ˜ˆμ „μ—λŠ” loss κ΅¬ν•˜λŠ” λΆ€λΆ„μ—μ„œ 데이터 νƒ€μž…μ„ 지정해쀀 적이 μ—†λŠ” 것 같은데 μ•„λ§ˆ 데이터셋 ν΄λž˜μŠ€λ‚˜ collate_fn ν•¨μˆ˜μ—μ„œ μ„€μ •ν•΄μ€¬λ˜ λ“―? μ΄λ²ˆμ—λŠ” μ € λΆ€λΆ„μ—μ„œ 데이터 νƒ€μž…μ„ λͺ…μ‹œν•΄μ€¬λ‹€.

 

βœ… ν•΄κ²° 방법

loss = loss_fn(torch.tensor(pred, dtype=torch.float16), torch.tensor(label, dtype=torch.float16))

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

train_dataloader λŒλ©΄μ„œ loss.backward()λ₯Ό μˆ˜ν–‰ν•˜λŠ” κ³Όμ •μ—μ„œ λ§Œλ‚œ μ—λŸ¬λ‹€. μ§€κΈˆκΉŒμ§€ PyTorch μ½”λ“œ μ§œλ©΄μ„œ 처음 λ³Έ μ—λŸ¬ 둜그라 원인이 λ­”μ§€λŠ” 아직 νŒŒμ•… λͺ»ν•¨!

 

βœ… ν•΄κ²° 방법

# loss.backward() μ•žμ— loss.requires_grad_ μΆ”κ°€
loss.requires_grad_(True)
loss.backward()