Coverage for src/gitlabracadabra/mixins/protected_branches.py: 82%
187 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-10 17:02 +0100
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-10 17:02 +0100
1#
2# Copyright (C) 2019-2025 Mathieu Parent <math.parent@gmail.com>
3#
4# This program is free software: you can redistribute it and/or modify
5# it under the terms of the GNU Lesser General Public License as published by
6# the Free Software Foundation, either version 3 of the License, or
7# (at your option) any later version.
8#
9# This program is distributed in the hope that it will be useful,
10# but WITHOUT ANY WARRANTY; without even the implied warranty of
11# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12# GNU Lesser General Public License for more details.
13#
14# You should have received a copy of the GNU Lesser General Public License
15# along with this program. If not, see <http://www.gnu.org/licenses/>.
17from __future__ import annotations
19import logging
20from json import dumps as json_dumps
21from time import sleep
22from types import MethodType
24from gitlab.exceptions import GitlabCreateError, GitlabParsingError, GitlabUpdateError
25from gitlab.mixins import CRUDMixin, NoUpdateMixin, SaveMixin
26from gitlab.v4.objects import ProjectProtectedBranch, ProjectProtectedBranchManager
28from gitlabracadabra.gitlab.access_levels import access_level_value
29from gitlabracadabra.objects.object import GitLabracadabraObject
31ALLOWED_TO_PREFIX = "allowed_to_"
32ALLOWED_TO_MERGE = "allowed_to_merge"
33ALLOWED_TO_PUSH = "allowed_to_push"
34ACCESS_LEVELS_SUFFIX = "_access_levels"
35ACCESS_LEVEL = "access_level"
36USER_ID = "user_id"
37GROUP_ID = "group_id"
38DEPLOY_KEY_ID = "deploy_key_id"
40logger = logging.getLogger(__name__)
43# Before https://github.com/python-gitlab/python-gitlab/commit/a867c48 (v4.5.0)
44if SaveMixin not in ProjectProtectedBranch.__bases__: 44 ↛ 45line 44 didn't jump to line 45 because the condition on line 44 was never true
45 ProjectProtectedBranch.__bases__ = (SaveMixin, *ProjectProtectedBranch.__bases__)
46if NoUpdateMixin in ProjectProtectedBranchManager.__bases__: 46 ↛ 47line 46 didn't jump to line 47 because the condition on line 46 was never true
47 ProjectProtectedBranchManager.__bases__ = tuple(
48 [CRUDMixin if base == NoUpdateMixin else base for base in ProjectProtectedBranchManager.__bases__]
49 )
51 # https://github.com/python-gitlab/python-gitlab/commit/f711d9e (v3.14)
52 def _http_patch(self, path, *, query_data=None, post_data=None, **kwargs):
53 query_data = query_data or {}
54 post_data = post_data or {}
56 response = self.http_request(
57 "patch",
58 path,
59 query_data=query_data,
60 post_data=post_data,
61 **kwargs,
62 )
63 try:
64 return response.json()
65 except Exception as err:
66 raise GitlabParsingError(
67 error_message="Failed to parse the server message",
68 ) from err
70 # https://github.com/python-gitlab/python-gitlab/commit/7073a2d (v4.0)
71 def _get_update_method(self):
72 return MethodType(_http_patch, self.gitlab)
74 ProjectProtectedBranchManager._get_update_method = _get_update_method # type: ignore # noqa: SLF001
77class ProtectedBranchesMixin(GitLabracadabraObject):
78 """Object with protected branches."""
80 def _process_protected_branches(self, param_name, param_value, *, dry_run=False, skip_save=False):
81 """Process the protected_branches param.
83 Args:
84 param_name: "protected_branches".
85 param_value: target protected branches.
86 dry_run: Dry run.
87 skip_save: False.
88 """
89 assert param_name == "protected_branches" # noqa: S101
90 assert not skip_save # noqa: S101
91 current_protected_branches = self._get_current_protected_branches()
92 self._create_or_update_protected_branches(param_value, dry_run, current_protected_branches)
93 self._remove_unknown_protected_branches(param_value, dry_run, current_protected_branches)
95 def _get_current_protected_branches(self):
96 current_protected_branches = dict(
97 [
98 [protected_branch.name, protected_branch]
99 for protected_branch in self._obj.protectedbranches.list(all=True)
100 ]
101 )
102 if not self._just_created:
103 return current_protected_branches
104 for _ in range(10): 104 ↛ 115line 104 didn't jump to line 115 because the loop on line 104 didn't complete
105 if current_protected_branches:
106 break
107 logger.debug("[%s] Waiting one second before retrieving protected branches again", self._name)
108 sleep(1)
109 current_protected_branches = dict(
110 [
111 [protected_branch.name, protected_branch]
112 for protected_branch in self._obj.protectedbranches.list(all=True)
113 ]
114 )
115 return current_protected_branches
117 def _create_or_update_protected_branches(self, param_value, dry_run, current_protected_branches):
118 for protected_name, target_config_str in sorted(param_value.items()):
119 if protected_name in current_protected_branches:
120 self._update_protected_branch(protected_name, target_config_str, dry_run, current_protected_branches)
121 else:
122 self._create_protected_branch(protected_name, target_config_str, dry_run, current_protected_branches)
124 def _update_protected_branch(self, protected_name, target_config_str, dry_run, current_protected_branches):
125 target_config = self._target_protected_branch_config(protected_name, target_config_str)
126 current_config = self._current_protected_branch_config(current_protected_branches, protected_name)
127 for current_k, current_v in current_config.items():
128 if current_k not in target_config:
129 target_config[current_k] = current_v
130 if current_config == target_config: 130 ↛ 131line 130 didn't jump to line 131 because the condition on line 130 was never true
131 return
132 # CE compatibility
133 can_use_allowed_to = "code_owner_approval_required" in current_config
134 if not can_use_allowed_to:
135 self._create_protected_branch(protected_name, target_config_str, dry_run, current_protected_branches)
136 return
137 if dry_run: 137 ↛ 138line 137 didn't jump to line 138 because the condition on line 137 was never true
138 logger.info(
139 "[%s] NOT Changing protected branch %s: %s -> %s (dry-run)",
140 self._name,
141 protected_name,
142 json_dumps(current_config, sort_keys=True),
143 json_dumps(target_config, sort_keys=True),
144 )
145 return
146 logger.info(
147 "[%s] Changing protected branch %s: %s -> %s",
148 self._name,
149 protected_name,
150 json_dumps(current_config, sort_keys=True),
151 json_dumps(target_config, sort_keys=True),
152 )
153 protected_branch = current_protected_branches.get(protected_name)
154 for target_config_key, target_config_value in target_config.items():
155 current_config_value = current_config.get(target_config_key)
156 if target_config_value == current_config_value:
157 continue
158 if target_config_key.startswith(ALLOWED_TO_PREFIX):
159 self._update_protected_branch_access_levels(
160 protected_branch,
161 current_config_value,
162 target_config_key,
163 target_config_value,
164 )
165 else:
166 setattr(protected_branch, target_config_key, target_config_value)
167 try:
168 protected_branch.save()
169 except GitlabUpdateError as err:
170 logger.warning(
171 "[%s] Unable to change protected branch %s: %s",
172 self._name,
173 protected_name,
174 err.error_message,
175 )
177 def _update_protected_branch_access_levels(
178 self,
179 protected_branch,
180 current_config_value,
181 target_config_key,
182 target_config_value,
183 ):
184 changes = []
185 for target_config_value_item in target_config_value:
186 if target_config_value_item in current_config_value:
187 continue
188 changes.append(target_config_value_item)
189 for current_config_value_item in getattr(protected_branch, self._received_attribute_name(target_config_key)):
190 current_access_level = self._current_access_level(current_config_value_item)
191 if current_access_level in target_config_value:
192 target_config_value.remove(current_access_level)
193 continue
194 changes.append({"id": current_config_value_item.get("id"), "_destroy": True})
195 setattr(protected_branch, target_config_key, changes)
197 def _received_attribute_name(self, attribute_name):
198 if attribute_name.startswith(ALLOWED_TO_PREFIX): 198 ↛ 200line 198 didn't jump to line 200 because the condition on line 198 was always true
199 return attribute_name.removeprefix(ALLOWED_TO_PREFIX) + ACCESS_LEVELS_SUFFIX
200 return attribute_name
202 def _create_protected_branch(self, protected_name, target_config_str, dry_run, current_protected_branches):
203 target_config = self._target_protected_branch_config(protected_name, target_config_str)
204 # CE compatibility
205 target_allowed_to_push = target_config.get(ALLOWED_TO_PUSH)
206 if len(target_allowed_to_push or []) == 1:
207 target_config["push_access_level"] = target_allowed_to_push[0][ACCESS_LEVEL]
208 target_allowed_to_merge = target_config.get(ALLOWED_TO_MERGE)
209 if len(target_allowed_to_merge or []) == 1:
210 target_config["merge_access_level"] = target_allowed_to_merge[0][ACCESS_LEVEL]
211 if dry_run: 211 ↛ 212line 211 didn't jump to line 212 because the condition on line 211 was never true
212 logger.info(
213 "[%s] NOT Creating protected branch %s: %s (dry-run)",
214 self._name,
215 protected_name,
216 json_dumps(target_config, sort_keys=True),
217 )
218 return
219 logger.info(
220 "[%s] Creating protected branch %s: %s",
221 self._name,
222 protected_name,
223 json_dumps(target_config, sort_keys=True),
224 )
225 try:
226 if protected_name in current_protected_branches:
227 # GitLab CE can't update push/merge access level
228 self._obj.protectedbranches.delete(protected_name)
229 self._obj.protectedbranches.create(target_config)
230 except GitlabCreateError as err:
231 logger.warning(
232 "[%s] Unable to create protected branch %s: %s",
233 self._name,
234 protected_name,
235 err.error_message,
236 )
238 def _target_protected_branch_config(self, protected_name, target_config_str):
239 target_config_int = {
240 "name": protected_name,
241 }
242 for target_config_key, target_config_value in target_config_str.items():
243 if target_config_key.endswith("_access_level"):
244 target_config_int[target_config_key] = access_level_value(target_config_value)
245 elif target_config_key.startswith(ALLOWED_TO_PREFIX):
246 target_config_int[target_config_key] = self._target_access_levels(target_config_value)
247 else:
248 target_config_int[target_config_key] = target_config_value
249 if "merge_access_level" in target_config_int:
250 if ALLOWED_TO_MERGE not in target_config_int: 250 ↛ 252line 250 didn't jump to line 252 because the condition on line 250 was always true
251 target_config_int[ALLOWED_TO_MERGE] = []
252 target_config_int[ALLOWED_TO_MERGE].append(
253 {
254 ACCESS_LEVEL: target_config_int.pop("merge_access_level"),
255 }
256 )
257 if "push_access_level" in target_config_int:
258 if ALLOWED_TO_PUSH not in target_config_int: 258 ↛ 260line 258 didn't jump to line 260 because the condition on line 258 was always true
259 target_config_int[ALLOWED_TO_PUSH] = []
260 target_config_int[ALLOWED_TO_PUSH].append(
261 {
262 ACCESS_LEVEL: target_config_int.pop("push_access_level"),
263 }
264 )
265 return target_config_int
267 def _target_access_levels(self, access_levels_str):
268 target_access_levels = []
269 for access_level_str in access_levels_str:
270 if "role" in access_level_str:
271 target_access_levels.append(
272 {
273 ACCESS_LEVEL: access_level_value(access_level_str.get("role")),
274 }
275 )
276 elif "user" in access_level_str:
277 target_access_levels.append(
278 {
279 USER_ID: self.connection.user_cache.id_from_username(access_level_str.get("user")),
280 }
281 )
282 elif "group" in access_level_str: 282 ↛ 283line 282 didn't jump to line 283 because the condition on line 282 was never true
283 target_access_levels.append(
284 {
285 GROUP_ID: self.connection.group_cache.id_from_full_path(access_level_str.get("group")),
286 }
287 )
288 elif "deploy_key" in access_level_str: 288 ↛ 269line 288 didn't jump to line 269 because the condition on line 288 was always true
289 target_access_levels.append(
290 {
291 DEPLOY_KEY_ID: self.connection.deploy_key_cache.id_from_title(
292 self._obj.id,
293 access_level_str.get("deploy_key"),
294 ),
295 }
296 )
297 return sorted(
298 target_access_levels,
299 key=lambda access_level: json_dumps(access_level, sort_keys=True),
300 )
302 def _current_protected_branch_config(self, current_protected_branches, protected_name):
303 if protected_name in current_protected_branches: 303 ↛ 312line 303 didn't jump to line 312 because the condition on line 303 was always true
304 current_protected_branch = current_protected_branches.get(protected_name)
305 current_config = {}
306 for param_k, param_v in current_protected_branch.attributes.items():
307 if param_k in {"push_access_levels", "merge_access_levels", "unprotect_access_levels"}:
308 current_config[self._sent_attribute_name(param_k)] = self._current_access_levels(param_v)
309 elif param_k in {"name", "allow_force_push", "code_owner_approval_required"}:
310 current_config[param_k] = param_v
311 return current_config
312 return {}
314 def _current_access_levels(self, access_levels):
315 return sorted(
316 [self._current_access_level(access_level) for access_level in access_levels],
317 key=lambda access_level: json_dumps(access_level, sort_keys=True),
318 )
320 def _current_access_level(self, access_level):
321 if access_level.get(USER_ID): 321 ↛ 322line 321 didn't jump to line 322 because the condition on line 321 was never true
322 return {USER_ID: access_level.get(USER_ID)}
323 if access_level.get(GROUP_ID): 323 ↛ 324line 323 didn't jump to line 324 because the condition on line 323 was never true
324 return {GROUP_ID: access_level.get(GROUP_ID)}
325 if access_level.get(DEPLOY_KEY_ID): 325 ↛ 326line 325 didn't jump to line 326 because the condition on line 325 was never true
326 return {DEPLOY_KEY_ID: access_level.get(DEPLOY_KEY_ID)}
327 return {ACCESS_LEVEL: access_level.get(ACCESS_LEVEL)}
329 def _sent_attribute_name(self, attribute_name):
330 if attribute_name.endswith(ACCESS_LEVELS_SUFFIX): 330 ↛ 332line 330 didn't jump to line 332 because the condition on line 330 was always true
331 return ALLOWED_TO_PREFIX + attribute_name.removesuffix(ACCESS_LEVELS_SUFFIX)
332 return attribute_name
334 def _remove_unknown_protected_branches(self, param_value, dry_run, current_protected_branches):
335 # Remaining protected branches
336 unknown_protected_branches = self._content.get("unknown_protected_branches", "warn")
337 if unknown_protected_branches in {"ignore", "skip"}: 337 ↛ 338line 337 didn't jump to line 338 because the condition on line 337 was never true
338 return
339 for current_protected_name in current_protected_branches:
340 if current_protected_name in param_value:
341 continue
342 if unknown_protected_branches in {"delete", "remove"} and dry_run: 342 ↛ 343line 342 didn't jump to line 343 because the condition on line 342 was never true
343 logger.info(
344 "[%s] NOT Deleting unknown protected branch: %s (dry-run)",
345 self._name,
346 current_protected_name,
347 )
348 elif unknown_protected_branches in {"delete", "remove"}: 348 ↛ 356line 348 didn't jump to line 356 because the condition on line 348 was always true
349 logger.info(
350 "[%s] Deleting unknown protected branch: %s",
351 self._name,
352 current_protected_name,
353 )
354 self._obj.protectedbranches.delete(current_protected_name)
355 else:
356 logger.warning(
357 "[%s] NOT Deleting unknown protected branch: %s (unknown_protected_branches=%s)",
358 self._name,
359 current_protected_name,
360 unknown_protected_branches,
361 )