diff --git a/example/flask_rp/conf.json b/example/flask_rp/conf.json index b7cd241..6d7203b 100644 --- a/example/flask_rp/conf.json +++ b/example/flask_rp/conf.json @@ -162,9 +162,7 @@ "redirect_uris": [ "https://{domain}:{port}/authz_cb/local" ], - "post_logout_redirect_uris": [ - "https://{domain}:{port}/session_logout/local" - ], + "post_logout_redirect_uri": "https://{domain}:{port}/session_logout/local", "frontchannel_logout_uri": "https://{domain}:{port}/fc_logout/local", "frontchannel_logout_session_required": true, "backchannel_logout_uri": "https://{domain}:{port}/bc_logout/local", @@ -231,9 +229,7 @@ "redirect_uris": [ "https://{domain}:{port}/authz_cb/django" ], - "post_logout_redirect_uris": [ - "https://{domain}:{port}/session_logout/django" - ], + "post_logout_redirect_uris": "https://{domain}:{port}/session_logout/django", "frontchannel_logout_uri": "https://{domain}:{port}/fc_logout/django", "frontchannel_logout_session_required": true, "backchannel_logout_uri": "https://{domain}:{port}/bc_logout/django", diff --git a/example/flask_rp/templates/opresult.html b/example/flask_rp/templates/opresult.html index 34bd7be..86c6789 100644 --- a/example/flask_rp/templates/opresult.html +++ b/example/flask_rp/templates/opresult.html @@ -18,6 +18,10 @@

Endpoints

{{ url }}
{% endfor %} +
+
ID Token
+
{{ id_token }}
+

User information

{% for key, value in userinfo.items() %} diff --git a/example/flask_rp/views.py b/example/flask_rp/views.py index 5af68ae..caa7141 100644 --- a/example/flask_rp/views.py +++ b/example/flask_rp/views.py @@ -53,10 +53,14 @@ def rp(): uid = '' if iss or uid: + args = { + 'req_args': { + "claims": {"id_token": {"acr": {"value": "https://refeds.org/profile/mfa"}}} + } + } + if uid: - args = {'user_id': uid} - else: - args = {} + args['user_id'] = uid session['op_identifier'] = iss try: @@ -145,6 +149,7 @@ def finalize(op_identifier, request_args): return render_template('opresult.html', endpoints=endpoints, userinfo=res['userinfo'], access_token=res['token'], + id_token=res["id_token"], **kwargs) else: return make_response(res['error'], 400) @@ -152,7 +157,7 @@ def finalize(op_identifier, request_args): def get_op_identifier_by_cb_uri(url: str): uri = splitquery(url)[0] - for k,v in current_app.rph.issuer2rp.items(): + for k, v in current_app.rph.issuer2rp.items(): _cntx = v.get_service_context() for endpoint in ("redirect_uris", "post_logout_redirect_uris", @@ -180,10 +185,10 @@ def repost_fragment(): return finalize(op_identifier, args) -@oidc_rp_views.route('/ihf_cb') -def ihf_cb(self, op_identifier='', **kwargs): +@oidc_rp_views.route('/authz_im_cb') +def authz_im_cb(op_identifier='', **kwargs): logger.debug('implicit_hybrid_flow kwargs: {}'.format(kwargs)) - return render_template('repost_fragment.html') + return render_template('repost_fragment.html', op_identifier=op_identifier) @oidc_rp_views.route('/session_iframe') diff --git a/setup.py b/setup.py index 75a0fe5..c8b8625 100755 --- a/setup.py +++ b/setup.py @@ -67,7 +67,7 @@ def run_tests(self): "Programming Language :: Python :: 3.9", "Topic :: Software Development :: Libraries :: Python Modules"], install_requires=[ - 'oidcmsg==1.3.3-1', + 'oidcmsg==1.4.1', 'pyyaml>=5.1.2', 'responses' ], diff --git a/src/oidcrp/__init__.py b/src/oidcrp/__init__.py index 23161e0..823f282 100644 --- a/src/oidcrp/__init__.py +++ b/src/oidcrp/__init__.py @@ -1,7 +1,7 @@ import logging __author__ = 'Roland Hedberg' -__version__ = '2.0.1' +__version__ = '2.1.0' logger = logging.getLogger(__name__) diff --git a/src/oidcrp/client_auth.py b/src/oidcrp/client_auth.py index f585ec7..571f428 100755 --- a/src/oidcrp/client_auth.py +++ b/src/oidcrp/client_auth.py @@ -12,8 +12,8 @@ from oidcmsg.oauth2 import SINGLE_OPTIONAL_STRING from oidcmsg.oidc import AuthnToken from oidcmsg.time_util import utc_time_sans_frac +from oidcmsg.util import rndstr -from oidcrp.util import rndstr from oidcrp.util import sanitize from .defaults import DEF_SIGN_ALG from .defaults import JWT_BEARER diff --git a/src/oidcrp/configure.py b/src/oidcrp/configure.py index d38cc76..84ac367 100755 --- a/src/oidcrp/configure.py +++ b/src/oidcrp/configure.py @@ -14,7 +14,7 @@ try: from secrets import token_urlsafe as rnd_token except ImportError: - from oidcendpoint import rndstr as rnd_token + from cryptojwt import rndstr as rnd_token DEFAULT_FILE_ATTRIBUTE_NAMES = ['server_key', 'server_cert', 'filename', 'template_dir', 'private_path', 'public_path', 'db_file'] diff --git a/src/oidcrp/oauth2/__init__.py b/src/oidcrp/oauth2/__init__.py index b195c8d..6b7612e 100755 --- a/src/oidcrp/oauth2/__init__.py +++ b/src/oidcrp/oauth2/__init__.py @@ -1,14 +1,19 @@ from json import JSONDecodeError import logging +from typing import Optional from oidcmsg.exception import FormatError +from oidcmsg.message import Message +from oidcmsg.oauth2 import is_error_message from oidcrp.entity import Entity +from oidcrp.exception import ConfigurationError from oidcrp.exception import OidcServiceError from oidcrp.exception import ParseError from oidcrp.http import HTTPLib from oidcrp.service import REQUEST_INFO from oidcrp.service import SUCCESSFUL +from oidcrp.service import Service from oidcrp.util import do_add_ons from oidcrp.util import get_deserialization_method @@ -72,7 +77,11 @@ def __init__(self, client_authn_factory=None, keyjar=None, verify_ssl=True, conf # just ignore verify_ssl until it goes away self.verify_ssl = self.httpc_params.get("verify", True) - def do_request(self, request_type, response_body_type="", request_args=None, **kwargs): + def do_request(self, + request_type: str, + response_body_type: Optional[str] = "", + request_args: Optional[dict] = None, + behaviour_args: Optional[dict] = None, **kwargs): _srv = self._service[request_type] _info = _srv.get_request_parameters(request_args=request_args, **kwargs) @@ -93,8 +102,13 @@ def set_client_id(self, client_id): self.client_id = client_id self._service_context.set('client_id', client_id) - def get_response(self, service, url, method="GET", body=None, response_body_type="", - headers=None, **kwargs): + def get_response(self, + service: Service, + url: str, + method: Optional[str] = "GET", + body: Optional[dict] = None, + response_body_type: Optional[str] = "", + headers: Optional[dict] = None, **kwargs): """ :param url: @@ -129,8 +143,13 @@ def get_response(self, service, url, method="GET", body=None, response_body_type return self.parse_request_response(service, resp, response_body_type, **kwargs) - def service_request(self, service, url, method="GET", body=None, - response_body_type="", headers=None, **kwargs): + def service_request(self, + service: Service, + url: str, + method: Optional[str] = "GET", + body: Optional[dict] = None, + response_body_type: Optional[str] = "", + headers: Optional[dict] = None, **kwargs) -> Message: """ The method that sends the request and handles the response returned. This assumes that the response arrives in the HTTP response. @@ -249,3 +268,27 @@ def parse_request_response(self, service, reqresp, response_body_type='', reqresp.text)) raise OidcServiceError("HTTP ERROR: %s [%s] on %s" % ( reqresp.text, reqresp.status_code, reqresp.url)) + + +def dynamic_provider_info_discovery(client: Client, behaviour_args: Optional[dict]=None): + """ + This is about performing dynamic Provider Info discovery + + :param behaviour_args: + :param client: A :py:class:`oidcrp.oidc.Client` instance + """ + try: + client.get_service('provider_info') + except KeyError: + raise ConfigurationError( + 'Can not do dynamic provider info discovery') + else: + _context = client.client_get("service_context") + try: + _context.set('issuer', _context.config['srv_discovery_url']) + except KeyError: + pass + + response = client.do_request('provider_info', behaviour_args=behaviour_args) + if is_error_message(response): + raise OidcServiceError(response['error']) diff --git a/src/oidcrp/oauth2/authorization.py b/src/oidcrp/oauth2/authorization.py index 366b223..d82c8e1 100644 --- a/src/oidcrp/oauth2/authorization.py +++ b/src/oidcrp/oauth2/authorization.py @@ -7,7 +7,7 @@ from oidcmsg.time_util import time_sans_frac from oidcrp.oauth2.utils import get_state_parameter -from oidcrp.oauth2.utils import pick_redirect_uris +from oidcrp.oauth2.utils import pre_construct_pick_redirect_uri from oidcrp.oauth2.utils import set_state_parameter from oidcrp.service import Service @@ -32,7 +32,7 @@ class Authorization(Service): def __init__(self, client_get, client_authn_factory=None, conf=None): Service.__init__(self, client_get, client_authn_factory=client_authn_factory, conf=conf) - self.pre_construct.extend([pick_redirect_uris, set_state_parameter]) + self.pre_construct.extend([pre_construct_pick_redirect_uri, set_state_parameter]) self.post_construct.append(self.store_auth_request) def update_service_context(self, resp, key='', **kwargs): diff --git a/src/oidcrp/oauth2/utils.py b/src/oidcrp/oauth2/utils.py index 2393766..d1c4038 100644 --- a/src/oidcrp/oauth2/utils.py +++ b/src/oidcrp/oauth2/utils.py @@ -1,4 +1,14 @@ +import logging +from typing import Optional +from typing import Union + from oidcmsg.exception import MissingParameter +from oidcmsg.exception import MissingRequiredAttribute +from oidcmsg.message import Message + +from oidcrp.service import Service + +logger = logging.getLogger(__name__) def get_state_parameter(request_args, kwargs): @@ -14,35 +24,48 @@ def get_state_parameter(request_args, kwargs): return _state -def pick_redirect_uris(request_args=None, service=None, **kwargs): - """Pick one redirect_uri base on response_mode out of a list of such.""" - _context = service.client_get("service_context") +def pick_redirect_uri(context, + request_args: Optional[Union[Message, dict]] = None, + response_type: Optional[str] = ''): + if request_args is None: + request_args = {} if 'redirect_uri' in request_args: - return request_args, {} + return request_args["redirect_uri"] - _callback = _context.callback - if _callback: - try: - _response_type = request_args['response_type'] - except KeyError: - _response_type = _context.behaviour['response_types'][0] - request_args['response_type'] = _response_type + if context.redirect_uris: + redirect_uri = context.redirect_uris[0] + elif context.callback: + if not response_type: + _conf_resp_types = context.behaviour.get('response_types', []) + response_type = request_args.get('response_type') + if not response_type and _conf_resp_types: + response_type = _conf_resp_types[0] - try: - _response_mode = request_args['response_mode'] - except KeyError: - _response_mode = '' + _response_mode = request_args.get('response_mode') - if _response_mode == 'form_post': - request_args['redirect_uri'] = _callback['form_post'] - elif _response_type == 'code': - request_args['redirect_uri'] = _callback['code'] + if _response_mode == 'form_post' or response_type == ["form_post"]: + redirect_uri = context.callback['form_post'] + elif response_type == 'code' or response_type == ["code"]: + redirect_uri = context.callback['code'] else: - request_args['redirect_uri'] = _callback['implicit'] + redirect_uri = context.callback['implicit'] + + logger.debug( + f"pick_redirect_uris: response_type={response_type}, response_mode={_response_mode}, " + f"redirect_uri={redirect_uri}") else: - request_args['redirect_uri'] = _context.redirect_uris[0] + logger.error("No redirect_uri") + raise MissingRequiredAttribute('redirect_uri') + return redirect_uri + + +def pre_construct_pick_redirect_uri(request_args: Optional[Union[Message, dict]] = None, + service: Optional[Service] = None, **kwargs): + _context = service.client_get("service_context") + request_args["redirect_uri"] = pick_redirect_uri(_context, + request_args=request_args) return request_args, {} diff --git a/src/oidcrp/oidc/__init__.py b/src/oidcrp/oidc/__init__.py index cb000e4..df9db5e 100755 --- a/src/oidcrp/oidc/__init__.py +++ b/src/oidcrp/oidc/__init__.py @@ -2,6 +2,7 @@ import logging from oidcrp.client_auth import BearerHeader +from oidcrp.oidc.registration import CALLBACK_URIS try: from json import JSONDecodeError @@ -112,6 +113,15 @@ def __init__(self, client_authn_factory=None, keyjar=keyjar, verify_ssl=verify_ssl, config=config, httplib=httplib, services=_srvs, httpc_params=httpc_params) + _context = self.get_service_context() + if _context.callback is None: + _context.callback = {} + + for _cb in CALLBACK_URIS: + _uri = config.get(_cb) + if _uri: + _context.callback[_cb] = _uri + def fetch_distributed_claims(self, userinfo, callback=None): """ diff --git a/src/oidcrp/oidc/access_token.py b/src/oidcrp/oidc/access_token.py index 03a87f1..7828b69 100644 --- a/src/oidcrp/oidc/access_token.py +++ b/src/oidcrp/oidc/access_token.py @@ -1,7 +1,9 @@ import logging from typing import Optional +from typing import Union from oidcmsg import oidc +from oidcmsg.message import Message from oidcmsg.oidc import verified_claim_name from oidcmsg.time_util import time_sans_frac @@ -26,7 +28,9 @@ def __init__(self, access_token.AccessToken.__init__(self, client_get, client_authn_factory=client_authn_factory, conf=conf) - def gather_verify_arguments(self): + def gather_verify_arguments(self, + response: Optional[Union[dict, Message]] = None, + behaviour_args: Optional[dict] = None): """ Need to add some information before running verify() diff --git a/src/oidcrp/oidc/authorization.py b/src/oidcrp/oidc/authorization.py index b95601b..5845640 100644 --- a/src/oidcrp/oidc/authorization.py +++ b/src/oidcrp/oidc/authorization.py @@ -1,14 +1,18 @@ import logging +from typing import Optional +from typing import Union +from oidcmsg import oauth2 from oidcmsg import oidc +from oidcmsg.exception import MissingRequiredAttribute +from oidcmsg.message import Message from oidcmsg.oidc import make_openid_request from oidcmsg.oidc import verified_claim_name from oidcmsg.time_util import time_sans_frac from oidcmsg.time_util import utc_time_sans_frac -from oidcrp.exception import ParameterError from oidcrp.oauth2 import authorization -from oidcrp.oauth2.utils import pick_redirect_uris +from oidcrp.oauth2.utils import pre_construct_pick_redirect_uri from oidcrp.oidc import IDT2REG from oidcrp.oidc.utils import construct_request_uri from oidcrp.oidc.utils import request_object_encryption @@ -28,7 +32,7 @@ def __init__(self, client_get, client_authn_factory=None, conf=None): authorization.Authorization.__init__(self, client_get, client_authn_factory, conf=conf) self.default_request_args = {'scope': ['openid']} - self.pre_construct = [self.set_state, pick_redirect_uris, + self.pre_construct = [self.set_state, pre_construct_pick_redirect_uri, self.oidc_pre_construct] self.post_construct = [self.oidc_post_construct] @@ -47,25 +51,33 @@ def set_state(self, request_args, **kwargs): def update_service_context(self, resp, key='', **kwargs): _context = self.client_get("service_context") - try: - _idt = resp[verified_claim_name('id_token')] - except KeyError: - pass - else: - # If there is a verified ID Token then we have to do nonce - # verification - try: - if _context.state.get_state_by_nonce(_idt['nonce']) != key: - raise ParameterError('Someone has messed with "nonce"') - except KeyError: - raise ValueError('Missing nonce value') - - _context.state.store_sub2state(_idt['sub'], key) if 'expires_in' in resp: resp['__expires_at'] = time_sans_frac() + int(resp['expires_in']) _context.state.store_item(resp.to_json(), 'auth_response', key) + def get_request_from_response(self, response): + _context = self.client_get("service_context") + return _context.state.get_item(oauth2.AuthorizationRequest, 'auth_request', + response["state"]) + + def post_parse_response(self, response, **kwargs): + response = authorization.Authorization.post_parse_response(self, response, **kwargs) + + _idt = response.get(verified_claim_name('id_token')) + if _idt: + # If there is a verified ID Token then we have to do nonce + # verification. + _request = self.get_request_from_response(response) + _req_nonce = _request.get('nonce') + if _req_nonce: + _id_token_nonce = _idt.get('nonce') + if not _id_token_nonce: + raise MissingRequiredAttribute('nonce') + elif _req_nonce != _id_token_nonce: + raise ValueError('Invalid nonce') + return response + def oidc_pre_construct(self, request_args=None, post_args=None, **kwargs): _context = self.client_get("service_context") if request_args is None: @@ -148,9 +160,9 @@ def store_request_on_file(self, req, **kwargs): fid.close() return _webname - def construct_request_parameter(self, req, request_method, audience=None, expires_in=0, + def construct_request_parameter(self, req, request_param, audience=None, expires_in=0, **kwargs): - """Construct a request parameter""" + """ Construct a request parameter """ alg = self.get_request_object_signing_alg(**kwargs) kwargs["request_object_signing_alg"] = alg @@ -158,6 +170,9 @@ def construct_request_parameter(self, req, request_method, audience=None, expire if "keys" not in kwargs and alg and alg != "none": kwargs["keys"] = _context.keyjar + if alg == "none": + kwargs["keys"] = [] + _srv_cntx = _context # This is the issuer of the JWT, that is me ! @@ -176,12 +191,15 @@ def construct_request_parameter(self, req, request_method, audience=None, expire if expires_in: req['exp'] = utc_time_sans_frac() + int(expires_in) - _req = make_openid_request(req, **kwargs) + _mor_args = {k: kwargs[k] for k in ["keys", "issuer", "request_object_signing_alg", "recv", + "with_jti", "lifetime"] if k in kwargs} + + _req = make_openid_request(req, **_mor_args) # Should the request be encrypted _req = request_object_encryption(_req, _context, **kwargs) - if request_method == "request": + if request_param == "request": req["request"] = _req else: # MUST be request_uri req["request_uri"] = self.store_request_on_file(_req, **kwargs) @@ -204,19 +222,28 @@ def oidc_post_construct(self, req, **kwargs): if 'prompt' not in req: req['prompt'] = 'consent' - try: - _request_method = kwargs['request_param'] - except KeyError: - pass - else: - del kwargs['request_param'] + _context.state.store_item(req, 'auth_request', req['state']) - self.construct_request_parameter(req, _request_method, **kwargs) + _request_param = kwargs.get('request_param') + if _request_param: + del kwargs['request_param'] + # local_dir, base_path + _config = _context.get('config') + kwargs["local_dir"] = _config.get('local_dir', './requests') + kwargs["base_path"] = _context.get('base_url') + '/' + "requests" + self.construct_request_parameter(req, _request_param, **kwargs) + # removed all arguments except request/request_uri and the required + _leave = ['request', 'request_uri'] + _leave.extend(req.required_parameters()) + _keys = [k for k in req.keys() if k not in _leave] + for k in _keys: + del req[k] - _context.state.store_item(req, 'auth_request', req['state']) return req - def gather_verify_arguments(self): + def gather_verify_arguments(self, + response: Optional[Union[dict, Message]] = None, + behaviour_args: Optional[dict] = None): """ Need to add some information before running verify() diff --git a/src/oidcrp/oidc/end_session.py b/src/oidcrp/oidc/end_session.py index a58c91d..df4d2d4 100644 --- a/src/oidcrp/oidc/end_session.py +++ b/src/oidcrp/oidc/end_session.py @@ -54,13 +54,17 @@ def get_id_token_hint(self, request_args=None, **kwargs): def add_post_logout_redirect_uri(self, request_args=None, **kwargs): if 'post_logout_redirect_uri' not in request_args: - try: - request_args[ - 'post_logout_redirect_uri' - ] = self.client_get("service_context").register_args[ - 'post_logout_redirect_uris'][0] - except KeyError: - pass + _context = self.client_get("service_context") + _uri = _context.register_args.get('post_logout_redirect_uris') + if _uri: + if isinstance(_uri, str): + request_args['post_logout_redirect_uri'] = _uri + else: # assume list + request_args['post_logout_redirect_uri'] = _uri[0] + else: + _uri = _context.callback.get("post_logout_redirect_uris") + if _uri: + request_args['post_logout_redirect_uri'] = _uri[0] return request_args, {} diff --git a/src/oidcrp/oidc/registration.py b/src/oidcrp/oidc/registration.py index 925cc66..06a3ea9 100644 --- a/src/oidcrp/oidc/registration.py +++ b/src/oidcrp/oidc/registration.py @@ -1,9 +1,12 @@ +import hashlib import logging +from typing import List +from typing import Optional +from cryptojwt.utils import as_bytes from oidcmsg import oidc from oidcmsg.oauth2 import ResponseMessage -from oidcrp.oidc.provider_info_discovery import add_redirect_uris from oidcrp.service import Service __author__ = 'Roland Hedberg' @@ -37,19 +40,133 @@ def response_types_to_grant_types(response_types): return list(_res) -def add_request_uri(request_args=None, service=None, **kwargs): - _context = service.client_get("service_context") - if _context.requests_dir: - _pi = _context.provider_info - if _pi: - _req = _pi.get('require_request_uri_registration', False) - if _req is True: - request_args['request_uris'] = _context.generate_request_uris(_context.requests_dir) +def create_callbacks(issuer: str, + hash_seed: str, + base_url: str, + code: Optional[bool] = False, + implicit: Optional[bool] = False, + form_post: Optional[bool] = False, + request_uris: Optional[bool] = False, + backchannel_logout_uri: Optional[bool] = False, + frontchannel_logout_uri: Optional[bool] = False): + """ + To mitigate some security issues the redirect_uris should be OP/AS + specific. This method creates a set of redirect_uris unique to the + OP/AS. + + :param frontchannel_logout_uri: Whether a front-channel logout uri should be constructed + :param backchannel_logout_uri: Whether a back-channel logout uri should be constructed + :param request_uri: Whether a request_uri should be constructed + :param issuer: Issuer ID + :return: A set of redirect_uris + """ + _hash = hashlib.sha256() + _hash.update(hash_seed) + _hash.update(as_bytes(issuer)) + _hex = _hash.hexdigest() + + res = {'__hex': _hex} + + if code: + res['code'] = f"{base_url}/authz_cb/{_hex}" + + if implicit: + res['implicit'] = f"{base_url}/authz_im_cb/{_hex}" + + if form_post: + res['form_post'] = f"{base_url}/authz_fp_cb/{_hex}" + + if request_uris: + res["request_uris"] = f"{base_url}/req_uri/{_hex}" + + if backchannel_logout_uri or frontchannel_logout_uri: + res["post_logout_redirect_uris"] = [f"{base_url}/session_logout/{_hex}"] + + if backchannel_logout_uri: + res["backchannel_logout_uri"] = f"{base_url}/bc_logout/{_hex}" + + if frontchannel_logout_uri: + res["frontchannel_logout_uri"] = f"{base_url}/fc_logout/{_hex}" + + logger.debug(f"Created callback URIs: {res}") + return res + + +def _cmp(a, b): + if b is None: # Don't care about the value as long as there is one + return True + elif isinstance(a, str) and a == b: + return True + elif isinstance(a, list) and b in a: + return True + + return a == b - return request_args, {} +def _in_config_or_client_preferences(config, attr, val): + _val = config.get("client_preferences", {}).get(attr) + if _cmp(_val, val): + return True + _val = config.get(attr) + return _cmp(_val, val) -def add_post_logout_redirect_uris(request_args=None, service=None, **kwargs): + +def add_callbacks(context, ignore: Optional[List[str]] = None): + if ignore is None: + ignore = [] + _iss = context.get('issuer') + + _uris = {} + + _pi = context.get('provider_info') + _cp = context.config.get("client_preferences") + + if "redirect_uris" not in ignore: + # code and/or implicit + if _in_config_or_client_preferences(context.config, "response_types", "code"): + _uris['code'] = True + for rt in ["id_token", "id_token token", "code id_token token", "code idtoken", + "code token"]: + if _in_config_or_client_preferences(context.config, "response_types", rt): + _uris["implicit"] = True + break + + if "form_post" not in ignore: + if _in_config_or_client_preferences(context.config, "form_post_usable", True): + _uris["form_post"] = True + + if "request_uris" not in ignore: + if 'require_request_uri_registration' in _pi and _in_config_or_client_preferences( + context.config, "request_uri_usable", True): + _uris['request_uris'] = True + + if "frontchannel_logout_uri" not in ignore: + if 'frontchannel_logout_supported' in _pi and _in_config_or_client_preferences( + context.config, "frontchannel_logout_usable", True): + _uris["frontchannel_logout_uri"] = True + + if "backchannel_logout_uri" not in ignore: + if 'backchannel_logout_supported' in _pi and _in_config_or_client_preferences( + context.config, "backchannel_logout_usable", True): + _uris["backchannel_logout_uri"] = True + + callbacks = create_callbacks(_iss, + hash_seed=context.get('hash_seed'), + base_url=context.get("base_url"), + **_uris) + context.hash2issuer[callbacks['__hex']] = _iss + + if "redirect_uris" not in ignore: + _redirect_uris = [v for k, v in callbacks.items() if k in ["code", "implicit", "form_post"]] + callbacks["redirect_uris"] = _redirect_uris + context.set('callback', callbacks) + + +CALLBACK_URIS = ["post_logout_redirect_uris", "backchannel_logout_uri", "frontchannel_logout_uri", + "request_uris", 'redirect_uris'] + + +def add_callback_uris(request_args=None, service=None, **kwargs): """ :param request_args: @@ -59,10 +176,17 @@ def add_post_logout_redirect_uris(request_args=None, service=None, **kwargs): :return: """ - if "post_logout_redirect_uris" not in request_args: - _uris = service.client_get("service_context").register_args.get("post_logout_redirect_uris") - if _uris: - request_args["post_logout_redirect_uris"] = _uris + _context = service.client_get("service_context") + _ignore = [k for k in list(request_args.keys()) if k in CALLBACK_URIS] + add_callbacks(_context, ignore=_ignore) + for _key in CALLBACK_URIS: + _req_val = request_args.get(_key) + if not _req_val: + _uri = _context.register_args.get(_key) + if not _uri: + _uri = _context.callback.get(_key) + if _uri: + request_args[_key] = _uri return request_args, {} @@ -107,8 +231,8 @@ def __init__(self, client_get, client_authn_factory=None, conf=None): client_authn_factory=client_authn_factory, conf=conf) self.pre_construct = [self.add_client_behaviour_preference, - add_redirect_uris, add_request_uri, - add_post_logout_redirect_uris, + #add_redirect_uris, + add_callback_uris, add_jwks_uri_or_jwks] self.post_construct = [self.oidc_post_construct] diff --git a/src/oidcrp/oidc/userinfo.py b/src/oidcrp/oidc/userinfo.py index 2b5852d..6c93b80 100644 --- a/src/oidcrp/oidc/userinfo.py +++ b/src/oidcrp/oidc/userinfo.py @@ -1,4 +1,6 @@ import logging +from typing import Optional +from typing import Union from oidcmsg import oidc from oidcmsg.exception import MissingSigningKey @@ -105,10 +107,16 @@ def post_parse_response(self, response, **kwargs): "url": spec["endpoint"] } + # Extension point + for meth in self.post_parse_process: + response = meth(response, _state_interface, kwargs['state']) + _state_interface.store_item(response, 'user_info', kwargs['state']) return response - def gather_verify_arguments(self): + def gather_verify_arguments(self, + response: Optional[Union[dict, Message]] = None, + behaviour_args: Optional[dict] = None): """ Need to add some information before running verify() diff --git a/src/oidcrp/oidc/utils.py b/src/oidcrp/oidc/utils.py index 40b9b90..9605733 100644 --- a/src/oidcrp/oidc/utils.py +++ b/src/oidcrp/oidc/utils.py @@ -83,5 +83,8 @@ def construct_request_uri(local_dir, base_path, **kwargs): while os.path.exists(filename): _name = rndstr(10) filename = os.path.join(_filedir, _name) - _webname = "%s%s" % (_webpath, _name) + if _webpath.endswith("/"): + _webname = f"{_webpath}{_name}" + else: + _webname = f"{_webpath}/{_name}" return filename, _webname diff --git a/src/oidcrp/rp_handler.py b/src/oidcrp/rp_handler.py index e359939..957ab2d 100644 --- a/src/oidcrp/rp_handler.py +++ b/src/oidcrp/rp_handler.py @@ -1,4 +1,4 @@ -import hashlib +from inspect import currentframe import logging import sys import traceback @@ -10,6 +10,7 @@ from cryptojwt.utils import as_bytes from oidcmsg import verified_claim_name from oidcmsg.exception import MessageException +from oidcmsg.exception import MissingRequiredAttribute from oidcmsg.exception import NotForMe from oidcmsg.oauth2 import ResponseMessage from oidcmsg.oauth2 import is_error_message @@ -18,6 +19,7 @@ from oidcmsg.oidc import AuthorizationResponse from oidcmsg.oidc import Claims from oidcmsg.oidc import OpenIDSchema +from oidcmsg.oidc import RegistrationRequest from oidcmsg.oidc.session import BackChannelLogoutRequest from oidcmsg.time_util import time_sans_frac @@ -27,8 +29,9 @@ from .defaults import DEFAULT_RP_KEY_DEFS from .exception import OidcServiceError from .oauth2 import Client +from .oauth2 import dynamic_provider_info_discovery +from .oauth2.utils import pick_redirect_uri from .util import add_path -from .util import dynamic_provider_info_discovery from .util import load_registration_response from .util import rndstr @@ -85,7 +88,7 @@ def __init__(self, base_url, client_configs=None, services=None, keyjar=None, else: self.client_configs = client_configs - # keep track on which RP instance that serves with OP + # keep track on which RP instance that serves which OP self.issuer2rp = {} self.hash2issuer = {} self.httplib = http_lib @@ -149,6 +152,9 @@ def init_client(self, issuer): :param issuer: An issuer ID :return: A Client instance """ + + logger.debug(20 * "*" + " init_client " + 20 * "*") + try: _cnf = self.pick_config(issuer) except KeyError: @@ -180,16 +186,21 @@ def init_client(self, issuer): _context.jwks_uri = self.jwks_uri return client - def do_provider_info(self, client=None, state=''): + def do_provider_info(self, + client: Optional[Client]=None, + state: Optional[str]='', + behaviour_args: Optional[dict]=None) -> str: """ Either get the provider info from configuration or through dynamic discovery. + :param behaviour_args: :param client: A Client instance :param state: A key by which the state of the session can be retrieved :return: issuer ID """ + logger.debug(20 * "*" + " do_provider_info " + 20 * "*") if not client: if state: @@ -199,7 +210,7 @@ def do_provider_info(self, client=None, state=''): _context = client.client_get("service_context") if not _context.get('provider_info'): - dynamic_provider_info_discovery(client) + dynamic_provider_info_discovery(client, behaviour_args=behaviour_args) return _context.get('provider_info')['issuer'] else: _pi = _context.get('provider_info') @@ -235,15 +246,23 @@ def do_provider_info(self, client=None, state=''): except KeyError: return _context.get('issuer') - def do_client_registration(self, client=None, iss_id='', state=''): + def do_client_registration(self, client=None, + iss_id: Optional[str] = '', + state: Optional[str] = '', + request_args: Optional[dict] = None, + behaviour_args: Optional[dict] = None): """ Prepare for and do client registration if configured to do so + :param iss_id: Issuer ID + :param behaviour_args: To fine tune behaviour :param client: A Client instance :param state: A key by which the state of the session can be retrieved """ + logger.debug(20 * "*" + " do_client_registration " + 20 * "*") + if not client: if state: client = self.get_client_from_session_key(state) @@ -252,32 +271,41 @@ def do_client_registration(self, client=None, iss_id='', state=''): _context = client.client_get("service_context") _iss = _context.get('issuer') - if not _context.get('redirect_uris'): - # Create the necessary callback URLs - # as a side effect self.hash2issuer is set - callbacks = self.create_callbacks(_iss) - - _context.set('redirect_uris', [ - v for k, v in callbacks.items() if not k.startswith('__')]) - _context.set('callbacks', callbacks) - else: - self.hash2issuer[iss_id] = _iss + self.hash2issuer[iss_id] = _iss # This should only be interesting if the client supports Single Log Out - if _context.post_logout_redirect_uris is None: - _context.post_logout_redirect_uris = [self.base_url] + # if _context.callback.get("post_logout_redirect_uri") is None: + # _context.callback["post_logout_redirect_uri"] = [self.base_url] - if not _context.client_id: - load_registration_response(client) + if not _context.client_id: # means I have to do dynamic client registration + if request_args is None: + request_args = {} - def add_callbacks(self, service_context): - _callbacks = self.create_callbacks(service_context.get('provider_info')['issuer']) - service_context.set('redirect_uris', [ - v for k, v in _callbacks.items() if not k.startswith('__')]) - service_context.set('callbacks', _callbacks) - return _callbacks + if behaviour_args: + _params = RegistrationRequest().parameters() + request_args.update({k: v for k, v in behaviour_args.items() if k in _params}) - def client_setup(self, iss_id='', user=''): + load_registration_response(client, request_args=request_args) + + def do_webfinger(self, user: str) -> Client: + """ + Does OpenID Provider Issuer discovery using webfinger. + + :param user: Identifier for the target End-User that is the subject of the discovery + request. + :return: A Client instance + """ + + logger.debug(20 * "*" + " do_webfinger " + 20 * "*") + + temporary_client = self.init_client('') + temporary_client.do_request('webfinger', resource=user) + return temporary_client + + def client_setup(self, + iss_id: Optional[str] = '', + user: Optional[str] = '', + behaviour_args: Optional[dict] = None) -> Client: """ First if no issuer ID is given then the identifier for the user is used by the webfinger service to try to find the issuer ID. @@ -286,11 +314,14 @@ def client_setup(self, iss_id='', user=''): the necessary information for the client to be able to communicate with the OP/AS that has the provided issuer ID. + :param behaviour_args: To fine tune behaviour :param iss_id: The issuer ID :param user: A user identifier :return: A :py:class:`oidcrp.oidc.Client` instance """ + logger.debug(20 * "*" + " client_setup " + 20 * "*") + logger.info('client_setup: iss_id={}, user={}'.format(iss_id, user)) if not iss_id: @@ -298,8 +329,7 @@ def client_setup(self, iss_id='', user=''): raise ValueError('Need issuer or user') logger.debug("Connecting to previously unknown OP") - temporary_client = self.init_client('') - temporary_client.do_request('webfinger', resource=user) + temporary_client = self.do_webfinger(user) else: temporary_client = None @@ -315,46 +345,39 @@ def client_setup(self, iss_id='', user=''): return client logger.debug("Get provider info") - issuer = self.do_provider_info(client) + issuer = self.do_provider_info(client, behaviour_args=behaviour_args) logger.debug("Do client registration") - self.do_client_registration(client, iss_id) + self.do_client_registration(client, iss_id, behaviour_args=behaviour_args) self.issuer2rp[issuer] = client return client - def create_callbacks(self, issuer): - """ - To mitigate some security issues the redirect_uris should be OP/AS - specific. This method creates a set of redirect_uris unique to the - OP/AS. - - :param issuer: Issuer ID - :return: A set of redirect_uris - """ - _hash = hashlib.sha256() - _hash.update(self.hash_seed) - _hash.update(as_bytes(issuer)) - _hex = _hash.hexdigest() - self.hash2issuer[_hex] = issuer - return { - 'code': "{}/authz_cb/{}".format(self.base_url, _hex), - 'implicit': "{}/authz_im_cb/{}".format(self.base_url, _hex), - 'form_post': "{}/authz_fp_cb/{}".format(self.base_url, _hex), - '__hex': _hex - } + def _get_response_type(self, context, req_args: Optional[dict] = None): + if req_args: + return req_args.get("response_type", context.get('behaviour')['response_types'][0]) + else: + return context.get('behaviour')['response_types'][0] - def init_authorization(self, client=None, state='', req_args=None): + def init_authorization(self, + client: Optional[Client] = None, + state: Optional[str] = '', + req_args: Optional[dict] = None, + behaviour_args: Optional[dict] = None) -> dict: """ Constructs the URL that will redirect the user to the authorization endpoint of the OP/AS. + :param behaviour_args: + :param state: :param client: A Client instance :param req_args: Non-default Request arguments :return: A dictionary with 2 keys: **url** The authorization redirect URL and **state** the key to the session information in the state data store. """ + + logger.debug(20 * "*" + " init_authorization " + 20 * "*") if not client: if state: client = self.get_client_from_session_key(state) @@ -364,10 +387,13 @@ def init_authorization(self, client=None, state='', req_args=None): _context = client.client_get("service_context") _nonce = rndstr(24) + _response_type = self._get_response_type(_context, req_args) request_args = { - 'redirect_uri': _context.get('redirect_uris')[0], + 'redirect_uri': pick_redirect_uri(_context, + request_args=req_args, + response_type=_response_type), 'scope': _context.get('behaviour')['scope'], - 'response_type': _context.get('behaviour')['response_types'][0], + 'response_type': _response_type, 'nonce': _nonce } @@ -387,12 +413,16 @@ def init_authorization(self, client=None, state='', req_args=None): logger.debug('Authorization request args: {}'.format(request_args)) + # if behaviour_args and "request_param" not in behaviour_args: + # _pi = _context.get("provider_info") + _srv = client.get_service('authorization') - _info = _srv.get_request_parameters(request_args=request_args) + _info = _srv.get_request_parameters(request_args=request_args, + behaviour_args=behaviour_args) logger.debug('Authorization info: {}'.format(_info)) return {'url': _info['url'], 'state': _state} - def begin(self, issuer_id='', user_id=''): + def begin(self, issuer_id='', user_id='', req_args=None, behaviour_args=None): """ This is the first of the 3 high level methods that most users of this library should confine them self to use. @@ -401,6 +431,8 @@ def begin(self, issuer_id='', user_id=''): Once it has the client it will construct an Authorization request. + :param behaviour_args: + :param req_args: :param issuer_id: Issuer ID :param user_id: A user identifier :return: A dictionary containing **url** the URL that will redirect the @@ -409,10 +441,10 @@ def begin(self, issuer_id='', user_id=''): """ # Get the client instance that has been assigned to this issuer - client = self.client_setup(issuer_id, user_id) + client = self.client_setup(issuer_id, user_id, behaviour_args=behaviour_args) try: - res = self.init_authorization(client) + res = self.init_authorization(client, req_args=req_args, behaviour_args=behaviour_args) except Exception: message = traceback.format_exception(*sys.exc_info()) logger.error(message) @@ -447,7 +479,8 @@ def get_client_authn_method(client, endpoint): """ if endpoint == 'token_endpoint': try: - am = client.client_get("service_context").get('behaviour')['token_endpoint_auth_method'] + am = client.client_get("service_context").get('behaviour')[ + 'token_endpoint_auth_method'] except KeyError: return '' else: @@ -456,7 +489,7 @@ def get_client_authn_method(client, endpoint): else: # a list return am[0] - def get_access_token(self, state, client: Optional[Client] = None): + def get_tokens(self, state, client: Optional[Client] = None): """ Use the 'accesstoken' service to get an access token from the OP/AS. @@ -466,7 +499,7 @@ def get_access_token(self, state, client: Optional[Client] = None): :return: A :py:class:`oidcmsg.oidc.AccessTokenResponse` or :py:class:`oidcmsg.oauth2.AuthorizationResponse` """ - logger.debug('get_accesstoken') + logger.debug(20 * "*" + " get_tokens " + 20 * "*") if client is None: client = self.get_client_from_session_key(state) @@ -512,6 +545,9 @@ def refresh_access_token(self, state, client=None, scope=''): :param scope: What the returned token should be valid for. :return: A :py:class:`oidcmsg.oidc.AccessTokenResponse` instance """ + + logger.debug(20 * "*" + " refresh_access_token " + 20 * "*") + if scope: req_args = {'scope': scope} else: @@ -548,6 +584,9 @@ def get_user_info(self, state, client=None, access_token='', :param kwargs: Extra keyword arguments :return: A :py:class:`oidcmsg.oidc.OpenIDSchema` instance """ + + logger.debug(20 * "*" + " get_user_info " + 20 * "*") + if client is None: client = self.get_client_from_session_key(state) @@ -574,33 +613,36 @@ def userinfo_in_id_token(id_token): :param id_token: An :py:class:`oidcmsg.oidc.IDToken` instance :return: A dictionary with user information """ - res = dict([(k, id_token[k]) for k in OpenIDSchema.c_param.keys() if - k in id_token]) + res = dict([(k, id_token[k]) for k in OpenIDSchema.c_param.keys() if k in id_token]) res.update(id_token.extra()) return res - def finalize_auth(self, client, issuer, response): + def finalize_auth(self, client, issuer: str, response: dict, + behaviour_args: Optional[dict] = None): """ Given the response returned to the redirect_uri, parse and verify it. + :param behaviour_args: For fine tuning behaviour :param client: A Client instance :param issuer: An Issuer ID :param response: The authorization response as a dictionary :return: An :py:class:`oidcmsg.oidc.AuthorizationResponse` or :py:class:`oidcmsg.oauth2.AuthorizationResponse` instance. """ + + logger.debug(20 * "*" + " finalize_auth " + 20 * "*") + _srv = client.get_service('authorization') try: - authorization_response = _srv.parse_response(response, - sformat='dict') + authorization_response = _srv.parse_response(response, sformat='dict', + behaviour_args=behaviour_args) except Exception as err: logger.error('Parsing authorization_response: {}'.format(err)) message = traceback.format_exception(*sys.exc_info()) logger.error(message) raise else: - logger.debug( - 'Authz response: {}'.format(authorization_response.to_dict())) + logger.debug('Authz response: {}'.format(authorization_response.to_dict())) if is_error_message(authorization_response): return authorization_response @@ -621,13 +663,16 @@ def finalize_auth(self, client, issuer, response): authorization_response['state']) return authorization_response - def get_access_and_id_token(self, authorization_response=None, state='', - client=None): + def get_access_and_id_token(self, authorization_response=None, + state: Optional[str] = '', + client: Optional[object] = None, + behaviour_args: Optional[dict] = None): """ There are a number of services where access tokens and ID tokens can occur in the response. This method goes through the possible places based on the response_type the client uses. + :param behaviour_args: For fine tuning behaviour :param authorization_response: The Authorization response :param state: The state key (the state parameter in the authorization request) @@ -636,6 +681,8 @@ def get_access_and_id_token(self, authorization_response=None, state='', was returned otherwise None. """ + logger.debug(20 * "*" + " get_access_and_id_token " + 20 * "*") + if client is None: client = self.get_client_from_session_key(state) @@ -652,8 +699,7 @@ def get_access_and_id_token(self, authorization_response=None, state='', if not state: state = authorization_response['state'] - authreq = _context.state.get_item( - AuthorizationRequest, 'auth_request', state) + authreq = _context.state.get_item(AuthorizationRequest, 'auth_request', state) _resp_type = set(authreq['response_type']) access_token = None @@ -665,24 +711,31 @@ def get_access_and_id_token(self, authorization_response=None, state='', if _resp_type in [{'token'}, {'id_token', 'token'}, {'code', 'token'}, {'code', 'id_token', 'token'}]: access_token = authorization_response["access_token"] - elif _resp_type in [{'code'}, {'code', 'id_token'}]: + if behaviour_args: + if behaviour_args.get("collect_tokens", False): + # get what you can from the token endpoint + token_resp = self.get_tokens(state, client=client) + if is_error_message(token_resp): + return False, "Invalid response %s." % token_resp["error"] + # Now which access_token should I use + access_token = token_resp["access_token"] + # May or may not get an ID Token + id_token = token_resp.get('__verified_id_token') + elif _resp_type in [{'code'}, {'code', 'id_token'}]: # get the access token - token_resp = self.get_access_token(state, client=client) + token_resp = self.get_tokens(state, client=client) if is_error_message(token_resp): return False, "Invalid response %s." % token_resp["error"] access_token = token_resp["access_token"] - - try: - id_token = token_resp['__verified_id_token'] - except KeyError: - pass + # May or may not get an ID Token + id_token = token_resp.get('__verified_id_token') return {'access_token': access_token, 'id_token': id_token} # noinspection PyUnusedLocal - def finalize(self, issuer, response): + def finalize(self, issuer, response, behaviour_args: Optional[dict] = None): """ The third of the high level methods that a user of this Class should know about. @@ -690,6 +743,7 @@ def finalize(self, issuer, response): callback URL there might be a number of services that the client should use. Which one those are are defined by the client configuration. + :param behaviour_args: For fine tuning :param issuer: Who sent the response :param response: The Authorization response as a dictionary :returns: A dictionary with two claims: @@ -701,6 +755,9 @@ def finalize(self, issuer, response): client = self.issuer2rp[issuer] + if behaviour_args: + logger.debug(f"Finalize behaviour args: {behaviour_args}") + authorization_response = self.finalize_auth(client, issuer, response) if is_error_message(authorization_response): return { @@ -709,8 +766,10 @@ def finalize(self, issuer, response): } _state = authorization_response['state'] - token = self.get_access_and_id_token(authorization_response, - state=_state, client=client) + token = self.get_access_and_id_token(authorization_response, state=_state, client=client, + behaviour_args=behaviour_args) + _id_token = token.get("id_token") + logger.debug(f"ID Token: {_id_token}") if client.client_get("service", "userinfo") and token['access_token']: inforesp = self.get_user_info( @@ -723,8 +782,8 @@ def finalize(self, issuer, response): 'state': _state } - elif token['id_token']: # look for it in the ID Token - inforesp = self.userinfo_in_id_token(token['id_token']) + elif _id_token: # look for it in the ID Token + inforesp = self.userinfo_in_id_token(_id_token) else: inforesp = {} @@ -732,8 +791,7 @@ def finalize(self, issuer, response): _context = client.client_get("service_context") try: - _sid_support = _context.get('provider_info')[ - 'backchannel_logout_session_supported'] + _sid_support = _context.get('provider_info')['backchannel_logout_session_supported'] except KeyError: try: _sid_support = _context.get('provider_info')[ @@ -741,21 +799,25 @@ def finalize(self, issuer, response): except: _sid_support = False - if _sid_support: + if _sid_support and _id_token: try: - sid = token['id_token']['sid'] + sid = _id_token['sid'] except KeyError: pass else: _context.state.store_sid2state(sid, _state) - _context.state.store_sub2state(token['id_token']['sub'], _state) + if _id_token: + _context.state.store_sub2state(_id_token['sub'], _state) + else: + _context.state.store_sub2state(inforesp['sub'], _state) return { 'userinfo': inforesp, 'state': authorization_response['state'], 'token': token['access_token'], - 'id_token': token['id_token'] + 'id_token': _id_token, + 'session_state': authorization_response.get('session_state', '') } def has_active_authentication(self, state): @@ -823,7 +885,9 @@ def get_valid_access_token(self, state): else: raise OidcServiceError('No valid access token') - def logout(self, state, client=None, post_logout_redirect_uri=''): + def logout(self, state: str, + client: Optional[Client] = None, + post_logout_redirect_uri: Optional[str] = '') -> dict: """ Does a RP initiated logout from an OP. After logout the user will be redirect by the OP to a URL of choice (post_logout_redirect_uri). @@ -834,6 +898,9 @@ def logout(self, state, client=None, post_logout_redirect_uri=''): should be used :return: A US """ + + logger.debug(20 * "*" + " logout " + 20 * "*") + if client is None: client = self.get_client_from_session_key(state) @@ -852,8 +919,23 @@ def logout(self, state, client=None, post_logout_redirect_uri=''): resp = srv.get_request_parameters(state=state, request_args=request_args) + logger.debug(f"EndSession Request: {resp}") return resp + def close(self, state: str, + issuer: Optional[str] = '', + post_logout_redirect_uri: Optional[str] = '') -> dict: + + logger.debug(20 * "*" + " close " + 20 * "*") + + if issuer: + client = self.issuer2rp[issuer] + else: + client = self.get_client_from_session_key(state) + + return self.logout(state=state, client=client, + post_logout_redirect_uri=post_logout_redirect_uri) + def clear_session(self, state): client = self.get_client_from_session_key(state) client.client_get("service_context").state.remove_state(state) @@ -867,8 +949,10 @@ def backchannel_logout(client, request='', request_args=None): """ if request: req = BackChannelLogoutRequest().from_urlencoded(as_unicode(request)) - else: + elif request_args: req = BackChannelLogoutRequest(**request_args) + else: + raise MissingRequiredAttribute('logout_token') _context = client.client_get("service_context") kwargs = { @@ -879,10 +963,13 @@ def backchannel_logout(client, request='', request_args=None): "id_token_signed_response_alg", "RS256") } + logger.debug(f"(backchannel_logout) Verifying request using: {kwargs}") try: req.verify(**kwargs) except (MessageException, ValueError, NotForMe) as err: raise MessageException('Bogus logout request: {}'.format(err)) + else: + logger.debug("Request verified OK") # Find the subject through 'sid' or 'sub' sub = req[verified_claim_name('logout_token')].get('sub') diff --git a/src/oidcrp/service.py b/src/oidcrp/service.py index 0e350e2..3eaa025 100644 --- a/src/oidcrp/service.py +++ b/src/oidcrp/service.py @@ -88,8 +88,9 @@ def __init__(self, self.pre_construct = [] self.post_construct = [] self.construct_extra_headers = [] + self.post_parse_process = [] - def gather_request_args(self, **kwargs): + def gather_request_args(self,**kwargs): """ Go through the attributes that the message class can contain and add values if they are missing but exists in the client info or @@ -204,8 +205,7 @@ def construct(self, request_args=None, **kwargs): # run the pre_construct methods. Will return a possibly new # set of request arguments but also a set of arguments to # be used by the post_construct methods. - request_args, post_args = self.do_pre_construct(request_args, - **kwargs) + request_args, post_args = self.do_pre_construct(request_args, **kwargs) # If 'state' appears among the keyword argument and is not # expected to appear in the request, remove it. @@ -222,6 +222,10 @@ def construct(self, request_args=None, **kwargs): # message type request = self.msg_type(**_args) + _behaviour_args = kwargs.get("behaviour_args") + if _behaviour_args: + post_args.update(_behaviour_args) + return self.do_post_construct(request, **post_args) def init_authentication_method(self, request, authn_method, @@ -340,7 +344,7 @@ def get_headers(self, return _headers def get_request_parameters(self, request_args=None, method="", - request_body_type="", authn_method='', **kwargs): + request_body_type="", authn_method='', **kwargs) -> dict: """ Builds the request message and constructs the HTTP headers. @@ -440,7 +444,9 @@ def post_parse_response(self, response, **kwargs): """ return response - def gather_verify_arguments(self): + def gather_verify_arguments(self, + response: Optional[Union[dict, Message]] = None, + behaviour_args: Optional[dict] = None): """ Need to add some information before running verify() @@ -497,7 +503,11 @@ def _do_response(self, info, sformat, **kwargs): raise return resp - def parse_response(self, info, sformat="", state="", **kwargs): + def parse_response(self, info, + sformat: Optional[str] = "", + state: Optional[str] = "", + behaviour_args: Optional[dict] = None, + **kwargs): """ This the start of a pipeline that will: @@ -509,8 +519,8 @@ def parse_response(self, info, sformat="", state="", **kwargs): 3 runs the do_post_parse_response method iff the response was not an error response. - :param info: The response, can be either in a JSON or an urlencoded - format + :param behaviour_args: + :param info: The response, can be either in a JSON or an urlencoded format :param sformat: Which serialization that was used :param state: The state :param kwargs: Extra key word arguments @@ -550,11 +560,10 @@ def parse_response(self, info, sformat="", state="", **kwargs): if is_error_message(resp): LOGGER.debug('Error response: %s', resp) else: - vargs = self.gather_verify_arguments() + vargs = self.gather_verify_arguments(response=resp, behaviour_args=behaviour_args) LOGGER.debug("Verify response with %s", vargs) try: - # verify the message. If something is wrong an exception is - # thrown + # verify the message. If something is wrong an exception is thrown resp.verify(**vargs) except Exception as err: LOGGER.error( diff --git a/src/oidcrp/service_context.py b/src/oidcrp/service_context.py index 0247162..7a6753c 100644 --- a/src/oidcrp/service_context.py +++ b/src/oidcrp/service_context.py @@ -11,6 +11,7 @@ from cryptojwt.utils import as_bytes from oidcmsg.context import OidcContext from oidcmsg.oidc import RegistrationRequest +from oidcmsg.util import rndstr from oidcrp.state_interface import StateInterface @@ -97,10 +98,12 @@ class ServiceContext(OidcContext): "client_secret_expires_at": 0, 'clock_skew': None, "config": None, + "hash_seed": b'', + "hash2issuer": None, "httpc_params": None, 'issuer': None, "kid": None, - "post_logout_redirect_uris": [], + "post_logout_redirect_uris": None, 'provider_info': None, 'redirect_uris': None, "requests_dir": None, @@ -126,6 +129,7 @@ def __init__(self, base_url="", keyjar=None, config=None, state=None, **kwargs): self.client_preferences = {} self.args = {} self.add_on = {} + self.hash2issuer = {} self.httpc_params = {} self.issuer = "" self.client_id = "" @@ -145,7 +149,7 @@ def __init__(self, base_url="", keyjar=None, config=None, state=None, **kwargs): 'behaviour', 'callback', 'issuer']: _val = config.get(param, _def_value[param]) self.set(param, _val) - if param == 'client_secret': + if param == 'client_secret' and _val: self.keyjar.add_symmetric('', _val) if not self.issuer: @@ -156,6 +160,12 @@ def __init__(self, base_url="", keyjar=None, config=None, state=None, **kwargs): except KeyError: self.clock_skew = 15 + _seed = config.get("hash_seed") + if _seed: + self.hash_seed = as_bytes(_seed) + else: + self.hash_seed = as_bytes(rndstr(32)) + for key, val in kwargs.items(): setattr(self, key, val) diff --git a/src/oidcrp/state_interface.py b/src/oidcrp/state_interface.py index a377da1..bfc7100 100644 --- a/src/oidcrp/state_interface.py +++ b/src/oidcrp/state_interface.py @@ -382,7 +382,7 @@ def remove_state(self, state): :param state: Key to the state """ del self._db[state] - refs = json.loads(self._db_db.get("ref{}ref".format(state))) + refs = json.loads(self._db.get("ref{}ref".format(state))) if refs: for xtyp, _val in refs.items(): del self._db[KEY_PATTERN[xtyp].format(_val)] diff --git a/src/oidcrp/util.py b/src/oidcrp/util.py index eb774cb..17e0d51 100755 --- a/src/oidcrp/util.py +++ b/src/oidcrp/util.py @@ -9,6 +9,7 @@ import ssl import string import sys +from typing import Optional from urllib.parse import parse_qs from urllib.parse import urlsplit from urllib.parse import urlunsplit @@ -534,17 +535,17 @@ def add_path(url, path): return '{}/{}'.format(url, path) -def load_registration_response(client): +def load_registration_response(client, request_args=None): """ If the client has been statically registered that information must be provided during the configuration. If expected to be - done dynamically. This method will do dynamic client registration. + done dynamically this method will do dynamic client registration. :param client: A :py:class:`oidcrp.oidc.Client` instance """ if not client.client_get("service_context").get('client_id'): try: - response = client.do_request('registration') + response = client.do_request('registration', request_args=request_args) except KeyError: raise ConfigurationError('No registration info') except Exception as err: @@ -553,26 +554,3 @@ def load_registration_response(client): else: if 'error' in response: raise OidcServiceError(response.to_json()) - - -def dynamic_provider_info_discovery(client): - """ - This is about performing dynamic Provider Info discovery - - :param client: A :py:class:`oidcrp.oidc.Client` instance - """ - try: - client.get_service('provider_info') - except KeyError: - raise ConfigurationError( - 'Can not do dynamic provider info discovery') - else: - _context = client.client_get("service_context") - try: - _context.set('issuer', _context.config['srv_discovery_url']) - except KeyError: - pass - - response = client.do_request('provider_info') - if is_error_message(response): - raise OidcServiceError(response['error']) diff --git a/tests/pub_client.jwks b/tests/pub_client.jwks index a57e904..d16a636 100644 --- a/tests/pub_client.jwks +++ b/tests/pub_client.jwks @@ -1 +1 @@ -{"keys": [{"kty": "RSA", "use": "sig", "kid": "SUswNi1MRFlDT0Y2YjU1Z1RfQlo2S3dEa3FTTkV3LThFcnhDTHF5elk2VQ", "n": "0UkUx2ewKyc-XJ1o0ToyGjws_JybAMZj2oYjsPyyvQ_T5dhZ2VmRRRkhsaVJ2xE_GGc7mSG0IjmGFyXp5y0w4mJBcsAEE5-8eBTvQdYIryjW74r3jt6Fi4Hlm1yFMTie3apv8mw79BUj-jT0kh3_m-FiKKUvLsq45DcLtTJ4cx7Ize37dl1sFSpQcoYMk7eiUEM8fiNboiVwvBYNAWVMkUM-LnVUPm3UjvKp0LihYEkZFWOxmuQmj2x25SFUkjus38ERrRqJQBZduxdBHFrWtWg8yOA53BkMU0FFg_r0H3ctl-5GaKw-BWlogU4qXnsq85xy0EoenRk7FPV8g_ulJw", "e": "AQAB"}, {"kty": "EC", "use": "sig", "kid": "NC1pdGRQN002bWM3bk1xX2R0SktscElqbFdtN29ITDV2WVd2b0hOYzREVQ", "crv": "P-256", "x": "kK7Qp1woSerI7rUOAwW_4sU6ZmwV3wwXKX3VU-v2fMI", "y": "iPWd_Pjq6EjxYy08KNFZ3PxhEwgWHgAQTTknlKMKJA0"}]} \ No newline at end of file +{"keys": [{"kty": "RSA", "use": "sig", "kid": "SUswNi1MRFlDT0Y2YjU1Z1RfQlo2S3dEa3FTTkV3LThFcnhDTHF5elk2VQ", "e": "AQAB", "n": "0UkUx2ewKyc-XJ1o0ToyGjws_JybAMZj2oYjsPyyvQ_T5dhZ2VmRRRkhsaVJ2xE_GGc7mSG0IjmGFyXp5y0w4mJBcsAEE5-8eBTvQdYIryjW74r3jt6Fi4Hlm1yFMTie3apv8mw79BUj-jT0kh3_m-FiKKUvLsq45DcLtTJ4cx7Ize37dl1sFSpQcoYMk7eiUEM8fiNboiVwvBYNAWVMkUM-LnVUPm3UjvKp0LihYEkZFWOxmuQmj2x25SFUkjus38ERrRqJQBZduxdBHFrWtWg8yOA53BkMU0FFg_r0H3ctl-5GaKw-BWlogU4qXnsq85xy0EoenRk7FPV8g_ulJw"}, {"kty": "EC", "use": "sig", "kid": "NC1pdGRQN002bWM3bk1xX2R0SktscElqbFdtN29ITDV2WVd2b0hOYzREVQ", "crv": "P-256", "x": "kK7Qp1woSerI7rUOAwW_4sU6ZmwV3wwXKX3VU-v2fMI", "y": "iPWd_Pjq6EjxYy08KNFZ3PxhEwgWHgAQTTknlKMKJA0"}]} \ No newline at end of file diff --git a/tests/request123456.jwt b/tests/request123456.jwt index f72cdd1..2e8824e 100644 --- a/tests/request123456.jwt +++ b/tests/request123456.jwt @@ -1 +1 @@ -eyJhbGciOiJSUzI1NiIsImtpZCI6IlNVc3dOaTFNUkZsRFQwWTJZalUxWjFSZlFsbzJTM2RFYTNGVFRrVjNMVGhGY25oRFRIRjVlbGsyVlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2xpL2F1dGh6X2NiIiwgInNjb3BlIjogIm9wZW5pZCIsICJub25jZSI6ICJWNVFGVUxMWGpBS0lOS3V3Y2JlVU43OHRuWkZ2MGloYyIsICJjbGllbnRfaWQiOiAiY2xpZW50X2lkIiwgImlzcyI6ICJjbGllbnRfaWQiLCAiaWF0IjogMTYxNzg3MDI0NiwgImF1ZCI6IFsiaHR0cHM6Ly9leGFtcGxlLmNvbSJdfQ.itHjMHec6T2py2zDxaAS11tFAYdsTZ-SY3AV2j5SCTHPfk9CkXa6g6s_t0oVvcdzLVXaqQ1iU9WqOeirwj4UDxCTHulPzBOcBuC3WbCki2HT9EPI88Lov-kCuz_4juw97lIyU1BkaofaZJSjRcEW_fOY0KuP7BKIDmthTLTWxEGMoMXmXcs_QD13tz0IWjrkqjwbcKjbUxrvGHJbOnOBGwgHPPm46otDMO-hQrtvTOGz4PbdD5XZ4imDV0bJkx72ITpAfM8iODmd9sKrOWkEZEhsRG1ugXa8RgOPNLsLTLjTpzWiVeczJHOiE5H4EC-uwzXiRDiShq-q7VbTKocyYw \ No newline at end of file +eyJhbGciOiJSUzI1NiIsImtpZCI6IlNVc3dOaTFNUkZsRFQwWTJZalUxWjFSZlFsbzJTM2RFYTNGVFRrVjNMVGhGY25oRFRIRjVlbGsyVlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2xpL2F1dGh6X2NiIiwgInNjb3BlIjogIm9wZW5pZCIsICJub25jZSI6ICJvWkpBNTRnZTVaUndNalkwOVVLVnpwYkx5MEdNUEwwaCIsICJjbGllbnRfaWQiOiAiY2xpZW50X2lkIiwgImlzcyI6ICJjbGllbnRfaWQiLCAiaWF0IjogMTYzMzU5NTc4OSwgImF1ZCI6IFsiaHR0cHM6Ly9leGFtcGxlLmNvbSJdfQ.KVMPK6leJ5pEXnJ0jXiXu21U176IU9iwkT4FkQV_33jGYTsgdqCqXw5XHR1ciixdcH2cWf0SzTPOgIzGsI4NJiPNdR9xOusYRyYKZciXHq85nrM7fr7dEPaVntWCU6uadH0MNHWCcq2FyBdz2YYDuiFPUXoxkFbfWZoo_jVMAWLxGQtGEitniI49qo0zbeSFck4hBmEtQTUOrGQvg_CjkSZb5oNb5rt_X5T-ZSK9y3AeKru4HLSQRkWj-oD-Fgd60Sm3XqfLQXrx26lk4a8ORah01BMmMsi5jeIUbOTthhhglZhMwoI9xCZ57I4SF7870-PrinIByW8d2keA1-LipQ \ No newline at end of file diff --git a/tests/test_13_oidc_service.py b/tests/test_13_oidc_service.py index c3a30dc..e3f4450 100644 --- a/tests/test_13_oidc_service.py +++ b/tests/test_13_oidc_service.py @@ -7,6 +7,7 @@ from cryptojwt.jwt import JWT from cryptojwt.key_jar import build_keyjar from cryptojwt.key_jar import init_key_jar +from oidcmsg.exception import MissingRequiredAttribute from oidcmsg.oidc import AccessTokenRequest from oidcmsg.oidc import AccessTokenResponse from oidcmsg.oidc import AuthorizationRequest @@ -68,7 +69,7 @@ def create_request(self): } entity = Entity(services=DEFAULT_OIDC_SERVICES, keyjar=CLI_KEY, config=client_config) entity.client_get("service_context").issuer = 'https://example.com' - self.service = entity.client_get("service",'authorization') + self.service = entity.client_get("service", 'authorization') def test_construct(self): req_args = { @@ -137,13 +138,12 @@ def test_request_init(self): def test_request_init_request_method(self): req_args = {'response_type': 'code', 'state': 'state'} self.service.endpoint = 'https://example.com/authorize' - _info = self.service.get_request_parameters(request_args=req_args, - request_method='value') + _info = self.service.get_request_parameters(request_args=req_args, request_method='value') assert set(_info.keys()) == {'url', 'method', 'request'} msg = AuthorizationRequest().from_urlencoded( self.service.get_urlinfo(_info['url'])) - assert set(msg.to_dict()) == {'client_id', 'redirect_uri', 'request', - 'response_type', 'state', 'scope', 'nonce'} + assert set(msg.to_dict()) == {'client_id', 'redirect_uri', 'request', 'response_type', + 'scope'} _jws = jws.factory(msg['request']) assert _jws _resp = _jws.verify_compact( @@ -203,7 +203,7 @@ def test_update_service_context_with_idtoken_wrong_nonce(self): idt = JWT(ISS_KEY, iss=ISS, lifetime=3600) payload = { 'sub': '123456789', 'aud': ['client_id'], - 'nonce': 'nonce' + 'nonce': 'noice' } # have to calculate c_hash alg = 'RS256' @@ -212,9 +212,8 @@ def test_update_service_context_with_idtoken_wrong_nonce(self): _idt = idt.pack(payload) resp = AuthorizationResponse(state='state', code='code', id_token=_idt) - resp = self.service.parse_response(resp.to_urlencoded()) - with pytest.raises(ParameterError): - self.service.update_service_context(resp, 'state2') + with pytest.raises(ValueError): + self.service.parse_response(resp.to_urlencoded()) def test_update_service_context_with_idtoken_missing_nonce(self): req_args = {'response_type': 'code', 'state': 'state', 'nonce': 'nonce'} @@ -230,9 +229,8 @@ def test_update_service_context_with_idtoken_missing_nonce(self): _idt = idt.pack(payload) resp = AuthorizationResponse(state='state', code='code', id_token=_idt) - resp = self.service.parse_response(resp.to_urlencoded()) - with pytest.raises(ValueError): - self.service.update_service_context(resp, 'state') + with pytest.raises(MissingRequiredAttribute): + self.service.parse_response(resp.to_urlencoded()) @pytest.mark.parametrize("allow_sign_alg_none", [True, False]) def test_allow_unsigned_idtoken(self, allow_sign_alg_none): @@ -241,14 +239,14 @@ def test_allow_unsigned_idtoken(self, allow_sign_alg_none): self.service.get_request_parameters(request_args=req_args) # Build an ID Token idt = JWT(ISS_KEY, iss=ISS, lifetime=3600, sign_alg='none') - payload = {'sub': '123456789', 'aud': ['client_id']} + payload = {'sub': '123456789', 'aud': ['client_id'], 'nonce': req_args['nonce']} _idt = idt.pack(payload) self.service.client_get("service_context").behaviour["verify_args"] = { "allow_sign_alg_none": allow_sign_alg_none } resp = AuthorizationResponse(state='state', code='code', id_token=_idt) if allow_sign_alg_none: - resp = self.service.parse_response(resp.to_urlencoded()) + self.service.parse_response(resp.to_urlencoded()) else: with pytest.raises(UnsupportedAlgorithm): self.service.parse_response(resp.to_urlencoded()) @@ -267,7 +265,7 @@ def create_request(self): } entity = Entity(keyjar=CLI_KEY, config=client_config, services=DEFAULT_OIDC_SERVICES) entity.client_get("service_context").issuer = 'https://example.com' - self.service = entity.client_get("service",'authorization') + self.service = entity.client_get("service", 'authorization') def test_construct_code(self): req_args = { @@ -318,7 +316,7 @@ def create_request(self): } entity = Entity(keyjar=CLI_KEY, config=client_config, services=DEFAULT_OIDC_SERVICES) entity.client_get("service_context").issuer = 'https://example.com' - self.service = entity.client_get("service",'accesstoken') + self.service = entity.client_get("service", 'accesstoken') # add some history auth_request = AuthorizationRequest( @@ -409,7 +407,7 @@ def create_service(self): } entity = Entity(keyjar=CLI_KEY, config=client_config, services=DEFAULT_OIDC_SERVICES) entity.client_get("service_context").issuer = 'https://example.com' - self.service = entity.client_get("service",'provider_info') + self.service = entity.client_get("service", 'provider_info') def test_construct(self): _req = self.service.construct() @@ -601,7 +599,7 @@ def create_request(self): } entity = Entity(keyjar=CLI_KEY, config=client_config, services=DEFAULT_OIDC_SERVICES) entity.client_get("service_context").issuer = 'https://example.com' - self.service = entity.client_get("service",'registration') + self.service = entity.client_get("service", 'registration') def test_construct(self): _req = self.service.construct() @@ -616,14 +614,59 @@ def test_config_with_post_logout(self): assert len(_req) == 5 assert 'post_logout_redirect_uris' in _req - def test_config_with_required_request_uri(self): - _pi = self.service.client_get("service_context").provider_info - _pi['require_request_uri_registration'] = True - self.service.client_get("service_context").provider_info = _pi - _req = self.service.construct() - assert isinstance(_req, RegistrationRequest) - assert len(_req) == 5 - assert 'request_uris' in _req + +def test_config_with_required_request_uri(): + client_config = { + 'client_id': 'client_id', 'client_secret': 'a longesh password', + 'redirect_uris': ['https://example.com/cli/authz_cb'], + 'issuer': ISS, + 'requests_dir': 'requests', + 'base_url': 'https://example.com/cli/', + 'client_preferences': { + "request_uri_usable": True + } + } + entity = Entity(keyjar=CLI_KEY, config=client_config, services=DEFAULT_OIDC_SERVICES) + entity.client_get("service_context").issuer = 'https://example.com' + service = entity.client_get("service", 'registration') + _context = service.client_get("service_context") + + _pi = _context.provider_info + _pi['require_request_uri_registration'] = True + _context.config["client_preferences"]["request_uri_usable"] = True + _req = service.construct() + assert isinstance(_req, RegistrationRequest) + assert len(_req) == 5 + assert 'request_uris' in _req + + +def test_config_logout_uri(): + client_config = { + 'client_id': 'client_id', 'client_secret': 'a longesh password', + 'redirect_uris': ['https://example.com/cli/authz_cb'], + 'issuer': ISS, + 'requests_dir': 'requests', + 'base_url': 'https://example.com/cli/', + 'client_preferences': { + "request_uri_usable": True + } + } + entity = Entity(keyjar=CLI_KEY, config=client_config, services=DEFAULT_OIDC_SERVICES) + entity.client_get("service_context").issuer = 'https://example.com' + service = entity.client_get("service", 'registration') + _context = service.client_get("service_context") + + _pi = _context.provider_info + _pi['require_request_uri_registration'] = True + _pi['frontchannel_logout_supported'] = True + _context.config["client_preferences"]["request_uri_usable"] = True + _context.config["client_preferences"]["frontchannel_logout_usable"] = True + _req = service.construct() + assert isinstance(_req, RegistrationRequest) + assert len(_req) == 7 + assert 'request_uris' in _req + assert 'frontchannel_logout_uri' in _req + assert 'post_logout_redirect_uris' in _req class TestUserInfo(object): @@ -638,7 +681,7 @@ def create_request(self): } entity = Entity(keyjar=CLI_KEY, config=client_config, services=DEFAULT_OIDC_SERVICES) entity.client_get("service_context").issuer = 'https://example.com' - self.service = entity.client_get("service",'userinfo') + self.service = entity.client_get("service", 'userinfo') entity.client_get("service_context").behaviour = { 'userinfo_signed_response_alg': 'RS256', @@ -780,7 +823,7 @@ def create_request(self): }} entity = Entity(keyjar=CLI_KEY, config=client_config, services=services) entity.client_get("service_context").issuer = 'https://example.com' - self.service = entity.client_get("service",'check_session') + self.service = entity.client_get("service", 'check_session') def test_construct(self): _state_interface = self.service.client_get("service_context").state @@ -809,7 +852,7 @@ def create_request(self): }} entity = Entity(keyjar=CLI_KEY, config=client_config, services=services) entity.client_get("service_context").issuer = 'https://example.com' - self.service = entity.client_get("service",'check_id') + self.service = entity.client_get("service", 'check_id') def test_construct(self): _state_interface = self.service.client_get("service_context").state @@ -839,7 +882,7 @@ def create_request(self): }} entity = Entity(keyjar=CLI_KEY, config=client_config, services=services) entity.client_get("service_context").issuer = 'https://example.com' - self.service = entity.client_get("service",'end_session') + self.service = entity.client_get("service", 'end_session') def test_construct(self): self.service.client_get("service_context").state.store_item( @@ -847,9 +890,8 @@ def test_construct(self): 'token_response', 'abcde') _req = self.service.construct(state='abcde') assert isinstance(_req, EndSessionRequest) - assert len(_req) == 3 - assert set(_req.keys()) == {'state', 'id_token_hint', - 'post_logout_redirect_uri'} + assert len(_req) == 2 + assert set(_req.keys()) == {'state', 'id_token_hint'} def test_authz_service_conf(): @@ -880,7 +922,7 @@ def test_authz_service_conf(): } entity = Entity(keyjar=CLI_KEY, config=client_config, services=services) entity.client_get("service_context").issuer = 'https://example.com' - service = entity.client_get("service",'authorization') + service = entity.client_get("service", 'authorization') req = service.construct() assert 'claims' in req @@ -900,7 +942,7 @@ def test_add_jwks_uri_or_jwks_0(): } entity = Entity(keyjar=CLI_KEY, config=client_config, services=DEFAULT_OIDC_SERVICES) entity.client_get("service_context").issuer = 'https://example.com' - service = entity.client_get("service",'registration') + service = entity.client_get("service", 'registration') req_args, post_args = add_jwks_uri_or_jwks({}, service) assert req_args['jwks_uri'] == 'https://example.com/jwks/jwks.json' @@ -919,7 +961,7 @@ def test_add_jwks_uri_or_jwks_1(): } } entity = Entity(keyjar=CLI_KEY, config=client_config, services=DEFAULT_OIDC_SERVICES) - service = entity.client_get("service",'registration') + service = entity.client_get("service", 'registration') req_args, post_args = add_jwks_uri_or_jwks({}, service) assert req_args['jwks_uri'] == 'https://example.com/jwks/jwks.json' @@ -938,7 +980,7 @@ def test_add_jwks_uri_or_jwks_2(): } entity = Entity(keyjar=CLI_KEY, config=client_config, jwks_uri='https://example.com/jwks/jwks.json', services=DEFAULT_OIDC_SERVICES) - service = entity.client_get("service",'registration') + service = entity.client_get("service", 'registration') req_args, post_args = add_jwks_uri_or_jwks({}, service) assert req_args['jwks_uri'] == 'https://example.com/jwks/jwks.json' diff --git a/tests/test_14_oidc.py b/tests/test_14_oidc.py index 5ebc41c..42de74b 100755 --- a/tests/test_14_oidc.py +++ b/tests/test_14_oidc.py @@ -108,8 +108,7 @@ def test_construct_refresh_token_request(self): token_response = AccessTokenResponse(refresh_token="refresh_with_me", access_token="access") - _context.state.store_item(token_response, - 'token_response', 'ABCDE') + _context.state.store_item(token_response, 'token_response', 'ABCDE') req_args = {} msg = self.client.client_get("service",'refresh_token').construct( @@ -131,17 +130,14 @@ def test_do_userinfo_request_init(self): state='state' ) - _context.state.store_item(auth_request, - 'auth_request', 'ABCDE') + _context.state.store_item(auth_request, 'auth_request', 'ABCDE') auth_response = AuthorizationResponse(code='access_code') - _context.state.store_item(auth_response, - 'auth_response', 'ABCDE') + _context.state.store_item(auth_response, 'auth_response', 'ABCDE') token_response = AccessTokenResponse(refresh_token="refresh_with_me", access_token="access") - _context.state.store_item(token_response, - 'token_response', 'ABCDE') + _context.state.store_item(token_response, 'token_response', 'ABCDE') _srv = self.client.client_get("service",'userinfo') _srv.endpoint = "https://example.com/userinfo" diff --git a/tests/test_20_rp_handler_oidc.py b/tests/test_20_rp_handler_oidc.py index b74b588..1191cd9 100644 --- a/tests/test_20_rp_handler_oidc.py +++ b/tests/test_20_rp_handler_oidc.py @@ -17,6 +17,7 @@ import responses from oidcrp.entity import Entity +from oidcrp.oidc.registration import add_callbacks from oidcrp.rp_handler import RPHandler BASE_URL = 'https://example.com/rp' @@ -29,7 +30,8 @@ "code id_token token", "code token"], "scope": ["openid", "profile", "email", "address", "phone"], "token_endpoint_auth_method": "client_secret_basic", - "verify_args": {"allow_sign_alg_none": True} + "verify_args": {"allow_sign_alg_none": True}, + "request_parameter_preference": ["request_uri", "request"] } CLIENT_CONFIG = { @@ -143,6 +145,40 @@ "userinfo_endpoint": "https://api.github.com/user" }, + 'services': { + 'authorization': { + 'class': 'oidcrp.oidc.authorization.Authorization', + }, + 'access_token': { + 'class': 'oidcrp.oidc.access_token.AccessToken' + }, + 'userinfo': { + 'class': 'oidcrp.oidc.userinfo.UserInfo', + 'kwargs': {'conf': {'default_authn_method': ''}} + }, + 'refresh_access_token': { + 'class': 'oidcrp.oidc.refresh_access_token.RefreshAccessToken' + } + } + }, + 'github2': { + "issuer": "https://github.com/login/oauth/authorize", + 'client_id': 'eeeeeeeee', + 'client_secret': 'aaaaaaaaaaaaaaaaaaaa', + "redirect_uris": ["{}/authz_cb/github".format(BASE_URL)], + "behaviour": { + "response_types": ["code"], + "scope": ["user", "public_repo"], + "token_endpoint_auth_method": '', + "verify_args": {"allow_sign_alg_none": True} + }, + "provider_info": { + "authorization_endpoint": "https://github.com/login/oauth/authorize", + "token_endpoint": "https://github.com/login/oauth/access_token", + "userinfo_endpoint": "https://api.github.com/user", + "request_parameter_supported": True, + "request_uri_parameter_supported": True + }, 'services': { 'authorization': { 'class': 'oidcrp.oidc.authorization.Authorization' @@ -155,8 +191,7 @@ 'kwargs': {'conf': {'default_authn_method': ''}} }, 'refresh_access_token': { - 'class': 'oidcrp.oidc.refresh_access_token' - '.RefreshAccessToken' + 'class': 'oidcrp.oidc.refresh_access_token.RefreshAccessToken' } } } @@ -265,7 +300,7 @@ def test_do_provider_info(self): # Make sure the service endpoints are set for service_type in ['authorization', 'accesstoken', 'userinfo']: - _srv = client.client_get("service",service_type) + _srv = client.client_get("service", service_type) _endp = client.client_get("service_context").get('provider_info')[_srv.endpoint_name] assert _srv.endpoint == _endp @@ -277,7 +312,7 @@ def test_do_client_registration(self): # only 2 things should have happened assert self.rph.hash2issuer['github'] == issuer - assert client.client_get("service_context").post_logout_redirect_uris == [] + assert client.client_get("service_context").callback.get("post_logout_redirect_uris") is None def test_do_client_setup(self): client = self.rph.client_setup('github') @@ -296,25 +331,29 @@ def test_do_client_setup(self): assert len(keys) == 2 for service_type in ['authorization', 'accesstoken', 'userinfo']: - _srv = client.client_get("service",service_type) + _srv = client.client_get("service", service_type) _endp = _srv.client_get("service_context").get('provider_info')[_srv.endpoint_name] assert _srv.endpoint == _endp assert self.rph.hash2issuer['github'] == _context.get('issuer') def test_create_callbacks(self): - cb = self.rph.create_callbacks('https://op.example.com/') + client = self.rph.init_client('https://op.example.com/') + _srv = client.client_get("service", "registration") + _context = _srv.client_get("service_context") + add_callbacks(_context, []) + + cb = _srv.client_get("service_context").callback - assert set(cb.keys()) == {'code', 'implicit', 'form_post', '__hex'} + assert set(cb.keys()) == {'redirect_uris', 'code', 'implicit', '__hex'} _hash = cb['__hex'] - assert cb['code'] == 'https://example.com/rp/authz_cb/{}'.format(_hash) - assert cb['implicit'] == 'https://example.com/rp/authz_im_cb/{}'.format(_hash) - assert cb['form_post'] == 'https://example.com/rp/authz_fp_cb/{}'.format(_hash) + assert cb['code'] == f'https://example.com/rp/authz_cb/{_hash}' + assert cb['implicit'] == f'https://example.com/rp/authz_im_cb/{_hash}' - assert list(self.rph.hash2issuer.keys()) == [_hash] + assert list(_context.hash2issuer.keys()) == [_hash] - assert self.rph.hash2issuer[_hash] == 'https://op.example.com/' + assert _context.hash2issuer[_hash] == 'https://op.example.com/' def test_begin(self): res = self.rph.begin(issuer_id='github') @@ -367,8 +406,9 @@ def test_finalize_auth(self): state=res['state']) resp = self.rph.finalize_auth(client, _session['iss'], auth_response.to_dict()) assert set(resp.keys()) == {'state', 'code'} - aresp = client.client_get("service_context").state.get_item(AuthorizationResponse, 'auth_response', - res['state']) + aresp = client.client_get("service_context").state.get_item(AuthorizationResponse, + 'auth_response', + res['state']) assert set(aresp.keys()) == {'state', 'code'} def test_get_client_authn_method(self): @@ -385,7 +425,7 @@ def test_get_client_authn_method(self): 'token_endpoint') assert authn_method == 'client_secret_post' - def test_get_access_token(self): + def test_get_tokens(self): res = self.rph.begin(issuer_id='github') _session = self.rph.get_session_information(res['state']) client = self.rph.issuer2rp[_session['iss']] @@ -417,14 +457,14 @@ def test_get_access_token(self): with responses.RequestsMock() as rsps: rsps.add("POST", _url, body=at.to_json(), adding_headers={"Content-Type": "application/json"}, status=200) - client.client_get("service",'accesstoken').endpoint = _url + client.client_get("service", 'accesstoken').endpoint = _url auth_response = AuthorizationResponse(code='access_code', state=res['state']) resp = self.rph.finalize_auth(client, _session['iss'], auth_response.to_dict()) - resp = self.rph.get_access_token(res['state'], client) + resp = self.rph.get_tokens(res['state'], client) assert set(resp.keys()) == {'access_token', 'expires_in', 'id_token', 'token_type', '__verified_id_token', '__expires_at'} @@ -466,7 +506,7 @@ def test_access_and_id_token(self): with responses.RequestsMock() as rsps: rsps.add("POST", _url, body=at.to_json(), adding_headers={"Content-Type": "application/json"}, status=200) - client.client_get("service",'accesstoken').endpoint = _url + client.client_get("service", 'accesstoken').endpoint = _url _response = AuthorizationResponse(code='access_code', state=res['state']) @@ -507,7 +547,7 @@ def test_access_and_id_token_by_reference(self): with responses.RequestsMock() as rsps: rsps.add("POST", _url, body=at.to_json(), adding_headers={"Content-Type": "application/json"}, status=200) - client.client_get("service",'accesstoken').endpoint = _url + client.client_get("service", 'accesstoken').endpoint = _url _response = AuthorizationResponse(code='access_code', state=res['state']) @@ -548,7 +588,7 @@ def test_get_user_info(self): with responses.RequestsMock() as rsps: rsps.add("POST", _url, body=at.to_json(), adding_headers={"Content-Type": "application/json"}, status=200) - client.client_get("service",'accesstoken').endpoint = _url + client.client_get("service", 'accesstoken').endpoint = _url _response = AuthorizationResponse(code='access_code', state=res['state']) @@ -562,7 +602,7 @@ def test_get_user_info(self): with responses.RequestsMock() as rsps: rsps.add("GET", _url, body='{"sub":"EndUserSubject"}', adding_headers={"Content-Type": "application/json"}, status=200) - client.client_get("service",'userinfo').endpoint = _url + client.client_get("service", 'userinfo').endpoint = _url userinfo_resp = self.rph.get_user_info(res['state'], client, token_resp['access_token']) @@ -595,7 +635,7 @@ def test_get_provider_specific_service(): } } entity = Entity(services=srv_desc) - assert entity.client_get("service",'accesstoken').response_body_type == 'urlencoded' + assert entity.client_get("service", 'accesstoken').response_body_type == 'urlencoded' class TestRPHandlerTier2(object): @@ -634,7 +674,7 @@ def rphandler_setup(self): rsps.add("POST", _url, body=at.to_json(), adding_headers={"Content-Type": "application/json"}, status=200) - client.client_get("service",'accesstoken').endpoint = _url + client.client_get("service", 'accesstoken').endpoint = _url _response = AuthorizationResponse(code='access_code', state=res['state']) @@ -649,7 +689,7 @@ def rphandler_setup(self): rsps.add("GET", _url, body='{"sub":"EndUserSubject"}', adding_headers={"Content-Type": "application/json"}, status=200) - client.client_get("service",'userinfo').endpoint = _url + client.client_get("service", 'userinfo').endpoint = _url self.rph.get_user_info(res['state'], client, token_resp['access_token']) self.state = res['state'] @@ -677,7 +717,7 @@ def test_refresh_access_token(self): rsps.add("POST", _url, body=at.to_json(), adding_headers={"Content-Type": "application/json"}, status=200) - client.client_get("service",'refresh_token').endpoint = _url + client.client_get("service", 'refresh_token').endpoint = _url res = self.rph.refresh_access_token(self.state, client, 'openid email') assert res['access_token'] == '2nd_accessTok' @@ -689,7 +729,7 @@ def test_get_user_info(self): with responses.RequestsMock() as rsps: rsps.add("GET", _url, body='{"sub":"EndUserSubject", "mail":"foo@example.com"}', adding_headers={"Content-Type": "application/json"}, status=200) - client.client_get("service",'userinfo').endpoint = _url + client.client_get("service", 'userinfo').endpoint = _url resp = self.rph.get_user_info(self.state, client) assert set(resp.keys()) == {'sub', 'mail'} @@ -722,8 +762,7 @@ def __init__(self, issuer, keyjar=None): self.post_response = {} self.register_post_response('default', 'OK', 200) - def register_get_response(self, path, data, status_code=200, - headers=None): + def register_get_response(self, path, data, status_code=200, headers=None): _headers = headers or {} self.get_response[path] = MockResponse(status_code, data, _headers) @@ -787,13 +826,32 @@ def registration_callback(data): return json.dumps(_req) +def test_rphandler_request_uri(): + rph = RPHandler(BASE_URL, CLIENT_CONFIG, keyjar=CLI_KEY) + res = rph.begin(issuer_id='github2', behaviour_args={"request_param": "request_uri"}) + _session = rph.get_session_information(res['state']) + _url = res["url"] + _qp = parse_qs(urlparse(_url).query) + assert 'request_uri' in _qp + + +def test_rphandler_request(): + rph = RPHandler(BASE_URL, CLIENT_CONFIG, keyjar=CLI_KEY) + res = rph.begin(issuer_id='github2', + behaviour_args={"request_param": "request"}) + _session = rph.get_session_information(res['state']) + _url = res["url"] + _qp = parse_qs(urlparse(_url).query) + assert 'request' in _qp + + class TestRPHandlerWithMockOP(object): @pytest.fixture(autouse=True) def rphandler_setup(self): self.issuer = 'https://github.com/login/oauth/authorize' self.mock_op = MockOP(issuer=self.issuer) self.rph = RPHandler(BASE_URL, client_configs=CLIENT_CONFIG, - http_lib=self.mock_op, keyjar=KeyJar()) + http_lib=self.mock_op, keyjar=CLI_KEY) def test_finalize(self): auth_query = self.rph.begin(issuer_id='github') @@ -845,7 +903,7 @@ def test_finalize(self): # assume code flow resp = self.rph.finalize(_session['iss'], auth_response.to_dict()) - assert set(resp.keys()) == {'userinfo', 'state', 'token', 'id_token'} + assert set(resp.keys()) == {'userinfo', 'state', 'token', 'id_token', 'session_state'} def test_dynamic_setup(self): user_id = 'acct:foobar@example.com' @@ -863,32 +921,22 @@ def test_dynamic_setup(self): "issuer": "https://server.example.com", "subject_types_supported": ['public'], "token_endpoint": "https://server.example.com/connect/token", - "token_endpoint_auth_methods_supported": ["client_secret_basic", - "private_key_jwt"], + "token_endpoint_auth_methods_supported": ["client_secret_basic", "private_key_jwt"], "userinfo_endpoint": "https://server.example.com/connect/user", "check_id_endpoint": "https://server.example.com/connect/check_id", - "refresh_session_endpoint": - "https://server.example.com/connect/refresh_session", - "end_session_endpoint": - "https://server.example.com/connect/end_session", + "refresh_session_endpoint": "https://server.example.com/connect/refresh_session", + "end_session_endpoint": "https://server.example.com/connect/end_session", "jwks_uri": "https://server.example.com/jwk.json", - "registration_endpoint": - "https://server.example.com/connect/register", - "scopes_supported": ["openid", "profile", "email", "address", - "phone"], - "response_types_supported": ["code", "code id_token", - "token id_token"], - "acrs_supported": ["1", "2", - "http://id.incommon.org/assurance/bronze"], + "registration_endpoint": "https://server.example.com/connect/register", + "scopes_supported": ["openid", "profile", "email", "address", "phone"], + "response_types_supported": ["code", "code id_token", "token id_token"], + "acrs_supported": ["1", "2", "http://id.incommon.org/assurance/bronze"], "user_id_types_supported": ["public", "pairwise"], - "userinfo_algs_supported": ["HS256", "RS256", "A128CBC", "A128KW", - "RSA1_5"], + "userinfo_algs_supported": ["HS256", "RS256", "A128CBC", "A128KW", "RSA1_5"], "id_token_signing_alg_values_supported": ["HS256", "RS256", "A128CBC", "A128KW", "RSA1_5"], - "request_object_algs_supported": ["HS256", "RS256", "A128CBC", - "A128KW", - "RSA1_5"] + "request_object_algs_supported": ["HS256", "RS256", "A128CBC", "A128KW", "RSA1_5"] } pcr = ProviderConfigurationResponse(**resp) @@ -902,3 +950,58 @@ def test_dynamic_setup(self): auth_query = self.rph.begin(user_id=user_id) assert auth_query + + def test_dynamic_setup_redirect_uri(self): + user_id = 'acct:foobar@example.com' + _link = Link(rel="http://openid.net/specs/connect/1.0/issuer", + href="https://server.example.com") + webfinger_response = JRD(subject=user_id, links=[_link]) + self.mock_op.register_get_response( + '/.well-known/webfinger', webfinger_response.to_json(), 200, + {'content-type': "application/json"}) + + resp = { + "authorization_endpoint": + "https://server.example.com/connect/authorize", + "issuer": "https://server.example.com", + "subject_types_supported": ['public'], + "token_endpoint": "https://server.example.com/connect/token", + "token_endpoint_auth_methods_supported": ["client_secret_basic", "private_key_jwt"], + "userinfo_endpoint": "https://server.example.com/connect/user", + "check_id_endpoint": "https://server.example.com/connect/check_id", + "refresh_session_endpoint": "https://server.example.com/connect/refresh_session", + "end_session_endpoint": "https://server.example.com/connect/end_session", + "jwks_uri": "https://server.example.com/jwk.json", + "registration_endpoint": "https://server.example.com/connect/register", + "scopes_supported": ["openid", "profile", "email", "address", "phone"], + "response_types_supported": ["code", "code id_token", "token id_token"], + "acrs_supported": ["1", "2", "http://id.incommon.org/assurance/bronze"], + "user_id_types_supported": ["public", "pairwise"], + "userinfo_algs_supported": ["HS256", "RS256", "A128CBC", "A128KW", "RSA1_5"], + "id_token_signing_alg_values_supported": ["HS256", "RS256", + "A128CBC", "A128KW", + "RSA1_5"], + "request_object_algs_supported": ["HS256", "RS256", "A128CBC", "A128KW", "RSA1_5"], + "request_parameter_supported": True, + "request_uri_parameter_supported": True, + "require_request_uri_registration": True + } + + pcr = ProviderConfigurationResponse(**resp) + self.mock_op.register_get_response( + '/.well-known/openid-configuration', pcr.to_json(), 200, + {'content-type': "application/json"}) + + self.mock_op.register_post_response( + '/connect/register', registration_callback, 200, + {'content-type': "application/json"}) + + res = self.rph.begin(user_id=user_id, + behaviour_args={ + "request_param": "request", + "request_object_signing_alg": "RS256"}) + assert res + + _url = res["url"] + _qp = parse_qs(urlparse(_url).query) + assert 'request' in _qp diff --git a/tests/test_21_rph_defaults.py b/tests/test_21_rph_defaults.py index ac07e72..ced2cad 100644 --- a/tests/test_21_rph_defaults.py +++ b/tests/test_21_rph_defaults.py @@ -72,7 +72,6 @@ def test_begin(self): _context = client.client_get("service_context") # Calculating request so I can build a reasonable response - self.rph.add_callbacks(_context) _req = client.client_get("service",'registration').construct_request() with responses.RequestsMock() as rsps: @@ -129,7 +128,6 @@ def test_begin_2(self): _context = client.client_get("service_context") # Calculating request so I can build a reasonable response - self.rph.add_callbacks(_context) # Publishing a JWKS instead of a JWKS_URI _context.jwks_uri = '' _context.jwks = _context.keyjar.export_jwks() diff --git a/tests/test_40_rp_handler_persistent.py b/tests/test_40_rp_handler_persistent.py index 2dd6faf..737365b 100644 --- a/tests/test_40_rp_handler_persistent.py +++ b/tests/test_40_rp_handler_persistent.py @@ -247,7 +247,7 @@ def test_do_client_registration(self): # only 2 things should have happened assert rph_1.hash2issuer['github'] == issuer - assert not client.client_get("service_context").post_logout_redirect_uris + assert not client.client_get("service_context").callback.get('post_logout_redirect_uris') def test_do_client_setup(self): rph_1 = RPHandler(BASE_URL, client_configs=CLIENT_CONFIG, @@ -275,23 +275,6 @@ def test_do_client_setup(self): assert rph_1.hash2issuer['github'] == _context.get('issuer') - def test_create_callbacks(self): - rph_1 = RPHandler(BASE_URL, client_configs=CLIENT_CONFIG, - keyjar=CLI_KEY, module_dirs=['oidc']) - - cb = rph_1.create_callbacks('https://op.example.com/') - - assert set(cb.keys()) == {'code', 'implicit', 'form_post', '__hex'} - _hash = cb['__hex'] - - assert cb['code'] == 'https://example.com/rp/authz_cb/{}'.format(_hash) - assert cb['implicit'] == 'https://example.com/rp/authz_im_cb/{}'.format(_hash) - assert cb['form_post'] == 'https://example.com/rp/authz_fp_cb/{}'.format(_hash) - - assert list(rph_1.hash2issuer.keys()) == [_hash] - - assert rph_1.hash2issuer[_hash] == 'https://op.example.com/' - def test_begin(self): rph_1 = RPHandler(BASE_URL, client_configs=CLIENT_CONFIG, keyjar=CLI_KEY, module_dirs=['oidc']) @@ -376,7 +359,7 @@ def test_get_client_authn_method(self): 'token_endpoint') assert authn_method == 'client_secret_post' - def test_get_access_token(self): + def test_get_tokens(self): rph_1 = RPHandler(BASE_URL, client_configs=CLIENT_CONFIG, keyjar=CLI_KEY, module_dirs=['oidc']) @@ -418,7 +401,7 @@ def test_get_access_token(self): resp = rph_1.finalize_auth(client, _session['iss'], auth_response.to_dict()) - resp = rph_1.get_access_token(res['state'], client) + resp = rph_1.get_tokens(res['state'], client) assert set(resp.keys()) == {'access_token', 'expires_in', 'id_token', 'token_type', '__verified_id_token', '__expires_at'}