YOLOv8 모델 구조 수정 방법
1. YOLOv8 패키지 클론
깃허브에서 코드 클론
1
git clone https://github.com/ultralytics/ultralytics.git
2. 사용할 모듈 정의
1. ultralytics/nn/modules 폴더에 들어가 새로운 파일 생성
2. 만들어준 파일에 pytorch 프레임워크로 새로운 모듈 클래스 정의
주의! 파일 내에 __all__ 변수에 모듈 이름 써줘야 함
1
2
3
4
5
6
7
8
9
10
__all__ = ('CoordAtt')
class h_sigmoid(nn.Module):
...
class h_swish(nn.Module):
...
class CoordAtt(nn.Module):
...
3. ultralytics/nn/modules 폴더 내에 있는
__init__ 에 모듈 추가
1
*\__all__* 튜플 클래스에도 모듈 추가
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from .coordatt import CoordAtt
...
__all__ = (
"Conv",
"Conv2",
"LightConv",
"RepConv",
"DWConv",
"DWConvTranspose2d",
"ConvTranspose",
...
"CoordAtt"
)
4. ultralytics/nn/tasks.py 에 모듈 추가
1
from ultralytics.nn.modules.coordatt import CoordAtt
모델 구조에 모듈 추가
ultratlycis/ultralytics/cfg/models/v8 폴더 안에 yolov8.yaml 파일 복사, 붙혀넣기
복사된 yolov8 copy.yaml 을 yolov8-ca.yaml 로 이름 변경
yolov8-ca.yaml 내부에서 구조 변경
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# YOLOv8.0n head
head:
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 3, C2f, [512]] # 12
-
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 3, C2f, [256]] # 15 (P3/8-small)
- [-1, 1, CoordAtt, [64, 64]] # 16 CA "여기 변경"
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 12], 1, Concat, [1]] # cat head P4
- [-1, 3, C2f, [512]] # 18 (P4/16-medium)
# 생략
각 줄의
첫 번째 요소는 이 레이어가 어디에 이어졌는지를 명시한다.
-1이라면 바로 앞 레이어이와 이어진 것이고
[-1, 12]라면, 바로 앞 레이어와 12번째 레이어를 연결하는 것이다.
두 번째 요소는 이 레이어를 몇 번 반복할 것인지를 명시한다.
세 번째 요소는 어떤 모듈을 사용할 것인지를 명시한다.
네 번째 요소는 이 레이어를 사용할 때 들어갈 인자값을 명시한다.
수정한 모델 구조 확인
ultralytics/train.py 파일 생성
1
2
3
4
5
6
from ultralytics import YOLO
from torchinfo import summary
model = YOLO('/home/gpuadmin/2023811010/Yolo/ultralytics/ultralytics/cfg/models/v8/yolov8n-ca.yaml')
summary(model.model, input_size=(1, 3, 640, 640))
1
2
3
4
5
6
7
8
└─CoordAtt: 2-75 [1, 64, 80, 80] --
│ │ └─AdaptiveAvgPool2d: 3-79 [1, 64, 80, 1] --
│ │ └─AdaptiveAvgPool2d: 3-80 [1, 64, 1, 80] --
│ │ └─Conv2d: 3-81 [1, 8, 160, 1] 520
│ │ └─BatchNorm2d: 3-82 [1, 8, 160, 1] 16
│ │ └─h_swish: 3-83 [1, 8, 160, 1] --
│ │ └─Conv2d: 3-84 [1, 64, 80, 1] 576
│ │ └─Conv2d: 3-85 [1, 64, 1, 80] 576
CoordAtt 가 추가된 것을 확인 할 수 있다.
전체 모델 구조
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
(yolo) gpuadmin@gpuserver:~/2023811010/Yolo/ultralytics$ python train.py
====================================================================================================
Layer (type:depth-idx) Output Shape Param #
====================================================================================================
DetectionModel [1, 84, 8400] --
├─Sequential: 1-1 -- --
│ └─Conv: 2-1 [1, 16, 320, 320] --
│ │ └─Conv2d: 3-1 [1, 16, 320, 320] 432
│ │ └─BatchNorm2d: 3-2 [1, 16, 320, 320] 32
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─Conv: 2-3 [1, 32, 160, 160] --
│ │ └─Conv2d: 3-4 [1, 32, 160, 160] 4,608
│ │ └─BatchNorm2d: 3-5 [1, 32, 160, 160] 64
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─C2f: 2-5 [1, 32, 160, 160] 6,272
│ │ └─Conv: 3-7 [1, 32, 160, 160] 1,088
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─C2f: 2-11 -- (recursive)
│ │ └─ModuleList: 3-11 -- (recursive)
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─C2f: 2-11 -- (recursive)
│ │ └─ModuleList: 3-11 -- (recursive)
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─C2f: 2-11 -- (recursive)
│ │ └─Conv: 3-13 [1, 32, 160, 160] 1,600
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─Conv: 2-13 [1, 64, 80, 80] --
│ │ └─Conv2d: 3-15 [1, 64, 80, 80] 18,432
│ │ └─BatchNorm2d: 3-16 [1, 64, 80, 80] 128
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─C2f: 2-15 [1, 64, 80, 80] 45,440
│ │ └─Conv: 3-18 [1, 64, 80, 80] 4,224
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─C2f: 2-25 -- (recursive)
│ │ └─ModuleList: 3-26 -- (recursive)
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─C2f: 2-25 -- (recursive)
│ │ └─ModuleList: 3-26 -- (recursive)
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─C2f: 2-25 -- (recursive)
│ │ └─ModuleList: 3-26 -- (recursive)
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─C2f: 2-25 -- (recursive)
│ │ └─ModuleList: 3-26 -- (recursive)
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─C2f: 2-25 -- (recursive)
│ │ └─Conv: 3-28 [1, 64, 80, 80] 8,320
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─Conv: 2-27 [1, 128, 40, 40] --
│ │ └─Conv2d: 3-30 [1, 128, 40, 40] 73,728
│ │ └─BatchNorm2d: 3-31 [1, 128, 40, 40] 256
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─C2f: 2-29 [1, 128, 40, 40] 180,992
│ │ └─Conv: 3-33 [1, 128, 40, 40] 16,640
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─C2f: 2-39 -- (recursive)
│ │ └─ModuleList: 3-41 -- (recursive)
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─C2f: 2-39 -- (recursive)
│ │ └─ModuleList: 3-41 -- (recursive)
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─C2f: 2-39 -- (recursive)
│ │ └─ModuleList: 3-41 -- (recursive)
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─C2f: 2-39 -- (recursive)
│ │ └─ModuleList: 3-41 -- (recursive)
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─C2f: 2-39 -- (recursive)
│ │ └─Conv: 3-43 [1, 128, 40, 40] 33,024
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─Conv: 2-41 [1, 256, 20, 20] --
│ │ └─Conv2d: 3-45 [1, 256, 20, 20] 294,912
│ │ └─BatchNorm2d: 3-46 [1, 256, 20, 20] 512
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─C2f: 2-43 [1, 256, 20, 20] 394,240
│ │ └─Conv: 3-48 [1, 256, 20, 20] 66,048
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─C2f: 2-49 -- (recursive)
│ │ └─ModuleList: 3-52 -- (recursive)
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─C2f: 2-49 -- (recursive)
│ │ └─ModuleList: 3-52 -- (recursive)
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─C2f: 2-49 -- (recursive)
│ │ └─Conv: 3-54 [1, 256, 20, 20] 98,816
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─SPPF: 2-51 [1, 256, 20, 20] 131,584
│ │ └─Conv: 3-56 [1, 128, 20, 20] 33,024
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─SPPF: 2-53 -- (recursive)
│ │ └─MaxPool2d: 3-58 [1, 128, 20, 20] --
│ │ └─MaxPool2d: 3-59 [1, 128, 20, 20] --
│ │ └─MaxPool2d: 3-60 [1, 128, 20, 20] --
│ │ └─Conv: 3-61 [1, 256, 20, 20] 131,584
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─Upsample: 2-55 [1, 256, 40, 40] --
│ └─Concat: 2-56 [1, 384, 40, 40] --
│ └─C2f: 2-57 [1, 128, 40, 40] 98,816
│ │ └─Conv: 3-63 [1, 128, 40, 40] 49,408
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─C2f: 2-63 -- (recursive)
│ │ └─ModuleList: 3-67 -- (recursive)
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─C2f: 2-63 -- (recursive)
│ │ └─ModuleList: 3-67 -- (recursive)
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─C2f: 2-63 -- (recursive)
│ │ └─Conv: 3-69 [1, 128, 40, 40] 24,832
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─Upsample: 2-65 [1, 128, 80, 80] --
│ └─Concat: 2-66 [1, 192, 80, 80] --
│ └─C2f: 2-67 [1, 64, 80, 80] 24,832
│ │ └─Conv: 3-71 [1, 64, 80, 80] 12,416
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─C2f: 2-73 -- (recursive)
│ │ └─ModuleList: 3-75 -- (recursive)
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─C2f: 2-73 -- (recursive)
│ │ └─ModuleList: 3-75 -- (recursive)
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─C2f: 2-73 -- (recursive)
│ │ └─Conv: 3-77 [1, 64, 80, 80] 6,272
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─CoordAtt: 2-75 [1, 64, 80, 80] --
│ │ └─AdaptiveAvgPool2d: 3-79 [1, 64, 80, 1] --
│ │ └─AdaptiveAvgPool2d: 3-80 [1, 64, 1, 80] --
│ │ └─Conv2d: 3-81 [1, 8, 160, 1] 520
│ │ └─BatchNorm2d: 3-82 [1, 8, 160, 1] 16
│ │ └─h_swish: 3-83 [1, 8, 160, 1] --
│ │ └─Conv2d: 3-84 [1, 64, 80, 1] 576
│ │ └─Conv2d: 3-85 [1, 64, 1, 80] 576
│ └─Conv: 2-76 [1, 64, 40, 40] --
│ │ └─Conv2d: 3-86 [1, 64, 40, 40] 36,864
│ │ └─BatchNorm2d: 3-87 [1, 64, 40, 40] 128
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─Concat: 2-78 [1, 192, 40, 40] --
│ └─C2f: 2-79 [1, 128, 40, 40] 98,816
│ │ └─Conv: 3-89 [1, 128, 40, 40] 24,832
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─C2f: 2-85 -- (recursive)
│ │ └─ModuleList: 3-93 -- (recursive)
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─C2f: 2-85 -- (recursive)
│ │ └─ModuleList: 3-93 -- (recursive)
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─C2f: 2-85 -- (recursive)
│ │ └─Conv: 3-95 [1, 128, 40, 40] 24,832
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─Conv: 2-87 [1, 128, 20, 20] --
│ │ └─Conv2d: 3-97 [1, 128, 20, 20] 147,456
│ │ └─BatchNorm2d: 3-98 [1, 128, 20, 20] 256
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─Concat: 2-89 [1, 384, 20, 20] --
│ └─C2f: 2-90 [1, 256, 20, 20] 394,240
│ │ └─Conv: 3-100 [1, 256, 20, 20] 98,816
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─C2f: 2-96 -- (recursive)
│ │ └─ModuleList: 3-104 -- (recursive)
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─C2f: 2-96 -- (recursive)
│ │ └─ModuleList: 3-104 -- (recursive)
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─C2f: 2-96 -- (recursive)
│ │ └─Conv: 3-106 [1, 256, 20, 20] 98,816
│ └─Detect: 2-97 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ └─Detect: 2-98 [1, 84, 8400] --
│ │ └─ModuleList: 3-124 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ │ └─ModuleList: 3-124 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ │ └─ModuleList: 3-124 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ │ └─ModuleList: 3-124 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ │ └─ModuleList: 3-124 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ │ └─ModuleList: 3-124 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ │ └─ModuleList: 3-124 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ │ └─ModuleList: 3-124 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ │ └─ModuleList: 3-124 -- (recursive)
│ │ └─ModuleList: 3-125 -- (recursive)
│ │ └─DFL: 3-126 [1, 4, 8400] (16)
====================================================================================================
Total params: 5,508,456
Trainable params: 5,508,440
Non-trainable params: 16
Total mult-adds (G): 4.57
====================================================================================================
Input size (MB): 4.92
Forward/backward pass size (MB): 232.26
Params size (MB): 13.63
Estimated Total Size (MB): 250.80
====================================================================================================
수정한 모델 학습
data.yaml 예시
1
2
3
4
5
6
train: /home/gpuadmin/2023811010/Yolo/dataset/train/images
val: /home/gpuadmin/2023811010/Yolo/dataset/valid/images
test: /home/gpuadmin/2023811010/Yolo/dataset/test/images
nc: 2
names: ['Crack', 'Normality']
1
2
3
model.train(
data='data.yaml', epochs=10, patience=30, batch=32
)
This post is licensed under CC BY 4.0 by the author.