55import dataclasses
66import enum
77import hashlib
8+ import hmac
9+ import logging
810import time
911from collections .abc import Generator
1012from typing import TYPE_CHECKING , Any
1113
1214import jwt as pyjwt
15+ import orjson
1316from django .conf import settings
1417
18+ logger = logging .getLogger (__name__ )
19+
1520if TYPE_CHECKING :
1621 from sentry .auth .services .auth import AuthenticatedToken
1722
@@ -103,6 +108,8 @@ def get_viewer_context() -> ViewerContext | None:
103108# ---------------------------------------------------------------------------
104109
105110_JWT_STANDARD_CLAIMS = frozenset ({"iat" , "exp" , "iss" , "aud" , "nbf" , "jti" , "sub" })
111+ # JWT header field identifying which key was used for signing (RFC 7515 §4.1.4).
112+ _JWT_KEY_ID_HEADER = "kid"
106113
107114
108115def _key_id (key : str ) -> str :
@@ -114,8 +121,8 @@ def _key_id(key: str) -> str:
114121 return hashlib .sha256 (key .encode ("utf-8" )).hexdigest ()[:8 ]
115122
116123
117- def _get_jwt_secret (key : str | None = None ) -> str :
118- """Return the symmetric key to use for JWT signing/verification .
124+ def _get_signing_key (key : str | None = None ) -> str :
125+ """Return the key to use for JWT signing.
119126
120127 Resolution: explicit *key* → ``SEER_API_SHARED_SECRET``.
121128
@@ -132,14 +139,28 @@ def _get_jwt_secret(key: str | None = None) -> str:
132139 raise ValueError ("No signing key available. Set SEER_API_SHARED_SECRET in settings." )
133140
134141
142+ def _get_verification_keys () -> dict [str , str ]:
143+ """Return a ``{kid: key}`` mapping of all known verification keys.
144+
145+ Add new service keys here as more services propagate ViewerContext.
146+ """
147+ keys : dict [str , str ] = {}
148+
149+ seer_secret = getattr (settings , "SEER_API_SHARED_SECRET" , "" )
150+ if seer_secret :
151+ keys [_key_id (seer_secret )] = seer_secret
152+
153+ return keys
154+
155+
135156def encode_viewer_context (
136157 viewer_context : ViewerContext ,
137158 * ,
138159 key : str | None = None ,
139160 ttl : int | None = None ,
140161) -> str :
141162 """Encode a :class:`ViewerContext` as a signed HS256 JWT."""
142- secret = _get_jwt_secret (key )
163+ secret = _get_signing_key (key )
143164
144165 if ttl is None :
145166 ttl = getattr (settings , "VIEWER_CONTEXT_JWT_TTL" , 900 )
@@ -152,7 +173,9 @@ def encode_viewer_context(
152173 "iss" : "sentry" ,
153174 }
154175
155- return pyjwt .encode (payload , secret , algorithm = "HS256" , headers = {"kid" : _key_id (secret )})
176+ return pyjwt .encode (
177+ payload , secret , algorithm = "HS256" , headers = {_JWT_KEY_ID_HEADER : _key_id (secret )}
178+ )
156179
157180
158181def decode_viewer_context (
@@ -161,8 +184,22 @@ def decode_viewer_context(
161184 key : str | None = None ,
162185 leeway : int = 5 ,
163186) -> ViewerContext :
164- """Decode and verify an HS256 JWT into a :class:`ViewerContext`."""
165- secret = _get_jwt_secret (key )
187+ """Decode and verify an HS256 JWT into a :class:`ViewerContext`.
188+
189+ When *key* is provided it is used directly. Otherwise all keys
190+ from ``_get_verification_keys()`` are tried, kid-matched key first.
191+ """
192+ if key is not None :
193+ secret = key
194+ else :
195+ keys_by_kid = _get_verification_keys ()
196+ if not keys_by_kid :
197+ raise ValueError ("No verification keys available." )
198+
199+ kid = pyjwt .get_unverified_header (token ).get (_JWT_KEY_ID_HEADER )
200+ secret = keys_by_kid .get (kid , "" ) if kid else ""
201+ if not secret :
202+ raise pyjwt .exceptions .InvalidKeyError (f"No verification key matches kid={ kid !r} " )
166203
167204 claims = pyjwt .decode (
168205 token ,
@@ -172,10 +209,52 @@ def decode_viewer_context(
172209 issuer = "sentry" ,
173210 leeway = leeway ,
174211 )
175- vc_data = {k : v for k , v in claims .items () if k not in _JWT_STANDARD_CLAIMS }
212+ vc_data = {ck : cv for ck , cv in claims .items () if ck not in _JWT_STANDARD_CLAIMS }
176213 return ViewerContext .deserialize (vc_data )
177214
178215
216+ def viewer_context_from_header (
217+ header_value : str , signature : str | None = None
218+ ) -> ViewerContext | None :
219+ """Decode a ViewerContext from ``X-Viewer-Context`` header(s).
220+
221+ Dual-mode for migration:
222+ - JWT (HS256) — new format, self-contained
223+ - Raw JSON + ``X-Viewer-Context-Signature`` HMAC — legacy format
224+ """
225+ if is_jwt_viewer_context (header_value ):
226+ try :
227+ return decode_viewer_context (header_value )
228+ except Exception :
229+ logger .warning ("viewer_context.jwt_decode_failed" , exc_info = True )
230+ return None
231+
232+ # Legacy: raw JSON + HMAC signature
233+ if signature is not None :
234+ return _verify_legacy_viewer_context (header_value , signature )
235+
236+ return None
237+
238+
239+ def _verify_legacy_viewer_context (context_json : str , signature : str ) -> ViewerContext | None :
240+ """Verify and decode a legacy JSON + HMAC-signed viewer context."""
241+ keys_by_kid = _get_verification_keys ()
242+ context_bytes = context_json .encode ("utf-8" )
243+
244+ for key in keys_by_kid .values ():
245+ computed = hmac .new (key .encode ("utf-8" ), context_bytes , hashlib .sha256 ).hexdigest ()
246+ if hmac .compare_digest (computed , signature ):
247+ try :
248+ data = orjson .loads (context_bytes )
249+ return ViewerContext .deserialize (data )
250+ except Exception :
251+ logger .warning ("viewer_context.legacy_decode_failed" , exc_info = True )
252+ return None
253+
254+ logger .warning ("viewer_context.legacy_signature_mismatch" )
255+ return None
256+
257+
179258def is_jwt_viewer_context (header_value : str ) -> bool :
180259 """Check whether the header value is a JWT by attempting to read its header.
181260
0 commit comments