本篇文章给大家分享的是有关pytorch中怎么查看可训练参数,小编觉得挺实用的,因此分享给大家学习,希望大家阅读完这篇文章后可以有所收获,话不多说,跟着小编一起来看看吧。
pytorch中model.parameters()函数定义如下:
def parameters(self):
r"""Returns an iterator over module parameters.
This is typically passed to an optimizer.
Yields:
Parameter: module parameter
Example::
>>> for param in model.parameters():
>>> print(type(param.data), param.size())
<class 'torch.FloatTensor'> (20L,)
<class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
"""
for name, param in self.named_parameters():
yield param
所以,我们可以遍历named_parameters()中的所有的参数,只打印那些param.requires_grad=True的变量。具体实现代码如下所示:
for name, param in model.named_parameters():
if param.requires_grad:
print(name)
以上就是pytorch中怎么查看可训练参数,小编相信有部分知识点可能是我们日常工作会见到或用到的。希望你能通过这篇文章学到更多知识。更多详情敬请关注天达云行业资讯频道。