现在位置: 首页 > 剪裁
2020年02月18日 编程语言 ⁄ 共 408字 评论关闭

pytorch梯度剪裁方式 我就废话不多说,看例子吧! import torch.nn as nn outputs = model(data) loss= loss_fn(outputs, target) optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2) optimizer.step() nn.utils.clip_grad_norm_ 的参数: param

阅读全文