diff --git a/example/flask_rp/conf.json b/example/flask_rp/conf.json
index b7cd241..6d7203b 100644
--- a/example/flask_rp/conf.json
+++ b/example/flask_rp/conf.json
@@ -162,9 +162,7 @@
"redirect_uris": [
"https://{domain}:{port}/authz_cb/local"
],
- "post_logout_redirect_uris": [
- "https://{domain}:{port}/session_logout/local"
- ],
+ "post_logout_redirect_uri": "https://{domain}:{port}/session_logout/local",
"frontchannel_logout_uri": "https://{domain}:{port}/fc_logout/local",
"frontchannel_logout_session_required": true,
"backchannel_logout_uri": "https://{domain}:{port}/bc_logout/local",
@@ -231,9 +229,7 @@
"redirect_uris": [
"https://{domain}:{port}/authz_cb/django"
],
- "post_logout_redirect_uris": [
- "https://{domain}:{port}/session_logout/django"
- ],
+ "post_logout_redirect_uris": "https://{domain}:{port}/session_logout/django",
"frontchannel_logout_uri": "https://{domain}:{port}/fc_logout/django",
"frontchannel_logout_session_required": true,
"backchannel_logout_uri": "https://{domain}:{port}/bc_logout/django",
diff --git a/example/flask_rp/templates/opresult.html b/example/flask_rp/templates/opresult.html
index 34bd7be..86c6789 100644
--- a/example/flask_rp/templates/opresult.html
+++ b/example/flask_rp/templates/opresult.html
@@ -18,6 +18,10 @@
Endpoints
{{ url }}
{% endfor %}
+
+ - ID Token
+ - {{ id_token }}
+
User information
{% for key, value in userinfo.items() %}
diff --git a/example/flask_rp/views.py b/example/flask_rp/views.py
index 5af68ae..caa7141 100644
--- a/example/flask_rp/views.py
+++ b/example/flask_rp/views.py
@@ -53,10 +53,14 @@ def rp():
uid = ''
if iss or uid:
+ args = {
+ 'req_args': {
+ "claims": {"id_token": {"acr": {"value": "https://refeds.org/profile/mfa"}}}
+ }
+ }
+
if uid:
- args = {'user_id': uid}
- else:
- args = {}
+ args['user_id'] = uid
session['op_identifier'] = iss
try:
@@ -145,6 +149,7 @@ def finalize(op_identifier, request_args):
return render_template('opresult.html', endpoints=endpoints,
userinfo=res['userinfo'],
access_token=res['token'],
+ id_token=res["id_token"],
**kwargs)
else:
return make_response(res['error'], 400)
@@ -152,7 +157,7 @@ def finalize(op_identifier, request_args):
def get_op_identifier_by_cb_uri(url: str):
uri = splitquery(url)[0]
- for k,v in current_app.rph.issuer2rp.items():
+ for k, v in current_app.rph.issuer2rp.items():
_cntx = v.get_service_context()
for endpoint in ("redirect_uris",
"post_logout_redirect_uris",
@@ -180,10 +185,10 @@ def repost_fragment():
return finalize(op_identifier, args)
-@oidc_rp_views.route('/ihf_cb')
-def ihf_cb(self, op_identifier='', **kwargs):
+@oidc_rp_views.route('/authz_im_cb')
+def authz_im_cb(op_identifier='', **kwargs):
logger.debug('implicit_hybrid_flow kwargs: {}'.format(kwargs))
- return render_template('repost_fragment.html')
+ return render_template('repost_fragment.html', op_identifier=op_identifier)
@oidc_rp_views.route('/session_iframe')
diff --git a/setup.py b/setup.py
index ec274de..27f19ad 100755
--- a/setup.py
+++ b/setup.py
@@ -74,7 +74,7 @@ def run_tests(self):
"Programming Language :: Python :: 3.9",
"Topic :: Software Development :: Libraries :: Python Modules"],
install_requires=[
- 'oidcmsg==1.3.3-1',
+ 'oidcmsg==1.4.1',
'pyyaml>=5.1.2',
'responses'
],
diff --git a/src/oidcrp/__init__.py b/src/oidcrp/__init__.py
index 23161e0..823f282 100644
--- a/src/oidcrp/__init__.py
+++ b/src/oidcrp/__init__.py
@@ -1,7 +1,7 @@
import logging
__author__ = 'Roland Hedberg'
-__version__ = '2.0.1'
+__version__ = '2.1.0'
logger = logging.getLogger(__name__)
diff --git a/src/oidcrp/client_auth.py b/src/oidcrp/client_auth.py
index f585ec7..571f428 100755
--- a/src/oidcrp/client_auth.py
+++ b/src/oidcrp/client_auth.py
@@ -12,8 +12,8 @@
from oidcmsg.oauth2 import SINGLE_OPTIONAL_STRING
from oidcmsg.oidc import AuthnToken
from oidcmsg.time_util import utc_time_sans_frac
+from oidcmsg.util import rndstr
-from oidcrp.util import rndstr
from oidcrp.util import sanitize
from .defaults import DEF_SIGN_ALG
from .defaults import JWT_BEARER
diff --git a/src/oidcrp/configure.py b/src/oidcrp/configure.py
index d38cc76..84ac367 100755
--- a/src/oidcrp/configure.py
+++ b/src/oidcrp/configure.py
@@ -14,7 +14,7 @@
try:
from secrets import token_urlsafe as rnd_token
except ImportError:
- from oidcendpoint import rndstr as rnd_token
+ from cryptojwt import rndstr as rnd_token
DEFAULT_FILE_ATTRIBUTE_NAMES = ['server_key', 'server_cert', 'filename', 'template_dir',
'private_path', 'public_path', 'db_file']
diff --git a/src/oidcrp/oauth2/__init__.py b/src/oidcrp/oauth2/__init__.py
index b195c8d..6b7612e 100755
--- a/src/oidcrp/oauth2/__init__.py
+++ b/src/oidcrp/oauth2/__init__.py
@@ -1,14 +1,19 @@
from json import JSONDecodeError
import logging
+from typing import Optional
from oidcmsg.exception import FormatError
+from oidcmsg.message import Message
+from oidcmsg.oauth2 import is_error_message
from oidcrp.entity import Entity
+from oidcrp.exception import ConfigurationError
from oidcrp.exception import OidcServiceError
from oidcrp.exception import ParseError
from oidcrp.http import HTTPLib
from oidcrp.service import REQUEST_INFO
from oidcrp.service import SUCCESSFUL
+from oidcrp.service import Service
from oidcrp.util import do_add_ons
from oidcrp.util import get_deserialization_method
@@ -72,7 +77,11 @@ def __init__(self, client_authn_factory=None, keyjar=None, verify_ssl=True, conf
# just ignore verify_ssl until it goes away
self.verify_ssl = self.httpc_params.get("verify", True)
- def do_request(self, request_type, response_body_type="", request_args=None, **kwargs):
+ def do_request(self,
+ request_type: str,
+ response_body_type: Optional[str] = "",
+ request_args: Optional[dict] = None,
+ behaviour_args: Optional[dict] = None, **kwargs):
_srv = self._service[request_type]
_info = _srv.get_request_parameters(request_args=request_args, **kwargs)
@@ -93,8 +102,13 @@ def set_client_id(self, client_id):
self.client_id = client_id
self._service_context.set('client_id', client_id)
- def get_response(self, service, url, method="GET", body=None, response_body_type="",
- headers=None, **kwargs):
+ def get_response(self,
+ service: Service,
+ url: str,
+ method: Optional[str] = "GET",
+ body: Optional[dict] = None,
+ response_body_type: Optional[str] = "",
+ headers: Optional[dict] = None, **kwargs):
"""
:param url:
@@ -129,8 +143,13 @@ def get_response(self, service, url, method="GET", body=None, response_body_type
return self.parse_request_response(service, resp,
response_body_type, **kwargs)
- def service_request(self, service, url, method="GET", body=None,
- response_body_type="", headers=None, **kwargs):
+ def service_request(self,
+ service: Service,
+ url: str,
+ method: Optional[str] = "GET",
+ body: Optional[dict] = None,
+ response_body_type: Optional[str] = "",
+ headers: Optional[dict] = None, **kwargs) -> Message:
"""
The method that sends the request and handles the response returned.
This assumes that the response arrives in the HTTP response.
@@ -249,3 +268,27 @@ def parse_request_response(self, service, reqresp, response_body_type='',
reqresp.text))
raise OidcServiceError("HTTP ERROR: %s [%s] on %s" % (
reqresp.text, reqresp.status_code, reqresp.url))
+
+
+def dynamic_provider_info_discovery(client: Client, behaviour_args: Optional[dict]=None):
+ """
+ This is about performing dynamic Provider Info discovery
+
+ :param behaviour_args:
+ :param client: A :py:class:`oidcrp.oidc.Client` instance
+ """
+ try:
+ client.get_service('provider_info')
+ except KeyError:
+ raise ConfigurationError(
+ 'Can not do dynamic provider info discovery')
+ else:
+ _context = client.client_get("service_context")
+ try:
+ _context.set('issuer', _context.config['srv_discovery_url'])
+ except KeyError:
+ pass
+
+ response = client.do_request('provider_info', behaviour_args=behaviour_args)
+ if is_error_message(response):
+ raise OidcServiceError(response['error'])
diff --git a/src/oidcrp/oauth2/authorization.py b/src/oidcrp/oauth2/authorization.py
index 366b223..d82c8e1 100644
--- a/src/oidcrp/oauth2/authorization.py
+++ b/src/oidcrp/oauth2/authorization.py
@@ -7,7 +7,7 @@
from oidcmsg.time_util import time_sans_frac
from oidcrp.oauth2.utils import get_state_parameter
-from oidcrp.oauth2.utils import pick_redirect_uris
+from oidcrp.oauth2.utils import pre_construct_pick_redirect_uri
from oidcrp.oauth2.utils import set_state_parameter
from oidcrp.service import Service
@@ -32,7 +32,7 @@ class Authorization(Service):
def __init__(self, client_get, client_authn_factory=None, conf=None):
Service.__init__(self, client_get,
client_authn_factory=client_authn_factory, conf=conf)
- self.pre_construct.extend([pick_redirect_uris, set_state_parameter])
+ self.pre_construct.extend([pre_construct_pick_redirect_uri, set_state_parameter])
self.post_construct.append(self.store_auth_request)
def update_service_context(self, resp, key='', **kwargs):
diff --git a/src/oidcrp/oauth2/utils.py b/src/oidcrp/oauth2/utils.py
index 2393766..d1c4038 100644
--- a/src/oidcrp/oauth2/utils.py
+++ b/src/oidcrp/oauth2/utils.py
@@ -1,4 +1,14 @@
+import logging
+from typing import Optional
+from typing import Union
+
from oidcmsg.exception import MissingParameter
+from oidcmsg.exception import MissingRequiredAttribute
+from oidcmsg.message import Message
+
+from oidcrp.service import Service
+
+logger = logging.getLogger(__name__)
def get_state_parameter(request_args, kwargs):
@@ -14,35 +24,48 @@ def get_state_parameter(request_args, kwargs):
return _state
-def pick_redirect_uris(request_args=None, service=None, **kwargs):
- """Pick one redirect_uri base on response_mode out of a list of such."""
- _context = service.client_get("service_context")
+def pick_redirect_uri(context,
+ request_args: Optional[Union[Message, dict]] = None,
+ response_type: Optional[str] = ''):
+ if request_args is None:
+ request_args = {}
if 'redirect_uri' in request_args:
- return request_args, {}
+ return request_args["redirect_uri"]
- _callback = _context.callback
- if _callback:
- try:
- _response_type = request_args['response_type']
- except KeyError:
- _response_type = _context.behaviour['response_types'][0]
- request_args['response_type'] = _response_type
+ if context.redirect_uris:
+ redirect_uri = context.redirect_uris[0]
+ elif context.callback:
+ if not response_type:
+ _conf_resp_types = context.behaviour.get('response_types', [])
+ response_type = request_args.get('response_type')
+ if not response_type and _conf_resp_types:
+ response_type = _conf_resp_types[0]
- try:
- _response_mode = request_args['response_mode']
- except KeyError:
- _response_mode = ''
+ _response_mode = request_args.get('response_mode')
- if _response_mode == 'form_post':
- request_args['redirect_uri'] = _callback['form_post']
- elif _response_type == 'code':
- request_args['redirect_uri'] = _callback['code']
+ if _response_mode == 'form_post' or response_type == ["form_post"]:
+ redirect_uri = context.callback['form_post']
+ elif response_type == 'code' or response_type == ["code"]:
+ redirect_uri = context.callback['code']
else:
- request_args['redirect_uri'] = _callback['implicit']
+ redirect_uri = context.callback['implicit']
+
+ logger.debug(
+ f"pick_redirect_uris: response_type={response_type}, response_mode={_response_mode}, "
+ f"redirect_uri={redirect_uri}")
else:
- request_args['redirect_uri'] = _context.redirect_uris[0]
+ logger.error("No redirect_uri")
+ raise MissingRequiredAttribute('redirect_uri')
+ return redirect_uri
+
+
+def pre_construct_pick_redirect_uri(request_args: Optional[Union[Message, dict]] = None,
+ service: Optional[Service] = None, **kwargs):
+ _context = service.client_get("service_context")
+ request_args["redirect_uri"] = pick_redirect_uri(_context,
+ request_args=request_args)
return request_args, {}
diff --git a/src/oidcrp/oidc/__init__.py b/src/oidcrp/oidc/__init__.py
index cb000e4..df9db5e 100755
--- a/src/oidcrp/oidc/__init__.py
+++ b/src/oidcrp/oidc/__init__.py
@@ -2,6 +2,7 @@
import logging
from oidcrp.client_auth import BearerHeader
+from oidcrp.oidc.registration import CALLBACK_URIS
try:
from json import JSONDecodeError
@@ -112,6 +113,15 @@ def __init__(self, client_authn_factory=None,
keyjar=keyjar, verify_ssl=verify_ssl, config=config,
httplib=httplib, services=_srvs, httpc_params=httpc_params)
+ _context = self.get_service_context()
+ if _context.callback is None:
+ _context.callback = {}
+
+ for _cb in CALLBACK_URIS:
+ _uri = config.get(_cb)
+ if _uri:
+ _context.callback[_cb] = _uri
+
def fetch_distributed_claims(self, userinfo, callback=None):
"""
diff --git a/src/oidcrp/oidc/access_token.py b/src/oidcrp/oidc/access_token.py
index 03a87f1..7828b69 100644
--- a/src/oidcrp/oidc/access_token.py
+++ b/src/oidcrp/oidc/access_token.py
@@ -1,7 +1,9 @@
import logging
from typing import Optional
+from typing import Union
from oidcmsg import oidc
+from oidcmsg.message import Message
from oidcmsg.oidc import verified_claim_name
from oidcmsg.time_util import time_sans_frac
@@ -26,7 +28,9 @@ def __init__(self,
access_token.AccessToken.__init__(self, client_get,
client_authn_factory=client_authn_factory, conf=conf)
- def gather_verify_arguments(self):
+ def gather_verify_arguments(self,
+ response: Optional[Union[dict, Message]] = None,
+ behaviour_args: Optional[dict] = None):
"""
Need to add some information before running verify()
diff --git a/src/oidcrp/oidc/authorization.py b/src/oidcrp/oidc/authorization.py
index b95601b..5845640 100644
--- a/src/oidcrp/oidc/authorization.py
+++ b/src/oidcrp/oidc/authorization.py
@@ -1,14 +1,18 @@
import logging
+from typing import Optional
+from typing import Union
+from oidcmsg import oauth2
from oidcmsg import oidc
+from oidcmsg.exception import MissingRequiredAttribute
+from oidcmsg.message import Message
from oidcmsg.oidc import make_openid_request
from oidcmsg.oidc import verified_claim_name
from oidcmsg.time_util import time_sans_frac
from oidcmsg.time_util import utc_time_sans_frac
-from oidcrp.exception import ParameterError
from oidcrp.oauth2 import authorization
-from oidcrp.oauth2.utils import pick_redirect_uris
+from oidcrp.oauth2.utils import pre_construct_pick_redirect_uri
from oidcrp.oidc import IDT2REG
from oidcrp.oidc.utils import construct_request_uri
from oidcrp.oidc.utils import request_object_encryption
@@ -28,7 +32,7 @@ def __init__(self, client_get, client_authn_factory=None, conf=None):
authorization.Authorization.__init__(self, client_get,
client_authn_factory, conf=conf)
self.default_request_args = {'scope': ['openid']}
- self.pre_construct = [self.set_state, pick_redirect_uris,
+ self.pre_construct = [self.set_state, pre_construct_pick_redirect_uri,
self.oidc_pre_construct]
self.post_construct = [self.oidc_post_construct]
@@ -47,25 +51,33 @@ def set_state(self, request_args, **kwargs):
def update_service_context(self, resp, key='', **kwargs):
_context = self.client_get("service_context")
- try:
- _idt = resp[verified_claim_name('id_token')]
- except KeyError:
- pass
- else:
- # If there is a verified ID Token then we have to do nonce
- # verification
- try:
- if _context.state.get_state_by_nonce(_idt['nonce']) != key:
- raise ParameterError('Someone has messed with "nonce"')
- except KeyError:
- raise ValueError('Missing nonce value')
-
- _context.state.store_sub2state(_idt['sub'], key)
if 'expires_in' in resp:
resp['__expires_at'] = time_sans_frac() + int(resp['expires_in'])
_context.state.store_item(resp.to_json(), 'auth_response', key)
+ def get_request_from_response(self, response):
+ _context = self.client_get("service_context")
+ return _context.state.get_item(oauth2.AuthorizationRequest, 'auth_request',
+ response["state"])
+
+ def post_parse_response(self, response, **kwargs):
+ response = authorization.Authorization.post_parse_response(self, response, **kwargs)
+
+ _idt = response.get(verified_claim_name('id_token'))
+ if _idt:
+ # If there is a verified ID Token then we have to do nonce
+ # verification.
+ _request = self.get_request_from_response(response)
+ _req_nonce = _request.get('nonce')
+ if _req_nonce:
+ _id_token_nonce = _idt.get('nonce')
+ if not _id_token_nonce:
+ raise MissingRequiredAttribute('nonce')
+ elif _req_nonce != _id_token_nonce:
+ raise ValueError('Invalid nonce')
+ return response
+
def oidc_pre_construct(self, request_args=None, post_args=None, **kwargs):
_context = self.client_get("service_context")
if request_args is None:
@@ -148,9 +160,9 @@ def store_request_on_file(self, req, **kwargs):
fid.close()
return _webname
- def construct_request_parameter(self, req, request_method, audience=None, expires_in=0,
+ def construct_request_parameter(self, req, request_param, audience=None, expires_in=0,
**kwargs):
- """Construct a request parameter"""
+ """ Construct a request parameter """
alg = self.get_request_object_signing_alg(**kwargs)
kwargs["request_object_signing_alg"] = alg
@@ -158,6 +170,9 @@ def construct_request_parameter(self, req, request_method, audience=None, expire
if "keys" not in kwargs and alg and alg != "none":
kwargs["keys"] = _context.keyjar
+ if alg == "none":
+ kwargs["keys"] = []
+
_srv_cntx = _context
# This is the issuer of the JWT, that is me !
@@ -176,12 +191,15 @@ def construct_request_parameter(self, req, request_method, audience=None, expire
if expires_in:
req['exp'] = utc_time_sans_frac() + int(expires_in)
- _req = make_openid_request(req, **kwargs)
+ _mor_args = {k: kwargs[k] for k in ["keys", "issuer", "request_object_signing_alg", "recv",
+ "with_jti", "lifetime"] if k in kwargs}
+
+ _req = make_openid_request(req, **_mor_args)
# Should the request be encrypted
_req = request_object_encryption(_req, _context, **kwargs)
- if request_method == "request":
+ if request_param == "request":
req["request"] = _req
else: # MUST be request_uri
req["request_uri"] = self.store_request_on_file(_req, **kwargs)
@@ -204,19 +222,28 @@ def oidc_post_construct(self, req, **kwargs):
if 'prompt' not in req:
req['prompt'] = 'consent'
- try:
- _request_method = kwargs['request_param']
- except KeyError:
- pass
- else:
- del kwargs['request_param']
+ _context.state.store_item(req, 'auth_request', req['state'])
- self.construct_request_parameter(req, _request_method, **kwargs)
+ _request_param = kwargs.get('request_param')
+ if _request_param:
+ del kwargs['request_param']
+ # local_dir, base_path
+ _config = _context.get('config')
+ kwargs["local_dir"] = _config.get('local_dir', './requests')
+ kwargs["base_path"] = _context.get('base_url') + '/' + "requests"
+ self.construct_request_parameter(req, _request_param, **kwargs)
+ # removed all arguments except request/request_uri and the required
+ _leave = ['request', 'request_uri']
+ _leave.extend(req.required_parameters())
+ _keys = [k for k in req.keys() if k not in _leave]
+ for k in _keys:
+ del req[k]
- _context.state.store_item(req, 'auth_request', req['state'])
return req
- def gather_verify_arguments(self):
+ def gather_verify_arguments(self,
+ response: Optional[Union[dict, Message]] = None,
+ behaviour_args: Optional[dict] = None):
"""
Need to add some information before running verify()
diff --git a/src/oidcrp/oidc/end_session.py b/src/oidcrp/oidc/end_session.py
index a58c91d..df4d2d4 100644
--- a/src/oidcrp/oidc/end_session.py
+++ b/src/oidcrp/oidc/end_session.py
@@ -54,13 +54,17 @@ def get_id_token_hint(self, request_args=None, **kwargs):
def add_post_logout_redirect_uri(self, request_args=None, **kwargs):
if 'post_logout_redirect_uri' not in request_args:
- try:
- request_args[
- 'post_logout_redirect_uri'
- ] = self.client_get("service_context").register_args[
- 'post_logout_redirect_uris'][0]
- except KeyError:
- pass
+ _context = self.client_get("service_context")
+ _uri = _context.register_args.get('post_logout_redirect_uris')
+ if _uri:
+ if isinstance(_uri, str):
+ request_args['post_logout_redirect_uri'] = _uri
+ else: # assume list
+ request_args['post_logout_redirect_uri'] = _uri[0]
+ else:
+ _uri = _context.callback.get("post_logout_redirect_uris")
+ if _uri:
+ request_args['post_logout_redirect_uri'] = _uri[0]
return request_args, {}
diff --git a/src/oidcrp/oidc/registration.py b/src/oidcrp/oidc/registration.py
index 925cc66..06a3ea9 100644
--- a/src/oidcrp/oidc/registration.py
+++ b/src/oidcrp/oidc/registration.py
@@ -1,9 +1,12 @@
+import hashlib
import logging
+from typing import List
+from typing import Optional
+from cryptojwt.utils import as_bytes
from oidcmsg import oidc
from oidcmsg.oauth2 import ResponseMessage
-from oidcrp.oidc.provider_info_discovery import add_redirect_uris
from oidcrp.service import Service
__author__ = 'Roland Hedberg'
@@ -37,19 +40,133 @@ def response_types_to_grant_types(response_types):
return list(_res)
-def add_request_uri(request_args=None, service=None, **kwargs):
- _context = service.client_get("service_context")
- if _context.requests_dir:
- _pi = _context.provider_info
- if _pi:
- _req = _pi.get('require_request_uri_registration', False)
- if _req is True:
- request_args['request_uris'] = _context.generate_request_uris(_context.requests_dir)
+def create_callbacks(issuer: str,
+ hash_seed: str,
+ base_url: str,
+ code: Optional[bool] = False,
+ implicit: Optional[bool] = False,
+ form_post: Optional[bool] = False,
+ request_uris: Optional[bool] = False,
+ backchannel_logout_uri: Optional[bool] = False,
+ frontchannel_logout_uri: Optional[bool] = False):
+ """
+ To mitigate some security issues the redirect_uris should be OP/AS
+ specific. This method creates a set of redirect_uris unique to the
+ OP/AS.
+
+ :param frontchannel_logout_uri: Whether a front-channel logout uri should be constructed
+ :param backchannel_logout_uri: Whether a back-channel logout uri should be constructed
+ :param request_uri: Whether a request_uri should be constructed
+ :param issuer: Issuer ID
+ :return: A set of redirect_uris
+ """
+ _hash = hashlib.sha256()
+ _hash.update(hash_seed)
+ _hash.update(as_bytes(issuer))
+ _hex = _hash.hexdigest()
+
+ res = {'__hex': _hex}
+
+ if code:
+ res['code'] = f"{base_url}/authz_cb/{_hex}"
+
+ if implicit:
+ res['implicit'] = f"{base_url}/authz_im_cb/{_hex}"
+
+ if form_post:
+ res['form_post'] = f"{base_url}/authz_fp_cb/{_hex}"
+
+ if request_uris:
+ res["request_uris"] = f"{base_url}/req_uri/{_hex}"
+
+ if backchannel_logout_uri or frontchannel_logout_uri:
+ res["post_logout_redirect_uris"] = [f"{base_url}/session_logout/{_hex}"]
+
+ if backchannel_logout_uri:
+ res["backchannel_logout_uri"] = f"{base_url}/bc_logout/{_hex}"
+
+ if frontchannel_logout_uri:
+ res["frontchannel_logout_uri"] = f"{base_url}/fc_logout/{_hex}"
+
+ logger.debug(f"Created callback URIs: {res}")
+ return res
+
+
+def _cmp(a, b):
+ if b is None: # Don't care about the value as long as there is one
+ return True
+ elif isinstance(a, str) and a == b:
+ return True
+ elif isinstance(a, list) and b in a:
+ return True
+
+ return a == b
- return request_args, {}
+def _in_config_or_client_preferences(config, attr, val):
+ _val = config.get("client_preferences", {}).get(attr)
+ if _cmp(_val, val):
+ return True
+ _val = config.get(attr)
+ return _cmp(_val, val)
-def add_post_logout_redirect_uris(request_args=None, service=None, **kwargs):
+
+def add_callbacks(context, ignore: Optional[List[str]] = None):
+ if ignore is None:
+ ignore = []
+ _iss = context.get('issuer')
+
+ _uris = {}
+
+ _pi = context.get('provider_info')
+ _cp = context.config.get("client_preferences")
+
+ if "redirect_uris" not in ignore:
+ # code and/or implicit
+ if _in_config_or_client_preferences(context.config, "response_types", "code"):
+ _uris['code'] = True
+ for rt in ["id_token", "id_token token", "code id_token token", "code idtoken",
+ "code token"]:
+ if _in_config_or_client_preferences(context.config, "response_types", rt):
+ _uris["implicit"] = True
+ break
+
+ if "form_post" not in ignore:
+ if _in_config_or_client_preferences(context.config, "form_post_usable", True):
+ _uris["form_post"] = True
+
+ if "request_uris" not in ignore:
+ if 'require_request_uri_registration' in _pi and _in_config_or_client_preferences(
+ context.config, "request_uri_usable", True):
+ _uris['request_uris'] = True
+
+ if "frontchannel_logout_uri" not in ignore:
+ if 'frontchannel_logout_supported' in _pi and _in_config_or_client_preferences(
+ context.config, "frontchannel_logout_usable", True):
+ _uris["frontchannel_logout_uri"] = True
+
+ if "backchannel_logout_uri" not in ignore:
+ if 'backchannel_logout_supported' in _pi and _in_config_or_client_preferences(
+ context.config, "backchannel_logout_usable", True):
+ _uris["backchannel_logout_uri"] = True
+
+ callbacks = create_callbacks(_iss,
+ hash_seed=context.get('hash_seed'),
+ base_url=context.get("base_url"),
+ **_uris)
+ context.hash2issuer[callbacks['__hex']] = _iss
+
+ if "redirect_uris" not in ignore:
+ _redirect_uris = [v for k, v in callbacks.items() if k in ["code", "implicit", "form_post"]]
+ callbacks["redirect_uris"] = _redirect_uris
+ context.set('callback', callbacks)
+
+
+CALLBACK_URIS = ["post_logout_redirect_uris", "backchannel_logout_uri", "frontchannel_logout_uri",
+ "request_uris", 'redirect_uris']
+
+
+def add_callback_uris(request_args=None, service=None, **kwargs):
"""
:param request_args:
@@ -59,10 +176,17 @@ def add_post_logout_redirect_uris(request_args=None, service=None, **kwargs):
:return:
"""
- if "post_logout_redirect_uris" not in request_args:
- _uris = service.client_get("service_context").register_args.get("post_logout_redirect_uris")
- if _uris:
- request_args["post_logout_redirect_uris"] = _uris
+ _context = service.client_get("service_context")
+ _ignore = [k for k in list(request_args.keys()) if k in CALLBACK_URIS]
+ add_callbacks(_context, ignore=_ignore)
+ for _key in CALLBACK_URIS:
+ _req_val = request_args.get(_key)
+ if not _req_val:
+ _uri = _context.register_args.get(_key)
+ if not _uri:
+ _uri = _context.callback.get(_key)
+ if _uri:
+ request_args[_key] = _uri
return request_args, {}
@@ -107,8 +231,8 @@ def __init__(self, client_get, client_authn_factory=None, conf=None):
client_authn_factory=client_authn_factory,
conf=conf)
self.pre_construct = [self.add_client_behaviour_preference,
- add_redirect_uris, add_request_uri,
- add_post_logout_redirect_uris,
+ #add_redirect_uris,
+ add_callback_uris,
add_jwks_uri_or_jwks]
self.post_construct = [self.oidc_post_construct]
diff --git a/src/oidcrp/oidc/userinfo.py b/src/oidcrp/oidc/userinfo.py
index 2b5852d..6c93b80 100644
--- a/src/oidcrp/oidc/userinfo.py
+++ b/src/oidcrp/oidc/userinfo.py
@@ -1,4 +1,6 @@
import logging
+from typing import Optional
+from typing import Union
from oidcmsg import oidc
from oidcmsg.exception import MissingSigningKey
@@ -105,10 +107,16 @@ def post_parse_response(self, response, **kwargs):
"url": spec["endpoint"]
}
+ # Extension point
+ for meth in self.post_parse_process:
+ response = meth(response, _state_interface, kwargs['state'])
+
_state_interface.store_item(response, 'user_info', kwargs['state'])
return response
- def gather_verify_arguments(self):
+ def gather_verify_arguments(self,
+ response: Optional[Union[dict, Message]] = None,
+ behaviour_args: Optional[dict] = None):
"""
Need to add some information before running verify()
diff --git a/src/oidcrp/oidc/utils.py b/src/oidcrp/oidc/utils.py
index 40b9b90..9605733 100644
--- a/src/oidcrp/oidc/utils.py
+++ b/src/oidcrp/oidc/utils.py
@@ -83,5 +83,8 @@ def construct_request_uri(local_dir, base_path, **kwargs):
while os.path.exists(filename):
_name = rndstr(10)
filename = os.path.join(_filedir, _name)
- _webname = "%s%s" % (_webpath, _name)
+ if _webpath.endswith("/"):
+ _webname = f"{_webpath}{_name}"
+ else:
+ _webname = f"{_webpath}/{_name}"
return filename, _webname
diff --git a/src/oidcrp/rp_handler.py b/src/oidcrp/rp_handler.py
index e359939..957ab2d 100644
--- a/src/oidcrp/rp_handler.py
+++ b/src/oidcrp/rp_handler.py
@@ -1,4 +1,4 @@
-import hashlib
+from inspect import currentframe
import logging
import sys
import traceback
@@ -10,6 +10,7 @@
from cryptojwt.utils import as_bytes
from oidcmsg import verified_claim_name
from oidcmsg.exception import MessageException
+from oidcmsg.exception import MissingRequiredAttribute
from oidcmsg.exception import NotForMe
from oidcmsg.oauth2 import ResponseMessage
from oidcmsg.oauth2 import is_error_message
@@ -18,6 +19,7 @@
from oidcmsg.oidc import AuthorizationResponse
from oidcmsg.oidc import Claims
from oidcmsg.oidc import OpenIDSchema
+from oidcmsg.oidc import RegistrationRequest
from oidcmsg.oidc.session import BackChannelLogoutRequest
from oidcmsg.time_util import time_sans_frac
@@ -27,8 +29,9 @@
from .defaults import DEFAULT_RP_KEY_DEFS
from .exception import OidcServiceError
from .oauth2 import Client
+from .oauth2 import dynamic_provider_info_discovery
+from .oauth2.utils import pick_redirect_uri
from .util import add_path
-from .util import dynamic_provider_info_discovery
from .util import load_registration_response
from .util import rndstr
@@ -85,7 +88,7 @@ def __init__(self, base_url, client_configs=None, services=None, keyjar=None,
else:
self.client_configs = client_configs
- # keep track on which RP instance that serves with OP
+ # keep track on which RP instance that serves which OP
self.issuer2rp = {}
self.hash2issuer = {}
self.httplib = http_lib
@@ -149,6 +152,9 @@ def init_client(self, issuer):
:param issuer: An issuer ID
:return: A Client instance
"""
+
+ logger.debug(20 * "*" + " init_client " + 20 * "*")
+
try:
_cnf = self.pick_config(issuer)
except KeyError:
@@ -180,16 +186,21 @@ def init_client(self, issuer):
_context.jwks_uri = self.jwks_uri
return client
- def do_provider_info(self, client=None, state=''):
+ def do_provider_info(self,
+ client: Optional[Client]=None,
+ state: Optional[str]='',
+ behaviour_args: Optional[dict]=None) -> str:
"""
Either get the provider info from configuration or through dynamic
discovery.
+ :param behaviour_args:
:param client: A Client instance
:param state: A key by which the state of the session can be
retrieved
:return: issuer ID
"""
+ logger.debug(20 * "*" + " do_provider_info " + 20 * "*")
if not client:
if state:
@@ -199,7 +210,7 @@ def do_provider_info(self, client=None, state=''):
_context = client.client_get("service_context")
if not _context.get('provider_info'):
- dynamic_provider_info_discovery(client)
+ dynamic_provider_info_discovery(client, behaviour_args=behaviour_args)
return _context.get('provider_info')['issuer']
else:
_pi = _context.get('provider_info')
@@ -235,15 +246,23 @@ def do_provider_info(self, client=None, state=''):
except KeyError:
return _context.get('issuer')
- def do_client_registration(self, client=None, iss_id='', state=''):
+ def do_client_registration(self, client=None,
+ iss_id: Optional[str] = '',
+ state: Optional[str] = '',
+ request_args: Optional[dict] = None,
+ behaviour_args: Optional[dict] = None):
"""
Prepare for and do client registration if configured to do so
+ :param iss_id: Issuer ID
+ :param behaviour_args: To fine tune behaviour
:param client: A Client instance
:param state: A key by which the state of the session can be
retrieved
"""
+ logger.debug(20 * "*" + " do_client_registration " + 20 * "*")
+
if not client:
if state:
client = self.get_client_from_session_key(state)
@@ -252,32 +271,41 @@ def do_client_registration(self, client=None, iss_id='', state=''):
_context = client.client_get("service_context")
_iss = _context.get('issuer')
- if not _context.get('redirect_uris'):
- # Create the necessary callback URLs
- # as a side effect self.hash2issuer is set
- callbacks = self.create_callbacks(_iss)
-
- _context.set('redirect_uris', [
- v for k, v in callbacks.items() if not k.startswith('__')])
- _context.set('callbacks', callbacks)
- else:
- self.hash2issuer[iss_id] = _iss
+ self.hash2issuer[iss_id] = _iss
# This should only be interesting if the client supports Single Log Out
- if _context.post_logout_redirect_uris is None:
- _context.post_logout_redirect_uris = [self.base_url]
+ # if _context.callback.get("post_logout_redirect_uri") is None:
+ # _context.callback["post_logout_redirect_uri"] = [self.base_url]
- if not _context.client_id:
- load_registration_response(client)
+ if not _context.client_id: # means I have to do dynamic client registration
+ if request_args is None:
+ request_args = {}
- def add_callbacks(self, service_context):
- _callbacks = self.create_callbacks(service_context.get('provider_info')['issuer'])
- service_context.set('redirect_uris', [
- v for k, v in _callbacks.items() if not k.startswith('__')])
- service_context.set('callbacks', _callbacks)
- return _callbacks
+ if behaviour_args:
+ _params = RegistrationRequest().parameters()
+ request_args.update({k: v for k, v in behaviour_args.items() if k in _params})
- def client_setup(self, iss_id='', user=''):
+ load_registration_response(client, request_args=request_args)
+
+ def do_webfinger(self, user: str) -> Client:
+ """
+ Does OpenID Provider Issuer discovery using webfinger.
+
+ :param user: Identifier for the target End-User that is the subject of the discovery
+ request.
+ :return: A Client instance
+ """
+
+ logger.debug(20 * "*" + " do_webfinger " + 20 * "*")
+
+ temporary_client = self.init_client('')
+ temporary_client.do_request('webfinger', resource=user)
+ return temporary_client
+
+ def client_setup(self,
+ iss_id: Optional[str] = '',
+ user: Optional[str] = '',
+ behaviour_args: Optional[dict] = None) -> Client:
"""
First if no issuer ID is given then the identifier for the user is
used by the webfinger service to try to find the issuer ID.
@@ -286,11 +314,14 @@ def client_setup(self, iss_id='', user=''):
the necessary information for the client to be able to communicate
with the OP/AS that has the provided issuer ID.
+ :param behaviour_args: To fine tune behaviour
:param iss_id: The issuer ID
:param user: A user identifier
:return: A :py:class:`oidcrp.oidc.Client` instance
"""
+ logger.debug(20 * "*" + " client_setup " + 20 * "*")
+
logger.info('client_setup: iss_id={}, user={}'.format(iss_id, user))
if not iss_id:
@@ -298,8 +329,7 @@ def client_setup(self, iss_id='', user=''):
raise ValueError('Need issuer or user')
logger.debug("Connecting to previously unknown OP")
- temporary_client = self.init_client('')
- temporary_client.do_request('webfinger', resource=user)
+ temporary_client = self.do_webfinger(user)
else:
temporary_client = None
@@ -315,46 +345,39 @@ def client_setup(self, iss_id='', user=''):
return client
logger.debug("Get provider info")
- issuer = self.do_provider_info(client)
+ issuer = self.do_provider_info(client, behaviour_args=behaviour_args)
logger.debug("Do client registration")
- self.do_client_registration(client, iss_id)
+ self.do_client_registration(client, iss_id, behaviour_args=behaviour_args)
self.issuer2rp[issuer] = client
return client
- def create_callbacks(self, issuer):
- """
- To mitigate some security issues the redirect_uris should be OP/AS
- specific. This method creates a set of redirect_uris unique to the
- OP/AS.
-
- :param issuer: Issuer ID
- :return: A set of redirect_uris
- """
- _hash = hashlib.sha256()
- _hash.update(self.hash_seed)
- _hash.update(as_bytes(issuer))
- _hex = _hash.hexdigest()
- self.hash2issuer[_hex] = issuer
- return {
- 'code': "{}/authz_cb/{}".format(self.base_url, _hex),
- 'implicit': "{}/authz_im_cb/{}".format(self.base_url, _hex),
- 'form_post': "{}/authz_fp_cb/{}".format(self.base_url, _hex),
- '__hex': _hex
- }
+ def _get_response_type(self, context, req_args: Optional[dict] = None):
+ if req_args:
+ return req_args.get("response_type", context.get('behaviour')['response_types'][0])
+ else:
+ return context.get('behaviour')['response_types'][0]
- def init_authorization(self, client=None, state='', req_args=None):
+ def init_authorization(self,
+ client: Optional[Client] = None,
+ state: Optional[str] = '',
+ req_args: Optional[dict] = None,
+ behaviour_args: Optional[dict] = None) -> dict:
"""
Constructs the URL that will redirect the user to the authorization
endpoint of the OP/AS.
+ :param behaviour_args:
+ :param state:
:param client: A Client instance
:param req_args: Non-default Request arguments
:return: A dictionary with 2 keys: **url** The authorization redirect
URL and **state** the key to the session information in the
state data store.
"""
+
+ logger.debug(20 * "*" + " init_authorization " + 20 * "*")
if not client:
if state:
client = self.get_client_from_session_key(state)
@@ -364,10 +387,13 @@ def init_authorization(self, client=None, state='', req_args=None):
_context = client.client_get("service_context")
_nonce = rndstr(24)
+ _response_type = self._get_response_type(_context, req_args)
request_args = {
- 'redirect_uri': _context.get('redirect_uris')[0],
+ 'redirect_uri': pick_redirect_uri(_context,
+ request_args=req_args,
+ response_type=_response_type),
'scope': _context.get('behaviour')['scope'],
- 'response_type': _context.get('behaviour')['response_types'][0],
+ 'response_type': _response_type,
'nonce': _nonce
}
@@ -387,12 +413,16 @@ def init_authorization(self, client=None, state='', req_args=None):
logger.debug('Authorization request args: {}'.format(request_args))
+ # if behaviour_args and "request_param" not in behaviour_args:
+ # _pi = _context.get("provider_info")
+
_srv = client.get_service('authorization')
- _info = _srv.get_request_parameters(request_args=request_args)
+ _info = _srv.get_request_parameters(request_args=request_args,
+ behaviour_args=behaviour_args)
logger.debug('Authorization info: {}'.format(_info))
return {'url': _info['url'], 'state': _state}
- def begin(self, issuer_id='', user_id=''):
+ def begin(self, issuer_id='', user_id='', req_args=None, behaviour_args=None):
"""
This is the first of the 3 high level methods that most users of this
library should confine them self to use.
@@ -401,6 +431,8 @@ def begin(self, issuer_id='', user_id=''):
Once it has the client it will construct an Authorization
request.
+ :param behaviour_args:
+ :param req_args:
:param issuer_id: Issuer ID
:param user_id: A user identifier
:return: A dictionary containing **url** the URL that will redirect the
@@ -409,10 +441,10 @@ def begin(self, issuer_id='', user_id=''):
"""
# Get the client instance that has been assigned to this issuer
- client = self.client_setup(issuer_id, user_id)
+ client = self.client_setup(issuer_id, user_id, behaviour_args=behaviour_args)
try:
- res = self.init_authorization(client)
+ res = self.init_authorization(client, req_args=req_args, behaviour_args=behaviour_args)
except Exception:
message = traceback.format_exception(*sys.exc_info())
logger.error(message)
@@ -447,7 +479,8 @@ def get_client_authn_method(client, endpoint):
"""
if endpoint == 'token_endpoint':
try:
- am = client.client_get("service_context").get('behaviour')['token_endpoint_auth_method']
+ am = client.client_get("service_context").get('behaviour')[
+ 'token_endpoint_auth_method']
except KeyError:
return ''
else:
@@ -456,7 +489,7 @@ def get_client_authn_method(client, endpoint):
else: # a list
return am[0]
- def get_access_token(self, state, client: Optional[Client] = None):
+ def get_tokens(self, state, client: Optional[Client] = None):
"""
Use the 'accesstoken' service to get an access token from the OP/AS.
@@ -466,7 +499,7 @@ def get_access_token(self, state, client: Optional[Client] = None):
:return: A :py:class:`oidcmsg.oidc.AccessTokenResponse` or
:py:class:`oidcmsg.oauth2.AuthorizationResponse`
"""
- logger.debug('get_accesstoken')
+ logger.debug(20 * "*" + " get_tokens " + 20 * "*")
if client is None:
client = self.get_client_from_session_key(state)
@@ -512,6 +545,9 @@ def refresh_access_token(self, state, client=None, scope=''):
:param scope: What the returned token should be valid for.
:return: A :py:class:`oidcmsg.oidc.AccessTokenResponse` instance
"""
+
+ logger.debug(20 * "*" + " refresh_access_token " + 20 * "*")
+
if scope:
req_args = {'scope': scope}
else:
@@ -548,6 +584,9 @@ def get_user_info(self, state, client=None, access_token='',
:param kwargs: Extra keyword arguments
:return: A :py:class:`oidcmsg.oidc.OpenIDSchema` instance
"""
+
+ logger.debug(20 * "*" + " get_user_info " + 20 * "*")
+
if client is None:
client = self.get_client_from_session_key(state)
@@ -574,33 +613,36 @@ def userinfo_in_id_token(id_token):
:param id_token: An :py:class:`oidcmsg.oidc.IDToken` instance
:return: A dictionary with user information
"""
- res = dict([(k, id_token[k]) for k in OpenIDSchema.c_param.keys() if
- k in id_token])
+ res = dict([(k, id_token[k]) for k in OpenIDSchema.c_param.keys() if k in id_token])
res.update(id_token.extra())
return res
- def finalize_auth(self, client, issuer, response):
+ def finalize_auth(self, client, issuer: str, response: dict,
+ behaviour_args: Optional[dict] = None):
"""
Given the response returned to the redirect_uri, parse and verify it.
+ :param behaviour_args: For fine tuning behaviour
:param client: A Client instance
:param issuer: An Issuer ID
:param response: The authorization response as a dictionary
:return: An :py:class:`oidcmsg.oidc.AuthorizationResponse` or
:py:class:`oidcmsg.oauth2.AuthorizationResponse` instance.
"""
+
+ logger.debug(20 * "*" + " finalize_auth " + 20 * "*")
+
_srv = client.get_service('authorization')
try:
- authorization_response = _srv.parse_response(response,
- sformat='dict')
+ authorization_response = _srv.parse_response(response, sformat='dict',
+ behaviour_args=behaviour_args)
except Exception as err:
logger.error('Parsing authorization_response: {}'.format(err))
message = traceback.format_exception(*sys.exc_info())
logger.error(message)
raise
else:
- logger.debug(
- 'Authz response: {}'.format(authorization_response.to_dict()))
+ logger.debug('Authz response: {}'.format(authorization_response.to_dict()))
if is_error_message(authorization_response):
return authorization_response
@@ -621,13 +663,16 @@ def finalize_auth(self, client, issuer, response):
authorization_response['state'])
return authorization_response
- def get_access_and_id_token(self, authorization_response=None, state='',
- client=None):
+ def get_access_and_id_token(self, authorization_response=None,
+ state: Optional[str] = '',
+ client: Optional[object] = None,
+ behaviour_args: Optional[dict] = None):
"""
There are a number of services where access tokens and ID tokens can
occur in the response. This method goes through the possible places
based on the response_type the client uses.
+ :param behaviour_args: For fine tuning behaviour
:param authorization_response: The Authorization response
:param state: The state key (the state parameter in the
authorization request)
@@ -636,6 +681,8 @@ def get_access_and_id_token(self, authorization_response=None, state='',
was returned otherwise None.
"""
+ logger.debug(20 * "*" + " get_access_and_id_token " + 20 * "*")
+
if client is None:
client = self.get_client_from_session_key(state)
@@ -652,8 +699,7 @@ def get_access_and_id_token(self, authorization_response=None, state='',
if not state:
state = authorization_response['state']
- authreq = _context.state.get_item(
- AuthorizationRequest, 'auth_request', state)
+ authreq = _context.state.get_item(AuthorizationRequest, 'auth_request', state)
_resp_type = set(authreq['response_type'])
access_token = None
@@ -665,24 +711,31 @@ def get_access_and_id_token(self, authorization_response=None, state='',
if _resp_type in [{'token'}, {'id_token', 'token'}, {'code', 'token'},
{'code', 'id_token', 'token'}]:
access_token = authorization_response["access_token"]
- elif _resp_type in [{'code'}, {'code', 'id_token'}]:
+ if behaviour_args:
+ if behaviour_args.get("collect_tokens", False):
+ # get what you can from the token endpoint
+ token_resp = self.get_tokens(state, client=client)
+ if is_error_message(token_resp):
+ return False, "Invalid response %s." % token_resp["error"]
+ # Now which access_token should I use
+ access_token = token_resp["access_token"]
+ # May or may not get an ID Token
+ id_token = token_resp.get('__verified_id_token')
+ elif _resp_type in [{'code'}, {'code', 'id_token'}]:
# get the access token
- token_resp = self.get_access_token(state, client=client)
+ token_resp = self.get_tokens(state, client=client)
if is_error_message(token_resp):
return False, "Invalid response %s." % token_resp["error"]
access_token = token_resp["access_token"]
-
- try:
- id_token = token_resp['__verified_id_token']
- except KeyError:
- pass
+ # May or may not get an ID Token
+ id_token = token_resp.get('__verified_id_token')
return {'access_token': access_token, 'id_token': id_token}
# noinspection PyUnusedLocal
- def finalize(self, issuer, response):
+ def finalize(self, issuer, response, behaviour_args: Optional[dict] = None):
"""
The third of the high level methods that a user of this Class should
know about.
@@ -690,6 +743,7 @@ def finalize(self, issuer, response):
callback URL there might be a number of services that the client should
use. Which one those are are defined by the client configuration.
+ :param behaviour_args: For fine tuning
:param issuer: Who sent the response
:param response: The Authorization response as a dictionary
:returns: A dictionary with two claims:
@@ -701,6 +755,9 @@ def finalize(self, issuer, response):
client = self.issuer2rp[issuer]
+ if behaviour_args:
+ logger.debug(f"Finalize behaviour args: {behaviour_args}")
+
authorization_response = self.finalize_auth(client, issuer, response)
if is_error_message(authorization_response):
return {
@@ -709,8 +766,10 @@ def finalize(self, issuer, response):
}
_state = authorization_response['state']
- token = self.get_access_and_id_token(authorization_response,
- state=_state, client=client)
+ token = self.get_access_and_id_token(authorization_response, state=_state, client=client,
+ behaviour_args=behaviour_args)
+ _id_token = token.get("id_token")
+ logger.debug(f"ID Token: {_id_token}")
if client.client_get("service", "userinfo") and token['access_token']:
inforesp = self.get_user_info(
@@ -723,8 +782,8 @@ def finalize(self, issuer, response):
'state': _state
}
- elif token['id_token']: # look for it in the ID Token
- inforesp = self.userinfo_in_id_token(token['id_token'])
+ elif _id_token: # look for it in the ID Token
+ inforesp = self.userinfo_in_id_token(_id_token)
else:
inforesp = {}
@@ -732,8 +791,7 @@ def finalize(self, issuer, response):
_context = client.client_get("service_context")
try:
- _sid_support = _context.get('provider_info')[
- 'backchannel_logout_session_supported']
+ _sid_support = _context.get('provider_info')['backchannel_logout_session_supported']
except KeyError:
try:
_sid_support = _context.get('provider_info')[
@@ -741,21 +799,25 @@ def finalize(self, issuer, response):
except:
_sid_support = False
- if _sid_support:
+ if _sid_support and _id_token:
try:
- sid = token['id_token']['sid']
+ sid = _id_token['sid']
except KeyError:
pass
else:
_context.state.store_sid2state(sid, _state)
- _context.state.store_sub2state(token['id_token']['sub'], _state)
+ if _id_token:
+ _context.state.store_sub2state(_id_token['sub'], _state)
+ else:
+ _context.state.store_sub2state(inforesp['sub'], _state)
return {
'userinfo': inforesp,
'state': authorization_response['state'],
'token': token['access_token'],
- 'id_token': token['id_token']
+ 'id_token': _id_token,
+ 'session_state': authorization_response.get('session_state', '')
}
def has_active_authentication(self, state):
@@ -823,7 +885,9 @@ def get_valid_access_token(self, state):
else:
raise OidcServiceError('No valid access token')
- def logout(self, state, client=None, post_logout_redirect_uri=''):
+ def logout(self, state: str,
+ client: Optional[Client] = None,
+ post_logout_redirect_uri: Optional[str] = '') -> dict:
"""
Does a RP initiated logout from an OP. After logout the user will be
redirect by the OP to a URL of choice (post_logout_redirect_uri).
@@ -834,6 +898,9 @@ def logout(self, state, client=None, post_logout_redirect_uri=''):
should be used
:return: A US
"""
+
+ logger.debug(20 * "*" + " logout " + 20 * "*")
+
if client is None:
client = self.get_client_from_session_key(state)
@@ -852,8 +919,23 @@ def logout(self, state, client=None, post_logout_redirect_uri=''):
resp = srv.get_request_parameters(state=state,
request_args=request_args)
+ logger.debug(f"EndSession Request: {resp}")
return resp
+ def close(self, state: str,
+ issuer: Optional[str] = '',
+ post_logout_redirect_uri: Optional[str] = '') -> dict:
+
+ logger.debug(20 * "*" + " close " + 20 * "*")
+
+ if issuer:
+ client = self.issuer2rp[issuer]
+ else:
+ client = self.get_client_from_session_key(state)
+
+ return self.logout(state=state, client=client,
+ post_logout_redirect_uri=post_logout_redirect_uri)
+
def clear_session(self, state):
client = self.get_client_from_session_key(state)
client.client_get("service_context").state.remove_state(state)
@@ -867,8 +949,10 @@ def backchannel_logout(client, request='', request_args=None):
"""
if request:
req = BackChannelLogoutRequest().from_urlencoded(as_unicode(request))
- else:
+ elif request_args:
req = BackChannelLogoutRequest(**request_args)
+ else:
+ raise MissingRequiredAttribute('logout_token')
_context = client.client_get("service_context")
kwargs = {
@@ -879,10 +963,13 @@ def backchannel_logout(client, request='', request_args=None):
"id_token_signed_response_alg", "RS256")
}
+ logger.debug(f"(backchannel_logout) Verifying request using: {kwargs}")
try:
req.verify(**kwargs)
except (MessageException, ValueError, NotForMe) as err:
raise MessageException('Bogus logout request: {}'.format(err))
+ else:
+ logger.debug("Request verified OK")
# Find the subject through 'sid' or 'sub'
sub = req[verified_claim_name('logout_token')].get('sub')
diff --git a/src/oidcrp/service.py b/src/oidcrp/service.py
index 0e350e2..3eaa025 100644
--- a/src/oidcrp/service.py
+++ b/src/oidcrp/service.py
@@ -88,8 +88,9 @@ def __init__(self,
self.pre_construct = []
self.post_construct = []
self.construct_extra_headers = []
+ self.post_parse_process = []
- def gather_request_args(self, **kwargs):
+ def gather_request_args(self,**kwargs):
"""
Go through the attributes that the message class can contain and
add values if they are missing but exists in the client info or
@@ -204,8 +205,7 @@ def construct(self, request_args=None, **kwargs):
# run the pre_construct methods. Will return a possibly new
# set of request arguments but also a set of arguments to
# be used by the post_construct methods.
- request_args, post_args = self.do_pre_construct(request_args,
- **kwargs)
+ request_args, post_args = self.do_pre_construct(request_args, **kwargs)
# If 'state' appears among the keyword argument and is not
# expected to appear in the request, remove it.
@@ -222,6 +222,10 @@ def construct(self, request_args=None, **kwargs):
# message type
request = self.msg_type(**_args)
+ _behaviour_args = kwargs.get("behaviour_args")
+ if _behaviour_args:
+ post_args.update(_behaviour_args)
+
return self.do_post_construct(request, **post_args)
def init_authentication_method(self, request, authn_method,
@@ -340,7 +344,7 @@ def get_headers(self,
return _headers
def get_request_parameters(self, request_args=None, method="",
- request_body_type="", authn_method='', **kwargs):
+ request_body_type="", authn_method='', **kwargs) -> dict:
"""
Builds the request message and constructs the HTTP headers.
@@ -440,7 +444,9 @@ def post_parse_response(self, response, **kwargs):
"""
return response
- def gather_verify_arguments(self):
+ def gather_verify_arguments(self,
+ response: Optional[Union[dict, Message]] = None,
+ behaviour_args: Optional[dict] = None):
"""
Need to add some information before running verify()
@@ -497,7 +503,11 @@ def _do_response(self, info, sformat, **kwargs):
raise
return resp
- def parse_response(self, info, sformat="", state="", **kwargs):
+ def parse_response(self, info,
+ sformat: Optional[str] = "",
+ state: Optional[str] = "",
+ behaviour_args: Optional[dict] = None,
+ **kwargs):
"""
This the start of a pipeline that will:
@@ -509,8 +519,8 @@ def parse_response(self, info, sformat="", state="", **kwargs):
3 runs the do_post_parse_response method iff the response was not
an error response.
- :param info: The response, can be either in a JSON or an urlencoded
- format
+ :param behaviour_args:
+ :param info: The response, can be either in a JSON or an urlencoded format
:param sformat: Which serialization that was used
:param state: The state
:param kwargs: Extra key word arguments
@@ -550,11 +560,10 @@ def parse_response(self, info, sformat="", state="", **kwargs):
if is_error_message(resp):
LOGGER.debug('Error response: %s', resp)
else:
- vargs = self.gather_verify_arguments()
+ vargs = self.gather_verify_arguments(response=resp, behaviour_args=behaviour_args)
LOGGER.debug("Verify response with %s", vargs)
try:
- # verify the message. If something is wrong an exception is
- # thrown
+ # verify the message. If something is wrong an exception is thrown
resp.verify(**vargs)
except Exception as err:
LOGGER.error(
diff --git a/src/oidcrp/service_context.py b/src/oidcrp/service_context.py
index 0247162..7a6753c 100644
--- a/src/oidcrp/service_context.py
+++ b/src/oidcrp/service_context.py
@@ -11,6 +11,7 @@
from cryptojwt.utils import as_bytes
from oidcmsg.context import OidcContext
from oidcmsg.oidc import RegistrationRequest
+from oidcmsg.util import rndstr
from oidcrp.state_interface import StateInterface
@@ -97,10 +98,12 @@ class ServiceContext(OidcContext):
"client_secret_expires_at": 0,
'clock_skew': None,
"config": None,
+ "hash_seed": b'',
+ "hash2issuer": None,
"httpc_params": None,
'issuer': None,
"kid": None,
- "post_logout_redirect_uris": [],
+ "post_logout_redirect_uris": None,
'provider_info': None,
'redirect_uris': None,
"requests_dir": None,
@@ -126,6 +129,7 @@ def __init__(self, base_url="", keyjar=None, config=None, state=None, **kwargs):
self.client_preferences = {}
self.args = {}
self.add_on = {}
+ self.hash2issuer = {}
self.httpc_params = {}
self.issuer = ""
self.client_id = ""
@@ -145,7 +149,7 @@ def __init__(self, base_url="", keyjar=None, config=None, state=None, **kwargs):
'behaviour', 'callback', 'issuer']:
_val = config.get(param, _def_value[param])
self.set(param, _val)
- if param == 'client_secret':
+ if param == 'client_secret' and _val:
self.keyjar.add_symmetric('', _val)
if not self.issuer:
@@ -156,6 +160,12 @@ def __init__(self, base_url="", keyjar=None, config=None, state=None, **kwargs):
except KeyError:
self.clock_skew = 15
+ _seed = config.get("hash_seed")
+ if _seed:
+ self.hash_seed = as_bytes(_seed)
+ else:
+ self.hash_seed = as_bytes(rndstr(32))
+
for key, val in kwargs.items():
setattr(self, key, val)
diff --git a/src/oidcrp/state_interface.py b/src/oidcrp/state_interface.py
index a377da1..bfc7100 100644
--- a/src/oidcrp/state_interface.py
+++ b/src/oidcrp/state_interface.py
@@ -382,7 +382,7 @@ def remove_state(self, state):
:param state: Key to the state
"""
del self._db[state]
- refs = json.loads(self._db_db.get("ref{}ref".format(state)))
+ refs = json.loads(self._db.get("ref{}ref".format(state)))
if refs:
for xtyp, _val in refs.items():
del self._db[KEY_PATTERN[xtyp].format(_val)]
diff --git a/src/oidcrp/util.py b/src/oidcrp/util.py
index eb774cb..17e0d51 100755
--- a/src/oidcrp/util.py
+++ b/src/oidcrp/util.py
@@ -9,6 +9,7 @@
import ssl
import string
import sys
+from typing import Optional
from urllib.parse import parse_qs
from urllib.parse import urlsplit
from urllib.parse import urlunsplit
@@ -534,17 +535,17 @@ def add_path(url, path):
return '{}/{}'.format(url, path)
-def load_registration_response(client):
+def load_registration_response(client, request_args=None):
"""
If the client has been statically registered that information
must be provided during the configuration. If expected to be
- done dynamically. This method will do dynamic client registration.
+ done dynamically this method will do dynamic client registration.
:param client: A :py:class:`oidcrp.oidc.Client` instance
"""
if not client.client_get("service_context").get('client_id'):
try:
- response = client.do_request('registration')
+ response = client.do_request('registration', request_args=request_args)
except KeyError:
raise ConfigurationError('No registration info')
except Exception as err:
@@ -553,26 +554,3 @@ def load_registration_response(client):
else:
if 'error' in response:
raise OidcServiceError(response.to_json())
-
-
-def dynamic_provider_info_discovery(client):
- """
- This is about performing dynamic Provider Info discovery
-
- :param client: A :py:class:`oidcrp.oidc.Client` instance
- """
- try:
- client.get_service('provider_info')
- except KeyError:
- raise ConfigurationError(
- 'Can not do dynamic provider info discovery')
- else:
- _context = client.client_get("service_context")
- try:
- _context.set('issuer', _context.config['srv_discovery_url'])
- except KeyError:
- pass
-
- response = client.do_request('provider_info')
- if is_error_message(response):
- raise OidcServiceError(response['error'])
diff --git a/tests/pub_client.jwks b/tests/pub_client.jwks
index a57e904..d16a636 100644
--- a/tests/pub_client.jwks
+++ b/tests/pub_client.jwks
@@ -1 +1 @@
-{"keys": [{"kty": "RSA", "use": "sig", "kid": "SUswNi1MRFlDT0Y2YjU1Z1RfQlo2S3dEa3FTTkV3LThFcnhDTHF5elk2VQ", "n": "0UkUx2ewKyc-XJ1o0ToyGjws_JybAMZj2oYjsPyyvQ_T5dhZ2VmRRRkhsaVJ2xE_GGc7mSG0IjmGFyXp5y0w4mJBcsAEE5-8eBTvQdYIryjW74r3jt6Fi4Hlm1yFMTie3apv8mw79BUj-jT0kh3_m-FiKKUvLsq45DcLtTJ4cx7Ize37dl1sFSpQcoYMk7eiUEM8fiNboiVwvBYNAWVMkUM-LnVUPm3UjvKp0LihYEkZFWOxmuQmj2x25SFUkjus38ERrRqJQBZduxdBHFrWtWg8yOA53BkMU0FFg_r0H3ctl-5GaKw-BWlogU4qXnsq85xy0EoenRk7FPV8g_ulJw", "e": "AQAB"}, {"kty": "EC", "use": "sig", "kid": "NC1pdGRQN002bWM3bk1xX2R0SktscElqbFdtN29ITDV2WVd2b0hOYzREVQ", "crv": "P-256", "x": "kK7Qp1woSerI7rUOAwW_4sU6ZmwV3wwXKX3VU-v2fMI", "y": "iPWd_Pjq6EjxYy08KNFZ3PxhEwgWHgAQTTknlKMKJA0"}]}
\ No newline at end of file
+{"keys": [{"kty": "RSA", "use": "sig", "kid": "SUswNi1MRFlDT0Y2YjU1Z1RfQlo2S3dEa3FTTkV3LThFcnhDTHF5elk2VQ", "e": "AQAB", "n": "0UkUx2ewKyc-XJ1o0ToyGjws_JybAMZj2oYjsPyyvQ_T5dhZ2VmRRRkhsaVJ2xE_GGc7mSG0IjmGFyXp5y0w4mJBcsAEE5-8eBTvQdYIryjW74r3jt6Fi4Hlm1yFMTie3apv8mw79BUj-jT0kh3_m-FiKKUvLsq45DcLtTJ4cx7Ize37dl1sFSpQcoYMk7eiUEM8fiNboiVwvBYNAWVMkUM-LnVUPm3UjvKp0LihYEkZFWOxmuQmj2x25SFUkjus38ERrRqJQBZduxdBHFrWtWg8yOA53BkMU0FFg_r0H3ctl-5GaKw-BWlogU4qXnsq85xy0EoenRk7FPV8g_ulJw"}, {"kty": "EC", "use": "sig", "kid": "NC1pdGRQN002bWM3bk1xX2R0SktscElqbFdtN29ITDV2WVd2b0hOYzREVQ", "crv": "P-256", "x": "kK7Qp1woSerI7rUOAwW_4sU6ZmwV3wwXKX3VU-v2fMI", "y": "iPWd_Pjq6EjxYy08KNFZ3PxhEwgWHgAQTTknlKMKJA0"}]}
\ No newline at end of file
diff --git a/tests/request123456.jwt b/tests/request123456.jwt
index f72cdd1..2e8824e 100644
--- a/tests/request123456.jwt
+++ b/tests/request123456.jwt
@@ -1 +1 @@
-eyJhbGciOiJSUzI1NiIsImtpZCI6IlNVc3dOaTFNUkZsRFQwWTJZalUxWjFSZlFsbzJTM2RFYTNGVFRrVjNMVGhGY25oRFRIRjVlbGsyVlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2xpL2F1dGh6X2NiIiwgInNjb3BlIjogIm9wZW5pZCIsICJub25jZSI6ICJWNVFGVUxMWGpBS0lOS3V3Y2JlVU43OHRuWkZ2MGloYyIsICJjbGllbnRfaWQiOiAiY2xpZW50X2lkIiwgImlzcyI6ICJjbGllbnRfaWQiLCAiaWF0IjogMTYxNzg3MDI0NiwgImF1ZCI6IFsiaHR0cHM6Ly9leGFtcGxlLmNvbSJdfQ.itHjMHec6T2py2zDxaAS11tFAYdsTZ-SY3AV2j5SCTHPfk9CkXa6g6s_t0oVvcdzLVXaqQ1iU9WqOeirwj4UDxCTHulPzBOcBuC3WbCki2HT9EPI88Lov-kCuz_4juw97lIyU1BkaofaZJSjRcEW_fOY0KuP7BKIDmthTLTWxEGMoMXmXcs_QD13tz0IWjrkqjwbcKjbUxrvGHJbOnOBGwgHPPm46otDMO-hQrtvTOGz4PbdD5XZ4imDV0bJkx72ITpAfM8iODmd9sKrOWkEZEhsRG1ugXa8RgOPNLsLTLjTpzWiVeczJHOiE5H4EC-uwzXiRDiShq-q7VbTKocyYw
\ No newline at end of file
+eyJhbGciOiJSUzI1NiIsImtpZCI6IlNVc3dOaTFNUkZsRFQwWTJZalUxWjFSZlFsbzJTM2RFYTNGVFRrVjNMVGhGY25oRFRIRjVlbGsyVlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2xpL2F1dGh6X2NiIiwgInNjb3BlIjogIm9wZW5pZCIsICJub25jZSI6ICJvWkpBNTRnZTVaUndNalkwOVVLVnpwYkx5MEdNUEwwaCIsICJjbGllbnRfaWQiOiAiY2xpZW50X2lkIiwgImlzcyI6ICJjbGllbnRfaWQiLCAiaWF0IjogMTYzMzU5NTc4OSwgImF1ZCI6IFsiaHR0cHM6Ly9leGFtcGxlLmNvbSJdfQ.KVMPK6leJ5pEXnJ0jXiXu21U176IU9iwkT4FkQV_33jGYTsgdqCqXw5XHR1ciixdcH2cWf0SzTPOgIzGsI4NJiPNdR9xOusYRyYKZciXHq85nrM7fr7dEPaVntWCU6uadH0MNHWCcq2FyBdz2YYDuiFPUXoxkFbfWZoo_jVMAWLxGQtGEitniI49qo0zbeSFck4hBmEtQTUOrGQvg_CjkSZb5oNb5rt_X5T-ZSK9y3AeKru4HLSQRkWj-oD-Fgd60Sm3XqfLQXrx26lk4a8ORah01BMmMsi5jeIUbOTthhhglZhMwoI9xCZ57I4SF7870-PrinIByW8d2keA1-LipQ
\ No newline at end of file
diff --git a/tests/test_13_oidc_service.py b/tests/test_13_oidc_service.py
index c3a30dc..e3f4450 100644
--- a/tests/test_13_oidc_service.py
+++ b/tests/test_13_oidc_service.py
@@ -7,6 +7,7 @@
from cryptojwt.jwt import JWT
from cryptojwt.key_jar import build_keyjar
from cryptojwt.key_jar import init_key_jar
+from oidcmsg.exception import MissingRequiredAttribute
from oidcmsg.oidc import AccessTokenRequest
from oidcmsg.oidc import AccessTokenResponse
from oidcmsg.oidc import AuthorizationRequest
@@ -68,7 +69,7 @@ def create_request(self):
}
entity = Entity(services=DEFAULT_OIDC_SERVICES, keyjar=CLI_KEY, config=client_config)
entity.client_get("service_context").issuer = 'https://example.com'
- self.service = entity.client_get("service",'authorization')
+ self.service = entity.client_get("service", 'authorization')
def test_construct(self):
req_args = {
@@ -137,13 +138,12 @@ def test_request_init(self):
def test_request_init_request_method(self):
req_args = {'response_type': 'code', 'state': 'state'}
self.service.endpoint = 'https://example.com/authorize'
- _info = self.service.get_request_parameters(request_args=req_args,
- request_method='value')
+ _info = self.service.get_request_parameters(request_args=req_args, request_method='value')
assert set(_info.keys()) == {'url', 'method', 'request'}
msg = AuthorizationRequest().from_urlencoded(
self.service.get_urlinfo(_info['url']))
- assert set(msg.to_dict()) == {'client_id', 'redirect_uri', 'request',
- 'response_type', 'state', 'scope', 'nonce'}
+ assert set(msg.to_dict()) == {'client_id', 'redirect_uri', 'request', 'response_type',
+ 'scope'}
_jws = jws.factory(msg['request'])
assert _jws
_resp = _jws.verify_compact(
@@ -203,7 +203,7 @@ def test_update_service_context_with_idtoken_wrong_nonce(self):
idt = JWT(ISS_KEY, iss=ISS, lifetime=3600)
payload = {
'sub': '123456789', 'aud': ['client_id'],
- 'nonce': 'nonce'
+ 'nonce': 'noice'
}
# have to calculate c_hash
alg = 'RS256'
@@ -212,9 +212,8 @@ def test_update_service_context_with_idtoken_wrong_nonce(self):
_idt = idt.pack(payload)
resp = AuthorizationResponse(state='state', code='code', id_token=_idt)
- resp = self.service.parse_response(resp.to_urlencoded())
- with pytest.raises(ParameterError):
- self.service.update_service_context(resp, 'state2')
+ with pytest.raises(ValueError):
+ self.service.parse_response(resp.to_urlencoded())
def test_update_service_context_with_idtoken_missing_nonce(self):
req_args = {'response_type': 'code', 'state': 'state', 'nonce': 'nonce'}
@@ -230,9 +229,8 @@ def test_update_service_context_with_idtoken_missing_nonce(self):
_idt = idt.pack(payload)
resp = AuthorizationResponse(state='state', code='code', id_token=_idt)
- resp = self.service.parse_response(resp.to_urlencoded())
- with pytest.raises(ValueError):
- self.service.update_service_context(resp, 'state')
+ with pytest.raises(MissingRequiredAttribute):
+ self.service.parse_response(resp.to_urlencoded())
@pytest.mark.parametrize("allow_sign_alg_none", [True, False])
def test_allow_unsigned_idtoken(self, allow_sign_alg_none):
@@ -241,14 +239,14 @@ def test_allow_unsigned_idtoken(self, allow_sign_alg_none):
self.service.get_request_parameters(request_args=req_args)
# Build an ID Token
idt = JWT(ISS_KEY, iss=ISS, lifetime=3600, sign_alg='none')
- payload = {'sub': '123456789', 'aud': ['client_id']}
+ payload = {'sub': '123456789', 'aud': ['client_id'], 'nonce': req_args['nonce']}
_idt = idt.pack(payload)
self.service.client_get("service_context").behaviour["verify_args"] = {
"allow_sign_alg_none": allow_sign_alg_none
}
resp = AuthorizationResponse(state='state', code='code', id_token=_idt)
if allow_sign_alg_none:
- resp = self.service.parse_response(resp.to_urlencoded())
+ self.service.parse_response(resp.to_urlencoded())
else:
with pytest.raises(UnsupportedAlgorithm):
self.service.parse_response(resp.to_urlencoded())
@@ -267,7 +265,7 @@ def create_request(self):
}
entity = Entity(keyjar=CLI_KEY, config=client_config, services=DEFAULT_OIDC_SERVICES)
entity.client_get("service_context").issuer = 'https://example.com'
- self.service = entity.client_get("service",'authorization')
+ self.service = entity.client_get("service", 'authorization')
def test_construct_code(self):
req_args = {
@@ -318,7 +316,7 @@ def create_request(self):
}
entity = Entity(keyjar=CLI_KEY, config=client_config, services=DEFAULT_OIDC_SERVICES)
entity.client_get("service_context").issuer = 'https://example.com'
- self.service = entity.client_get("service",'accesstoken')
+ self.service = entity.client_get("service", 'accesstoken')
# add some history
auth_request = AuthorizationRequest(
@@ -409,7 +407,7 @@ def create_service(self):
}
entity = Entity(keyjar=CLI_KEY, config=client_config, services=DEFAULT_OIDC_SERVICES)
entity.client_get("service_context").issuer = 'https://example.com'
- self.service = entity.client_get("service",'provider_info')
+ self.service = entity.client_get("service", 'provider_info')
def test_construct(self):
_req = self.service.construct()
@@ -601,7 +599,7 @@ def create_request(self):
}
entity = Entity(keyjar=CLI_KEY, config=client_config, services=DEFAULT_OIDC_SERVICES)
entity.client_get("service_context").issuer = 'https://example.com'
- self.service = entity.client_get("service",'registration')
+ self.service = entity.client_get("service", 'registration')
def test_construct(self):
_req = self.service.construct()
@@ -616,14 +614,59 @@ def test_config_with_post_logout(self):
assert len(_req) == 5
assert 'post_logout_redirect_uris' in _req
- def test_config_with_required_request_uri(self):
- _pi = self.service.client_get("service_context").provider_info
- _pi['require_request_uri_registration'] = True
- self.service.client_get("service_context").provider_info = _pi
- _req = self.service.construct()
- assert isinstance(_req, RegistrationRequest)
- assert len(_req) == 5
- assert 'request_uris' in _req
+
+def test_config_with_required_request_uri():
+ client_config = {
+ 'client_id': 'client_id', 'client_secret': 'a longesh password',
+ 'redirect_uris': ['https://example.com/cli/authz_cb'],
+ 'issuer': ISS,
+ 'requests_dir': 'requests',
+ 'base_url': 'https://example.com/cli/',
+ 'client_preferences': {
+ "request_uri_usable": True
+ }
+ }
+ entity = Entity(keyjar=CLI_KEY, config=client_config, services=DEFAULT_OIDC_SERVICES)
+ entity.client_get("service_context").issuer = 'https://example.com'
+ service = entity.client_get("service", 'registration')
+ _context = service.client_get("service_context")
+
+ _pi = _context.provider_info
+ _pi['require_request_uri_registration'] = True
+ _context.config["client_preferences"]["request_uri_usable"] = True
+ _req = service.construct()
+ assert isinstance(_req, RegistrationRequest)
+ assert len(_req) == 5
+ assert 'request_uris' in _req
+
+
+def test_config_logout_uri():
+ client_config = {
+ 'client_id': 'client_id', 'client_secret': 'a longesh password',
+ 'redirect_uris': ['https://example.com/cli/authz_cb'],
+ 'issuer': ISS,
+ 'requests_dir': 'requests',
+ 'base_url': 'https://example.com/cli/',
+ 'client_preferences': {
+ "request_uri_usable": True
+ }
+ }
+ entity = Entity(keyjar=CLI_KEY, config=client_config, services=DEFAULT_OIDC_SERVICES)
+ entity.client_get("service_context").issuer = 'https://example.com'
+ service = entity.client_get("service", 'registration')
+ _context = service.client_get("service_context")
+
+ _pi = _context.provider_info
+ _pi['require_request_uri_registration'] = True
+ _pi['frontchannel_logout_supported'] = True
+ _context.config["client_preferences"]["request_uri_usable"] = True
+ _context.config["client_preferences"]["frontchannel_logout_usable"] = True
+ _req = service.construct()
+ assert isinstance(_req, RegistrationRequest)
+ assert len(_req) == 7
+ assert 'request_uris' in _req
+ assert 'frontchannel_logout_uri' in _req
+ assert 'post_logout_redirect_uris' in _req
class TestUserInfo(object):
@@ -638,7 +681,7 @@ def create_request(self):
}
entity = Entity(keyjar=CLI_KEY, config=client_config, services=DEFAULT_OIDC_SERVICES)
entity.client_get("service_context").issuer = 'https://example.com'
- self.service = entity.client_get("service",'userinfo')
+ self.service = entity.client_get("service", 'userinfo')
entity.client_get("service_context").behaviour = {
'userinfo_signed_response_alg': 'RS256',
@@ -780,7 +823,7 @@ def create_request(self):
}}
entity = Entity(keyjar=CLI_KEY, config=client_config, services=services)
entity.client_get("service_context").issuer = 'https://example.com'
- self.service = entity.client_get("service",'check_session')
+ self.service = entity.client_get("service", 'check_session')
def test_construct(self):
_state_interface = self.service.client_get("service_context").state
@@ -809,7 +852,7 @@ def create_request(self):
}}
entity = Entity(keyjar=CLI_KEY, config=client_config, services=services)
entity.client_get("service_context").issuer = 'https://example.com'
- self.service = entity.client_get("service",'check_id')
+ self.service = entity.client_get("service", 'check_id')
def test_construct(self):
_state_interface = self.service.client_get("service_context").state
@@ -839,7 +882,7 @@ def create_request(self):
}}
entity = Entity(keyjar=CLI_KEY, config=client_config, services=services)
entity.client_get("service_context").issuer = 'https://example.com'
- self.service = entity.client_get("service",'end_session')
+ self.service = entity.client_get("service", 'end_session')
def test_construct(self):
self.service.client_get("service_context").state.store_item(
@@ -847,9 +890,8 @@ def test_construct(self):
'token_response', 'abcde')
_req = self.service.construct(state='abcde')
assert isinstance(_req, EndSessionRequest)
- assert len(_req) == 3
- assert set(_req.keys()) == {'state', 'id_token_hint',
- 'post_logout_redirect_uri'}
+ assert len(_req) == 2
+ assert set(_req.keys()) == {'state', 'id_token_hint'}
def test_authz_service_conf():
@@ -880,7 +922,7 @@ def test_authz_service_conf():
}
entity = Entity(keyjar=CLI_KEY, config=client_config, services=services)
entity.client_get("service_context").issuer = 'https://example.com'
- service = entity.client_get("service",'authorization')
+ service = entity.client_get("service", 'authorization')
req = service.construct()
assert 'claims' in req
@@ -900,7 +942,7 @@ def test_add_jwks_uri_or_jwks_0():
}
entity = Entity(keyjar=CLI_KEY, config=client_config, services=DEFAULT_OIDC_SERVICES)
entity.client_get("service_context").issuer = 'https://example.com'
- service = entity.client_get("service",'registration')
+ service = entity.client_get("service", 'registration')
req_args, post_args = add_jwks_uri_or_jwks({}, service)
assert req_args['jwks_uri'] == 'https://example.com/jwks/jwks.json'
@@ -919,7 +961,7 @@ def test_add_jwks_uri_or_jwks_1():
}
}
entity = Entity(keyjar=CLI_KEY, config=client_config, services=DEFAULT_OIDC_SERVICES)
- service = entity.client_get("service",'registration')
+ service = entity.client_get("service", 'registration')
req_args, post_args = add_jwks_uri_or_jwks({}, service)
assert req_args['jwks_uri'] == 'https://example.com/jwks/jwks.json'
@@ -938,7 +980,7 @@ def test_add_jwks_uri_or_jwks_2():
}
entity = Entity(keyjar=CLI_KEY, config=client_config,
jwks_uri='https://example.com/jwks/jwks.json', services=DEFAULT_OIDC_SERVICES)
- service = entity.client_get("service",'registration')
+ service = entity.client_get("service", 'registration')
req_args, post_args = add_jwks_uri_or_jwks({}, service)
assert req_args['jwks_uri'] == 'https://example.com/jwks/jwks.json'
diff --git a/tests/test_14_oidc.py b/tests/test_14_oidc.py
index 5ebc41c..42de74b 100755
--- a/tests/test_14_oidc.py
+++ b/tests/test_14_oidc.py
@@ -108,8 +108,7 @@ def test_construct_refresh_token_request(self):
token_response = AccessTokenResponse(refresh_token="refresh_with_me",
access_token="access")
- _context.state.store_item(token_response,
- 'token_response', 'ABCDE')
+ _context.state.store_item(token_response, 'token_response', 'ABCDE')
req_args = {}
msg = self.client.client_get("service",'refresh_token').construct(
@@ -131,17 +130,14 @@ def test_do_userinfo_request_init(self):
state='state'
)
- _context.state.store_item(auth_request,
- 'auth_request', 'ABCDE')
+ _context.state.store_item(auth_request, 'auth_request', 'ABCDE')
auth_response = AuthorizationResponse(code='access_code')
- _context.state.store_item(auth_response,
- 'auth_response', 'ABCDE')
+ _context.state.store_item(auth_response, 'auth_response', 'ABCDE')
token_response = AccessTokenResponse(refresh_token="refresh_with_me",
access_token="access")
- _context.state.store_item(token_response,
- 'token_response', 'ABCDE')
+ _context.state.store_item(token_response, 'token_response', 'ABCDE')
_srv = self.client.client_get("service",'userinfo')
_srv.endpoint = "https://example.com/userinfo"
diff --git a/tests/test_20_rp_handler_oidc.py b/tests/test_20_rp_handler_oidc.py
index b74b588..1191cd9 100644
--- a/tests/test_20_rp_handler_oidc.py
+++ b/tests/test_20_rp_handler_oidc.py
@@ -17,6 +17,7 @@
import responses
from oidcrp.entity import Entity
+from oidcrp.oidc.registration import add_callbacks
from oidcrp.rp_handler import RPHandler
BASE_URL = 'https://example.com/rp'
@@ -29,7 +30,8 @@
"code id_token token", "code token"],
"scope": ["openid", "profile", "email", "address", "phone"],
"token_endpoint_auth_method": "client_secret_basic",
- "verify_args": {"allow_sign_alg_none": True}
+ "verify_args": {"allow_sign_alg_none": True},
+ "request_parameter_preference": ["request_uri", "request"]
}
CLIENT_CONFIG = {
@@ -143,6 +145,40 @@
"userinfo_endpoint":
"https://api.github.com/user"
},
+ 'services': {
+ 'authorization': {
+ 'class': 'oidcrp.oidc.authorization.Authorization',
+ },
+ 'access_token': {
+ 'class': 'oidcrp.oidc.access_token.AccessToken'
+ },
+ 'userinfo': {
+ 'class': 'oidcrp.oidc.userinfo.UserInfo',
+ 'kwargs': {'conf': {'default_authn_method': ''}}
+ },
+ 'refresh_access_token': {
+ 'class': 'oidcrp.oidc.refresh_access_token.RefreshAccessToken'
+ }
+ }
+ },
+ 'github2': {
+ "issuer": "https://github.com/login/oauth/authorize",
+ 'client_id': 'eeeeeeeee',
+ 'client_secret': 'aaaaaaaaaaaaaaaaaaaa',
+ "redirect_uris": ["{}/authz_cb/github".format(BASE_URL)],
+ "behaviour": {
+ "response_types": ["code"],
+ "scope": ["user", "public_repo"],
+ "token_endpoint_auth_method": '',
+ "verify_args": {"allow_sign_alg_none": True}
+ },
+ "provider_info": {
+ "authorization_endpoint": "https://github.com/login/oauth/authorize",
+ "token_endpoint": "https://github.com/login/oauth/access_token",
+ "userinfo_endpoint": "https://api.github.com/user",
+ "request_parameter_supported": True,
+ "request_uri_parameter_supported": True
+ },
'services': {
'authorization': {
'class': 'oidcrp.oidc.authorization.Authorization'
@@ -155,8 +191,7 @@
'kwargs': {'conf': {'default_authn_method': ''}}
},
'refresh_access_token': {
- 'class': 'oidcrp.oidc.refresh_access_token'
- '.RefreshAccessToken'
+ 'class': 'oidcrp.oidc.refresh_access_token.RefreshAccessToken'
}
}
}
@@ -265,7 +300,7 @@ def test_do_provider_info(self):
# Make sure the service endpoints are set
for service_type in ['authorization', 'accesstoken', 'userinfo']:
- _srv = client.client_get("service",service_type)
+ _srv = client.client_get("service", service_type)
_endp = client.client_get("service_context").get('provider_info')[_srv.endpoint_name]
assert _srv.endpoint == _endp
@@ -277,7 +312,7 @@ def test_do_client_registration(self):
# only 2 things should have happened
assert self.rph.hash2issuer['github'] == issuer
- assert client.client_get("service_context").post_logout_redirect_uris == []
+ assert client.client_get("service_context").callback.get("post_logout_redirect_uris") is None
def test_do_client_setup(self):
client = self.rph.client_setup('github')
@@ -296,25 +331,29 @@ def test_do_client_setup(self):
assert len(keys) == 2
for service_type in ['authorization', 'accesstoken', 'userinfo']:
- _srv = client.client_get("service",service_type)
+ _srv = client.client_get("service", service_type)
_endp = _srv.client_get("service_context").get('provider_info')[_srv.endpoint_name]
assert _srv.endpoint == _endp
assert self.rph.hash2issuer['github'] == _context.get('issuer')
def test_create_callbacks(self):
- cb = self.rph.create_callbacks('https://op.example.com/')
+ client = self.rph.init_client('https://op.example.com/')
+ _srv = client.client_get("service", "registration")
+ _context = _srv.client_get("service_context")
+ add_callbacks(_context, [])
+
+ cb = _srv.client_get("service_context").callback
- assert set(cb.keys()) == {'code', 'implicit', 'form_post', '__hex'}
+ assert set(cb.keys()) == {'redirect_uris', 'code', 'implicit', '__hex'}
_hash = cb['__hex']
- assert cb['code'] == 'https://example.com/rp/authz_cb/{}'.format(_hash)
- assert cb['implicit'] == 'https://example.com/rp/authz_im_cb/{}'.format(_hash)
- assert cb['form_post'] == 'https://example.com/rp/authz_fp_cb/{}'.format(_hash)
+ assert cb['code'] == f'https://example.com/rp/authz_cb/{_hash}'
+ assert cb['implicit'] == f'https://example.com/rp/authz_im_cb/{_hash}'
- assert list(self.rph.hash2issuer.keys()) == [_hash]
+ assert list(_context.hash2issuer.keys()) == [_hash]
- assert self.rph.hash2issuer[_hash] == 'https://op.example.com/'
+ assert _context.hash2issuer[_hash] == 'https://op.example.com/'
def test_begin(self):
res = self.rph.begin(issuer_id='github')
@@ -367,8 +406,9 @@ def test_finalize_auth(self):
state=res['state'])
resp = self.rph.finalize_auth(client, _session['iss'], auth_response.to_dict())
assert set(resp.keys()) == {'state', 'code'}
- aresp = client.client_get("service_context").state.get_item(AuthorizationResponse, 'auth_response',
- res['state'])
+ aresp = client.client_get("service_context").state.get_item(AuthorizationResponse,
+ 'auth_response',
+ res['state'])
assert set(aresp.keys()) == {'state', 'code'}
def test_get_client_authn_method(self):
@@ -385,7 +425,7 @@ def test_get_client_authn_method(self):
'token_endpoint')
assert authn_method == 'client_secret_post'
- def test_get_access_token(self):
+ def test_get_tokens(self):
res = self.rph.begin(issuer_id='github')
_session = self.rph.get_session_information(res['state'])
client = self.rph.issuer2rp[_session['iss']]
@@ -417,14 +457,14 @@ def test_get_access_token(self):
with responses.RequestsMock() as rsps:
rsps.add("POST", _url, body=at.to_json(),
adding_headers={"Content-Type": "application/json"}, status=200)
- client.client_get("service",'accesstoken').endpoint = _url
+ client.client_get("service", 'accesstoken').endpoint = _url
auth_response = AuthorizationResponse(code='access_code',
state=res['state'])
resp = self.rph.finalize_auth(client, _session['iss'],
auth_response.to_dict())
- resp = self.rph.get_access_token(res['state'], client)
+ resp = self.rph.get_tokens(res['state'], client)
assert set(resp.keys()) == {'access_token', 'expires_in', 'id_token',
'token_type', '__verified_id_token',
'__expires_at'}
@@ -466,7 +506,7 @@ def test_access_and_id_token(self):
with responses.RequestsMock() as rsps:
rsps.add("POST", _url, body=at.to_json(),
adding_headers={"Content-Type": "application/json"}, status=200)
- client.client_get("service",'accesstoken').endpoint = _url
+ client.client_get("service", 'accesstoken').endpoint = _url
_response = AuthorizationResponse(code='access_code',
state=res['state'])
@@ -507,7 +547,7 @@ def test_access_and_id_token_by_reference(self):
with responses.RequestsMock() as rsps:
rsps.add("POST", _url, body=at.to_json(),
adding_headers={"Content-Type": "application/json"}, status=200)
- client.client_get("service",'accesstoken').endpoint = _url
+ client.client_get("service", 'accesstoken').endpoint = _url
_response = AuthorizationResponse(code='access_code',
state=res['state'])
@@ -548,7 +588,7 @@ def test_get_user_info(self):
with responses.RequestsMock() as rsps:
rsps.add("POST", _url, body=at.to_json(),
adding_headers={"Content-Type": "application/json"}, status=200)
- client.client_get("service",'accesstoken').endpoint = _url
+ client.client_get("service", 'accesstoken').endpoint = _url
_response = AuthorizationResponse(code='access_code',
state=res['state'])
@@ -562,7 +602,7 @@ def test_get_user_info(self):
with responses.RequestsMock() as rsps:
rsps.add("GET", _url, body='{"sub":"EndUserSubject"}',
adding_headers={"Content-Type": "application/json"}, status=200)
- client.client_get("service",'userinfo').endpoint = _url
+ client.client_get("service", 'userinfo').endpoint = _url
userinfo_resp = self.rph.get_user_info(res['state'], client,
token_resp['access_token'])
@@ -595,7 +635,7 @@ def test_get_provider_specific_service():
}
}
entity = Entity(services=srv_desc)
- assert entity.client_get("service",'accesstoken').response_body_type == 'urlencoded'
+ assert entity.client_get("service", 'accesstoken').response_body_type == 'urlencoded'
class TestRPHandlerTier2(object):
@@ -634,7 +674,7 @@ def rphandler_setup(self):
rsps.add("POST", _url, body=at.to_json(),
adding_headers={"Content-Type": "application/json"}, status=200)
- client.client_get("service",'accesstoken').endpoint = _url
+ client.client_get("service", 'accesstoken').endpoint = _url
_response = AuthorizationResponse(code='access_code',
state=res['state'])
@@ -649,7 +689,7 @@ def rphandler_setup(self):
rsps.add("GET", _url, body='{"sub":"EndUserSubject"}',
adding_headers={"Content-Type": "application/json"}, status=200)
- client.client_get("service",'userinfo').endpoint = _url
+ client.client_get("service", 'userinfo').endpoint = _url
self.rph.get_user_info(res['state'], client,
token_resp['access_token'])
self.state = res['state']
@@ -677,7 +717,7 @@ def test_refresh_access_token(self):
rsps.add("POST", _url, body=at.to_json(),
adding_headers={"Content-Type": "application/json"}, status=200)
- client.client_get("service",'refresh_token').endpoint = _url
+ client.client_get("service", 'refresh_token').endpoint = _url
res = self.rph.refresh_access_token(self.state, client, 'openid email')
assert res['access_token'] == '2nd_accessTok'
@@ -689,7 +729,7 @@ def test_get_user_info(self):
with responses.RequestsMock() as rsps:
rsps.add("GET", _url, body='{"sub":"EndUserSubject", "mail":"foo@example.com"}',
adding_headers={"Content-Type": "application/json"}, status=200)
- client.client_get("service",'userinfo').endpoint = _url
+ client.client_get("service", 'userinfo').endpoint = _url
resp = self.rph.get_user_info(self.state, client)
assert set(resp.keys()) == {'sub', 'mail'}
@@ -722,8 +762,7 @@ def __init__(self, issuer, keyjar=None):
self.post_response = {}
self.register_post_response('default', 'OK', 200)
- def register_get_response(self, path, data, status_code=200,
- headers=None):
+ def register_get_response(self, path, data, status_code=200, headers=None):
_headers = headers or {}
self.get_response[path] = MockResponse(status_code, data, _headers)
@@ -787,13 +826,32 @@ def registration_callback(data):
return json.dumps(_req)
+def test_rphandler_request_uri():
+ rph = RPHandler(BASE_URL, CLIENT_CONFIG, keyjar=CLI_KEY)
+ res = rph.begin(issuer_id='github2', behaviour_args={"request_param": "request_uri"})
+ _session = rph.get_session_information(res['state'])
+ _url = res["url"]
+ _qp = parse_qs(urlparse(_url).query)
+ assert 'request_uri' in _qp
+
+
+def test_rphandler_request():
+ rph = RPHandler(BASE_URL, CLIENT_CONFIG, keyjar=CLI_KEY)
+ res = rph.begin(issuer_id='github2',
+ behaviour_args={"request_param": "request"})
+ _session = rph.get_session_information(res['state'])
+ _url = res["url"]
+ _qp = parse_qs(urlparse(_url).query)
+ assert 'request' in _qp
+
+
class TestRPHandlerWithMockOP(object):
@pytest.fixture(autouse=True)
def rphandler_setup(self):
self.issuer = 'https://github.com/login/oauth/authorize'
self.mock_op = MockOP(issuer=self.issuer)
self.rph = RPHandler(BASE_URL, client_configs=CLIENT_CONFIG,
- http_lib=self.mock_op, keyjar=KeyJar())
+ http_lib=self.mock_op, keyjar=CLI_KEY)
def test_finalize(self):
auth_query = self.rph.begin(issuer_id='github')
@@ -845,7 +903,7 @@ def test_finalize(self):
# assume code flow
resp = self.rph.finalize(_session['iss'], auth_response.to_dict())
- assert set(resp.keys()) == {'userinfo', 'state', 'token', 'id_token'}
+ assert set(resp.keys()) == {'userinfo', 'state', 'token', 'id_token', 'session_state'}
def test_dynamic_setup(self):
user_id = 'acct:foobar@example.com'
@@ -863,32 +921,22 @@ def test_dynamic_setup(self):
"issuer": "https://server.example.com",
"subject_types_supported": ['public'],
"token_endpoint": "https://server.example.com/connect/token",
- "token_endpoint_auth_methods_supported": ["client_secret_basic",
- "private_key_jwt"],
+ "token_endpoint_auth_methods_supported": ["client_secret_basic", "private_key_jwt"],
"userinfo_endpoint": "https://server.example.com/connect/user",
"check_id_endpoint": "https://server.example.com/connect/check_id",
- "refresh_session_endpoint":
- "https://server.example.com/connect/refresh_session",
- "end_session_endpoint":
- "https://server.example.com/connect/end_session",
+ "refresh_session_endpoint": "https://server.example.com/connect/refresh_session",
+ "end_session_endpoint": "https://server.example.com/connect/end_session",
"jwks_uri": "https://server.example.com/jwk.json",
- "registration_endpoint":
- "https://server.example.com/connect/register",
- "scopes_supported": ["openid", "profile", "email", "address",
- "phone"],
- "response_types_supported": ["code", "code id_token",
- "token id_token"],
- "acrs_supported": ["1", "2",
- "http://id.incommon.org/assurance/bronze"],
+ "registration_endpoint": "https://server.example.com/connect/register",
+ "scopes_supported": ["openid", "profile", "email", "address", "phone"],
+ "response_types_supported": ["code", "code id_token", "token id_token"],
+ "acrs_supported": ["1", "2", "http://id.incommon.org/assurance/bronze"],
"user_id_types_supported": ["public", "pairwise"],
- "userinfo_algs_supported": ["HS256", "RS256", "A128CBC", "A128KW",
- "RSA1_5"],
+ "userinfo_algs_supported": ["HS256", "RS256", "A128CBC", "A128KW", "RSA1_5"],
"id_token_signing_alg_values_supported": ["HS256", "RS256",
"A128CBC", "A128KW",
"RSA1_5"],
- "request_object_algs_supported": ["HS256", "RS256", "A128CBC",
- "A128KW",
- "RSA1_5"]
+ "request_object_algs_supported": ["HS256", "RS256", "A128CBC", "A128KW", "RSA1_5"]
}
pcr = ProviderConfigurationResponse(**resp)
@@ -902,3 +950,58 @@ def test_dynamic_setup(self):
auth_query = self.rph.begin(user_id=user_id)
assert auth_query
+
+ def test_dynamic_setup_redirect_uri(self):
+ user_id = 'acct:foobar@example.com'
+ _link = Link(rel="http://openid.net/specs/connect/1.0/issuer",
+ href="https://server.example.com")
+ webfinger_response = JRD(subject=user_id, links=[_link])
+ self.mock_op.register_get_response(
+ '/.well-known/webfinger', webfinger_response.to_json(), 200,
+ {'content-type': "application/json"})
+
+ resp = {
+ "authorization_endpoint":
+ "https://server.example.com/connect/authorize",
+ "issuer": "https://server.example.com",
+ "subject_types_supported": ['public'],
+ "token_endpoint": "https://server.example.com/connect/token",
+ "token_endpoint_auth_methods_supported": ["client_secret_basic", "private_key_jwt"],
+ "userinfo_endpoint": "https://server.example.com/connect/user",
+ "check_id_endpoint": "https://server.example.com/connect/check_id",
+ "refresh_session_endpoint": "https://server.example.com/connect/refresh_session",
+ "end_session_endpoint": "https://server.example.com/connect/end_session",
+ "jwks_uri": "https://server.example.com/jwk.json",
+ "registration_endpoint": "https://server.example.com/connect/register",
+ "scopes_supported": ["openid", "profile", "email", "address", "phone"],
+ "response_types_supported": ["code", "code id_token", "token id_token"],
+ "acrs_supported": ["1", "2", "http://id.incommon.org/assurance/bronze"],
+ "user_id_types_supported": ["public", "pairwise"],
+ "userinfo_algs_supported": ["HS256", "RS256", "A128CBC", "A128KW", "RSA1_5"],
+ "id_token_signing_alg_values_supported": ["HS256", "RS256",
+ "A128CBC", "A128KW",
+ "RSA1_5"],
+ "request_object_algs_supported": ["HS256", "RS256", "A128CBC", "A128KW", "RSA1_5"],
+ "request_parameter_supported": True,
+ "request_uri_parameter_supported": True,
+ "require_request_uri_registration": True
+ }
+
+ pcr = ProviderConfigurationResponse(**resp)
+ self.mock_op.register_get_response(
+ '/.well-known/openid-configuration', pcr.to_json(), 200,
+ {'content-type': "application/json"})
+
+ self.mock_op.register_post_response(
+ '/connect/register', registration_callback, 200,
+ {'content-type': "application/json"})
+
+ res = self.rph.begin(user_id=user_id,
+ behaviour_args={
+ "request_param": "request",
+ "request_object_signing_alg": "RS256"})
+ assert res
+
+ _url = res["url"]
+ _qp = parse_qs(urlparse(_url).query)
+ assert 'request' in _qp
diff --git a/tests/test_21_rph_defaults.py b/tests/test_21_rph_defaults.py
index ac07e72..ced2cad 100644
--- a/tests/test_21_rph_defaults.py
+++ b/tests/test_21_rph_defaults.py
@@ -72,7 +72,6 @@ def test_begin(self):
_context = client.client_get("service_context")
# Calculating request so I can build a reasonable response
- self.rph.add_callbacks(_context)
_req = client.client_get("service",'registration').construct_request()
with responses.RequestsMock() as rsps:
@@ -129,7 +128,6 @@ def test_begin_2(self):
_context = client.client_get("service_context")
# Calculating request so I can build a reasonable response
- self.rph.add_callbacks(_context)
# Publishing a JWKS instead of a JWKS_URI
_context.jwks_uri = ''
_context.jwks = _context.keyjar.export_jwks()
diff --git a/tests/test_40_rp_handler_persistent.py b/tests/test_40_rp_handler_persistent.py
index 2dd6faf..737365b 100644
--- a/tests/test_40_rp_handler_persistent.py
+++ b/tests/test_40_rp_handler_persistent.py
@@ -247,7 +247,7 @@ def test_do_client_registration(self):
# only 2 things should have happened
assert rph_1.hash2issuer['github'] == issuer
- assert not client.client_get("service_context").post_logout_redirect_uris
+ assert not client.client_get("service_context").callback.get('post_logout_redirect_uris')
def test_do_client_setup(self):
rph_1 = RPHandler(BASE_URL, client_configs=CLIENT_CONFIG,
@@ -275,23 +275,6 @@ def test_do_client_setup(self):
assert rph_1.hash2issuer['github'] == _context.get('issuer')
- def test_create_callbacks(self):
- rph_1 = RPHandler(BASE_URL, client_configs=CLIENT_CONFIG,
- keyjar=CLI_KEY, module_dirs=['oidc'])
-
- cb = rph_1.create_callbacks('https://op.example.com/')
-
- assert set(cb.keys()) == {'code', 'implicit', 'form_post', '__hex'}
- _hash = cb['__hex']
-
- assert cb['code'] == 'https://example.com/rp/authz_cb/{}'.format(_hash)
- assert cb['implicit'] == 'https://example.com/rp/authz_im_cb/{}'.format(_hash)
- assert cb['form_post'] == 'https://example.com/rp/authz_fp_cb/{}'.format(_hash)
-
- assert list(rph_1.hash2issuer.keys()) == [_hash]
-
- assert rph_1.hash2issuer[_hash] == 'https://op.example.com/'
-
def test_begin(self):
rph_1 = RPHandler(BASE_URL, client_configs=CLIENT_CONFIG,
keyjar=CLI_KEY, module_dirs=['oidc'])
@@ -376,7 +359,7 @@ def test_get_client_authn_method(self):
'token_endpoint')
assert authn_method == 'client_secret_post'
- def test_get_access_token(self):
+ def test_get_tokens(self):
rph_1 = RPHandler(BASE_URL, client_configs=CLIENT_CONFIG,
keyjar=CLI_KEY, module_dirs=['oidc'])
@@ -418,7 +401,7 @@ def test_get_access_token(self):
resp = rph_1.finalize_auth(client, _session['iss'],
auth_response.to_dict())
- resp = rph_1.get_access_token(res['state'], client)
+ resp = rph_1.get_tokens(res['state'], client)
assert set(resp.keys()) == {'access_token', 'expires_in', 'id_token',
'token_type', '__verified_id_token',
'__expires_at'}