Post

YOLOv8 모델 구조 수정 방법

1. YOLOv8 패키지 클론

깃허브에서 코드 클론

1
git clone https://github.com/ultralytics/ultralytics.git



2. 사용할 모듈 정의

1. ultralytics/nn/modules 폴더에 들어가 새로운 파일 생성



2. 만들어준 파일에 pytorch 프레임워크로 새로운 모듈 클래스 정의

사용한 코드:


주의! 파일 내에 __all__ 변수에 모듈 이름 써줘야 함

1
2
3
4
5
6
7
8
9
10
__all__ = ('CoordAtt')

class h_sigmoid(nn.Module):
...

class h_swish(nn.Module):
...

class CoordAtt(nn.Module):
...


3. ultralytics/nn/modules 폴더 내에 있는

__init__ 에 모듈 추가

1
*\__all__* 튜플 클래스에도 모듈 추가
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from .coordatt import CoordAtt
...

__all__ = (
"Conv",
"Conv2",
"LightConv",
"RepConv",
"DWConv",
"DWConvTranspose2d",
"ConvTranspose",
...
"CoordAtt"
)


4. ultralytics/nn/tasks.py 에 모듈 추가

1
from ultralytics.nn.modules.coordatt import CoordAtt



모델 구조에 모듈 추가

  1. ultratlycis/ultralytics/cfg/models/v8 폴더 안에 yolov8.yaml 파일 복사, 붙혀넣기

  2. 복사된 yolov8 copy.yaml 을 yolov8-ca.yaml 로 이름 변경

  3. yolov8-ca.yaml 내부에서 구조 변경

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# YOLOv8.0n head
head:
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 3, C2f, [512]] # 12
- 
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 3, C2f, [256]] # 15 (P3/8-small)
- [-1, 1, CoordAtt, [64, 64]] # 16 CA "여기 변경"

- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 12], 1, Concat, [1]] # cat head P4
- [-1, 3, C2f, [512]] # 18 (P4/16-medium)

# 생략

각 줄의
첫 번째 요소는 이 레이어가 어디에 이어졌는지를 명시한다. 
-1이라면 바로 앞 레이어이와 이어진 것이고
[-1, 12]라면, 바로 앞 레이어와 12번째 레이어를 연결하는 것이다.
두 번째 요소는 이 레이어를 몇 번 반복할 것인지를 명시한다. 
세 번째 요소는 어떤 모듈을 사용할 것인지를 명시한다.
네 번째 요소는 이 레이어를 사용할 때 들어갈 인자값을 명시한다.



수정한 모델 구조 확인

ultralytics/train.py 파일 생성

1
2
3
4
5
6
from ultralytics import YOLO
from torchinfo import summary

model = YOLO('/home/gpuadmin/2023811010/Yolo/ultralytics/ultralytics/cfg/models/v8/yolov8n-ca.yaml')

summary(model.model, input_size=(1, 3, 640, 640))
1
2
3
4
5
6
7
8
└─CoordAtt: 2-75                              [1, 64, 80, 80]           --
│    │    └─AdaptiveAvgPool2d: 3-79                [1, 64, 80, 1]            --
│    │    └─AdaptiveAvgPool2d: 3-80                [1, 64, 1, 80]            --
│    │    └─Conv2d: 3-81                           [1, 8, 160, 1]            520
│    │    └─BatchNorm2d: 3-82                      [1, 8, 160, 1]            16
│    │    └─h_swish: 3-83                          [1, 8, 160, 1]            --
│    │    └─Conv2d: 3-84                           [1, 64, 80, 1]            576
│    │    └─Conv2d: 3-85                           [1, 64, 1, 80]            576

CoordAtt 가 추가된 것을 확인 할 수 있다.



전체 모델 구조

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
(yolo) gpuadmin@gpuserver:~/2023811010/Yolo/ultralytics$ python train.py 
====================================================================================================
Layer (type:depth-idx)                             Output Shape              Param #
====================================================================================================
DetectionModel                                     [1, 84, 8400]             --
├─Sequential: 1-1                                  --                        --
│    └─Conv: 2-1                                   [1, 16, 320, 320]         --
│    │    └─Conv2d: 3-1                            [1, 16, 320, 320]         432
│    │    └─BatchNorm2d: 3-2                       [1, 16, 320, 320]         32
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─Conv: 2-3                                   [1, 32, 160, 160]         --
│    │    └─Conv2d: 3-4                            [1, 32, 160, 160]         4,608
│    │    └─BatchNorm2d: 3-5                       [1, 32, 160, 160]         64
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─C2f: 2-5                                    [1, 32, 160, 160]         6,272
│    │    └─Conv: 3-7                              [1, 32, 160, 160]         1,088
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─C2f: 2-11                                   --                        (recursive)
│    │    └─ModuleList: 3-11                       --                        (recursive)
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─C2f: 2-11                                   --                        (recursive)
│    │    └─ModuleList: 3-11                       --                        (recursive)
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─C2f: 2-11                                   --                        (recursive)
│    │    └─Conv: 3-13                             [1, 32, 160, 160]         1,600
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─Conv: 2-13                                  [1, 64, 80, 80]           --
│    │    └─Conv2d: 3-15                           [1, 64, 80, 80]           18,432
│    │    └─BatchNorm2d: 3-16                      [1, 64, 80, 80]           128
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─C2f: 2-15                                   [1, 64, 80, 80]           45,440
│    │    └─Conv: 3-18                             [1, 64, 80, 80]           4,224
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─C2f: 2-25                                   --                        (recursive)
│    │    └─ModuleList: 3-26                       --                        (recursive)
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─C2f: 2-25                                   --                        (recursive)
│    │    └─ModuleList: 3-26                       --                        (recursive)
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─C2f: 2-25                                   --                        (recursive)
│    │    └─ModuleList: 3-26                       --                        (recursive)
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─C2f: 2-25                                   --                        (recursive)
│    │    └─ModuleList: 3-26                       --                        (recursive)
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─C2f: 2-25                                   --                        (recursive)
│    │    └─Conv: 3-28                             [1, 64, 80, 80]           8,320
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─Conv: 2-27                                  [1, 128, 40, 40]          --
│    │    └─Conv2d: 3-30                           [1, 128, 40, 40]          73,728
│    │    └─BatchNorm2d: 3-31                      [1, 128, 40, 40]          256
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─C2f: 2-29                                   [1, 128, 40, 40]          180,992
│    │    └─Conv: 3-33                             [1, 128, 40, 40]          16,640
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─C2f: 2-39                                   --                        (recursive)
│    │    └─ModuleList: 3-41                       --                        (recursive)
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─C2f: 2-39                                   --                        (recursive)
│    │    └─ModuleList: 3-41                       --                        (recursive)
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─C2f: 2-39                                   --                        (recursive)
│    │    └─ModuleList: 3-41                       --                        (recursive)
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─C2f: 2-39                                   --                        (recursive)
│    │    └─ModuleList: 3-41                       --                        (recursive)
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─C2f: 2-39                                   --                        (recursive)
│    │    └─Conv: 3-43                             [1, 128, 40, 40]          33,024
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─Conv: 2-41                                  [1, 256, 20, 20]          --
│    │    └─Conv2d: 3-45                           [1, 256, 20, 20]          294,912
│    │    └─BatchNorm2d: 3-46                      [1, 256, 20, 20]          512
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─C2f: 2-43                                   [1, 256, 20, 20]          394,240
│    │    └─Conv: 3-48                             [1, 256, 20, 20]          66,048
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─C2f: 2-49                                   --                        (recursive)
│    │    └─ModuleList: 3-52                       --                        (recursive)
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─C2f: 2-49                                   --                        (recursive)
│    │    └─ModuleList: 3-52                       --                        (recursive)
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─C2f: 2-49                                   --                        (recursive)
│    │    └─Conv: 3-54                             [1, 256, 20, 20]          98,816
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─SPPF: 2-51                                  [1, 256, 20, 20]          131,584
│    │    └─Conv: 3-56                             [1, 128, 20, 20]          33,024
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─SPPF: 2-53                                  --                        (recursive)
│    │    └─MaxPool2d: 3-58                        [1, 128, 20, 20]          --
│    │    └─MaxPool2d: 3-59                        [1, 128, 20, 20]          --
│    │    └─MaxPool2d: 3-60                        [1, 128, 20, 20]          --
│    │    └─Conv: 3-61                             [1, 256, 20, 20]          131,584
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─Upsample: 2-55                              [1, 256, 40, 40]          --
│    └─Concat: 2-56                                [1, 384, 40, 40]          --
│    └─C2f: 2-57                                   [1, 128, 40, 40]          98,816
│    │    └─Conv: 3-63                             [1, 128, 40, 40]          49,408
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─C2f: 2-63                                   --                        (recursive)
│    │    └─ModuleList: 3-67                       --                        (recursive)
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─C2f: 2-63                                   --                        (recursive)
│    │    └─ModuleList: 3-67                       --                        (recursive)
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─C2f: 2-63                                   --                        (recursive)
│    │    └─Conv: 3-69                             [1, 128, 40, 40]          24,832
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─Upsample: 2-65                              [1, 128, 80, 80]          --
│    └─Concat: 2-66                                [1, 192, 80, 80]          --
│    └─C2f: 2-67                                   [1, 64, 80, 80]           24,832
│    │    └─Conv: 3-71                             [1, 64, 80, 80]           12,416
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─C2f: 2-73                                   --                        (recursive)
│    │    └─ModuleList: 3-75                       --                        (recursive)
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─C2f: 2-73                                   --                        (recursive)
│    │    └─ModuleList: 3-75                       --                        (recursive)
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─C2f: 2-73                                   --                        (recursive)
│    │    └─Conv: 3-77                             [1, 64, 80, 80]           6,272
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─CoordAtt: 2-75                              [1, 64, 80, 80]           --
│    │    └─AdaptiveAvgPool2d: 3-79                [1, 64, 80, 1]            --
│    │    └─AdaptiveAvgPool2d: 3-80                [1, 64, 1, 80]            --
│    │    └─Conv2d: 3-81                           [1, 8, 160, 1]            520
│    │    └─BatchNorm2d: 3-82                      [1, 8, 160, 1]            16
│    │    └─h_swish: 3-83                          [1, 8, 160, 1]            --
│    │    └─Conv2d: 3-84                           [1, 64, 80, 1]            576
│    │    └─Conv2d: 3-85                           [1, 64, 1, 80]            576
│    └─Conv: 2-76                                  [1, 64, 40, 40]           --
│    │    └─Conv2d: 3-86                           [1, 64, 40, 40]           36,864
│    │    └─BatchNorm2d: 3-87                      [1, 64, 40, 40]           128
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─Concat: 2-78                                [1, 192, 40, 40]          --
│    └─C2f: 2-79                                   [1, 128, 40, 40]          98,816
│    │    └─Conv: 3-89                             [1, 128, 40, 40]          24,832
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─C2f: 2-85                                   --                        (recursive)
│    │    └─ModuleList: 3-93                       --                        (recursive)
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─C2f: 2-85                                   --                        (recursive)
│    │    └─ModuleList: 3-93                       --                        (recursive)
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─C2f: 2-85                                   --                        (recursive)
│    │    └─Conv: 3-95                             [1, 128, 40, 40]          24,832
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─Conv: 2-87                                  [1, 128, 20, 20]          --
│    │    └─Conv2d: 3-97                           [1, 128, 20, 20]          147,456
│    │    └─BatchNorm2d: 3-98                      [1, 128, 20, 20]          256
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─Concat: 2-89                                [1, 384, 20, 20]          --
│    └─C2f: 2-90                                   [1, 256, 20, 20]          394,240
│    │    └─Conv: 3-100                            [1, 256, 20, 20]          98,816
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─C2f: 2-96                                   --                        (recursive)
│    │    └─ModuleList: 3-104                      --                        (recursive)
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─C2f: 2-96                                   --                        (recursive)
│    │    └─ModuleList: 3-104                      --                        (recursive)
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─C2f: 2-96                                   --                        (recursive)
│    │    └─Conv: 3-106                            [1, 256, 20, 20]          98,816
│    └─Detect: 2-97                                --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    └─Detect: 2-98                                [1, 84, 8400]             --
│    │    └─ModuleList: 3-124                      --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    │    └─ModuleList: 3-124                      --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    │    └─ModuleList: 3-124                      --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    │    └─ModuleList: 3-124                      --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    │    └─ModuleList: 3-124                      --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    │    └─ModuleList: 3-124                      --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    │    └─ModuleList: 3-124                      --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    │    └─ModuleList: 3-124                      --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    │    └─ModuleList: 3-124                      --                        (recursive)
│    │    └─ModuleList: 3-125                      --                        (recursive)
│    │    └─DFL: 3-126                             [1, 4, 8400]              (16)
====================================================================================================
Total params: 5,508,456
Trainable params: 5,508,440
Non-trainable params: 16
Total mult-adds (G): 4.57
====================================================================================================
Input size (MB): 4.92
Forward/backward pass size (MB): 232.26
Params size (MB): 13.63
Estimated Total Size (MB): 250.80
====================================================================================================



수정한 모델 학습

data.yaml 예시

1
2
3
4
5
6
train: /home/gpuadmin/2023811010/Yolo/dataset/train/images
val: /home/gpuadmin/2023811010/Yolo/dataset/valid/images
test: /home/gpuadmin/2023811010/Yolo/dataset/test/images

nc: 2
names: ['Crack', 'Normality']
1
2
3
model.train(
	data='data.yaml', epochs=10, patience=30, batch=32
	)
This post is licensed under CC BY 4.0 by the author.