Coverage for src/gitlabracadabra/containers/authenticated_session.py: 85%
111 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-10 17:02 +0100
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-10 17:02 +0100
1#
2# Copyright (C) 2019-2025 Mathieu Parent <math.parent@gmail.com>
3#
4# This program is free software: you can redistribute it and/or modify
5# it under the terms of the GNU Lesser General Public License as published by
6# the Free Software Foundation, either version 3 of the License, or
7# (at your option) any later version.
8#
9# This program is distributed in the hope that it will be useful,
10# but WITHOUT ANY WARRANTY; without even the implied warranty of
11# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12# GNU Lesser General Public License for more details.
13#
14# You should have received a copy of the GNU Lesser General Public License
15# along with this program. If not, see <http://www.gnu.org/licenses/>.
17from __future__ import annotations
19from time import time
20from typing import TYPE_CHECKING
21from urllib.parse import urlparse
22from urllib.request import parse_http_list, parse_keqv_list
24from requests import PreparedRequest, Response, codes
25from requests.structures import CaseInsensitiveDict
27from gitlabracadabra import __version__ as gitlabracadabra_version
28from gitlabracadabra.auth_info import AuthInfo
29from gitlabracadabra.session import Session
31if TYPE_CHECKING: 31 ↛ 32line 31 didn't jump to line 32 because the condition on line 31 was never true
32 from collections.abc import Iterable, MutableMapping
33 from typing import Any
35 from requests.auth import AuthBase
37 from gitlabracadabra.containers.scope import Scope
39 Params = (
40 MutableMapping[
41 str,
42 str | list[str],
43 ]
44 | None
45 )
46 Data = Iterable[bytes]
47 _SimpleParams = dict[str, str | list[str]]
48 _TokenKey = tuple[str, str, int | None, str | None]
51class Token:
52 """JWT Token."""
54 def __init__(
55 self,
56 token: str,
57 expires_in: int,
58 ) -> None:
59 """Instantiate a token.
61 Args:
62 token: Token.
63 expires_in: Expires in x seconds.
64 """
65 minimum_token_lifetime_seconds = 60
67 self._token = token
68 self._expires_in = expires_in
69 if self._expires_in < minimum_token_lifetime_seconds:
70 self._expires_in = minimum_token_lifetime_seconds
72 # We ignore issued_at property, and use local time instead
73 self._issued_at = time()
75 @property
76 def token(self) -> str:
77 """Get token.
79 Returns:
80 The token.
81 """
82 return self._token
84 @property
85 def expiration_time(self) -> float:
86 """Get expiration time.
88 Returns:
89 Expiration time.
90 """
91 return self._issued_at + self._expires_in
93 def is_expired(self) -> bool:
94 """Check if token is expired.
96 Returns:
97 True if token is expired.
98 """
99 return time() >= self.expiration_time
102class AuthenticatedSession(Session):
103 """Session with auth per-host."""
105 def __init__(self, *args: list[Any], **kwargs: dict[str, Any]) -> None:
106 """Instantiate a session.
108 Args:
109 args: Positional arguments.
110 kwargs: Named arguments.
111 """
112 super().__init__(*args, **kwargs)
113 self.headers = CaseInsensitiveDict(
114 {
115 "User-Agent": f"gitlabracadabra/{gitlabracadabra_version}",
116 "Docker-Distribution-Api-Version": "registry/2.0",
117 }
118 )
120 # Added attributes
121 self.scheme = "https"
122 self.connection_hostname = ""
123 self.auth_info = AuthInfo()
124 # Tokens, by set of scheme, host, port and scopes (as query string or None for all scope)
125 self._tokens: dict[_TokenKey, Token] = {}
126 self._current_scopes: set[Scope] | None = None
128 def authenticated_request(
129 self,
130 method: str,
131 url: str,
132 params: Params | None = None,
133 data: Data | None = None,
134 headers: dict[str, str] | None = None,
135 auth: AuthBase | None = None,
136 stream: bool | None = None,
137 ) -> Response:
138 """Send an HTTP request.
140 Args:
141 method: HTTP method.
142 url: Either a path or a full url.
143 params: query string params.
144 data: Request body stream.
145 headers: Request headers.
146 auth: HTTPBasicAuth.
147 stream: Stream the response.
149 Returns:
150 A Response.
151 """
152 if url.startswith("/"):
153 url = f"{self.scheme}://{self.connection_hostname}{url}"
154 token = self._get_token(url, self._current_scopes)
155 if token:
156 if headers is None: 156 ↛ 157line 156 didn't jump to line 157 because the condition on line 156 was never true
157 headers = {}
158 headers["Authorization"] = f"Bearer {token.token}"
159 return self.request(
160 method,
161 url,
162 params=params,
163 data=data,
164 headers=headers,
165 auth=auth,
166 stream=stream,
167 )
169 def rebuild_auth(self, prepared_request: PreparedRequest, response: Response) -> None:
170 """Override Session method to inject bearer tokens.
172 Args:
173 prepared_request: Prepared request.
174 response: Response.
175 """
176 super().rebuild_auth(prepared_request, response) # type: ignore
177 token = self._get_token(prepared_request.url or "", self._current_scopes)
178 if token: 178 ↛ 179line 178 didn't jump to line 179 because the condition on line 178 was never true
179 prepared_request.headers["Authorization"] = f"Bearer {token.token}"
181 def connect(self, scopes: set[Scope] | None) -> None:
182 """Connect.
184 Args:
185 scopes: An optional set of scopes.
186 """
187 self._current_scopes = scopes
188 url = f"{self.scheme}://{self.connection_hostname}/v2/"
189 token = self._get_token(url, scopes)
190 if token:
191 return
192 token = self._get_token(url, None)
193 if token:
194 return
195 response = self.authenticated_request("get", url)
196 if response.history:
197 self.connection_hostname = urlparse(response.url).hostname or self.connection_hostname
198 if response.status_code == codes["ok"]:
199 one_hour = 3600
200 self._set_token(response, None, Token("no_auth", one_hour))
201 return
202 if response.status_code == codes["unauthorized"] and response.headers["Www-Authenticate"].startswith("Bearer "): 202 ↛ 205line 202 didn't jump to line 205 because the condition on line 202 was always true
203 self._get_bearer_token(response)
204 return
205 response.raise_for_status()
207 def _get_bearer_token(self, response: Response) -> None:
208 if self._current_scopes is None: 208 ↛ 209line 208 didn't jump to line 209 because the condition on line 208 was never true
209 raise ValueError
210 challenge_parameters = self._get_challenge_parameters(response)
211 get_params: _SimpleParams = {}
212 if "service" in challenge_parameters: 212 ↛ 214line 212 didn't jump to line 214 because the condition on line 212 was always true
213 get_params["service"] = challenge_parameters.get("service", "unknown")
214 get_params["scope"] = []
215 for scope in sorted(self._current_scopes):
216 get_params["scope"].append( # type: ignore
217 f"repository:{scope.remote_name}:{scope.actions}",
218 )
219 challenge_response = self.authenticated_request(
220 "get",
221 challenge_parameters["realm"],
222 params=get_params,
223 headers=self.auth_info.headers,
224 auth=self.auth_info.auth,
225 )
226 challenge_response.raise_for_status()
227 json = challenge_response.json()
228 self._set_token(
229 response,
230 self._current_scopes,
231 Token(
232 str(json.get("token", json.get("access_token", ""))),
233 int(json.get("expires_in", 0)),
234 ),
235 )
237 def _get_challenge_parameters(self, response: Response) -> dict[str, str]:
238 _, _, challenge = response.headers["Www-Authenticate"].partition("Bearer ")
239 return parse_keqv_list(parse_http_list(challenge))
241 def _get_token(self, url: str, scopes: set[Scope] | None) -> Token | None:
242 parsed = urlparse(url)
243 key = (
244 parsed.scheme,
245 parsed.hostname or "",
246 parsed.port,
247 self._scopes_hash(scopes),
248 )
249 token = self._tokens.get(key)
250 if token and token.is_expired(): 250 ↛ 251line 250 didn't jump to line 251 because the condition on line 250 was never true
251 self._tokens.pop(key)
252 return None
253 return token
255 def _set_token(self, response: Response, scopes: set[Scope] | None, token: Token) -> None:
256 parsed = urlparse(response.url)
257 key = (
258 parsed.scheme,
259 parsed.hostname or "",
260 parsed.port,
261 self._scopes_hash(scopes),
262 )
263 self._tokens[key] = token
265 def _scopes_hash(self, scopes: set[Scope] | None) -> str | None:
266 if scopes is None:
267 return None
268 return ",".join(map(str, sorted(scopes)))