求网络的flops

from tensorflow.keras import Model, Input
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D, Dropout

from keras_flops import get_flops
# build model

# build model
# inp = Input((32, 32, 1))
# x = Conv2D(32, kernel_size=(3, 3), activation="relu")(inp)
# x = MaxPooling2D(pool_size=(2, 2))(x)
# x = Flatten()(x)
# x = Dense(128, activation="relu")(x)
# out = Dense(10, activation="softmax")(x)
# model = Model(inp, out)
from tensorflow.keras.applications.resnet import ResNet101

model = ResNet101(weights=None,input_shape=(480,480,3), classes=3)
model.summary()
# calculae flops
flops = get_flops(model, batch_size=1)
print(f"flops: {flops / 10 ** 9:.03} g")

版权声明:本文为qq_61365595原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。