forked from OlafenwaMoses/ImageAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsqueezenet.py
More file actions
86 lines (53 loc) · 3.08 KB
/
squeezenet.py
File metadata and controls
86 lines (53 loc) · 3.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from tensorflow.python.keras.layers import Input, Conv2D, MaxPool2D, Activation, concatenate, Dropout
from tensorflow.python.keras.layers import GlobalAvgPool2D, GlobalMaxPool2D
from tensorflow.python.keras.models import Model
def squeezenet_fire_module(input, input_channel_small=16, input_channel_large=64):
channel_axis = 3
input = Conv2D(input_channel_small, (1,1), padding="valid" )(input)
input = Activation("relu")(input)
input_branch_1 = Conv2D(input_channel_large, (1,1), padding="valid" )(input)
input_branch_1 = Activation("relu")(input_branch_1)
input_branch_2 = Conv2D(input_channel_large, (3, 3), padding="same")(input)
input_branch_2 = Activation("relu")(input_branch_2)
input = concatenate([input_branch_1, input_branch_2], axis=channel_axis)
return input
def SqueezeNet(include_top = True, weights="imagenet", model_input=None, non_top_pooling=None,
num_classes=1000, model_path = ""):
if(weights == "imagenet" and num_classes != 1000):
raise ValueError("You must parse in SqueezeNet model trained on the 1000 class ImageNet")
image_input = model_input
network = Conv2D(64, (3,3), strides=(2,2), padding="valid")(image_input)
network = Activation("relu")(network)
network = MaxPool2D( pool_size=(3,3) , strides=(2,2))(network)
network = squeezenet_fire_module(input=network, input_channel_small=16, input_channel_large=64)
network = squeezenet_fire_module(input=network, input_channel_small=16, input_channel_large=64)
network = MaxPool2D(pool_size=(3,3), strides=(2,2))(network)
network = squeezenet_fire_module(input=network, input_channel_small=32, input_channel_large=128)
network = squeezenet_fire_module(input=network, input_channel_small=32, input_channel_large=128)
network = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(network)
network = squeezenet_fire_module(input=network, input_channel_small=48, input_channel_large=192)
network = squeezenet_fire_module(input=network, input_channel_small=48, input_channel_large=192)
network = squeezenet_fire_module(input=network, input_channel_small=64, input_channel_large=256)
network = squeezenet_fire_module(input=network, input_channel_small=64, input_channel_large=256)
if(include_top):
network = Dropout(0.5)(network)
network = Conv2D(num_classes, kernel_size=(1,1), padding="valid", name="last_conv")(network)
network = Activation("relu")(network)
network = GlobalAvgPool2D()(network)
network = Activation("softmax")(network)
else:
if(non_top_pooling == "Average"):
network = GlobalAvgPool2D()(network)
elif(non_top_pooling == "Maximum"):
network = GlobalMaxPool2D()(network)
elif(non_top_pooling == None):
pass
input_image = image_input
model = Model(inputs=input_image, outputs=network)
if(weights =="imagenet"):
weights_path = model_path
model.load_weights(weights_path)
elif(weights =="trained"):
weights_path = model_path
model.load_weights(weights_path)
return model