Coverage for src/gitlabracadabra/mixins/protected_branches.py: 82%

187 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-10 17:02 +0100

1# 

2# Copyright (C) 2019-2025 Mathieu Parent <math.parent@gmail.com> 

3# 

4# This program is free software: you can redistribute it and/or modify 

5# it under the terms of the GNU Lesser General Public License as published by 

6# the Free Software Foundation, either version 3 of the License, or 

7# (at your option) any later version. 

8# 

9# This program is distributed in the hope that it will be useful, 

10# but WITHOUT ANY WARRANTY; without even the implied warranty of 

11# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

12# GNU Lesser General Public License for more details. 

13# 

14# You should have received a copy of the GNU Lesser General Public License 

15# along with this program. If not, see <http://www.gnu.org/licenses/>. 

16 

17from __future__ import annotations 

18 

19import logging 

20from json import dumps as json_dumps 

21from time import sleep 

22from types import MethodType 

23 

24from gitlab.exceptions import GitlabCreateError, GitlabParsingError, GitlabUpdateError 

25from gitlab.mixins import CRUDMixin, NoUpdateMixin, SaveMixin 

26from gitlab.v4.objects import ProjectProtectedBranch, ProjectProtectedBranchManager 

27 

28from gitlabracadabra.gitlab.access_levels import access_level_value 

29from gitlabracadabra.objects.object import GitLabracadabraObject 

30 

31ALLOWED_TO_PREFIX = "allowed_to_" 

32ALLOWED_TO_MERGE = "allowed_to_merge" 

33ALLOWED_TO_PUSH = "allowed_to_push" 

34ACCESS_LEVELS_SUFFIX = "_access_levels" 

35ACCESS_LEVEL = "access_level" 

36USER_ID = "user_id" 

37GROUP_ID = "group_id" 

38DEPLOY_KEY_ID = "deploy_key_id" 

39 

40logger = logging.getLogger(__name__) 

41 

42 

43# Before https://github.com/python-gitlab/python-gitlab/commit/a867c48 (v4.5.0) 

44if SaveMixin not in ProjectProtectedBranch.__bases__: 44 ↛ 45line 44 didn't jump to line 45 because the condition on line 44 was never true

45 ProjectProtectedBranch.__bases__ = (SaveMixin, *ProjectProtectedBranch.__bases__) 

46if NoUpdateMixin in ProjectProtectedBranchManager.__bases__: 46 ↛ 47line 46 didn't jump to line 47 because the condition on line 46 was never true

47 ProjectProtectedBranchManager.__bases__ = tuple( 

48 [CRUDMixin if base == NoUpdateMixin else base for base in ProjectProtectedBranchManager.__bases__] 

49 ) 

50 

51 # https://github.com/python-gitlab/python-gitlab/commit/f711d9e (v3.14) 

52 def _http_patch(self, path, *, query_data=None, post_data=None, **kwargs): 

53 query_data = query_data or {} 

54 post_data = post_data or {} 

55 

56 response = self.http_request( 

57 "patch", 

58 path, 

59 query_data=query_data, 

60 post_data=post_data, 

61 **kwargs, 

62 ) 

63 try: 

64 return response.json() 

65 except Exception as err: 

66 raise GitlabParsingError( 

67 error_message="Failed to parse the server message", 

68 ) from err 

69 

70 # https://github.com/python-gitlab/python-gitlab/commit/7073a2d (v4.0) 

71 def _get_update_method(self): 

72 return MethodType(_http_patch, self.gitlab) 

73 

74 ProjectProtectedBranchManager._get_update_method = _get_update_method # type: ignore # noqa: SLF001 

75 

76 

77class ProtectedBranchesMixin(GitLabracadabraObject): 

78 """Object with protected branches.""" 

79 

80 def _process_protected_branches(self, param_name, param_value, *, dry_run=False, skip_save=False): 

81 """Process the protected_branches param. 

82 

83 Args: 

84 param_name: "protected_branches". 

85 param_value: target protected branches. 

86 dry_run: Dry run. 

87 skip_save: False. 

88 """ 

89 assert param_name == "protected_branches" # noqa: S101 

90 assert not skip_save # noqa: S101 

91 current_protected_branches = self._get_current_protected_branches() 

92 self._create_or_update_protected_branches(param_value, dry_run, current_protected_branches) 

93 self._remove_unknown_protected_branches(param_value, dry_run, current_protected_branches) 

94 

95 def _get_current_protected_branches(self): 

96 current_protected_branches = dict( 

97 [ 

98 [protected_branch.name, protected_branch] 

99 for protected_branch in self._obj.protectedbranches.list(all=True) 

100 ] 

101 ) 

102 if not self._just_created: 

103 return current_protected_branches 

104 for _ in range(10): 104 ↛ 115line 104 didn't jump to line 115 because the loop on line 104 didn't complete

105 if current_protected_branches: 

106 break 

107 logger.debug("[%s] Waiting one second before retrieving protected branches again", self._name) 

108 sleep(1) 

109 current_protected_branches = dict( 

110 [ 

111 [protected_branch.name, protected_branch] 

112 for protected_branch in self._obj.protectedbranches.list(all=True) 

113 ] 

114 ) 

115 return current_protected_branches 

116 

117 def _create_or_update_protected_branches(self, param_value, dry_run, current_protected_branches): 

118 for protected_name, target_config_str in sorted(param_value.items()): 

119 if protected_name in current_protected_branches: 

120 self._update_protected_branch(protected_name, target_config_str, dry_run, current_protected_branches) 

121 else: 

122 self._create_protected_branch(protected_name, target_config_str, dry_run, current_protected_branches) 

123 

124 def _update_protected_branch(self, protected_name, target_config_str, dry_run, current_protected_branches): 

125 target_config = self._target_protected_branch_config(protected_name, target_config_str) 

126 current_config = self._current_protected_branch_config(current_protected_branches, protected_name) 

127 for current_k, current_v in current_config.items(): 

128 if current_k not in target_config: 

129 target_config[current_k] = current_v 

130 if current_config == target_config: 130 ↛ 131line 130 didn't jump to line 131 because the condition on line 130 was never true

131 return 

132 # CE compatibility 

133 can_use_allowed_to = "code_owner_approval_required" in current_config 

134 if not can_use_allowed_to: 

135 self._create_protected_branch(protected_name, target_config_str, dry_run, current_protected_branches) 

136 return 

137 if dry_run: 137 ↛ 138line 137 didn't jump to line 138 because the condition on line 137 was never true

138 logger.info( 

139 "[%s] NOT Changing protected branch %s: %s -> %s (dry-run)", 

140 self._name, 

141 protected_name, 

142 json_dumps(current_config, sort_keys=True), 

143 json_dumps(target_config, sort_keys=True), 

144 ) 

145 return 

146 logger.info( 

147 "[%s] Changing protected branch %s: %s -> %s", 

148 self._name, 

149 protected_name, 

150 json_dumps(current_config, sort_keys=True), 

151 json_dumps(target_config, sort_keys=True), 

152 ) 

153 protected_branch = current_protected_branches.get(protected_name) 

154 for target_config_key, target_config_value in target_config.items(): 

155 current_config_value = current_config.get(target_config_key) 

156 if target_config_value == current_config_value: 

157 continue 

158 if target_config_key.startswith(ALLOWED_TO_PREFIX): 

159 self._update_protected_branch_access_levels( 

160 protected_branch, 

161 current_config_value, 

162 target_config_key, 

163 target_config_value, 

164 ) 

165 else: 

166 setattr(protected_branch, target_config_key, target_config_value) 

167 try: 

168 protected_branch.save() 

169 except GitlabUpdateError as err: 

170 logger.warning( 

171 "[%s] Unable to change protected branch %s: %s", 

172 self._name, 

173 protected_name, 

174 err.error_message, 

175 ) 

176 

177 def _update_protected_branch_access_levels( 

178 self, 

179 protected_branch, 

180 current_config_value, 

181 target_config_key, 

182 target_config_value, 

183 ): 

184 changes = [] 

185 for target_config_value_item in target_config_value: 

186 if target_config_value_item in current_config_value: 

187 continue 

188 changes.append(target_config_value_item) 

189 for current_config_value_item in getattr(protected_branch, self._received_attribute_name(target_config_key)): 

190 current_access_level = self._current_access_level(current_config_value_item) 

191 if current_access_level in target_config_value: 

192 target_config_value.remove(current_access_level) 

193 continue 

194 changes.append({"id": current_config_value_item.get("id"), "_destroy": True}) 

195 setattr(protected_branch, target_config_key, changes) 

196 

197 def _received_attribute_name(self, attribute_name): 

198 if attribute_name.startswith(ALLOWED_TO_PREFIX): 198 ↛ 200line 198 didn't jump to line 200 because the condition on line 198 was always true

199 return attribute_name.removeprefix(ALLOWED_TO_PREFIX) + ACCESS_LEVELS_SUFFIX 

200 return attribute_name 

201 

202 def _create_protected_branch(self, protected_name, target_config_str, dry_run, current_protected_branches): 

203 target_config = self._target_protected_branch_config(protected_name, target_config_str) 

204 # CE compatibility 

205 target_allowed_to_push = target_config.get(ALLOWED_TO_PUSH) 

206 if len(target_allowed_to_push or []) == 1: 

207 target_config["push_access_level"] = target_allowed_to_push[0][ACCESS_LEVEL] 

208 target_allowed_to_merge = target_config.get(ALLOWED_TO_MERGE) 

209 if len(target_allowed_to_merge or []) == 1: 

210 target_config["merge_access_level"] = target_allowed_to_merge[0][ACCESS_LEVEL] 

211 if dry_run: 211 ↛ 212line 211 didn't jump to line 212 because the condition on line 211 was never true

212 logger.info( 

213 "[%s] NOT Creating protected branch %s: %s (dry-run)", 

214 self._name, 

215 protected_name, 

216 json_dumps(target_config, sort_keys=True), 

217 ) 

218 return 

219 logger.info( 

220 "[%s] Creating protected branch %s: %s", 

221 self._name, 

222 protected_name, 

223 json_dumps(target_config, sort_keys=True), 

224 ) 

225 try: 

226 if protected_name in current_protected_branches: 

227 # GitLab CE can't update push/merge access level 

228 self._obj.protectedbranches.delete(protected_name) 

229 self._obj.protectedbranches.create(target_config) 

230 except GitlabCreateError as err: 

231 logger.warning( 

232 "[%s] Unable to create protected branch %s: %s", 

233 self._name, 

234 protected_name, 

235 err.error_message, 

236 ) 

237 

238 def _target_protected_branch_config(self, protected_name, target_config_str): 

239 target_config_int = { 

240 "name": protected_name, 

241 } 

242 for target_config_key, target_config_value in target_config_str.items(): 

243 if target_config_key.endswith("_access_level"): 

244 target_config_int[target_config_key] = access_level_value(target_config_value) 

245 elif target_config_key.startswith(ALLOWED_TO_PREFIX): 

246 target_config_int[target_config_key] = self._target_access_levels(target_config_value) 

247 else: 

248 target_config_int[target_config_key] = target_config_value 

249 if "merge_access_level" in target_config_int: 

250 if ALLOWED_TO_MERGE not in target_config_int: 250 ↛ 252line 250 didn't jump to line 252 because the condition on line 250 was always true

251 target_config_int[ALLOWED_TO_MERGE] = [] 

252 target_config_int[ALLOWED_TO_MERGE].append( 

253 { 

254 ACCESS_LEVEL: target_config_int.pop("merge_access_level"), 

255 } 

256 ) 

257 if "push_access_level" in target_config_int: 

258 if ALLOWED_TO_PUSH not in target_config_int: 258 ↛ 260line 258 didn't jump to line 260 because the condition on line 258 was always true

259 target_config_int[ALLOWED_TO_PUSH] = [] 

260 target_config_int[ALLOWED_TO_PUSH].append( 

261 { 

262 ACCESS_LEVEL: target_config_int.pop("push_access_level"), 

263 } 

264 ) 

265 return target_config_int 

266 

267 def _target_access_levels(self, access_levels_str): 

268 target_access_levels = [] 

269 for access_level_str in access_levels_str: 

270 if "role" in access_level_str: 

271 target_access_levels.append( 

272 { 

273 ACCESS_LEVEL: access_level_value(access_level_str.get("role")), 

274 } 

275 ) 

276 elif "user" in access_level_str: 

277 target_access_levels.append( 

278 { 

279 USER_ID: self.connection.user_cache.id_from_username(access_level_str.get("user")), 

280 } 

281 ) 

282 elif "group" in access_level_str: 282 ↛ 283line 282 didn't jump to line 283 because the condition on line 282 was never true

283 target_access_levels.append( 

284 { 

285 GROUP_ID: self.connection.group_cache.id_from_full_path(access_level_str.get("group")), 

286 } 

287 ) 

288 elif "deploy_key" in access_level_str: 288 ↛ 269line 288 didn't jump to line 269 because the condition on line 288 was always true

289 target_access_levels.append( 

290 { 

291 DEPLOY_KEY_ID: self.connection.deploy_key_cache.id_from_title( 

292 self._obj.id, 

293 access_level_str.get("deploy_key"), 

294 ), 

295 } 

296 ) 

297 return sorted( 

298 target_access_levels, 

299 key=lambda access_level: json_dumps(access_level, sort_keys=True), 

300 ) 

301 

302 def _current_protected_branch_config(self, current_protected_branches, protected_name): 

303 if protected_name in current_protected_branches: 303 ↛ 312line 303 didn't jump to line 312 because the condition on line 303 was always true

304 current_protected_branch = current_protected_branches.get(protected_name) 

305 current_config = {} 

306 for param_k, param_v in current_protected_branch.attributes.items(): 

307 if param_k in {"push_access_levels", "merge_access_levels", "unprotect_access_levels"}: 

308 current_config[self._sent_attribute_name(param_k)] = self._current_access_levels(param_v) 

309 elif param_k in {"name", "allow_force_push", "code_owner_approval_required"}: 

310 current_config[param_k] = param_v 

311 return current_config 

312 return {} 

313 

314 def _current_access_levels(self, access_levels): 

315 return sorted( 

316 [self._current_access_level(access_level) for access_level in access_levels], 

317 key=lambda access_level: json_dumps(access_level, sort_keys=True), 

318 ) 

319 

320 def _current_access_level(self, access_level): 

321 if access_level.get(USER_ID): 321 ↛ 322line 321 didn't jump to line 322 because the condition on line 321 was never true

322 return {USER_ID: access_level.get(USER_ID)} 

323 if access_level.get(GROUP_ID): 323 ↛ 324line 323 didn't jump to line 324 because the condition on line 323 was never true

324 return {GROUP_ID: access_level.get(GROUP_ID)} 

325 if access_level.get(DEPLOY_KEY_ID): 325 ↛ 326line 325 didn't jump to line 326 because the condition on line 325 was never true

326 return {DEPLOY_KEY_ID: access_level.get(DEPLOY_KEY_ID)} 

327 return {ACCESS_LEVEL: access_level.get(ACCESS_LEVEL)} 

328 

329 def _sent_attribute_name(self, attribute_name): 

330 if attribute_name.endswith(ACCESS_LEVELS_SUFFIX): 330 ↛ 332line 330 didn't jump to line 332 because the condition on line 330 was always true

331 return ALLOWED_TO_PREFIX + attribute_name.removesuffix(ACCESS_LEVELS_SUFFIX) 

332 return attribute_name 

333 

334 def _remove_unknown_protected_branches(self, param_value, dry_run, current_protected_branches): 

335 # Remaining protected branches 

336 unknown_protected_branches = self._content.get("unknown_protected_branches", "warn") 

337 if unknown_protected_branches in {"ignore", "skip"}: 337 ↛ 338line 337 didn't jump to line 338 because the condition on line 337 was never true

338 return 

339 for current_protected_name in current_protected_branches: 

340 if current_protected_name in param_value: 

341 continue 

342 if unknown_protected_branches in {"delete", "remove"} and dry_run: 342 ↛ 343line 342 didn't jump to line 343 because the condition on line 342 was never true

343 logger.info( 

344 "[%s] NOT Deleting unknown protected branch: %s (dry-run)", 

345 self._name, 

346 current_protected_name, 

347 ) 

348 elif unknown_protected_branches in {"delete", "remove"}: 348 ↛ 356line 348 didn't jump to line 356 because the condition on line 348 was always true

349 logger.info( 

350 "[%s] Deleting unknown protected branch: %s", 

351 self._name, 

352 current_protected_name, 

353 ) 

354 self._obj.protectedbranches.delete(current_protected_name) 

355 else: 

356 logger.warning( 

357 "[%s] NOT Deleting unknown protected branch: %s (unknown_protected_branches=%s)", 

358 self._name, 

359 current_protected_name, 

360 unknown_protected_branches, 

361 )