计算模型的params和FLOPs

1. 安装pytorch-OpCounter

pip install thop

或者

pip install --upgrade git+https://github.com/Lyken17/pytorch-OpCounter.git

2. 使用

from torch import nn
from thop import profile
from thop import clever_format


class Net(nn.Module):
	def __init__(self):
		super().__init__()
	
	def forward(self, x1, x2):
		...
	
x1 = torch.randn(1, 3, 224, 224)
x2 = torch.randn(1, 4, 250)

# applicable to multiple inputs
macs, params = profile(net, inputs=((x1, x2)))
print(macs / 10 ** 9, params / 10 ** 6)

# convert to str
macs, params = clever_format([macs, params], "%.3f")
print(macs, params)  # output ***.G, ***.M

版权声明:本文为u012897374原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。