1. 安装pytorch-OpCounter
pip install thop
或者
pip install --upgrade git+https://github.com/Lyken17/pytorch-OpCounter.git
2. 使用
from torch import nn
from thop import profile
from thop import clever_format
class Net(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2):
...
x1 = torch.randn(1, 3, 224, 224)
x2 = torch.randn(1, 4, 250)
# applicable to multiple inputs
macs, params = profile(net, inputs=((x1, x2)))
print(macs / 10 ** 9, params / 10 ** 6)
# convert to str
macs, params = clever_format([macs, params], "%.3f")
print(macs, params) # output ***.G, ***.M
版权声明:本文为u012897374原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。