diff --git a/doc/source/rp_handler.rst b/doc/source/rp_handler.rst index 2fe8e2f..2f139cc 100644 --- a/doc/source/rp_handler.rst +++ b/doc/source/rp_handler.rst @@ -14,7 +14,8 @@ that some of the functions the service provides needs access to some user related resources on a resource server. That's when you need OpenID Connect (OIDC) or Oauth2. -The RPHandler as implemented in :py:class:`oidcrp.RPHandler` is a service within +The RPHandler as implemented in :py:class:`oidcrp.rp_handler.RPHandler` is a +service within the web service that handles user authentication and access authorization on behalf of the web service. @@ -127,7 +128,7 @@ Tier 1 API The high-level methods you have access to (in the order they are to be used) are: -:py:meth:`oidcrp.RPHandler.begin` +:py:meth:`oidcrp.rp_handler.RPHandler.begin` This method will initiate a RP/Client instance if none exists for the OP/AS in question. It will then run service 1 if needed, services 2 and 3 according to configuration and finally will construct the authorization @@ -152,7 +153,7 @@ like this:: After the RP has received this response the processing continues with: -:py:meth:`oidcrp.RPHandler.get_session_information` +:py:meth:`oidcrp.rp_handler.RPHandler.get_session_information` In the authorization response there MUST be a state parameter. The value of that parameter is the key into a data store that will provide you with information about the session so far. @@ -161,7 +162,7 @@ After the RP has received this response the processing continues with: session_info = rph.state_db_interface.get_state(kwargs['state']) -:py:meth:`oidcrp.RPHandler.finalize` +:py:meth:`oidcrp.rp_handler.RPHandler.finalize` Will parse the authorization response and depending on the configuration run the services 5 and 6. @@ -177,14 +178,14 @@ The tier 1 API is good for getting you started with authenticating a user and getting user information but if you're look at a long-term engagement you need a finer grained set of methods. These I call the tier 2 API: -:py:meth:`oidcrp.RPHandler.do_provider_info` +:py:meth:`oidcrp.rp_handler.RPHandler.do_provider_info` Either get the provider info from configuration or through dynamic discovery. Will overwrite previously saved provider metadata. -:py:meth:`oidcrp.RPHandler.do_client_registration` +:py:meth:`oidcrp.rp_handler.RPHandler.do_client_registration` Do dynamic client registration is configured to do so and the OP supports it. -:py:meth:`oidcrp.RPHandler.init_authorization` +:py:meth:`oidcrp.rp_handler.RPHandler.init_authorization` Initialize an authorization/authentication event. If the user has a previous session stored this will not overwrite that but will create a new one. @@ -197,7 +198,7 @@ a finer grained set of methods. These I call the tier 2 API: The state_key you see mentioned here and below is the value of the state parameter in the authorization request. -:py:meth:`oidcrp.RPHandler.get_access_token` +:py:meth:`oidcrp.rp_handler.RPHandler.get_access_token` Will use an access code received as the response to an authentication/authorization to get an access token from the OP/AS. Access codes can only be used once. @@ -206,7 +207,7 @@ parameter in the authorization request. res = self.rph.get_access_token(state_key) -:py:meth:`oidcrp.RPHandler.refresh_access_token` +:py:meth:`oidcrp.rp_handler.RPHandler.refresh_access_token` If the client has received a refresh token this method can be used to get a new access token. @@ -218,7 +219,7 @@ You may change the set of scopes that are bound to the new access token but that change can only be a downgrade from what was specified in the authorization request and accepted by the user. -:py:meth:`oidcrp.RPHandler.get_user_info` +:py:meth:`oidcrp.rp_handler.RPHandler.get_user_info` If the client is allowed to do so, it can refresh the user info by requesting user information from the userinfo endpoint. @@ -226,7 +227,7 @@ authorization request and accepted by the user. resp = self.rph.get_user_info(state_key) -:py:meth:`oidcrp.RPHandler.has_active_authentication` +:py:meth:`oidcrp.rp_handler.RPHandler.has_active_authentication` After a while when the user returns after having been away for a while you may want to know if you should let her reauthenticate or not. This method will tell you if the last done authentication is still @@ -238,7 +239,7 @@ authorization request and accepted by the user. response will be True or False depending in the state of the authentication. -:py:meth:`oidcrp.RPHandler.get_valid_access_token` +:py:meth:`oidcrp.rp_handler.RPHandler.get_valid_access_token` When you are issued a access token it normally comes with a life time. After that time you are expected to use the refresh token to get a new access token. There are 2 ways of finding out if the access token you have is @@ -289,7 +290,7 @@ these 2 together then defines the base_url. which is normally defined as:: logging How the process should log -http_params +httpc_params Defines how the process performs HTTP requests to other entities. Parameters here are typically **verify** which controls whether the http client will verify the server TLS certificate or not. @@ -302,9 +303,6 @@ rp_keys Definition of the private keys that all RPs are going to use in the OIDC protocol exchange. -jwks_uri - Where the OP/AS can find the RPs public keys - There might be other parameters that you need dependent on which web framework you chose to use. @@ -339,8 +337,22 @@ redirect_uris the use back to this URL after the authorization/authentication has completed. These URLs should be OP/AS specific. -behavior - Information about how the RP should behave towards the OP/AS +behaviour + Information about how the RP should behave towards the OP/AS. This is + a set of attributes with values. The attributes taken from the + `client metadata`_ specification. *behaviour* is used when the client + has been registered statically and it is know what the client wants to + use and the OP supports. + + Usage example:: + + "behaviour": { + "response_types": ["code"], + "scope": ["openid", "profile", "email"], + "token_endpoint_auth_method": ["client_secret_basic", + 'client_secret_post'] + } + rp_keys If the OP doesn't support dynamic provider discovery it may still want to @@ -356,7 +368,11 @@ rp_keys If the provider info discovery is done dynamically you need this client_preferences - How the RP should prefer to behave against the OP/AS + How the RP should prefer to behave against the OP/AS. The content are the + same as for *behaviour*. The difference is that this is specified if the + RP is expected to do dynamic client registration which means that at the + point of writing the configuration it is only known what the RP can and + wants to do but unknown what the OP supports. issuer The Issuer ID of the OP. @@ -368,6 +384,8 @@ allow in the provider info is not the same as the URL you used to fetch the information. +.. _client metadata: https://openid.net/specs/openid-connect-registration-1_0.html#ClientMetadata + ------------------------- RP configuration - Google ------------------------- @@ -380,7 +398,7 @@ with dummy values:: "client_id": "xxxxxxxxx.apps.googleusercontent.com", "client_secret": "2222222222", "redirect_uris": ["{}/authz_cb/google".format(BASEURL)], - "client_prefs": { + "behaviour": { "response_types": ["code"], "scope": ["openid", "profile", "email"], "token_endpoint_auth_method": ["client_secret_basic", @@ -415,7 +433,7 @@ right now supports 2 variants both listed here. The RP will by default pick the first if a list of possible values. Which in this case means the RP will authenticate using the *client_secret_basic* if allowed by Google:: - "client_prefs": { + "behaviour": { "response_types": ["code"], "scope": ["openid", "profile", "email"], "token_endpoint_auth_method": ["client_secret_basic", @@ -447,7 +465,7 @@ Configuration that allows you to use a Microsoft OP as identity provider:: 'client_id': '242424242424', 'client_secret': 'ipipipippipipippi', "redirect_uris": ["{}/authz_cb/microsoft".format(BASEURL)], - "client_prefs": { + "behaviour": { "response_types": ["id_token"], "scope": ["openid"], "token_endpoint_auth_method": ['client_secret_post'], @@ -465,7 +483,7 @@ Configuration that allows you to use a Microsoft OP as identity provider:: One piece at the time. Microsoft has something called a tenant. Either you specify your RP to only one tenant in which case the issuer returned as *iss* in the id_token will be the same as the *issuer*. If our RP -is expected to work in a multi-tenant environment then the *iss* will never +is expected to work in a multi-tenant environment then the *iss* will **never** match issuer. Let's assume our RP works in a single-tenant context:: 'issuer': 'https://login.microsoftonline.com//v2.0', @@ -486,7 +504,7 @@ response not in the fragment of the redirect URL which is the default but instead using the response_mode *form_post*. *client_secret_post* is a client authentication that Microsoft supports at the token enpoint:: - "client_prefs": { + "behaviour": { "response_types": ["id_token"], "scope": ["openid"], "token_endpoint_auth_method": ['client_secret_post'], @@ -574,7 +592,7 @@ can be used to access user info at the userinfo endpoint. GitHub deviates from the standard in a number of way. First the Oauth2 standard doesn't mention anything like an userinfo endpoint, that is OIDC. So GitHub has implemented something that is in between OAuth2 and OIDC. -What's more disturbing is that the accesstoken response by default is not +What's more disturbing is that the access token response by default is not encoded as a JSON document which the standard say but instead it's urlencoded. Lucky for us, we can deal with both these things by configuration rather then writing code.:: @@ -584,3 +602,4 @@ rather then writing code.:: 'AccessToken': {'response_body_type': 'urlencoded'}, 'UserInfo': {'default_authn_method': ''} } + diff --git a/flask_rp/__init__.py b/example/flask_rp/__init__.py similarity index 100% rename from flask_rp/__init__.py rename to example/flask_rp/__init__.py diff --git a/flask_rp/application.py b/example/flask_rp/application.py similarity index 78% rename from flask_rp/application.py rename to example/flask_rp/application.py index 6cda469..a2aa16b 100644 --- a/flask_rp/application.py +++ b/example/flask_rp/application.py @@ -5,8 +5,7 @@ from cryptojwt.key_jar import init_key_jar from flask.app import Flask -from oidcrp import RPHandler -from oidcrp.configure import Configuration +from oidcrp.rp_handler import RPHandler dir_path = os.path.dirname(os.path.realpath(__file__)) @@ -14,9 +13,9 @@ def init_oidc_rp_handler(app): _rp_conf = app.rp_config - if _rp_conf.rp_keys: - _kj = init_key_jar(**_rp_conf.rp_keys) - _path = _rp_conf.rp_keys['public_path'] + if _rp_conf.keys: + _kj = init_key_jar(**_rp_conf.keys) + _path = _rp_conf.keys['public_path'] # removes ./ and / from the begin of the string _path = re.sub('^(.)/', '', _path) else: @@ -31,11 +30,11 @@ def init_oidc_rp_handler(app): return rph -def oidc_provider_init_app(config_file, name=None, **kwargs): +def oidc_provider_init_app(config, name=None, **kwargs): name = name or __name__ app = Flask(name, static_url_path='', **kwargs) - app.rp_config = Configuration.create_from_config_file(config_file) + app.rp_config = config # Session key for the application session app.config['SECRET_KEY'] = os.urandom(12).hex() diff --git a/chrp/certs/cert.pem b/example/flask_rp/certs/cert.pem similarity index 100% rename from chrp/certs/cert.pem rename to example/flask_rp/certs/cert.pem diff --git a/chrp/certs/key.pem b/example/flask_rp/certs/key.pem similarity index 100% rename from chrp/certs/key.pem rename to example/flask_rp/certs/key.pem diff --git a/example/flask_rp/conf.json b/example/flask_rp/conf.json new file mode 100644 index 0000000..7696d67 --- /dev/null +++ b/example/flask_rp/conf.json @@ -0,0 +1,216 @@ +{ + "logging": { + "version": 1, + "disable_existing_loggers": false, + "root": { + "handlers": [ + "console", + "file" + ], + "level": "DEBUG" + }, + "loggers": { + "idp": { + "level": "DEBUG" + } + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "stream": "ext://sys.stdout", + "formatter": "default" + }, + "file": { + "class": "logging.FileHandler", + "filename": "debug.log", + "formatter": "default" + } + }, + "formatters": { + "default": { + "format": "%(asctime)s %(name)s %(levelname)s %(message)s" + } + } + }, + "port": 8090, + "domain": "127.0.0.1", + "base_url": "https://{domain}:{port}", + "httpc_params": { + "verify": false + }, + "rp_keys": { + "private_path": "private/jwks.json", + "key_defs": [ + { + "type": "RSA", + "key": "", + "use": [ + "sig" + ] + }, + { + "type": "EC", + "crv": "P-256", + "use": [ + "sig" + ] + } + ], + "public_path": "static/jwks.json", + "read_only": false + }, + "services": { + "discovery": { + "class": "oidcrp.oidc.provider_info_discovery.ProviderInfoDiscovery", + "kwargs": {} + }, + "registration": { + "class": "oidcrp.oidc.registration.Registration", + "kwargs": {} + }, + "authorization": { + "class": "oidcrp.oidc.authorization.Authorization", + "kwargs": {} + }, + "accesstoken": { + "class": "oidcrp.oidc.access_token.AccessToken", + "kwargs": {} + }, + "userinfo": { + "class": "oidcrp.oidc.userinfo.UserInfo", + "kwargs": {} + }, + "end_session": { + "class": "oidcrp.oidc.end_session.EndSession", + "kwargs": {} + } + }, + "clients": { + "": { + "client_preferences": { + "application_name": "rphandler", + "application_type": "web", + "contacts": [ + "ops@example.com" + ], + "response_types": [ + "code" + ], + "scope": [ + "openid", + "profile", + "email", + "address", + "phone" + ], + "token_endpoint_auth_method": [ + "client_secret_basic", + "client_secret_post" + ] + }, + "redirect_uris": "None", + "services": { + "discovery": { + "class": "oidcrp.oidc.provider_info_discovery.ProviderInfoDiscovery", + "kwargs": {} + }, + "registration": { + "class": "oidcrp.oidc.registration.Registration", + "kwargs": {} + }, + "authorization": { + "class": "oidcrp.oidc.authorization.Authorization", + "kwargs": {} + }, + "accesstoken": { + "class": "oidcrp.oidc.access_token.AccessToken", + "kwargs": {} + }, + "userinfo": { + "class": "oidcrp.oidc.userinfo.UserInfo", + "kwargs": {} + }, + "end_session": { + "class": "oidcrp.oidc.end_session.EndSession", + "kwargs": {} + } + } + }, + "local": { + "client_preferences": { + "application_name": "rphandler", + "application_type": "web", + "contacts": [ + "ops@example.com" + ], + "response_types": [ + "code" + ], + "scope": [ + "openid", + "profile", + "email", + "address", + "phone" + ], + "token_endpoint_auth_method": [ + "client_secret_basic", + "client_secret_post" + ] + }, + "issuer": "https://127.0.0.1:5000/", + "redirect_uris": [ + "https://{domain}:{port}/authz_cb/local" + ], + "post_logout_redirect_uris": [ + "https://{domain}:{port}/session_logout/local" + ], + "frontchannel_logout_uri": "https://{domain}:{port}/fc_logout/local", + "frontchannel_logout_session_required": true, + "backchannel_logout_uri": "https://{domain}:{port}/bc_logout/local", + "backchannel_logout_session_required": true, + "services": { + "discovery": { + "class": "oidcrp.oidc.provider_info_discovery.ProviderInfoDiscovery", + "kwargs": {} + }, + "registration": { + "class": "oidcrp.oidc.registration.Registration", + "kwargs": {} + }, + "authorization": { + "class": "oidcrp.oidc.authorization.Authorization", + "kwargs": {} + }, + "accesstoken": { + "class": "oidcrp.oidc.access_token.AccessToken", + "kwargs": {} + }, + "userinfo": { + "class": "oidcrp.oidc.userinfo.UserInfo", + "kwargs": {} + }, + "end_session": { + "class": "oidcrp.oidc.end_session.EndSession", + "kwargs": {} + } + }, + "add_ons": { + "pkce": { + "function": "oidcrp.oauth2.add_on.pkce.add_support", + "kwargs": { + "code_challenge_length": 64, + "code_challenge_method": "S256" + } + } + } + } + }, + "webserver": { + "port": 8090, + "domain": "127.0.0.1", + "server_cert": "certs/cert.pem", + "server_key": "certs/key.pem", + "debug": true + } +} diff --git a/example/flask_rp/conf.yaml b/example/flask_rp/conf.yaml new file mode 100644 index 0000000..ea23a1a --- /dev/null +++ b/example/flask_rp/conf.yaml @@ -0,0 +1,129 @@ +logging: + version: 1 + disable_existing_loggers: False + root: + handlers: + - console + - file + level: DEBUG + loggers: + idp: + level: DEBUG + handlers: + console: + class: logging.StreamHandler + stream: 'ext://sys.stdout' + formatter: default + file: + class: logging.FileHandler + filename: 'debug.log' + formatter: default + formatters: + default: + format: '%(asctime)s %(name)s %(levelname)s %(message)s' + +port: &port 8090 +domain: &domain 127.0.0.1 +base_url: "https://{domain}:{port}" + +httpc_params: + # This is just for testing an local usage. In all other cases it MUST be True + verify: false + # Client side + #client_cert: "certs/client.crt" + #client_key: "certs/client.key" + +keydefs: &keydef + - "type": "RSA" + "key": '' + "use": ["sig"] + - "type": "EC" + "crv": "P-256" + "use": ["sig"] + +rp_keys: + 'private_path': 'private/jwks.json' + 'key_defs': *keydef + 'public_path': 'static/jwks.json' + # this will create the jwks files if they are absent + 'read_only': False + +client_preferences: &id001 + application_name: rphandler + application_type: web + contacts: + - ops@example.com + response_types: + - code + scope: + - openid + - profile + - email + - address + - phone + token_endpoint_auth_method: + - client_secret_basic + - client_secret_post + +services: &id002 + discovery: &disc + class: oidcrp.oidc.provider_info_discovery.ProviderInfoDiscovery + kwargs: {} + registration: ®ist + class: oidcrp.oidc.registration.Registration + kwargs: {} + authorization: &authz + class: oidcrp.oidc.authorization.Authorization + kwargs: {} + accesstoken: &acctok + class: oidcrp.oidc.access_token.AccessToken + kwargs: {} + userinfo: &userinfo + class: oidcrp.oidc.userinfo.UserInfo + kwargs: {} + end_session: &sess + class: oidcrp.oidc.end_session.EndSession + kwargs: {} + +clients: + "": + client_preferences: *id001 + redirect_uris: None + services: *id002 + local: + client_preferences: *id001 + issuer: https://127.0.0.1:5000/ + redirect_uris: + - 'https://{domain}:{port}/authz_cb/local' + post_logout_redirect_uris: + - "https://{domain}:{port}/session_logout/local" + frontchannel_logout_uri: "https://{domain}:{port}/fc_logout/local" + frontchannel_logout_session_required: True + backchannel_logout_uri: "https://{domain}:{port}/bc_logout/local" + backchannel_logout_session_required: True + services: + discovery: *disc + registration: *regist + authorization: *authz + accesstoken: *acctok + userinfo: *userinfo + end_session: *sess + add_ons: + pkce: + function: oidcrp.oauth2.add_on.pkce.add_support + kwargs: + code_challenge_length: 64 + code_challenge_method: S256 + + +webserver: + port: *port + domain: *domain + # If BASE is https these has to be specified + server_cert: "certs/cert.pem" + server_key: "certs/key.pem" + # If you want the clients cert to be verified + # verify_user: optional + # The you also need + # ca_bundle: '' + debug: true diff --git a/flask_rp/example_conf.yaml b/example/flask_rp/example_conf.yaml similarity index 100% rename from flask_rp/example_conf.yaml rename to example/flask_rp/example_conf.yaml diff --git a/flask_rp/run.sh b/example/flask_rp/run.sh similarity index 51% rename from flask_rp/run.sh rename to example/flask_rp/run.sh index 71414f4..6ecdd29 100755 --- a/flask_rp/run.sh +++ b/example/flask_rp/run.sh @@ -1,3 +1,3 @@ #!/usr/bin/env bash -./wsgi.py bc_conf.py \ No newline at end of file +./wsgi.py conf.json \ No newline at end of file diff --git a/flask_rp/templates/opbyuid.html b/example/flask_rp/templates/opbyuid.html similarity index 100% rename from flask_rp/templates/opbyuid.html rename to example/flask_rp/templates/opbyuid.html diff --git a/flask_rp/templates/opresult.html b/example/flask_rp/templates/opresult.html similarity index 100% rename from flask_rp/templates/opresult.html rename to example/flask_rp/templates/opresult.html diff --git a/flask_rp/templates/repost_fragment.html b/example/flask_rp/templates/repost_fragment.html similarity index 100% rename from flask_rp/templates/repost_fragment.html rename to example/flask_rp/templates/repost_fragment.html diff --git a/flask_rp/templates/rp_iframe.html b/example/flask_rp/templates/rp_iframe.html similarity index 100% rename from flask_rp/templates/rp_iframe.html rename to example/flask_rp/templates/rp_iframe.html diff --git a/flask_rp/templates/session_status.html b/example/flask_rp/templates/session_status.html similarity index 100% rename from flask_rp/templates/session_status.html rename to example/flask_rp/templates/session_status.html diff --git a/flask_rp/views.py b/example/flask_rp/views.py similarity index 86% rename from flask_rp/views.py rename to example/flask_rp/views.py index 7ac5b99..15abf2d 100644 --- a/flask_rp/views.py +++ b/example/flask_rp/views.py @@ -1,7 +1,6 @@ import logging from urllib.parse import parse_qs -import werkzeug from flask import Blueprint from flask import current_app from flask import redirect @@ -10,9 +9,10 @@ from flask import session from flask.helpers import make_response from flask.helpers import send_from_directory -from oidcservice.exception import OidcServiceError +import werkzeug -import oidcrp +from oidcrp import rp_handler +from oidcrp.exception import OidcServiceError logger = logging.getLogger(__name__) @@ -96,12 +96,13 @@ def finalize(op_hash, request_args): logger.error(rp.response[0].decode()) return rp.response[0], rp.status_code - session['client_id'] = rp.service_context.get('client_id') + _context = rp.client_get("service_context") + session['client_id'] = _context.get('client_id') session['state'] = request_args.get('state') if session['state']: - iss = rp.session_interface.get_iss(session['state']) + iss = _context.state.get_iss(session['state']) else: return make_response('Unknown state', 400) @@ -118,8 +119,9 @@ def finalize(op_hash, request_args): raise excp if 'userinfo' in res: + _context = rp.client_get("service_context") endpoints = {} - for k, v in rp.service_context.get('provider_info').items(): + for k, v in _context.provider_info.items(): if k.endswith('_endpoint'): endp = k.replace('_', ' ') endp = endp.capitalize() @@ -128,16 +130,16 @@ def finalize(op_hash, request_args): kwargs = {} # Do I support session status checking ? - _status_check_info = rp.service_context.add_on.get('status_check') + _status_check_info = _context.add_on.get('status_check') if _status_check_info: # Does the OP support session status checking ? - _chk_iframe = rp.service_context.get('provider_info').get('check_session_iframe') + _chk_iframe = _context.get('provider_info').get('check_session_iframe') if _chk_iframe: kwargs['check_session_iframe'] = _chk_iframe kwargs["status_check_iframe"] = _status_check_info['rp_iframe_path'] # Where to go if the user clicks on logout - kwargs['logout_url'] = "{}/logout".format(rp.service_context.base_url) + kwargs['logout_url'] = "{}/logout".format(_context.base_url) return render_template('opresult.html', endpoints=endpoints, userinfo=res['userinfo'], @@ -175,7 +177,8 @@ def session_iframe(): # session management logger.debug('session_iframe request_args: {}'.format(request.args)) _rp = get_rp(session['op_hash']) - session_change_url = "{}/session_change".format(_rp.service_context.base_url) + _context = _rp.client_get("service_context") + session_change_url = "{}/session_change".format(_context.base_url) _issuer = current_app.rph.hash2issuer[session['op_hash']] args = { @@ -185,7 +188,7 @@ def session_iframe(): # session management 'session_change_url': session_change_url } logger.debug('rp_iframe args: {}'.format(args)) - _template = _rp.service_context.add_on["status_check"]["session_iframe_template_file"] + _template = _context.add_on["status_check"]["session_iframe_template_file"] return render_template(_template, **args) @@ -195,7 +198,7 @@ def session_change(): _rp = get_rp(session['op_hash']) # If there is an ID token send it along as a id_token_hint - _aserv = _rp.service_context.service['authorization'] + _aserv = _rp.client_get("service", 'authorization') request_args = {"prompt": "none"} request_args = _aserv.multiple_extend_request_args( @@ -214,7 +217,7 @@ def session_change(): def session_logout(op_hash): _rp = get_rp(op_hash) logger.debug('post_logout') - return "Post logout from {}".format(_rp.service_context.get('issuer')) + return "Post logout from {}".format(_rp.client_get("service_context").issuer) # RP initiated logout @@ -230,7 +233,7 @@ def logout(): def backchannel_logout(op_hash): _rp = get_rp(op_hash) try: - _state = oidcrp.backchannel_logout(request.data, _rp) + _state = rp_handler.backchannel_logout(_rp, request.data) except Exception as err: logger.error('Exception: {}'.format(err)) return 'System error!', 400 @@ -244,7 +247,7 @@ def frontchannel_logout(op_hash): _rp = get_rp(op_hash) sid = request.args['sid'] _iss = request.args['iss'] - if _iss != _rp.service_context.get('issuer'): + if _iss != _rp.client_get("service_context").get('issuer'): return 'Bad request', 400 _state = _rp.session_interface.get_state_by_sid(sid) _rp.session_interface.remove_state(_state) diff --git a/example/flask_rp/wsgi.py b/example/flask_rp/wsgi.py new file mode 100755 index 0000000..be25906 --- /dev/null +++ b/example/flask_rp/wsgi.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 + +import os +import sys + +from oidcrp.configure import Configuration +from oidcrp.configure import RPConfiguration +from oidcrp.configure import create_from_config_file +from oidcrp.util import create_context + +try: + from . import application +except ImportError: + import application + +dir_path = os.path.dirname(os.path.realpath(__file__)) + +if __name__ == "__main__": + conf = sys.argv[1] + name = 'oidc_rp' + template_dir = os.path.join(dir_path, 'templates') + + _config = create_from_config_file(Configuration, + entity_conf=[{"class": RPConfiguration, "attr": "rp"}], + filename=conf) + + app = application.oidc_provider_init_app(_config.rp, name, template_folder=template_dir) + _web_conf = _config.web_conf + context = create_context(dir_path, _web_conf) + + debug = _web_conf.get('debug', True) + app.run(host=_web_conf["domain"], port=_web_conf["port"], + debug=_web_conf.get("debug", False), ssl_context=context) diff --git a/flask_rp/wsgi.py b/flask_rp/wsgi.py deleted file mode 100755 index 8c0de63..0000000 --- a/flask_rp/wsgi.py +++ /dev/null @@ -1,43 +0,0 @@ -#!/usr/bin/env python3 - -import logging -import os -import sys - -from oidcrp.util import create_context - -try: - from . import application -except ImportError: - import application - -# logger = logging.getLogger("") -# RP_LOGFILE_NAME = os.environ.get('RP_LOGFILE_NAME', 'flrp.log') -# -# hdlr = logging.FileHandler(RP_LOGFILE_NAME) -# log_format = ("%(asctime)s %(name)s:%(levelname)s " -# "%(message)s [%(name)s.%(funcName)s:%(lineno)s]") -# base_formatter = logging.Formatter(log_format) -# -# hdlr.setFormatter(base_formatter) -# logger.addHandler(hdlr) -# logger.setLevel(logging.DEBUG) -# -# stdout = logging.StreamHandler() -# stdout.setFormatter(base_formatter) -# logger.addHandler(stdout) - -dir_path = os.path.dirname(os.path.realpath(__file__)) - -if __name__ == "__main__": - conf = sys.argv[1] - name = 'oidc_rp' - template_dir = os.path.join(dir_path, 'templates') - app = application.oidc_provider_init_app(conf, name, - template_folder=template_dir) - _web_conf = app.rp_config.web_conf - context = create_context(dir_path, _web_conf) - - debug = _web_conf.get('debug', True) - app.run(host=app.rp_config.domain, port=app.rp_config.port, - debug=_web_conf.get("debug", False), ssl_context=context) diff --git a/setup.py b/setup.py index 5c5432c..b79b836 100755 --- a/setup.py +++ b/setup.py @@ -55,7 +55,8 @@ def run_tests(self): author_email="roland@catalogix.se", license="Apache 2.0", url='https://github.com/IdentityPython/oicrp/', - packages=["oidcrp", "oidcrp/provider", "oidcrp/oidc", "oidcrp/oauth2"], + packages=["oidcrp", "oidcrp/provider", "oidcrp/oidc", "oidcrp/oauth2", + "oidcrp/oauth2/add_on", "oidcrp/oauth2/client_credentials"], package_dir={"": "src"}, classifiers=[ "Development Status :: 4 - Beta", @@ -66,14 +67,13 @@ def run_tests(self): "Programming Language :: Python :: 3.9", "Topic :: Software Development :: Libraries :: Python Modules"], install_requires=[ - 'oidcservice>=1.1.0', - 'oidcmsg>=1.1.3', + 'oidcmsg>=1.3.0', 'pyyaml', 'responses' ], tests_require=[ 'pytest', - 'pytest-localserver', + 'pytest-localserver' ], zip_safe=False, cmdclass={'test': PyTest}, diff --git a/src/oidcrp/__init__.py b/src/oidcrp/__init__.py index 8886208..be0e9d9 100644 --- a/src/oidcrp/__init__.py +++ b/src/oidcrp/__init__.py @@ -1,1025 +1,11 @@ -import hashlib import logging -import sys -import traceback - -from cryptojwt.key_bundle import keybundle_from_local_file -from cryptojwt.key_jar import init_key_jar -from cryptojwt.utils import as_bytes -from cryptojwt.utils import as_unicode -from oidcmsg.exception import MessageException -from oidcmsg.exception import NotForMe -from oidcmsg.oauth2 import ResponseMessage -from oidcmsg.oauth2 import is_error_message -from oidcmsg.oidc import AccessTokenResponse -from oidcmsg.oidc import AuthorizationRequest -from oidcmsg.oidc import AuthorizationResponse -from oidcmsg.oidc import Claims -from oidcmsg.oidc import OpenIDSchema -from oidcmsg.oidc import verified_claim_name -from oidcmsg.oidc.session import BackChannelLogoutRequest -from oidcmsg.time_util import time_sans_frac -from oidcservice import rndstr -from oidcservice.exception import OidcServiceError -from oidcservice.state_interface import InMemoryStateDataBase - -from oidcrp import oauth2 -from oidcrp import oidc -from oidcrp import provider __author__ = 'Roland Hedberg' -__version__ = '0.8.1' +__version__ = '2.0.0' logger = logging.getLogger(__name__) SUCCESSFUL = [200, 201, 202, 203, 204, 205, 206] -class HandlerError(Exception): - pass - - -class ConfigurationError(Exception): - pass - - -class HttpError(OidcServiceError): - pass - - -def token_secret_key(sid): - return "token_secret_%s" % sid - - -SERVICE_NAME = "OIC" -CLIENT_CONFIG = {} - -DEFAULT_SEVICES = { - 'web_finger': {'class': 'oidcservice.oidc.webfinger.WebFinger'}, - 'discovery': {'class': 'oidcservice.oidc.provider_info_discovery.ProviderInfoDiscovery'}, - 'registration': {'class': 'oidcservice.oidc.registration.Registration'}, - 'authorization': {'class': 'oidcservice.oidc.authorization.Authorization'}, - 'access_token': {'class': 'oidcservice.oidc.access_token.AccessToken'}, - 'refresh_access_token': {'class': 'oidcservice.oidc.refresh_access_token.RefreshAccessToken'}, - 'userinfo': {'class': 'oidcservice.oidc.userinfo.UserInfo'} -} - -DEFAULT_CLIENT_PREFS = { - 'application_type': 'web', - 'application_name': 'rphandler', - 'response_types': ['code', 'id_token', 'id_token token', 'code id_token', 'code id_token token', - 'code token'], - 'scope': ['openid'], - 'token_endpoint_auth_method': 'client_secret_basic' -} - -# Using PKCE is default -DEFAULT_CLIENT_CONFIGS = { - "": { - "client_preferences": DEFAULT_CLIENT_PREFS, - "add_ons": { - "pkce": { - "function": "oidcservice.oidc.add_on.pkce.add_pkce_support", - "kwargs": { - "code_challenge_length": 64, - "code_challenge_method": "S256" - } - } - } - } -} - -DEFAULT_KEY_DEFS = [ - {"type": "RSA", "use": ["sig"]}, - {"type": "EC", "crv": "P-256", "use": ["sig"]}, -] - -DEFAULT_RP_KEY_DEFS = { - 'private_path': 'private/jwks.json', - 'key_defs': DEFAULT_KEY_DEFS, - 'public_path': 'static/jwks.json', - 'read_only': False -} - - -def add_path(url, path): - if url.endswith('/'): - if path.startswith('/'): - return '{}{}'.format(url, path[1:]) - else: - return '{}{}'.format(url, path) - else: - if path.startswith('/'): - return '{}{}'.format(url, path) - else: - return '{}/{}'.format(url, path) - - -def load_registration_response(client): - """ - If the client has been statically registered that information - must be provided during the configuration. If expected to be - done dynamically. This method will do dynamic client registration. - - :param client: A :py:class:`oidcservice.oidc.Client` instance - """ - if not client.service_context.get('client_id'): - try: - response = client.do_request('registration') - except KeyError: - raise ConfigurationError('No registration info') - except Exception as err: - logger.error(err) - raise - else: - if 'error' in response: - raise OidcServiceError(response.to_json()) - - -def dynamic_provider_info_discovery(client): - """ - This is about performing dynamic Provider Info discovery - - :param client: A :py:class:`oidcservice.oidc.Client` instance - """ - try: - client.service['provider_info'] - except KeyError: - raise ConfigurationError( - 'Can not do dynamic provider info discovery') - else: - try: - client.service_context.set( - 'issuer', - client.service_context.config['srv_discovery_url']) - except KeyError: - pass - - response = client.do_request('provider_info') - if is_error_message(response): - raise OidcServiceError(response['error']) - - -class RPHandler(object): - def __init__(self, base_url, client_configs=None, services=None, keyjar=None, - hash_seed="", verify_ssl=True, client_authn_factory=None, - client_cls=None, state_db=None, http_lib=None, httpc_params=None, - **kwargs): - - self.base_url = base_url - if hash_seed: - self.hash_seed = as_bytes(hash_seed) - else: - self.hash_seed = as_bytes(rndstr(32)) - - _jwks_path = kwargs.get('jwks_path') - if keyjar is None: - self.keyjar = init_key_jar(**DEFAULT_RP_KEY_DEFS, issuer_id='') - self.keyjar.import_jwks_as_json(self.keyjar.export_jwks_as_json(True, ''), base_url) - if _jwks_path is None: - _jwks_path = DEFAULT_RP_KEY_DEFS['public_path'] - else: - self.keyjar = keyjar - - if _jwks_path: - self.jwks_uri = add_path(base_url, _jwks_path) - else: - self.jwks_uri = "" - if len(self.keyjar): - self.jwks = self.keyjar.export_jwks() - else: - self.jwks = {} - - if state_db: - self.state_db = state_db - else: - self.state_db = InMemoryStateDataBase() - - # self.session_interface = StateInterface(self.state_db) - - self.extra = kwargs - - self.client_cls = client_cls or oidc.RP - if services is None: - self.services = DEFAULT_SEVICES - else: - self.services = services - - self.client_authn_factory = client_authn_factory - - if client_configs is None: - self.client_configs = DEFAULT_CLIENT_CONFIGS - else: - self.client_configs = client_configs - - # keep track on which RP instance that serves with OP - self.issuer2rp = {} - self.hash2issuer = {} - self.httplib = http_lib - - if not httpc_params: - self.httpc_params = {'verify': verify_ssl} - else: - self.httpc_params = httpc_params - - if not self.keyjar.httpc_params: - self.keyjar.httpc_params = self.httpc_params - - def state2issuer(self, state): - """ - Given the state value find the Issuer ID of the OP/AS that state value - was used against. - Will raise a KeyError if the state is unknown. - - :param state: The state value - :return: An Issuer ID - """ - for _rp in self.issuer2rp.values(): - try: - _iss = _rp.session_interface.get_iss(state) - except KeyError: - continue - else: - if _iss: - return _iss - return None - - def pick_config(self, issuer): - """ - From the set of client configurations pick one based on the issuer ID. - Will raise a KeyError if issuer is unknown. - - :param issuer: Issuer ID - :return: A client configuration - """ - return self.client_configs[issuer] - - def get_session_information(self, key, client=None): - """ - This is the second of the methods users of this class should know about. - It will return the complete session information as an - :py:class:`oidcservice.state_interface.State` instance. - - :param key: The session key (state) - :return: A State instance - """ - if not client: - client = self.get_client_from_session_key(key) - - return client.session_interface.get_state(key) - - def init_client(self, issuer): - """ - Initiate a Client instance. Specifically which Client class is used - is decided by configuration. - - :param issuer: An issuer ID - :return: A Client instance - """ - try: - _cnf = self.pick_config(issuer) - except KeyError: - _cnf = self.pick_config('') - _cnf['issuer'] = issuer - - try: - _services = _cnf['services'] - except KeyError: - _services = self.services - - try: - client = self.client_cls( - client_authn_factory=self.client_authn_factory, - services=_services, config=_cnf, httplib=self.httplib, - httpc_params=self.httpc_params) - except Exception as err: - logger.error('Failed initiating client: {}'.format(err)) - message = traceback.format_exception(*sys.exc_info()) - logger.error(message) - raise - - # If non persistent - client.service_context.keyjar.load(self.keyjar.dump()) - # If persistent nothings has to be copied - - client.service_context.base_url = self.base_url - client.service_context.jwks_uri = self.jwks_uri - return client - - def do_provider_info(self, client=None, state=''): - """ - Either get the provider info from configuration or through dynamic - discovery. - - :param client: A Client instance - :param state: A key by which the state of the session can be - retrieved - :return: issuer ID - """ - - if not client: - if state: - client = self.get_client_from_session_key(state) - else: - raise ValueError('Missing state/session key') - - if not client.service_context.get('provider_info'): - dynamic_provider_info_discovery(client) - return client.service_context.get('provider_info')['issuer'] - else: - _pi = client.service_context.get('provider_info') - for key, val in _pi.items(): - # All service endpoint parameters in the provider info has - # a name ending in '_endpoint' so I can look specifically - # for those - if key.endswith("_endpoint"): - for _srv in client.service_context.service.values(): - # Every service has an endpoint_name assigned - # when initiated. This name *MUST* match the - # endpoint names used in the provider info - if _srv.endpoint_name == key: - _srv.endpoint = val - - if 'keys' in _pi: - _kj = client.service_context.keyjar - for typ, _spec in _pi['keys'].items(): - if typ == 'url': - for _iss, _url in _spec.items(): - _kj.add_url(_iss, _url) - elif typ == 'file': - for kty, _name in _spec.items(): - if kty == 'jwks': - _kj.import_jwks_from_file(_name, - client.service_context.get('issuer')) - elif kty == 'rsa': # PEM file - _kb = keybundle_from_local_file(_name, "der", ["sig"]) - _kj.add_kb(client.service_context.get('issuer'), _kb) - else: - raise ValueError('Unknown provider JWKS type: {}'.format(typ)) - try: - return client.service_context.get('provider_info')['issuer'] - except KeyError: - return client.service_context.get('issuer') - - def do_client_registration(self, client=None, iss_id='', state=''): - """ - Prepare for and do client registration if configured to do so - - :param client: A Client instance - :param state: A key by which the state of the session can be - retrieved - """ - - if not client: - if state: - client = self.get_client_from_session_key(state) - else: - raise ValueError('Missing state/session key') - - _iss = client.service_context.get('issuer') - if not client.service_context.get('redirect_uris'): - # Create the necessary callback URLs - # as a side effect self.hash2issuer is set - callbacks = self.create_callbacks(_iss) - - client.service_context.set('redirect_uris', [ - v for k, v in callbacks.items() if not k.startswith('__')]) - client.service_context.set('callbacks', callbacks) - else: - self.hash2issuer[iss_id] = _iss - - # This should only be interesting if the client supports Single Log Out - - if client.service_context.get('post_logout_redirect_uris') is not None: - client.service_context.set('post_logout_redirect_uris', [self.base_url]) - - if not client.service_context.get('client_id'): - load_registration_response(client) - - def add_callbacks(self, service_context): - _callbacks = self.create_callbacks(service_context.get('provider_info')['issuer']) - service_context.set('redirect_uris', [ - v for k, v in _callbacks.items() if not k.startswith('__')]) - service_context.set('callbacks', _callbacks) - return _callbacks - - def client_setup(self, iss_id='', user=''): - """ - First if no issuer ID is given then the identifier for the user is - used by the webfinger service to try to find the issuer ID. - Once the method has an issuer ID if no client is bound to this issuer - one is created and initiated with - the necessary information for the client to be able to communicate - with the OP/AS that has the provided issuer ID. - - :param iss_id: The issuer ID - :param user: A user identifier - :return: A :py:class:`oidcservice.oidc.Client` instance - """ - - logger.info('client_setup: iss_id={}, user={}'.format(iss_id, user)) - - if not iss_id: - if not user: - raise ValueError('Need issuer or user') - - logger.debug("Connecting to previously unknown OP") - temporary_client = self.init_client('') - temporary_client.do_request('webfinger', resource=user) - else: - temporary_client = None - - try: - client = self.issuer2rp[iss_id] - except KeyError: - if temporary_client: - client = temporary_client - else: - logger.debug("Creating new client: %s", iss_id) - client = self.init_client(iss_id) - else: - return client - - logger.debug("Get provider info") - issuer = self.do_provider_info(client) - - logger.debug("Do client registration") - self.do_client_registration(client, iss_id) - - self.issuer2rp[issuer] = client - return client - - def create_callbacks(self, issuer): - """ - To mitigate some security issues the redirect_uris should be OP/AS - specific. This method creates a set of redirect_uris unique to the - OP/AS. - - :param issuer: Issuer ID - :return: A set of redirect_uris - """ - _hash = hashlib.sha256() - _hash.update(self.hash_seed) - _hash.update(as_bytes(issuer)) - _hex = _hash.hexdigest() - self.hash2issuer[_hex] = issuer - return { - 'code': "{}/authz_cb/{}".format(self.base_url, _hex), - 'implicit': "{}/authz_im_cb/{}".format(self.base_url, _hex), - 'form_post': "{}/authz_fp_cb/{}".format(self.base_url, _hex), - '__hex': _hex - } - - def init_authorization(self, client=None, state='', req_args=None): - """ - Constructs the URL that will redirect the user to the authorization - endpoint of the OP/AS. - - :param client: A Client instance - :param req_args: Non-default Request arguments - :return: A dictionary with 2 keys: **url** The authorization redirect - URL and **state** the key to the session information in the - state data store. - """ - if not client: - if state: - client = self.get_client_from_session_key(state) - else: - raise ValueError('Missing state/session key') - - service_context = client.service_context - - _nonce = rndstr(24) - request_args = { - 'redirect_uri': service_context.get('redirect_uris')[0], - 'scope': service_context.get('behaviour')['scope'], - 'response_type': service_context.get('behaviour')['response_types'][0], - 'nonce': _nonce - } - - _req_args = service_context.config.get("request_args") - if _req_args: - if 'claims' in _req_args: - _req_args["claims"] = Claims(**_req_args["claims"]) - request_args.update(_req_args) - - if req_args is not None: - request_args.update(req_args) - - # Need a new state for a new authorization request - _state = client.session_interface.create_state(service_context.get('issuer')) - request_args['state'] = _state - client.session_interface.store_nonce2state(_nonce, _state) - - logger.debug('Authorization request args: {}'.format(request_args)) - - _srv = client.service['authorization'] - _info = _srv.get_request_parameters(request_args=request_args) - logger.debug('Authorization info: {}'.format(_info)) - return {'url': _info['url'], 'state': _state} - - def begin(self, issuer_id='', user_id=''): - """ - This is the first of the 3 high level methods that most users of this - library should confine them self to use. - If will use client_setup to produce a Client instance ready to be used - against the OP/AS the user wants to use. - Once it has the client it will construct an Authorization - request. - - :param issuer_id: Issuer ID - :param user_id: A user identifier - :return: A dictionary containing **url** the URL that will redirect the - user to the OP/AS and **state** the session key which will - allow higher level code to access session information. - """ - - # Get the client instance that has been assigned to this issuer - client = self.client_setup(issuer_id, user_id) - - try: - res = self.init_authorization(client) - except Exception: - message = traceback.format_exception(*sys.exc_info()) - logger.error(message) - raise - else: - return res - - # ---------------------------------------------------------------------- - - def get_client_from_session_key(self, state): - return self.issuer2rp[self.state2issuer(state)] - - @staticmethod - def get_response_type(client): - """ - Return the response_type a specific client wants to use. - - :param client: A Client instance - :return: The response_type - """ - return client.service_context.get('behaviour')['response_types'][0] - - @staticmethod - def get_client_authn_method(client, endpoint): - """ - Return the client authentication method a client wants to use a - specific endpoint - - :param client: A Client instance - :param endpoint: The endpoint at which the client has to authenticate - :return: The client authentication method - """ - if endpoint == 'token_endpoint': - try: - am = client.service_context.get('behaviour')['token_endpoint_auth_method'] - except KeyError: - return '' - else: - if isinstance(am, str): - return am - else: # a list - return am[0] - - def get_access_token(self, state, client=None): - """ - Use the 'accesstoken' service to get an access token from the OP/AS. - - :param state: The state key (the state parameter in the - authorization request) - :param client: A Client instance - :return: A :py:class:`oidcmsg.oidc.AccessTokenResponse` or - :py:class:`oidcmsg.oauth2.AuthorizationResponse` - """ - logger.debug('get_accesstoken') - - if client is None: - client = self.get_client_from_session_key(state) - - authorization_response = client.session_interface.get_item( - AuthorizationResponse, 'auth_response', state) - authorization_request = client.session_interface.get_item( - AuthorizationRequest, 'auth_request', state) - - req_args = { - 'code': authorization_response['code'], - 'state': state, - 'redirect_uri': authorization_request['redirect_uri'], - 'grant_type': 'authorization_code', - 'client_id': client.service_context.get('client_id'), - 'client_secret': client.service_context.get('client_secret') - } - logger.debug('request_args: {}'.format(req_args)) - try: - tokenresp = client.do_request( - 'accesstoken', request_args=req_args, - authn_method=self.get_client_authn_method(client, - "token_endpoint"), - state=state - ) - except Exception as err: - message = traceback.format_exception(*sys.exc_info()) - logger.error(message) - raise - else: - if is_error_message(tokenresp): - raise OidcServiceError(tokenresp['error']) - - return tokenresp - - def refresh_access_token(self, state, client=None, scope=''): - """ - Refresh an access token using a refresh_token. When asking for a new - access token the RP can ask for another scope for the new token. - - :param client: A Client instance - :param state: The state key (the state parameter in the - authorization request) - :param scope: What the returned token should be valid for. - :return: A :py:class:`oidcmsg.oidc.AccessTokenResponse` instance - """ - if scope: - req_args = {'scope': scope} - else: - req_args = {} - - if client is None: - client = self.get_client_from_session_key(state) - - try: - tokenresp = client.do_request( - 'refresh_token', - authn_method=self.get_client_authn_method(client, - "token_endpoint"), - state=state, request_args=req_args - ) - except Exception as err: - message = traceback.format_exception(*sys.exc_info()) - logger.error(message) - raise - else: - if is_error_message(tokenresp): - raise OidcServiceError(tokenresp['error']) - - return tokenresp - - def get_user_info(self, state, client=None, access_token='', - **kwargs): - """ - use the access token previously acquired to get some userinfo - - :param client: A Client instance - :param state: The state value, this is the key into the session - data store - :param access_token: An access token - :param kwargs: Extra keyword arguments - :return: A :py:class:`oidcmsg.oidc.OpenIDSchema` instance - """ - if client is None: - client = self.get_client_from_session_key(state) - - if not access_token: - _arg = client.session_interface.multiple_extend_request_args( - {}, state, ['access_token'], - ['auth_response', 'token_response', 'refresh_token_response']) - - request_args = {'access_token': access_token} - - resp = client.do_request('userinfo', state=state, - request_args=request_args, **kwargs) - if is_error_message(resp): - raise OidcServiceError(resp['error']) - - return resp - - @staticmethod - def userinfo_in_id_token(id_token): - """ - Given an verified ID token return all the claims that may been user - information. - - :param id_token: An :py:class:`oidcmsg.oidc.IDToken` instance - :return: A dictionary with user information - """ - res = dict([(k, id_token[k]) for k in OpenIDSchema.c_param.keys() if - k in id_token]) - res.update(id_token.extra()) - return res - - def finalize_auth(self, client, issuer, response): - """ - Given the response returned to the redirect_uri, parse and verify it. - - :param client: A Client instance - :param issuer: An Issuer ID - :param response: The authorization response as a dictionary - :return: An :py:class:`oidcmsg.oidc.AuthorizationResponse` or - :py:class:`oidcmsg.oauth2.AuthorizationResponse` instance. - """ - _srv = client.service['authorization'] - try: - authorization_response = _srv.parse_response(response, - sformat='dict') - except Exception as err: - logger.error('Parsing authorization_response: {}'.format(err)) - message = traceback.format_exception(*sys.exc_info()) - logger.error(message) - raise - else: - logger.debug( - 'Authz response: {}'.format(authorization_response.to_dict())) - - if is_error_message(authorization_response): - return authorization_response - - try: - _iss = client.session_interface.get_iss(authorization_response['state']) - except KeyError: - raise KeyError('Unknown state value') - - if _iss != issuer: - logger.error('Issuer problem: {} != {}'.format(_iss, issuer)) - # got it from the wrong bloke - raise ValueError('Impersonator {}'.format(issuer)) - - _srv.update_service_context(authorization_response, key=authorization_response['state']) - client.session_interface.store_item(authorization_response, "auth_response", - authorization_response['state']) - return authorization_response - - def get_access_and_id_token(self, authorization_response=None, state='', - client=None): - """ - There are a number of services where access tokens and ID tokens can - occur in the response. This method goes through the possible places - based on the response_type the client uses. - - :param authorization_response: The Authorization response - :param state: The state key (the state parameter in the - authorization request) - :return: A dictionary with 2 keys: **access_token** with the access - token as value and **id_token** with a verified ID Token if one - was returned otherwise None. - """ - - if client is None: - client = self.get_client_from_session_key(state) - - if authorization_response is None: - if state: - authorization_response = client.session_interface.get_item( - AuthorizationResponse, 'auth_response', state) - else: - raise ValueError( - 'One of authorization_response or state must be provided') - - if not state: - state = authorization_response['state'] - - authreq = client.session_interface.get_item( - AuthorizationRequest, 'auth_request', state) - _resp_type = set(authreq['response_type']) - - access_token = None - id_token = None - if _resp_type in [{'id_token'}, {'id_token', 'token'}, - {'code', 'id_token', 'token'}]: - id_token = authorization_response['__verified_id_token'] - - if _resp_type in [{'token'}, {'id_token', 'token'}, {'code', 'token'}, - {'code', 'id_token', 'token'}]: - access_token = authorization_response["access_token"] - elif _resp_type in [{'code'}, {'code', 'id_token'}]: - - # get the access token - token_resp = self.get_access_token(state, client=client) - if is_error_message(token_resp): - return False, "Invalid response %s." % token_resp["error"] - - access_token = token_resp["access_token"] - - try: - id_token = token_resp['__verified_id_token'] - except KeyError: - pass - - return {'access_token': access_token, 'id_token': id_token} - - # noinspection PyUnusedLocal - def finalize(self, issuer, response): - """ - The third of the high level methods that a user of this Class should - know about. - Once the consumer has redirected the user back to the - callback URL there might be a number of services that the client should - use. Which one those are are defined by the client configuration. - - :param issuer: Who sent the response - :param response: The Authorization response as a dictionary - :returns: A dictionary with two claims: - **state** The key under which the session information is - stored in the data store and - **error** and encountered error or - **userinfo** The collected user information - """ - - client = self.issuer2rp[issuer] - - authorization_response = self.finalize_auth(client, issuer, response) - if is_error_message(authorization_response): - return { - 'state': authorization_response['state'], - 'error': authorization_response['error'] - } - - _state = authorization_response['state'] - token = self.get_access_and_id_token(authorization_response, - state=_state, client=client) - - if 'userinfo' in client.service and token['access_token']: - inforesp = self.get_user_info( - state=authorization_response['state'], client=client, - access_token=token['access_token']) - - if isinstance(inforesp, ResponseMessage) and 'error' in inforesp: - return { - 'error': "Invalid response %s." % inforesp["error"], - 'state': _state - } - - elif token['id_token']: # look for it in the ID Token - inforesp = self.userinfo_in_id_token(token['id_token']) - else: - inforesp = {} - - logger.debug("UserInfo: %s", inforesp) - - try: - _sid_support = client.service_context.get('provider_info')[ - 'backchannel_logout_session_supported'] - except KeyError: - try: - _sid_support = client.service_context.get('provider_info')[ - 'frontchannel_logout_session_supported'] - except: - _sid_support = False - - if _sid_support: - try: - sid = token['id_token']['sid'] - except KeyError: - pass - else: - client.session_interface.store_sid2state(sid, _state) - - client.session_interface.store_sub2state(token['id_token']['sub'], _state) - - return { - 'userinfo': inforesp, - 'state': authorization_response['state'], - 'token': token['access_token'], - 'id_token': token['id_token'] - } - - def has_active_authentication(self, state): - """ - Find out if the user has an active authentication - - :param state: - :return: True/False - """ - - client = self.get_client_from_session_key(state) - - # Look for Id Token in all the places where it can be - _arg = client.session_interface.multiple_extend_request_args( - {}, state, ['__verified_id_token'], - ['auth_response', 'token_response', 'refresh_token_response']) - - if _arg: - _now = time_sans_frac() - exp = _arg['__verified_id_token']['exp'] - return _now < exp - else: - return False - - def get_valid_access_token(self, state): - """ - Find a valid access token. - - :param state: - :return: An access token if a valid one exists and when it - expires. Otherwise raise exception. - """ - - exp = 0 - token = None - indefinite = [] - now = time_sans_frac() - - client = self.get_client_from_session_key(state) - - for cls, typ in [(AccessTokenResponse, 'refresh_token_response'), - (AccessTokenResponse, 'token_response'), - (AuthorizationResponse, 'auth_response')]: - try: - response = client.session_interface.get_item(cls, typ, state) - except KeyError: - pass - else: - if 'access_token' in response: - access_token = response["access_token"] - try: - _exp = response['__expires_at'] - except KeyError: # No expiry date, lives for ever - indefinite.append((access_token, 0)) - else: - if _exp > now and _exp > exp: # expires sometime in the future - exp = _exp - token = (access_token, _exp) - - if indefinite: - return indefinite[0] - else: - if token: - return token - else: - raise OidcServiceError('No valid access token') - - def logout(self, state, client=None, post_logout_redirect_uri=''): - """ - Does a RP initiated logout from an OP. After logout the user will be - redirect by the OP to a URL of choice (post_logout_redirect_uri). - - :param state: Key to an active session - :param client: Which client to use - :param post_logout_redirect_uri: If a special post_logout_redirect_uri - should be used - :return: A US - """ - if client is None: - client = self.get_client_from_session_key(state) - - try: - srv = client.service['end_session'] - except KeyError: - raise OidcServiceError("Does not know how to logout") - - if post_logout_redirect_uri: - request_args = { - "post_logout_redirect_uri": post_logout_redirect_uri - } - else: - request_args = {} - - resp = srv.get_request_parameters(state=state, - request_args=request_args) - - return resp - - def clear_session(self, state): - client = self.get_client_from_session_key(state) - client.session_interface.remove_state(state) - - -def backchannel_logout(client, request='', request_args=None): - """ - - :param request: URL encoded logout request - :return: - """ - - if request: - req = BackChannelLogoutRequest().from_urlencoded(as_unicode(request)) - else: - req = BackChannelLogoutRequest(**request_args) - - kwargs = { - 'aud': client.service_context.get('client_id'), - 'iss': client.service_context.get('issuer'), - 'keyjar': client.service_context.keyjar, - 'allowed_sign_alg': client.service_context.get('registration_response').get( - "id_token_signed_response_alg", "RS256") - } - - try: - req.verify(**kwargs) - except (MessageException, ValueError, NotForMe) as err: - raise MessageException('Bogus logout request: {}'.format(err)) - - # Find the subject through 'sid' or 'sub' - - try: - sub = req[verified_claim_name('logout_token')]['sub'] - except KeyError: - try: - sid = req[verified_claim_name('logout_token')]['sid'] - except KeyError: - raise MessageException('Neither "sid" nor "sub"') - else: - _state = client.session_interface.get_state_by_sid(sid) - else: - _state = client.session_interface.get_state_by_sub(sub) - return _state diff --git a/src/oidcrp/client_auth.py b/src/oidcrp/client_auth.py new file mode 100755 index 0000000..f585ec7 --- /dev/null +++ b/src/oidcrp/client_auth.py @@ -0,0 +1,621 @@ +"""Implementation of a number of client authentication methods.""" +import base64 +import logging +from urllib.parse import quote_plus + +from cryptojwt.exception import MissingKey +from cryptojwt.exception import UnsupportedAlgorithm +from cryptojwt.jws.jws import SIGNER_ALGS +from cryptojwt.jws.utils import alg2keytype +from oidcmsg.message import VREQUIRED +from oidcmsg.oauth2 import AccessTokenRequest +from oidcmsg.oauth2 import SINGLE_OPTIONAL_STRING +from oidcmsg.oidc import AuthnToken +from oidcmsg.time_util import utc_time_sans_frac + +from oidcrp.util import rndstr +from oidcrp.util import sanitize +from .defaults import DEF_SIGN_ALG +from .defaults import JWT_BEARER + +LOGGER = logging.getLogger(__name__) + +__author__ = 'roland hedberg' + + +class AuthnFailure(Exception): + """Unspecified Authentication failure""" + + +class UnknownAuthnMethod(Exception): + """Unknown Authentication method.""" + + +# ======================================================================== +def assertion_jwt(client_id, keys, audience, algorithm, lifetime=600): + """ + Create a signed Json Web Token containing some information. + + :param client_id: The Client ID + :param keys: Signing keys + :param audience: Who is the receivers for this assertion + :param algorithm: Signing algorithm + :param lifetime: The lifetime of the signed Json Web Token + :return: A Signed Json Web Token + """ + _now = utc_time_sans_frac() + + _token = AuthnToken(iss=client_id, sub=client_id, + aud=audience, jti=rndstr(32), + exp=_now + lifetime, iat=_now) + LOGGER.debug('AuthnToken: %s', _token.to_dict()) + return _token.to_jwt(key=keys, algorithm=algorithm) + + +class ClientAuthnMethod: + """ + Basic Client Authentication Method class. + Only has one public method: *construct* + """ + + def construct(self, request, service=None, http_args=None, **kwargs): + """ Add authentication information to a request""" + raise NotImplementedError() + + def modify_request(self, request, service, **kwargs): + """ + Modify the request if necessary. + + :param request: The request + :param service: The service using this authentication method. + """ + + +class ClientSecretBasic(ClientAuthnMethod): + """ + Clients that have received a client_secret value from the Authorization + Server, may authenticate with the Authorization Server in accordance with + Section 3.2.1 of OAuth 2.0 [RFC6749] using HTTP Basic authentication scheme. + + The upshot of this is to construct an Authorization header that has the + value 'Basic ' where is username and password concatenated + together with a ':' in between and then URL safe base64 encoded. + + Note that both username and password + """ + + @staticmethod + def _get_passwd(request, service, **kwargs): + try: + passwd = kwargs["password"] + except KeyError: + try: + passwd = request["client_secret"] + except KeyError: + passwd = service.client_get("service_context").client_secret + return passwd + + @staticmethod + def _get_user(service, **kwargs): + try: + user = kwargs["user"] + except KeyError: + user = service.client_get("service_context").client_id + return user + + def _get_authentication_token(self, request, service, **kwargs): + """ + Return authentication Token. + + The credential is username and password concatenated with a ':' + in between and then base 64 encoded becomes the authentication token. + :param request: The request + :param service: A :py:class:`oidcrp.service.Service` instance + :param kwargs: Extra key word arguments + :return: An authentication token + """ + passwd = self._get_passwd(request, service, **kwargs) + user = self._get_user(service, **kwargs) + + credentials = "{}:{}".format(quote_plus(user), quote_plus(passwd)) + return base64.urlsafe_b64encode(credentials.encode("utf-8")).decode("utf-8") + + @staticmethod + def _with_or_without_client_id(request, service): + """ Add or delete client_id from request. + + If we're doing an access token request with an authorization code + then we should add client_id to the request if it's not already there. + :param request: A request + :param service: A :py:class:`oidcrp.service.Service` instance + """ + if isinstance(request, AccessTokenRequest) and request[ + 'grant_type'] == 'authorization_code': + if 'client_id' not in request: + try: + request['client_id'] = service.client_get("service_context").client_id + except AttributeError: + pass + else: + # remove client_id if not required by the request definition + try: + _req = request.c_param["client_id"][VREQUIRED] + except (KeyError, AttributeError): + _req = False + + # if it's not required remove it + if not _req: + try: + del request["client_id"] + except KeyError: + pass + + def modify_request(self, request, service, **kwargs): + """ + Modify the request if necessary. + + :param request: The request + :param service: The service using this authentication method. + """ + # If client_secret was part of the request message instance remove it + try: + del request["client_secret"] + except (KeyError, TypeError): + pass + + # Modifies the request + self._with_or_without_client_id(request, service) + + def construct(self, request, service=None, http_args=None, **kwargs): + """ + Construct a dictionary to be added to the HTTP request headers + + :param request: The request + :param service: A :py:class:`oidcrp.service.Service` instance + :param http_args: HTTP arguments + :return: dictionary of HTTP arguments + """ + + if http_args is None: + http_args = {} + + if "headers" not in http_args: + http_args["headers"] = {} + + _token = self._get_authentication_token(request, service, **kwargs) + + http_args["headers"]["Authorization"] = "Basic {}".format(_token) + + self.modify_request(request, service) + + return http_args + + +class ClientSecretPost(ClientSecretBasic): + """ + Clients that have received a client_secret value from the Authorization + Server, authenticate with the Authorization Server in accordance with + Section 3.2.1 of OAuth 2.0 [RFC6749] by including the Client Credentials in + the request body. + + These means putting both client_secret and client_id in the request body. + """ + + def modify_request(self, request, service, **kwargs): + """ + I MUST have a client_secret, there are 3 possible places + where I can find it. In the request, as an argument in http_args + or among the client information. + + :param request: The request + :param service: The service that is using this authentication method + """ + _context = service.client_get("service_context") + if "client_secret" not in request: + try: + request["client_secret"] = kwargs["client_secret"] + except (KeyError, TypeError): + if _context.client_secret: + request["client_secret"] = _context.client_secret + else: + raise AuthnFailure("Missing client secret") + + # Set the client_id in the the request + request["client_id"] = _context.client_id + + def construct(self, request, service=None, http_args=None, **kwargs): + """ + Does not add any authentication information to the HTTP arguments. + Adds authentication information to the request. + + :param request: The request + :param service: The service that is using this authentication method + :param http_args: HTTP arguments + :param kwargs: Extra keyword arguments. + """ + self.modify_request(request, service, **kwargs) + return http_args + + +def find_token(request, token_type, service, **kwargs): + """ + The access token can be in a number of places. + There are priority rules as to which one to use, abide by those: + + 1 If it's among the request parameters use that + 2 If among the extra keyword arguments + 3 Acquired by a previous run service. + + :param request: + :param token_type: + :param service: + :param kwargs: + :return: + """ + if request is not None: + try: + _token = request[token_type] + except KeyError: + pass + else: + del request[token_type] + # Required under certain circumstances :-) not under other + request.c_param[token_type] = SINGLE_OPTIONAL_STRING + return _token + + try: + return kwargs["access_token"] + except KeyError: + # I should pick the latest acquired token, this should be the right + # order for that. + _arg = service.client_get("service_context").state.multiple_extend_request_args( + {}, kwargs['key'], ['access_token'], + ['auth_response', 'token_response', 'refresh_token_response']) + return _arg['access_token'] + + +class BearerHeader(ClientAuthnMethod): + """The bearer header authentication method.""" + + def construct(self, request=None, service=None, http_args=None, + **kwargs): + """ + Constructing the Authorization header. The value of + the Authorization header is "Bearer ". + + :param request: Request class instance + :param service: Service + :param http_args: HTTP header arguments + :param kwargs: extra keyword arguments + :return: + """ + + if service.service_name == 'refresh_token': + _acc_token = find_token(request, 'refresh_token', service, **kwargs) + else: + _acc_token = find_token(request, 'access_token', service, **kwargs) + + if not _acc_token: + raise KeyError('No access or refresh token available') + + # The authorization value starts with 'Bearer' when bearer tokens + # are used + _bearer = "Bearer {}".format(_acc_token) + + # Add 'Authorization' to the headers + if http_args is None: + http_args = {"headers": {}} + http_args["headers"]["Authorization"] = _bearer + else: + try: + http_args["headers"]["Authorization"] = _bearer + except KeyError: + http_args["headers"] = {"Authorization": _bearer} + + return http_args + + +class BearerBody(ClientAuthnMethod): + """The bearer body authentication method.""" + + def modify_request(self, request, service, **kwargs): + """ + Modify the request if necessary. + + :param request: The request + :param service: The service using this authentication method. + :param kwargs: Extra keyword arguments + """ + _acc_token = '' + for _token_type in ['access_token', 'refresh_token']: + _acc_token = find_token(request, _token_type, service, **kwargs) + if _acc_token: + break + + if not _acc_token: + raise KeyError('No access or refresh token available') + + request["access_token"] = _acc_token + + def construct(self, request, service=None, http_args=None, **kwargs): + """ + Will add a token to the request if not present + + :param request: The request + :param service: The service that handles these kind of things. + :param http_args: HTTP arguments + :param kwargs: extra keyword arguments + :return: A possibly modified dictionary with HTTP arguments. + """ + + self.modify_request(request, service, **kwargs) + + return http_args + + +def bearer_auth(request, authn): + """ + Pick out the access token, either in HTTP_Authorization header or + in request body. + + :param request: The request + :param authn: The value of the Authorization header + :return: An access token + """ + + try: + return request["access_token"] + except KeyError: + if not authn.startswith("Bearer "): + raise ValueError('Not a bearer token') + return authn[7:] + + +class JWSAuthnMethod(ClientAuthnMethod): + """ + Base class for client authentication methods that uses signed JSON + Web Tokens. + """ + + @staticmethod + def choose_algorithm(context, **kwargs): + """ + Pick signing algorithm + + :param context: Signing context + :param kwargs: extra keyword arguments + :return: Name of a signing algorithm + """ + try: + algorithm = kwargs["algorithm"] + except KeyError: + # different contexts uses different signing algorithms + algorithm = DEF_SIGN_ALG[context] + if not algorithm: + raise AuthnFailure("Missing algorithm specification") + return algorithm + + @staticmethod + def get_signing_key_from_keyjar(algorithm, service_context): + """ + Pick signing key based on signing algorithm to be used + + :param algorithm: Signing algorithm + :param service_context: A :py:class:`oidcrp.service_context.ServiceContext` instance + :return: A key + """ + return service_context.keyjar.get_signing_key( + alg2keytype(algorithm), alg=algorithm) + + @staticmethod + def _get_key_by_kid(kid, algorithm, service_context): + """ + Pick a key that matches a given key ID and signing algorithm. + + :param kid: Key ID + :param algorithm: Signing algorithm + :param service_context: A + :py:class:`oidcrp.service_context.ServiceContext` instance + :return: A matching key + """ + # signing so using my keys + for _key in service_context.keyjar.get_issuer_keys(""): + if kid == _key.kid: + ktype = alg2keytype(algorithm) + if _key.kty != ktype: + raise MissingKey("Wrong key type") + + return _key + + raise MissingKey("No key with kid:%s" % kid) + + def _get_signing_key(self, algorithm, context, kid=None): + ktype = alg2keytype(algorithm) + try: + if kid: + signing_key = [self._get_key_by_kid(kid, algorithm, context)] + elif ktype in context.kid["sig"]: + try: + signing_key = [self._get_key_by_kid( + context.kid["sig"][ktype], algorithm, context)] + except KeyError: + signing_key = self.get_signing_key_from_keyjar(algorithm, context) + else: + signing_key = self.get_signing_key_from_keyjar(algorithm, context) + except (MissingKey,) as err: + LOGGER.error("%s", sanitize(err)) + raise + + return signing_key + + def _get_audience_and_algorithm(self, context, **kwargs): + algorithm = None + + # audience for the signed JWT depends on which endpoint + # we're talking to. + if 'authn_endpoint' in kwargs and kwargs['authn_endpoint'] in ['token_endpoint']: + reg_resp = context.registration_response + if reg_resp: + algorithm = reg_resp["token_endpoint_auth_signing_alg"] + else: + algorithm = context.client_preferences.get("token_endpoint_auth_signing_alg") + if algorithm is None: + _pi = context.provider_info + try: + algs = _pi["token_endpoint_auth_signing_alg_values_supported"] + except KeyError: + algorithm = "RS256" # default + else: + for alg in algs: # pick the first one I support and have keys for + if alg in SIGNER_ALGS and self.get_signing_key_from_keyjar(alg, + context): + algorithm = alg + break + + audience = context.provider_info['token_endpoint'] + else: + audience = context.provider_info['issuer'] + + if not algorithm: + algorithm = self.choose_algorithm(**kwargs) + return audience, algorithm + + def _construct_client_assertion(self, service, **kwargs): + _context = service.client_get("service_context") + + audience, algorithm = self._get_audience_and_algorithm(_context, **kwargs) + + if 'kid' in kwargs: + signing_key = self._get_signing_key(algorithm, _context, kid=kwargs['kid']) + else: + signing_key = self._get_signing_key(algorithm, _context) + + if not signing_key: + raise UnsupportedAlgorithm(algorithm) + + try: + _args = {'lifetime': kwargs['lifetime']} + except KeyError: + _args = {} + + # construct the signed JWT with the assertions and add + # it as value to the 'client_assertion' claim of the request + return assertion_jwt(_context.client_id, signing_key, audience, algorithm, **_args) + + def modify_request(self, request, service, **kwargs): + """ + Modify the request if necessary. + + :param request: The request + :param service: The service using this authentication method. + :param kwargs: Extra keyword arguments + """ + if 'client_assertion' in kwargs: + request["client_assertion"] = kwargs['client_assertion'] + if 'client_assertion_type' in kwargs: + request[ + 'client_assertion_type'] = kwargs['client_assertion_type'] + else: + request["client_assertion_type"] = JWT_BEARER + elif 'client_assertion' in request: + if 'client_assertion_type' not in request: + request["client_assertion_type"] = JWT_BEARER + else: + request["client_assertion"] = self._construct_client_assertion(service, **kwargs) + request["client_assertion_type"] = JWT_BEARER + + try: + del request["client_secret"] + except KeyError: + pass + + # If client_id is not required to be present, remove it. + if not request.c_param["client_id"][VREQUIRED]: + try: + del request["client_id"] + except KeyError: + pass + + def construct(self, request, service=None, http_args=None, **kwargs): + """ + Constructs a client assertion and signs it with a key. + The request is modified as a side effect. + + :param request: The request + :param service: A :py:class:`oidcrp.service.Service` instance + :param http_args: HTTP arguments + :param kwargs: Extra arguments + :return: Constructed HTTP arguments, in this case none + """ + self.modify_request(request, service, **kwargs) + + return {} + + +class ClientSecretJWT(JWSAuthnMethod): + """ + Clients that have received a client_secret value from the Authorization + Server can create a signed JWT using an HMAC SHA algorithm, such as + HMAC SHA-256. + The HMAC (Hash-based Message Authentication Code) is calculated using the + bytes of the UTF-8 representation of the client_secret as the shared key. + """ + + def choose_algorithm(self, context="client_secret_jwt", **kwargs): + return JWSAuthnMethod.choose_algorithm(context, **kwargs) + + def get_signing_key_from_keyjar(self, algorithm, service_context): + return service_context.keyjar.get_signing_key(alg2keytype(algorithm), alg=algorithm) + + +class PrivateKeyJWT(JWSAuthnMethod): + """ + Clients that have registered a public key can sign a JWT using that key. + """ + + def choose_algorithm(self, context="private_key_jwt", **kwargs): + return JWSAuthnMethod.choose_algorithm(context, **kwargs) + + def get_signing_key_from_keyjar(self, algorithm, service_context=None): + return service_context.keyjar.get_signing_key(alg2keytype(algorithm), "", alg=algorithm) + + +# Map from client authentication identifiers to corresponding class +CLIENT_AUTHN_METHOD = { + "client_secret_basic": ClientSecretBasic, + "client_secret_post": ClientSecretPost, + "bearer_header": BearerHeader, + "bearer_body": BearerBody, + "client_secret_jwt": ClientSecretJWT, + "private_key_jwt": PrivateKeyJWT, +} + +TYPE_METHOD = [(JWT_BEARER, JWSAuthnMethod)] + + +def valid_service_context(service_context, when=0): + """ + Check if the client_secret has expired + + :param service_context: A + :py:class:`oidcrp.service_context.ServiceContext` instance + :param when: A time stamp against which the expiration time is to be checked + :return: True if the client_secret is still valid + """ + eta = service_context.client_secret_expires_at + now = when or utc_time_sans_frac() + if eta != 0 and eta < now: + return False + return True + + +def factory(auth_method): + """Return an instance of a client authentication class. + + :param auth_method: The name of the client authentication method + """ + try: + return CLIENT_AUTHN_METHOD[auth_method]() + except KeyError: + LOGGER.error('Unknown client authentication method: %s', auth_method) + raise ValueError(auth_method) diff --git a/src/oidcrp/configure.py b/src/oidcrp/configure.py index 8b89a39..d38cc76 100755 --- a/src/oidcrp/configure.py +++ b/src/oidcrp/configure.py @@ -1,112 +1,205 @@ """Configuration management for RP""" +import importlib +import json import logging +import os from typing import Dict +from typing import List from typing import Optional -from oidcmsg import add_base_path - from oidcrp.logging import configure_logging -from oidcrp.util import get_http_params from oidcrp.util import load_yaml_config from oidcrp.util import lower_or_upper -from oidcrp.util import replace -from oidcrp.util import set_param try: from secrets import token_urlsafe as rnd_token except ImportError: from oidcendpoint import rndstr as rnd_token -DEFAULT_ITEM_PATHS = { - "webserver": ['server_key', 'server_cert'], - "rp_keys": ["public_path", "private_path"], - "oidc_keys": ["public_path", "private_path"], - "httpc_params": ["client_cert", "client_key"], - "db_conf": { - "keyjar": ["fdir"], - "default": ["fdir"], - "state": ["fdir"] - }, - "logging": { - "handlers": { - "file": ["filename"] - } - } -} - - -class Configuration: - """RP Configuration""" +DEFAULT_FILE_ATTRIBUTE_NAMES = ['server_key', 'server_cert', 'filename', 'template_dir', + 'private_path', 'public_path', 'db_file'] + + +def add_base_path(conf: dict, base_path: str, file_attributes: List[str]): + for key, val in conf.items(): + if key in file_attributes: + if val.startswith("/"): + continue + elif val == "": + conf[key] = "./" + val + else: + conf[key] = os.path.join(base_path, val) + if isinstance(val, dict): + conf[key] = add_base_path(val, base_path, file_attributes) + + return conf + + +def set_domain_and_port(conf: dict, uris: List[str], domain: str, port: int): + for key, val in conf.items(): + if key in uris: + if not val: + continue + + if isinstance(val, list): + _new = [v.format(domain=domain, port=port) for v in val] + else: + _new = val.format(domain=domain, port=port) + conf[key] = _new + elif isinstance(val, dict): + conf[key] = set_domain_and_port(val, uris, domain, port) + return conf + + +class Base: + """ Configuration base class """ - def __init__(self, conf: Dict, base_path: str = '', item_paths: Optional[dict] = None) -> None: - if item_paths is None: - item_paths = DEFAULT_ITEM_PATHS + def __init__(self, + conf: Dict, + base_path: str = '', + file_attributes: Optional[List[str]] = None, + ): - if base_path and item_paths: + if file_attributes is None: + file_attributes = DEFAULT_FILE_ATTRIBUTE_NAMES + + if base_path and file_attributes: # this adds a base path to all paths in the configuration - add_base_path(conf, item_paths, base_path) + add_base_path(conf, base_path, file_attributes) - log_conf = conf.get('logging') - if log_conf: - self.logger = configure_logging(config=log_conf).getChild(__name__) + def __getitem__(self, item): + if item in self.__dict__: + return self.__dict__[item] else: - self.logger = logging.getLogger('oidcrp') + raise KeyError - # server info - self.domain = lower_or_upper(conf, "domain") - self.port = lower_or_upper(conf, "port") - if self.port: - format_args = {'domain': self.domain, 'port': self.port} - else: - format_args = {'domain': self.domain, "port": ""} + def get(self, item, default=None): + return getattr(self, item, default) - for param in ["server_name", "base_url"]: - set_param(self, conf, param, **format_args) + def __contains__(self, item): + return item in self.__dict__ - # HTTP params - _params = get_http_params(conf.get("httpc_params")) - if _params: - self.httpc_params = _params - else: - _params = {'verify', lower_or_upper(conf, "verify_ssl", True)} + def items(self): + for key in self.__dict__: + if key.startswith('__') and key.endswith('__'): + continue + yield key, getattr(self, key) - self.web_conf = lower_or_upper(conf, "webserver") + def extend(self, entity_conf, conf, base_path, file_attributes, domain, port): + for econf in entity_conf: + _path = econf.get("path") + _cnf = conf + if _path: + for step in _path: + _cnf = _cnf[step] + _attr = econf["attr"] + _cls = econf["class"] + setattr(self, _attr, + _cls(_cnf, base_path=base_path, file_attributes=file_attributes, + domain=domain, port=port)) + + +URIS = [ + "redirect_uris", 'post_logout_redirect_uris', 'frontchannel_logout_uri', + 'backchannel_logout_uri', 'issuer', 'base_url'] - # diverse - for param in ["html_home", "session_cookie_name", "preferred_url_scheme", - "services", "federation"]: - set_param(self, conf, param) - rp_keys_conf = lower_or_upper(conf, 'rp_keys') - if rp_keys_conf is None: - rp_keys_conf = lower_or_upper(conf, 'oidc_keys') +class RPConfiguration(Base): + def __init__(self, + conf: Dict, + base_path: Optional[str] = '', + entity_conf: Optional[List[dict]] = None, + domain: Optional[str] = "127.0.0.1", + port: Optional[int] = 80, + file_attributes: Optional[List[str]] = None, + ): - setattr(self, "rp_keys", rp_keys_conf) + Base.__init__(self, conf, base_path=base_path, file_attributes=file_attributes) - _clients = lower_or_upper(conf, "clients") - if _clients: - for key, spec in _clients.items(): - if key == "": - continue - # if not spec.get("redirect_uris"): - # continue + _keys_conf = lower_or_upper(conf, 'rp_keys') + if _keys_conf is None: + _keys_conf = lower_or_upper(conf, 'oidc_keys') # legacy - for uri in ['redirect_uris', 'post_logout_redirect_uris', 'frontchannel_logout_uri', - 'backchannel_logout_uri', 'issuer']: - replace(spec, uri, **format_args) + self.keys = _keys_conf - setattr(self, "clients", _clients) + if not domain: + domain = conf.get("domain", "127.0.0.1") + + if not port: + port = conf.get("port", 80) + + conf = set_domain_and_port(conf, URIS, domain, port) + self.clients = lower_or_upper(conf, "clients") hash_seed = lower_or_upper(conf, 'hash_seed') if not hash_seed: hash_seed = rnd_token(32) - setattr(self, "hash_seed", hash_seed) - self.load_extension(conf) + self.hash_seed = hash_seed + + self.services = lower_or_upper(conf, "services") + self.base_url = lower_or_upper(conf, "base_url") + self.httpc_params = lower_or_upper(conf, "httpc_params", {"verify": True}) + + if entity_conf: + self.extend(entity_conf=entity_conf, conf=conf, base_path=base_path, + file_attributes=file_attributes, domain=domain, port=port) + + +class Configuration(Base): + """RP Configuration""" + + def __init__(self, + conf: Dict, + base_path: str = '', + entity_conf: Optional[List[dict]] = None, + file_attributes: Optional[List[str]] = None, + domain: Optional[str] = "", + port: Optional[int] = 0, + ): + Base.__init__(self, conf, base_path=base_path, file_attributes=file_attributes) + + log_conf = conf.get('logging') + if log_conf: + self.logger = configure_logging(config=log_conf).getChild(__name__) + else: + self.logger = logging.getLogger('oidcrp') + + self.web_conf = lower_or_upper(conf, "webserver") + + # entity info + if not domain: + domain = conf.get("domain", "127.0.0.1") + + if not port: + port = conf.get("port", 80) + + if entity_conf: + self.extend(entity_conf=entity_conf, conf=conf, base_path=base_path, + file_attributes=file_attributes, domain=domain, port=port) - def load_extension(self, conf): - pass - @classmethod - def create_from_config_file(cls, filename: str, base_path: str = ''): +def create_from_config_file(cls, + filename: str, + base_path: Optional[str] = '', + entity_conf: Optional[List[dict]] = None, + file_attributes: Optional[List[str]] = None, + domain: Optional[str] = "", + port: Optional[int] = 0): + if filename.endswith(".yaml"): """Load configuration as YAML""" - return cls(load_yaml_config(filename), base_path) + _cnf = load_yaml_config(filename) + elif filename.endswith(".json"): + _str = open(filename).read() + _cnf = json.loads(_str) + elif filename.endswith(".py"): + head, tail = os.path.split(filename) + tail = tail[:-3] + module = importlib.import_module(tail) + _cnf = getattr(module, "CONFIG") + else: + raise ValueError("Unknown file type") + + return cls(_cnf, + entity_conf=entity_conf, + base_path=base_path, file_attributes=file_attributes, + domain=domain, port=port) diff --git a/src/oidcrp/cookie.py b/src/oidcrp/cookie.py index 357630a..783257b 100755 --- a/src/oidcrp/cookie.py +++ b/src/oidcrp/cookie.py @@ -1,29 +1,26 @@ import base64 import hashlib import hmac +from http.cookies import SimpleCookie import logging import os import sys import time -from http.cookies import SimpleCookie - from cryptography.hazmat.primitives.ciphers.aead import AESGCM - -from cryptojwt.utils import as_bytes -from cryptojwt.utils import as_unicode from cryptojwt.jwe.exception import JWEException from cryptojwt.jwe.utils import split_ctx_and_tag - -from oidcservice import rndstr -from oidcservice.exception import ImproperlyConfigured +from cryptojwt.utils import as_bytes +from cryptojwt.utils import as_unicode from oidcmsg import time_util +from oidcrp.exception import ImproperlyConfigured +from oidcrp.util import rndstr + __author__ = 'Roland Hedberg' logger = logging.getLogger(__name__) - CORS_HEADERS = [ ("Access-Control-Allow-Origin", "*"), ("Access-Control-Allow-Methods", "GET"), @@ -45,6 +42,7 @@ def safe_str_cmp(a, b): r |= ord(c) ^ ord(d) return r == 0 + def _expiration(timeout, time_format=None): """ Return an expiration time @@ -161,7 +159,7 @@ def make_cookie(name, load, seed, expire=0, domain="", path="", timestamp="", # to the top level APIs. key = _make_hashed_key((enc_key, seed)) - #key = AESGCM.generate_key(bit_length=128) + # key = AESGCM.generate_key(bit_length=128) aesgcm = AESGCM(key) iv = os.urandom(12) @@ -279,6 +277,7 @@ class CookieDealer(object): Functionality that an entity that deals with cookies need to have access to. """ + def _get_server(self): return self._srv diff --git a/src/oidcrp/defaults.py b/src/oidcrp/defaults.py new file mode 100644 index 0000000..bcbdfb2 --- /dev/null +++ b/src/oidcrp/defaults.py @@ -0,0 +1,73 @@ +import hashlib +import string + +SERVICE_NAME = "OIC" +CLIENT_CONFIG = {} + +DEFAULT_OIDC_SERVICES = { + 'web_finger': {'class': 'oidcrp.oidc.webfinger.WebFinger'}, + 'discovery': {'class': 'oidcrp.oidc.provider_info_discovery.ProviderInfoDiscovery'}, + 'registration': {'class': 'oidcrp.oidc.registration.Registration'}, + 'authorization': {'class': 'oidcrp.oidc.authorization.Authorization'}, + 'access_token': {'class': 'oidcrp.oidc.access_token.AccessToken'}, + 'refresh_access_token': {'class': 'oidcrp.oidc.refresh_access_token.RefreshAccessToken'}, + 'userinfo': {'class': 'oidcrp.oidc.userinfo.UserInfo'} +} + +DEFAULT_CLIENT_PREFS = { + 'application_type': 'web', + 'application_name': 'rphandler', + 'response_types': ['code', 'id_token', 'id_token token', 'code id_token', 'code id_token token', + 'code token'], + 'scope': ['openid'], + 'token_endpoint_auth_method': 'client_secret_basic' +} + +# Using PKCE is default +DEFAULT_CLIENT_CONFIGS = { + "": { + "client_preferences": DEFAULT_CLIENT_PREFS, + "add_ons": { + "pkce": { + "function": "oidcrp.oauth2.add_on.pkce.add_support", + "kwargs": { + "code_challenge_length": 64, + "code_challenge_method": "S256" + } + } + } + } +} + +DEFAULT_KEY_DEFS = [ + {"type": "RSA", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] + +DEFAULT_RP_KEY_DEFS = { + 'private_path': 'private/jwks.json', + 'key_defs': DEFAULT_KEY_DEFS, + 'public_path': 'static/jwks.json', + 'read_only': False +} + +OIDCONF_PATTERN = "{}/.well-known/openid-configuration" +CC_METHOD = { + 'S256': hashlib.sha256, + 'S384': hashlib.sha384, + 'S512': hashlib.sha512, +} + +# Map the signing context to a signing algorithm +DEF_SIGN_ALG = {"id_token": "RS256", + "userinfo": "RS256", + "request_object": "RS256", + "client_secret_jwt": "HS256", + "private_key_jwt": "RS256"} + +HTTP_ARGS = ["headers", "redirections", "connection_type"] + +JWT_BEARER = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" +SAML2_BEARER_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:saml2-bearer" + +BASECHR = string.ascii_letters + string.digits diff --git a/src/oidcrp/entity.py b/src/oidcrp/entity.py new file mode 100644 index 0000000..4c7702a --- /dev/null +++ b/src/oidcrp/entity.py @@ -0,0 +1,80 @@ +from typing import Callable +from typing import Optional +from typing import Union + +from cryptojwt import KeyJar + +from oidcrp.client_auth import factory +from oidcrp.configure import Configuration +from oidcrp.service import init_services +from oidcrp.service_context import ServiceContext + +DEFAULT_SERVICES = { + "discovery": { + 'class': 'oidcrp.oauth2.provider_info_discovery.ProviderInfoDiscovery' + }, + 'authorization': { + 'class': 'oidcrp.oauth2.authorization.Authorization' + }, + 'access_token': { + 'class': 'oidcrp.oauth2.access_token.AccessToken' + }, + 'refresh_access_token': { + 'class': 'oidcrp.oauth2.refresh_access_token.RefreshAccessToken' + } +} + + +class Entity(): + def __init__(self, + client_authn_factory: Optional[Callable] = None, + keyjar: Optional[KeyJar] = None, + config: Optional[Union[dict, Configuration]] = None, + services: Optional[dict] = None, + jwks_uri: Optional[str] = '', + httpc_params: Optional[dict] = None): + + if httpc_params: + self.httpc_params = httpc_params + else: + self.httpc_params = {"verify": True} + + self._service_context = ServiceContext(keyjar=keyjar, config=config, + jwks_uri=jwks_uri, httpc_params=self.httpc_params) + + _cid = self._service_context.get('client_id') + if _cid: + self.client_id = _cid + + _cam = client_authn_factory or factory + + _srvs = services or DEFAULT_SERVICES + + self._service = init_services(service_definitions=_srvs, + client_get=self.client_get, + client_authn_factory=_cam) + + def client_get(self, what, *arg): + _func = getattr(self, "get_{}".format(what), None) + if _func: + return _func(*arg) + return None + + def get_services(self, *arg): + return self._service + + def get_service_context(self, *arg): + return self._service_context + + def get_service(self, service_name, *arg): + try: + return self._service[service_name] + except KeyError: + return None + + def get_service_by_endpoint_name(self, endpoint_name, *arg): + for service in self._service.values(): + if service.endpoint_name == endpoint_name: + return service + + return None diff --git a/src/oidcrp/exception.py b/src/oidcrp/exception.py new file mode 100755 index 0000000..c9beb38 --- /dev/null +++ b/src/oidcrp/exception.py @@ -0,0 +1,124 @@ + +__author__ = 'roland' + + +# The base exception class for oidc service specific exceptions +class OidcServiceError(Exception): + def __init__(self, errmsg, content_type="", *args): + Exception.__init__(self, errmsg, *args) + self.content_type = content_type + + +class MissingRequiredAttribute(OidcServiceError): + pass + + +class VerificationError(OidcServiceError): + pass + + +class ResponseError(OidcServiceError): + pass + + +class TimeFormatError(OidcServiceError): + pass + + +class CapabilitiesMisMatch(OidcServiceError): + pass + + +class MissingEndpoint(OidcServiceError): + pass + + +class TokenError(OidcServiceError): + pass + + +class GrantError(OidcServiceError): + pass + + +class ParseError(OidcServiceError): + pass + + +class OtherError(OidcServiceError): + pass + + +class NoClientInfoReceivedError(OidcServiceError): + pass + + +class InvalidRequest(OidcServiceError): + pass + + +class NonFatalException(OidcServiceError): + """ + :param resp: A response that the function/method would return on non-error + :param msg: A message describing what error has occurred. + """ + + def __init__(self, resp, msg): + self.resp = resp + self.msg = msg + + +class Unsupported(OidcServiceError): + pass + + +class UnsupportedResponseType(Unsupported): + pass + + +class AccessDenied(OidcServiceError): + pass + + +class ImproperlyConfigured(OidcServiceError): + pass + + +class UnsupportedMethod(OidcServiceError): + pass + + +class AuthzError(OidcServiceError): + pass + + +class AuthnToOld(OidcServiceError): + pass + + +class ParameterError(OidcServiceError): + pass + + +class SubMismatch(OidcServiceError): + pass + + +class ConfigurationError(OidcServiceError): + pass + + +class WrongContentType(OidcServiceError): + pass + + +class WebFingerError(OidcServiceError): + pass + + +class HandlerError(Exception): + pass + + +class HttpError(OidcServiceError): + pass diff --git a/src/oidcrp/http.py b/src/oidcrp/http.py index 4fc9dcd..ef77c49 100644 --- a/src/oidcrp/http.py +++ b/src/oidcrp/http.py @@ -1,14 +1,13 @@ import copy -import logging -import requests - from http.cookiejar import FileCookieJar from http.cookies import CookieError from http.cookies import SimpleCookie +import logging -from oidcservice import sanitize -from oidcservice.exception import NonFatalException +import requests +from oidcrp.exception import NonFatalException +from oidcrp.util import sanitize from oidcrp.util import set_cookie __author__ = 'roland' diff --git a/src/oidcrp/logging.py b/src/oidcrp/logging.py index f5134de..baf1db9 100755 --- a/src/oidcrp/logging.py +++ b/src/oidcrp/logging.py @@ -3,6 +3,7 @@ import os import logging from logging.config import dictConfig +from typing import Optional import yaml @@ -29,8 +30,9 @@ } -def configure_logging(debug: bool = False, config: dict = None, - filename: str = LOGGING_CONF) -> logging.Logger: +def configure_logging(debug: Optional[bool] = False, + config: Optional[dict] = None, + filename: Optional[str] = LOGGING_CONF) -> logging.Logger: """Configure logging""" if config is not None: diff --git a/src/oidcrp/oauth2/__init__.py b/src/oidcrp/oauth2/__init__.py index 51be362..52fcf79 100755 --- a/src/oidcrp/oauth2/__init__.py +++ b/src/oidcrp/oauth2/__init__.py @@ -1,20 +1,14 @@ -import logging from json import JSONDecodeError +import logging -from cryptojwt.key_jar import KeyJar from oidcmsg.exception import FormatError -from oidcservice.client_auth import factory as ca_factory -from oidcservice.exception import OidcServiceError -from oidcservice.exception import ParseError -from oidcservice.oauth2 import DEFAULT_SERVICES -from oidcservice.service import REQUEST_INFO -from oidcservice.service import SUCCESSFUL -from oidcservice.service import init_services -from oidcservice.service_context import ServiceContext -from oidcservice.state_interface import StateInterface -from oidcservice.util import importer +from oidcrp.entity import Entity +from oidcrp.exception import OidcServiceError +from oidcrp.exception import ParseError from oidcrp.http import HTTPLib +from oidcrp.service import REQUEST_INFO +from oidcrp.service import SUCCESSFUL from oidcrp.util import do_add_ons from oidcrp.util import get_deserialization_method @@ -29,10 +23,26 @@ class ExpiredToken(Exception): pass +DEFAULT_OAUTH2_SERVICES = { + "discovery": { + 'class': 'oidcrp.oauth2.provider_info_discovery.ProviderInfoDiscovery' + }, + 'authorization': { + 'class': 'oidcrp.oauth2.authorization.Authorization' + }, + 'access_token': { + 'class': 'oidcrp.oauth2.access_token.AccessToken' + }, + 'refresh_access_token': { + 'class': 'oidcrp.oauth2.refresh_access_token.RefreshAccessToken' + } +} + + # ============================================================================= -class Client(object): +class Client(Entity): def __init__(self, client_authn_factory=None, keyjar=None, verify_ssl=True, config=None, httplib=None, services=None, jwks_uri='', httpc_params=None): """ @@ -41,7 +51,7 @@ def __init__(self, client_authn_factory=None, keyjar=None, verify_ssl=True, conf initiate a client authentication class. :param keyjar: A py:class:`oidcmsg.key_jar.KeyJar` instance :param config: Configuration information passed on to the - :py:class:`oidcservice.service_context.ServiceContext` + :py:class:`oidcrp.service_context.ServiceContext` initialization :param httplib: A HTTP client to use :param services: A list of service definitions @@ -50,48 +60,22 @@ def __init__(self, client_authn_factory=None, keyjar=None, verify_ssl=True, conf :return: Client instance """ - if httpc_params is None: - httpc_params = {"verify": True} + Entity.__init__(self, client_authn_factory=client_authn_factory, keyjar=keyjar, + config=config, services=services, jwks_uri=jwks_uri, + httpc_params=httpc_params) self.http = httplib or HTTPLib(httpc_params) - # db_conf = config.get('db_conf') - # if db_conf: - # _storage_cls_name = db_conf.get('abstract_storage_cls') - # self._storage_cls = importer(_storage_cls_name) - # self.db = self._storage_cls(db_conf.get('default')) - # if not keyjar: - # key_db_conf = db_conf.get('keyjar', db_conf.get('default')) - # keyjar = KeyJar(abstract_storage_cls=self._storage_cls, storage_conf=key_db_conf) - # keyjar.verify_ssl = verify_ssl - - self.events = None - self.service_context = ServiceContext(keyjar, config=config, - jwks_uri=jwks_uri, - httpc_params=httpc_params) - - self.session_interface = StateInterface(self.service_context.state_db) - - if self.service_context.get('client_id'): - self.client_id = self.service_context.get('client_id') - - _cam = client_authn_factory or ca_factory - - _srvs = services or DEFAULT_SERVICES - - self.service = init_services(_srvs, self.service_context, _cam) - if 'add_ons' in config: - do_add_ons(config['add_ons'], self.service) + do_add_ons(config['add_ons'], self._service) - self.service_context.service = self.service # just ignore verify_ssl until it goes away - self.verify_ssl = httpc_params.get("verify", True) + self.verify_ssl = self.httpc_params.get("verify", True) def do_request(self, request_type, response_body_type="", request_args=None, **kwargs): - _srv = self.service[request_type] + _srv = self._service[request_type] _info = _srv.get_request_parameters(request_args=request_args, **kwargs) @@ -109,7 +93,7 @@ def do_request(self, request_type, response_body_type="", request_args=None, def set_client_id(self, client_id): self.client_id = client_id - self.service_context.set('client_id', client_id) + self._service_context.set('client_id', client_id) def get_response(self, service, url, method="GET", body=None, response_body_type="", headers=None, **kwargs): @@ -134,7 +118,7 @@ def get_response(self, service, url, method="GET", body=None, response_body_type if resp.status_code < 300: if "keyjar" not in kwargs: - kwargs["keyjar"] = service.service_context.keyjar + kwargs["keyjar"] = service.client_get("service_context").keyjar if not response_body_type: response_body_type = service.response_body_type @@ -197,7 +181,7 @@ def parse_request_response(self, service, reqresp, response_body_type='', - text (The text version of the response) - url (The calling URL) - :param service: A :py:class:`oidcservice.service.Service` instance + :param service: A :py:class:`oidcrp.service.Service` instance :param reqresp: The HTTP request response :param response_body_type: If response in body one of 'json', 'jwt' or 'urlencoded' diff --git a/src/oidcrp/oauth2/access_token.py b/src/oidcrp/oauth2/access_token.py new file mode 100644 index 0000000..285a51f --- /dev/null +++ b/src/oidcrp/oauth2/access_token.py @@ -0,0 +1,63 @@ +"""Implements the service that talks to the Access Token endpoint.""" +import logging + +from oidcmsg import oauth2 +from oidcmsg.oauth2 import ResponseMessage +from oidcmsg.time_util import time_sans_frac + +from oidcrp.oauth2.utils import get_state_parameter +from oidcrp.service import Service + +LOGGER = logging.getLogger(__name__) + + +class AccessToken(Service): + """The access token service.""" + msg_type = oauth2.AccessTokenRequest + response_cls = oauth2.AccessTokenResponse + error_msg = ResponseMessage + endpoint_name = 'token_endpoint' + synchronous = True + service_name = 'accesstoken' + default_authn_method = 'client_secret_basic' + http_method = 'POST' + request_body_type = 'urlencoded' + response_body_type = 'json' + + def __init__(self, client_get, client_authn_factory=None, conf=None): + Service.__init__(self, client_get, + client_authn_factory=client_authn_factory, conf=conf) + self.pre_construct.append(self.oauth_pre_construct) + + def update_service_context(self, resp, key='', **kwargs): + if 'expires_in' in resp: + resp['__expires_at'] = time_sans_frac() + int(resp['expires_in']) + self.client_get("service_context").state.store_item(resp, 'token_response', key) + + def oauth_pre_construct(self, request_args=None, post_args=None, **kwargs): + """ + + :param request_args: Initial set of request arguments + :param kwargs: Extra keyword arguments + :return: Request arguments + """ + _state = get_state_parameter(request_args, kwargs) + parameters = list(self.msg_type.c_param.keys()) + + _context = self.client_get("service_context") + _args = _context.state.extend_request_args({}, oauth2.AuthorizationRequest, + 'auth_request', _state, parameters) + + _args = _context.state.extend_request_args(_args, oauth2.AuthorizationResponse, + 'auth_response', _state, parameters) + + if "grant_type" not in _args: + _args["grant_type"] = "authorization_code" + + if request_args is None: + request_args = _args + else: + _args.update(request_args) + request_args = _args + + return request_args, post_args diff --git a/src/oidcrp/oauth2/add_on/__init__.py b/src/oidcrp/oauth2/add_on/__init__.py new file mode 100644 index 0000000..daf6fd2 --- /dev/null +++ b/src/oidcrp/oauth2/add_on/__init__.py @@ -0,0 +1,7 @@ +from oidcrp.util import importer + + +def do_add_ons(add_ons, services): + for key, spec in add_ons.items(): + _func = importer(spec['function']) + _func(services, **spec['kwargs']) diff --git a/src/oidcrp/oauth2/add_on/dpop.py b/src/oidcrp/oauth2/add_on/dpop.py new file mode 100644 index 0000000..654db97 --- /dev/null +++ b/src/oidcrp/oauth2/add_on/dpop.py @@ -0,0 +1,163 @@ +from typing import Optional +from typing import Union +import uuid + +from cryptojwt.jwk.jwk import key_from_jwk_dict +from cryptojwt.jws.jws import JWS +from cryptojwt.jws.jws import factory +from cryptojwt.key_bundle import key_by_alg +from oidcmsg.message import Message +from oidcmsg.message import SINGLE_REQUIRED_INT +from oidcmsg.message import SINGLE_REQUIRED_JSON +from oidcmsg.message import SINGLE_REQUIRED_STRING +from oidcmsg.time_util import utc_time_sans_frac + +from oidcrp.service_context import ServiceContext + + +class DPoPProof(Message): + c_param = { + # header + "typ": SINGLE_REQUIRED_STRING, + "alg": SINGLE_REQUIRED_STRING, + "jwk": SINGLE_REQUIRED_JSON, + # body + "jti": SINGLE_REQUIRED_STRING, + "htm": SINGLE_REQUIRED_STRING, + "htu": SINGLE_REQUIRED_STRING, + "iat": SINGLE_REQUIRED_INT + } + header_params = {"typ", "alg", "jwk"} + body_params = {"jti", "htm", "htu", "iat"} + + def __init__(self, set_defaults=True, **kwargs): + self.key = None + Message.__init__(self, set_defaults=set_defaults, **kwargs) + + if self.key: + pass + elif "jwk" in self: + self.key = key_from_jwk_dict(self["jwk"]) + self.key.deserialize() + + def from_dict(self, dictionary, **kwargs): + Message.from_dict(self, dictionary, **kwargs) + + if "jwk" in self: + self.key = key_from_jwk_dict(self["jwk"]) + self.key.deserialize() + + return self + + def verify(self, **kwargs): + Message.verify(self, **kwargs) + if self["typ"] != "dpop+jwt": + raise ValueError("Wrong type") + if self["alg"] == "none": + raise ValueError("'none' is not allowed as signing algorithm") + + def create_header(self) -> str: + payload = {k: self[k] for k in self.body_params} + _jws = JWS(payload, alg=self["alg"]) + _jws_headers = {k: self[k] for k in self.header_params} + _signed_jwt = _jws.sign_compact(keys=[self.key], **_jws_headers) + return _signed_jwt + + def verify_header(self, dpop_header) -> Optional["DPoPProof"]: + _jws = factory(dpop_header) + if _jws: + _jwt = _jws.jwt + if "jwk" in _jwt.headers: + _pub_key = key_from_jwk_dict(_jwt.headers["jwk"]) + _pub_key.deserialize() + _info = _jws.verify_compact(keys=[_pub_key], sigalg=_jwt.headers["alg"]) + for k, v in _jwt.headers.items(): + self[k] = v + + for k, v in _info.items(): + self[k] = v + else: + raise Exception() + + return self + else: + return None + + +def dpop_header(service_context: ServiceContext, + request: Union[dict, Message], + service_endpoint: str, + http_method: str, + headers: Optional[dict] = None, + authn_method: Optional[str] = "", + **kwargs) -> dict: + """ + + :param service_context: + :param request: + :param service_endpoint: + :param http_method: + :param headers: + :param authn_method: + :param kwargs: + :return: + """ + + provider_info = service_context.provider_info + dpop_key = service_context.add_on['dpop'].get('key') + + if not dpop_key: + algs_supported = provider_info["dpop_signing_alg_values_supported"] + if not algs_supported: # does not support DPoP + return headers + + chosen_alg = '' + for alg in service_context.add_on['dpop']["sign_algs"]: + if alg in algs_supported: + chosen_alg = alg + break + + if not chosen_alg: + return headers + + # Mint a new key + dpop_key = key_by_alg(chosen_alg) + service_context.add_on['dpop']['key'] = dpop_key + service_context.add_on['dpop']['alg'] = chosen_alg + + header_dict = { + "typ": "dpop+jwt", + "alg": service_context.add_on['dpop']['alg'], + "jwk": dpop_key.serialize(), + "jti": uuid.uuid4().hex, + "htm": http_method, + "htu": provider_info[service_endpoint], + "iat": utc_time_sans_frac() + } + + _dpop = DPoPProof(**header_dict) + _dpop.key = dpop_key + jws = _dpop.create_header() + + if headers is None: + headers = {"dpop": jws} + else: + headers["dpop"] = jws + + return headers + + +def add_support(services, signing_algorithms: Optional[list] = None): + """ + Add the necessary pieces to make pushed authorization happen. + + :param services: A dictionary with all the services the client has access to. + :param signing_algorithms: + """ + + _service = services["accesstoken"] + _service.client_get("service_context").add_on['dpop'] = { + # "key": key_by_alg(signing_algorithm), + "sign_algs": signing_algorithms + } + _service.construct_extra_headers.append(dpop_header) diff --git a/src/oidcrp/oauth2/add_on/pkce.py b/src/oidcrp/oauth2/add_on/pkce.py new file mode 100644 index 0000000..0b4632d --- /dev/null +++ b/src/oidcrp/oauth2/add_on/pkce.py @@ -0,0 +1,111 @@ +import logging + +from cryptojwt.utils import b64e +from oidcmsg.message import Message + +from oidcrp.defaults import CC_METHOD +from oidcrp.exception import Unsupported +from oidcrp.oauth2.utils import get_state_parameter +from oidcrp.util import unreserved + +logger = logging.getLogger(__name__) + + +def add_code_challenge(request_args, service, **kwargs): + """ + PKCE RFC 7636 support + To be added as a post_construct method to an + :py:class:`oidcrp.oidc.service.Authorization` instance + + :param service: The service that uses this function + :param request_args: Set of request arguments + :param kwargs: Extra set of keyword arguments + :return: Updated set of request arguments + """ + _context = service.client_get("service_context") + _kwargs = _context.add_on["pkce"] + + try: + cv_len = _kwargs['code_challenge_length'] + except KeyError: + cv_len = 64 # Use default + + # code_verifier: string of length cv_len + code_verifier = unreserved(cv_len) + _cv = code_verifier.encode() + + try: + _method = _kwargs['code_challenge_method'] + except KeyError: + _method = 'S256' + + try: + # Pick hash method + _hash_method = CC_METHOD[_method] + # Use it on the code_verifier + _hv = _hash_method(_cv).digest() + # base64 encode the hash value + code_challenge = b64e(_hv).decode('ascii') + except KeyError: + raise Unsupported( + 'PKCE Transformation method:{}'.format(_method)) + + _item = Message(code_verifier=code_verifier, code_challenge_method=_method) + _context.state.store_item(_item, 'pkce', request_args['state']) + + request_args.update( + { + "code_challenge": code_challenge, + "code_challenge_method": _method + }) + return request_args, {} + + +def add_code_verifier(request_args, service, **kwargs): + """ + PKCE RFC 7636 support + To be added as a post_construct method to an + :py:class:`oidcrp.oidc.service.AccessToken` instance + + :param service: The service that uses this function + :param request_args: Set of request arguments + :return: updated set of request arguments + """ + _state = request_args.get('state') + if _state is None: + _state = kwargs.get('state') + _item = service.client_get("service_context").state.get_item(Message, 'pkce', _state) + request_args.update({'code_verifier': _item['code_verifier']}) + return request_args + + +def put_state_in_post_args(request_args, **kwargs): + state = get_state_parameter(request_args, kwargs) + return request_args, {'state': state} + + +def add_support(service, code_challenge_length, code_challenge_method): + """ + PKCE support can only be considered if this client can access authorization and + access token services. + + :param service: Dictionary of services + :param code_challenge_length: + :param code_challenge_method: + :return: + """ + if "authorization" in service and "accesstoken" in service: + _service = service["authorization"] + _context = _service.client_get("service_context") + _context.add_on['pkce'] = { + "code_challenge_length": code_challenge_length, + "code_challenge_method": code_challenge_method + } + + _service.pre_construct.append(add_code_challenge) + + token_service = service['accesstoken'] + token_service.pre_construct.append(put_state_in_post_args) + token_service.post_construct.append(add_code_verifier) + else: + logger.warning("PKCE support could NOT be added") diff --git a/src/oidcrp/oauth2/add_on/pushed_authorization.py b/src/oidcrp/oauth2/add_on/pushed_authorization.py new file mode 100644 index 0000000..4378041 --- /dev/null +++ b/src/oidcrp/oauth2/add_on/pushed_authorization.py @@ -0,0 +1,74 @@ +import logging + +from cryptojwt import JWT +from oidcmsg.message import Message +from oidcmsg.oauth2 import JWTSecuredAuthorizationRequest +import requests + +logger = logging.getLogger(__name__) + + +def push_authorization(request_args, service, **kwargs): + """ + :param request_args: All the request arguments as a AuthorizationRequest instance + :param service: The service to which this post construct method is applied. + :param kwargs: Extra keyword arguments. + """ + + _context = service.client_get("service_context") + method_args = _context.add_on["pushed_authorization"] + + # construct the message body + if method_args["body_format"] == "urlencoded": + _body = request_args.to_urlencoded() + else: + _jwt = JWT(key_jar=_context.keyjar, iss=_context.base_url) + _jws = _jwt.pack(request_args.to_dict()) + + _msg = Message(request=_jws) + if method_args["merge_rule"] == "lax": + for param in request_args.required_parameters(): + _msg[param] = request_args.get(param) + + _body = _msg.to_urlencoded() + + # Send it to the Pushed Authorization Request Endpoint + resp = method_args["http_client"].get( + _context.provider_info["pushed_authorization_request_endpoint"], data=_body + ) + + if resp.status_code == 200: + _resp = Message().from_json(resp.text) + _req = JWTSecuredAuthorizationRequest(request_uri=_resp["request_uri"]) + if method_args["merge_rule"] == "lax": + for param in request_args.required_parameters(): + _req[param] = request_args.get(param) + request_args = _req + + return request_args + + +def add_support(services, body_format="jws", signing_algorithm="RS256", + http_client=None, merge_rule="strict"): + """ + Add the necessary pieces to make pushed authorization happen. + + :param merge_rule: + :param http_client: + :param signing_algorithm: + :param services: A dictionary with all the services the client has access to. + :param body_format: jws or urlencoded + """ + + if http_client is None: + http_client = requests + + _service = services["authorization"] + _service.client_get("service_context").add_on['pushed_authorization'] = { + "body_format": body_format, + "signing_algorithm": signing_algorithm, + "http_client": http_client, + "merge_rule": merge_rule + } + + _service.post_construct.append(push_authorization) diff --git a/src/oidcrp/oauth2/authorization.py b/src/oidcrp/oauth2/authorization.py new file mode 100644 index 0000000..366b223 --- /dev/null +++ b/src/oidcrp/oauth2/authorization.py @@ -0,0 +1,83 @@ +"""The service that talks to the OAuth2 Authorization endpoint.""" +import logging + +from oidcmsg import oauth2 +from oidcmsg.exception import MissingParameter +from oidcmsg.oauth2 import ResponseMessage +from oidcmsg.time_util import time_sans_frac + +from oidcrp.oauth2.utils import get_state_parameter +from oidcrp.oauth2.utils import pick_redirect_uris +from oidcrp.oauth2.utils import set_state_parameter +from oidcrp.service import Service + +LOGGER = logging.getLogger(__name__) + + +class Authorization(Service): + """The service that talks to the OAuth2 Authorization endpoint.""" + msg_type = oauth2.AuthorizationRequest + response_cls = oauth2.AuthorizationResponse + error_msg = ResponseMessage + endpoint_name = 'authorization_endpoint' + synchronous = False + service_name = 'authorization' + response_body_type = 'urlencoded' + + # parameter = Service.parameter.copy() + # parameter.update({ + # "endpoint": "" + # }) + + def __init__(self, client_get, client_authn_factory=None, conf=None): + Service.__init__(self, client_get, + client_authn_factory=client_authn_factory, conf=conf) + self.pre_construct.extend([pick_redirect_uris, set_state_parameter]) + self.post_construct.append(self.store_auth_request) + + def update_service_context(self, resp, key='', **kwargs): + if 'expires_in' in resp: + resp['__expires_at'] = time_sans_frac() + int(resp['expires_in']) + self.client_get("service_context").state.store_item(resp, 'auth_response', key) + + def store_auth_request(self, request_args=None, **kwargs): + """Store the authorization request in the state DB.""" + _key = get_state_parameter(request_args, kwargs) + self.client_get("service_context").state.store_item(request_args, 'auth_request', _key) + return request_args + + def gather_request_args(self, **kwargs): + ar_args = Service.gather_request_args(self, **kwargs) + + if 'redirect_uri' not in ar_args: + try: + ar_args['redirect_uri'] = self.client_get("service_context").redirect_uris[0] + except (KeyError, AttributeError): + raise MissingParameter('redirect_uri') + + return ar_args + + def post_parse_response(self, response, **kwargs): + """ + Add scope claim to response, from the request, if not present in the + response + + :param response: The response + :param kwargs: Extra Keyword arguments + :return: A possibly augmented response + """ + + if "scope" not in response: + try: + _key = kwargs['state'] + except KeyError: + pass + else: + if _key: + item = self.client_get("service_context").state.get_item(oauth2.AuthorizationRequest, + 'auth_request', _key) + try: + response["scope"] = item["scope"] + except KeyError: + pass + return response diff --git a/chrp/utils.py b/src/oidcrp/oauth2/client_credentials/__init__.py similarity index 100% rename from chrp/utils.py rename to src/oidcrp/oauth2/client_credentials/__init__.py diff --git a/src/oidcrp/oauth2/client_credentials/cc_access_token.py b/src/oidcrp/oauth2/client_credentials/cc_access_token.py new file mode 100644 index 0000000..0bec1bb --- /dev/null +++ b/src/oidcrp/oauth2/client_credentials/cc_access_token.py @@ -0,0 +1,27 @@ +from oidcmsg import oauth2 +from oidcmsg.oauth2 import ResponseMessage +from oidcmsg.time_util import time_sans_frac + +from oidcrp.service import Service + + +class CCAccessToken(Service): + msg_type = oauth2.CCAccessTokenRequest + response_cls = oauth2.AccessTokenResponse + error_msg = ResponseMessage + endpoint_name = 'token_endpoint' + synchronous = True + service_name = 'accesstoken' + default_authn_method = 'client_secret_basic' + http_method = 'POST' + request_body_type = 'urlencoded' + response_body_type = 'json' + + def __init__(self, client_get, client_authn_factory=None, conf=None): + Service.__init__(self, client_get, + client_authn_factory=client_authn_factory, conf=conf) + + def update_service_context(self, resp, key='cc', **kwargs): + if 'expires_in' in resp: + resp['__expires_at'] = time_sans_frac() + int(resp['expires_in']) + self.client_get('service_context').state.store_item(resp, 'token_response', key) diff --git a/src/oidcrp/oauth2/client_credentials/cc_refresh_access_token.py b/src/oidcrp/oauth2/client_credentials/cc_refresh_access_token.py new file mode 100644 index 0000000..13468a4 --- /dev/null +++ b/src/oidcrp/oauth2/client_credentials/cc_refresh_access_token.py @@ -0,0 +1,55 @@ +from oidcmsg import oauth2 +from oidcmsg.oauth2 import ResponseMessage +from oidcmsg.time_util import time_sans_frac + +from oidcrp.service import Service + + +class CCRefreshAccessToken(Service): + msg_type = oauth2.RefreshAccessTokenRequest + response_cls = oauth2.AccessTokenResponse + error_msg = ResponseMessage + endpoint_name = 'token_endpoint' + synchronous = True + service_name = 'refresh_token' + default_authn_method = 'bearer_header' + http_method = 'POST' + + def __init__(self, client_get, client_authn_factory=None, conf=None): + Service.__init__(self, client_get, + client_authn_factory=client_authn_factory, conf=conf) + self.pre_construct.append(self.cc_pre_construct) + self.post_construct.append(self.cc_post_construct) + + def cc_pre_construct(self, request_args=None, **kwargs): + _state_id = kwargs.get("state", "cc") + parameters = ['refresh_token'] + _state_interface = self.client_get("service_context").state + _args = _state_interface.extend_request_args({}, oauth2.AccessTokenResponse, + 'token_response', _state_id, parameters) + + _args = _state_interface.extend_request_args(_args, oauth2.AccessTokenResponse, + 'refresh_token_response', _state_id, + parameters) + + if request_args is None: + request_args = _args + else: + _args.update(request_args) + request_args = _args + + return request_args, {} + + def cc_post_construct(self, request_args, **kwargs): + for attr in ['client_id', 'client_secret']: + try: + del request_args[attr] + except KeyError: + pass + + return request_args + + def update_service_context(self, resp, key='cc', **kwargs): + if 'expires_in' in resp: + resp['__expires_at'] = time_sans_frac() + int(resp['expires_in']) + self.client_get("service_context").state.store_item(resp, 'token_response', key) diff --git a/src/oidcrp/oauth2/provider_info_discovery.py b/src/oidcrp/oauth2/provider_info_discovery.py new file mode 100644 index 0000000..4aa4033 --- /dev/null +++ b/src/oidcrp/oauth2/provider_info_discovery.py @@ -0,0 +1,130 @@ +"""The service that talks to the OAuth2 provider info discovery endpoint.""" +import logging + +from cryptojwt.key_jar import KeyJar +from oidcmsg import oauth2 +from oidcmsg.oauth2 import ResponseMessage + +from oidcrp.defaults import OIDCONF_PATTERN +from oidcrp.exception import OidcServiceError +from oidcrp.service import Service + +LOGGER = logging.getLogger(__name__) + + +class ProviderInfoDiscovery(Service): + """The service that talks to the OAuth2 provider info discovery endpoint.""" + msg_type = oauth2.Message + response_cls = oauth2.ASConfigurationResponse + error_msg = ResponseMessage + synchronous = True + service_name = 'provider_info' + http_method = 'GET' + + def __init__(self, client_get, client_authn_factory=None, conf=None): + Service.__init__(self, client_get, + client_authn_factory=client_authn_factory, conf=conf) + + def get_endpoint(self): + """ + Find the issuer ID and from it construct the service endpoint + + :return: Service endpoint + """ + try: + _iss = self.client_get("service_context").issuer + except AttributeError: + _iss = self.endpoint + + if _iss.endswith('/'): + return OIDCONF_PATTERN.format(_iss[:-1]) + + return OIDCONF_PATTERN.format(_iss) + + def get_request_parameters(self, method="GET", **kwargs): + """ + The Provider info discovery version of get_request_parameters() + + :param method: + :param kwargs: + :return: + """ + return {'url': self.get_endpoint(), 'method': method} + + def _verify_issuer(self, resp, issuer): + _pcr_issuer = resp["issuer"] + if resp["issuer"].endswith("/"): + if issuer.endswith("/"): + _issuer = issuer + else: + _issuer = issuer + "/" + else: + if issuer.endswith("/"): + _issuer = issuer[:-1] + else: + _issuer = issuer + + # In some cases we can live with the two URLs not being + # the same. But this is an excepted that has to be explicit + try: + self.client_get("service_context").allow['issuer_mismatch'] + except KeyError: + if _issuer != _pcr_issuer: + raise OidcServiceError( + "provider info issuer mismatch '%s' != '%s'" % ( + _issuer, _pcr_issuer)) + return _issuer + + def _set_endpoints(self, resp): + """ + If there are services defined set the service endpoint to be + the URLs specified in the provider information.""" + for key, val in resp.items(): + # All service endpoint parameters in the provider info has + # a name ending in '_endpoint' so I can look specifically + # for those + if key.endswith("_endpoint"): + _srv = self.client_get("service_by_endpoint_name", key) + if _srv: + _srv.endpoint = val + + def _update_service_context(self, resp): + """ + Deal with Provider Config Response. Based on the provider info + response a set of parameters in different places needs to be set. + + :param resp: The provider info response + :param service_context: Information collected/used by services + """ + + _context = self.client_get("service_context") + # Verify that the issuer value received is the same as the + # url that was used as service endpoint (without the .well-known part) + if "issuer" in resp: + _pcr_issuer = self._verify_issuer(resp, _context.issuer) + else: # No prior knowledge + _pcr_issuer = _context.issuer + + _context.issuer = _pcr_issuer + _context.provider_info = resp + + self._set_endpoints(resp) + + # If I already have a Key Jar then I'll add then provider keys to + # that. Otherwise a new Key Jar is minted + try: + _keyjar = _context.keyjar + except KeyError: + _keyjar = KeyJar() + + # Load the keys. Note that this only means that the key specification + # is loaded not necessarily that any keys are fetched. + if 'jwks_uri' in resp: + _keyjar.load_keys(_pcr_issuer, jwks_uri=resp['jwks_uri']) + elif 'jwks' in resp: + _keyjar.load_keys(_pcr_issuer, jwks=resp['jwks']) + + _context.keyjar = _keyjar + + def update_service_context(self, resp, **kwargs): + return self._update_service_context(resp) diff --git a/src/oidcrp/oauth2/refresh_access_token.py b/src/oidcrp/oauth2/refresh_access_token.py new file mode 100644 index 0000000..4b5efa6 --- /dev/null +++ b/src/oidcrp/oauth2/refresh_access_token.py @@ -0,0 +1,54 @@ +"""The service that talks to the OAuth2 refresh access token endpoint.""" +import logging + +from oidcmsg import oauth2 +from oidcmsg.oauth2 import ResponseMessage +from oidcmsg.time_util import time_sans_frac + +from oidcrp.oauth2.utils import get_state_parameter +from oidcrp.service import Service + +LOGGER = logging.getLogger(__name__) + + +class RefreshAccessToken(Service): + """The service that talks to the OAuth2 refresh access token endpoint.""" + msg_type = oauth2.RefreshAccessTokenRequest + response_cls = oauth2.AccessTokenResponse + error_msg = ResponseMessage + endpoint_name = 'token_endpoint' + synchronous = True + service_name = 'refresh_token' + default_authn_method = 'bearer_header' + http_method = 'POST' + + def __init__(self, client_get, client_authn_factory=None, conf=None): + Service.__init__(self, client_get, + client_authn_factory=client_authn_factory, conf=conf) + self.pre_construct.append(self.oauth_pre_construct) + + def update_service_context(self, resp, key='', **kwargs): + if 'expires_in' in resp: + resp['__expires_at'] = time_sans_frac() + int(resp['expires_in']) + self.client_get("service_context").state.store_item(resp, 'token_response', key) + + def oauth_pre_construct(self, request_args=None, **kwargs): + """Preconstructor of request arguments""" + _state = get_state_parameter(request_args, kwargs) + parameters = list(self.msg_type.c_param.keys()) + + _si = self.client_get("service_context").state + _args = _si.extend_request_args({}, oauth2.AccessTokenResponse, + 'token_response', _state, parameters) + + _args = _si.extend_request_args(_args, oauth2.AccessTokenResponse, + 'refresh_token_response', _state, + parameters) + + if request_args is None: + request_args = _args + else: + _args.update(request_args) + request_args = _args + + return request_args, {} diff --git a/src/oidcrp/oauth2/utils.py b/src/oidcrp/oauth2/utils.py new file mode 100644 index 0000000..2393766 --- /dev/null +++ b/src/oidcrp/oauth2/utils.py @@ -0,0 +1,52 @@ +from oidcmsg.exception import MissingParameter + + +def get_state_parameter(request_args, kwargs): + """Find a state value from a set of possible places.""" + try: + _state = kwargs['state'] + except KeyError: + try: + _state = request_args['state'] + except KeyError: + raise MissingParameter('state') + + return _state + + +def pick_redirect_uris(request_args=None, service=None, **kwargs): + """Pick one redirect_uri base on response_mode out of a list of such.""" + _context = service.client_get("service_context") + + if 'redirect_uri' in request_args: + return request_args, {} + + _callback = _context.callback + if _callback: + try: + _response_type = request_args['response_type'] + except KeyError: + _response_type = _context.behaviour['response_types'][0] + request_args['response_type'] = _response_type + + try: + _response_mode = request_args['response_mode'] + except KeyError: + _response_mode = '' + + if _response_mode == 'form_post': + request_args['redirect_uri'] = _callback['form_post'] + elif _response_type == 'code': + request_args['redirect_uri'] = _callback['code'] + else: + request_args['redirect_uri'] = _callback['implicit'] + else: + request_args['redirect_uri'] = _context.redirect_uris[0] + + return request_args, {} + + +def set_state_parameter(request_args=None, **kwargs): + """Assigned a state value.""" + request_args['state'] = get_state_parameter(request_args, kwargs) + return request_args, {'state': request_args['state']} diff --git a/src/oidcrp/oidc/__init__.py b/src/oidcrp/oidc/__init__.py index 647a67a..34dc2ab 100755 --- a/src/oidcrp/oidc/__init__.py +++ b/src/oidcrp/oidc/__init__.py @@ -1,8 +1,7 @@ import json import logging -from oidcservice.client_auth import BearerHeader -from oidcservice.oidc import DEFAULT_SERVICES +from oidcrp.client_auth import BearerHeader try: from json import JSONDecodeError @@ -18,6 +17,46 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- +# + +DEFAULT_SERVICES = { + "discovery": { + 'class': 'oidcrp.oidc.provider_info_discovery' + '.ProviderInfoDiscovery' + }, + 'registration': { + 'class': 'oidcrp.oidc.registration.Registration' + }, + 'authorization': { + 'class': 'oidcrp.oidc.authorization.Authorization' + }, + 'access_token': { + 'class': 'oidcrp.oidc.access_token.AccessToken' + }, + 'refresh_access_token': { + 'class': 'oidcrp.oidc.refresh_access_token.RefreshAccessToken' + }, + 'userinfo': { + 'class': 'oidcrp.oidc.userinfo.UserInfo' + } +} + +WF_URL = "https://{}/.well-known/webfinger" +OIC_ISSUER = "http://openid.net/specs/connect/1.0/issuer" + +IDT2REG = { + 'sigalg': 'id_token_signed_response_alg', + 'encalg': 'id_token_encrypted_response_alg', + 'encenc': 'id_token_encrypted_response_enc' +} + +ENDPOINT2SERVICE = { + 'authorization': ['authorization'], + 'token': ['accesstoken', 'refresh_token'], + 'userinfo': ['userinfo'], + 'registration': ['registration'], + 'end_sesssion': ['end_session'] +} # This should probably be part of the configuration MAX_AUTHENTICATION_AGE = 86400 @@ -89,7 +128,7 @@ def fetch_distributed_claims(self, userinfo, callback=None): if "access_token" in spec: cauth = BearerHeader() httpc_params = cauth.construct( - service=self.service['userinfo'], + service=self.client_get("service", 'userinfo'), access_token=spec['access_token']) _resp = self.http.send(spec["endpoint"], 'GET', **httpc_params) @@ -98,7 +137,7 @@ def fetch_distributed_claims(self, userinfo, callback=None): token = callback(spec['endpoint']) cauth = BearerHeader() httpc_params = cauth.construct( - service=self.service['userinfo'], + service=self.client_get("service",'userinfo'), access_token=token) _resp = self.http.send( spec["endpoint"], 'GET', **httpc_params) diff --git a/src/oidcrp/oidc/access_token.py b/src/oidcrp/oidc/access_token.py new file mode 100644 index 0000000..03a87f1 --- /dev/null +++ b/src/oidcrp/oidc/access_token.py @@ -0,0 +1,91 @@ +import logging +from typing import Optional + +from oidcmsg import oidc +from oidcmsg.oidc import verified_claim_name +from oidcmsg.time_util import time_sans_frac + +from oidcrp.exception import ParameterError +from oidcrp.oauth2 import access_token +from oidcrp.oidc import IDT2REG + +__author__ = 'Roland Hedberg' + +LOGGER = logging.getLogger(__name__) + + +class AccessToken(access_token.AccessToken): + msg_type = oidc.AccessTokenRequest + response_cls = oidc.AccessTokenResponse + error_msg = oidc.ResponseMessage + + def __init__(self, + client_get, + client_authn_factory=None, + conf: Optional[dict]=None): + access_token.AccessToken.__init__(self, client_get, + client_authn_factory=client_authn_factory, conf=conf) + + def gather_verify_arguments(self): + """ + Need to add some information before running verify() + + :return: dictionary with arguments to the verify call + """ + _context = self.client_get("service_context") + # Default is RS256 + + kwargs = { + 'client_id': _context.client_id, + 'iss': _context.issuer, + 'keyjar': _context.keyjar, + 'verify': True, + 'skew': _context.clock_skew, + } + + _reg_resp = _context.registration_response + if _reg_resp: + for attr, param in IDT2REG.items(): + try: + kwargs[attr] = _reg_resp[param] + except KeyError: + pass + + try: + kwargs['allow_missing_kid'] = _context.allow['missing_kid'] + except KeyError: + pass + + _verify_args = _context.behaviour.get("verify_args") + if _verify_args: + if _verify_args: + kwargs.update(_verify_args) + + return kwargs + + def update_service_context(self, resp, key='', **kwargs): + _state_interface = self.client_get("service_context").state + try: + _idt = resp[verified_claim_name('id_token')] + except KeyError: + pass + else: + try: + if _state_interface.get_state_by_nonce(_idt['nonce']) != key: + raise ParameterError('Someone has messed with "nonce"') + except KeyError: + raise ValueError('Invalid nonce value') + + _state_interface.store_sub2state(_idt['sub'], key) + + if 'expires_in' in resp: + resp['__expires_at'] = time_sans_frac() + int( + resp['expires_in']) + + _state_interface.store_item(resp, 'token_response', key) + + def get_authn_method(self): + try: + return self.client_get("service_context").behaviour['token_endpoint_auth_method'] + except KeyError: + return self.default_authn_method diff --git a/src/oidcrp/oidc/authorization.py b/src/oidcrp/oidc/authorization.py new file mode 100644 index 0000000..b95601b --- /dev/null +++ b/src/oidcrp/oidc/authorization.py @@ -0,0 +1,253 @@ +import logging + +from oidcmsg import oidc +from oidcmsg.oidc import make_openid_request +from oidcmsg.oidc import verified_claim_name +from oidcmsg.time_util import time_sans_frac +from oidcmsg.time_util import utc_time_sans_frac + +from oidcrp.exception import ParameterError +from oidcrp.oauth2 import authorization +from oidcrp.oauth2.utils import pick_redirect_uris +from oidcrp.oidc import IDT2REG +from oidcrp.oidc.utils import construct_request_uri +from oidcrp.oidc.utils import request_object_encryption +from oidcrp.util import rndstr + +__author__ = 'Roland Hedberg' + +LOGGER = logging.getLogger(__name__) + + +class Authorization(authorization.Authorization): + msg_type = oidc.AuthorizationRequest + response_cls = oidc.AuthorizationResponse + error_msg = oidc.ResponseMessage + + def __init__(self, client_get, client_authn_factory=None, conf=None): + authorization.Authorization.__init__(self, client_get, + client_authn_factory, conf=conf) + self.default_request_args = {'scope': ['openid']} + self.pre_construct = [self.set_state, pick_redirect_uris, + self.oidc_pre_construct] + self.post_construct = [self.oidc_post_construct] + + def set_state(self, request_args, **kwargs): + try: + _state = kwargs['state'] + except KeyError: + try: + _state = request_args['state'] + except KeyError: + _state = '' + + _context = self.client_get("service_context") + request_args['state'] = _context.state.create_state(_context.issuer, _state) + return request_args, {} + + def update_service_context(self, resp, key='', **kwargs): + _context = self.client_get("service_context") + try: + _idt = resp[verified_claim_name('id_token')] + except KeyError: + pass + else: + # If there is a verified ID Token then we have to do nonce + # verification + try: + if _context.state.get_state_by_nonce(_idt['nonce']) != key: + raise ParameterError('Someone has messed with "nonce"') + except KeyError: + raise ValueError('Missing nonce value') + + _context.state.store_sub2state(_idt['sub'], key) + + if 'expires_in' in resp: + resp['__expires_at'] = time_sans_frac() + int(resp['expires_in']) + _context.state.store_item(resp.to_json(), 'auth_response', key) + + def oidc_pre_construct(self, request_args=None, post_args=None, **kwargs): + _context = self.client_get("service_context") + if request_args is None: + request_args = {} + + try: + _response_types = [request_args["response_type"]] + except KeyError: + _response_types = _context.behaviour.get('response_types') + if _response_types: + request_args["response_type"] = _response_types[0] + else: + request_args["response_type"] = "code" + + # For OIDC 'openid' is required in scope + if 'scope' not in request_args: + request_args['scope'] = _context.behaviour.get("scope", ["openid"]) + elif 'openid' not in request_args['scope']: + request_args['scope'].append('openid') + + # 'code' and/or 'id_token' in response_type means an ID Roken + # will eventually be returnedm, hence the need for a nonce + if "code" in _response_types or "id_token" in _response_types: + if "nonce" not in request_args: + request_args["nonce"] = rndstr(32) + + if post_args is None: + post_args = {} + + for attr in ["request_object_signing_alg", "algorithm", 'sig_kid']: + try: + post_args[attr] = kwargs[attr] + except KeyError: + pass + else: + del kwargs[attr] + + if "request_method" in kwargs: + if kwargs["request_method"] == "reference": + post_args['request_param'] = "request_uri" + else: + post_args['request_param'] = "request" + del kwargs["request_method"] + + return request_args, post_args + + def get_request_object_signing_alg(self, **kwargs): + alg = '' + for arg in ["request_object_signing_alg", "algorithm"]: + try: # Trumps everything + alg = kwargs[arg] + except KeyError: + pass + else: + break + + if not alg: + try: + alg = self.client_get("service_context").behaviour["request_object_signing_alg"] + except KeyError: # Use default + alg = "RS256" + return alg + + def store_request_on_file(self, req, **kwargs): + """ + Stores the request parameter in a file. + :param req: The request + :param kwargs: Extra keyword arguments + :return: The URL the OP should use to access the file + """ + _context = self.client_get("service_context") + try: + _webname = _context.registration_response['request_uris'][0] + filename = _context.filename_from_webname(_webname) + except KeyError: + filename, _webname = construct_request_uri(**kwargs) + + fid = open(filename, mode="w") + fid.write(req) + fid.close() + return _webname + + def construct_request_parameter(self, req, request_method, audience=None, expires_in=0, + **kwargs): + """Construct a request parameter""" + alg = self.get_request_object_signing_alg(**kwargs) + kwargs["request_object_signing_alg"] = alg + + _context = self.client_get("service_context") + if "keys" not in kwargs and alg and alg != "none": + kwargs["keys"] = _context.keyjar + + _srv_cntx = _context + + # This is the issuer of the JWT, that is me ! + _issuer = kwargs.get("issuer") + if _issuer is None: + kwargs['issuer'] = _srv_cntx.client_id + + if kwargs.get("recv") is None: + try: + kwargs['recv'] = _srv_cntx.provider_info['issuer'] + except KeyError: + kwargs['recv'] = _srv_cntx.issuer + + del kwargs['service'] + + if expires_in: + req['exp'] = utc_time_sans_frac() + int(expires_in) + + _req = make_openid_request(req, **kwargs) + + # Should the request be encrypted + _req = request_object_encryption(_req, _context, **kwargs) + + if request_method == "request": + req["request"] = _req + else: # MUST be request_uri + req["request_uri"] = self.store_request_on_file(_req, **kwargs) + + def oidc_post_construct(self, req, **kwargs): + """ + Modify the request arguments. + + :param req: The request + :param kwargs: Extra keyword arguments + :return: A possibly modified request. + """ + _context = self.client_get("service_context") + if 'openid' in req['scope']: + _response_type = req['response_type'][0] + if 'id_token' in _response_type or 'code' in _response_type: + _context.state.store_nonce2state(req['nonce'], req['state']) + + if 'offline_access' in req['scope']: + if 'prompt' not in req: + req['prompt'] = 'consent' + + try: + _request_method = kwargs['request_param'] + except KeyError: + pass + else: + del kwargs['request_param'] + + self.construct_request_parameter(req, _request_method, **kwargs) + + _context.state.store_item(req, 'auth_request', req['state']) + return req + + def gather_verify_arguments(self): + """ + Need to add some information before running verify() + + :return: dictionary with arguments to the verify call + """ + _context = self.client_get("service_context") + kwargs = { + 'iss': _context.issuer, + 'keyjar': _context.keyjar, 'verify': True, + 'skew': _context.clock_skew + } + + _client_id = _context.client_id + if _client_id: + kwargs['client_id'] = _client_id + + _reg_res = _context.registration_response + if _reg_res: + for attr, param in IDT2REG.items(): + try: + kwargs[attr] = _reg_res[param] + except KeyError: + pass + + try: + kwargs['allow_missing_kid'] = _context.allow['missing_kid'] + except KeyError: + pass + + _verify_args = _context.behaviour.get("verify_args") + if _verify_args: + kwargs.update(_verify_args) + + return kwargs diff --git a/src/oidcrp/oidc/check_id.py b/src/oidcrp/oidc/check_id.py new file mode 100644 index 0000000..8ef8dd8 --- /dev/null +++ b/src/oidcrp/oidc/check_id.py @@ -0,0 +1,31 @@ +import logging + +from oidcmsg.oauth2 import Message, ResponseMessage +from oidcmsg.oidc import session + +from oidcrp.service import Service + +__author__ = 'Roland Hedberg' + +logger = logging.getLogger(__name__) + + +class CheckID(Service): + msg_type = session.CheckIDRequest + response_cls = Message + error_msg = ResponseMessage + endpoint_name = '' + synchronous = True + service_name = 'check_id' + + def __init__(self, client_get, client_authn_factory=None, conf=None): + Service.__init__(self, client_get, + client_authn_factory=client_authn_factory, + conf=conf) + self.pre_construct = [self.oidc_pre_construct] + + def oidc_pre_construct(self, request_args=None, **kwargs): + request_args = self.client_get("service_context").state.multiple_extend_request_args( + request_args, kwargs['state'], ['id_token'], + ['auth_response', 'token_response', 'refresh_token_response']) + return request_args, {} diff --git a/src/oidcrp/oidc/check_session.py b/src/oidcrp/oidc/check_session.py new file mode 100644 index 0000000..bf7661a --- /dev/null +++ b/src/oidcrp/oidc/check_session.py @@ -0,0 +1,31 @@ +import logging + +from oidcmsg.oauth2 import Message, ResponseMessage +from oidcmsg.oidc import session + +from oidcrp.service import Service + +__author__ = 'Roland Hedberg' + +logger = logging.getLogger(__name__) + + +class CheckSession(Service): + msg_type = session.CheckSessionRequest + response_cls = Message + error_msg = ResponseMessage + endpoint_name = '' + synchronous = True + service_name = 'check_session' + + def __init__(self, client_get, client_authn_factory=None, conf=None): + Service.__init__(self, client_get, + client_authn_factory=client_authn_factory, + conf=conf) + self.pre_construct = [self.oidc_pre_construct] + + def oidc_pre_construct(self, request_args=None, **kwargs): + request_args = self.client_get("service_context").state.multiple_extend_request_args( + request_args, kwargs['state'], ['id_token'], + ['auth_response', 'token_response', 'refresh_token_response']) + return request_args, {} diff --git a/src/oidcrp/oidc/end_session.py b/src/oidcrp/oidc/end_session.py new file mode 100644 index 0000000..a58c91d --- /dev/null +++ b/src/oidcrp/oidc/end_session.py @@ -0,0 +1,75 @@ +import logging + +from oidcmsg.oauth2 import Message +from oidcmsg.oauth2 import ResponseMessage +from oidcmsg.oidc import session + +from oidcrp.service import Service +from oidcrp.util import rndstr + +__author__ = 'Roland Hedberg' + +logger = logging.getLogger(__name__) + + +class EndSession(Service): + msg_type = session.EndSessionRequest + response_cls = Message + error_msg = ResponseMessage + endpoint_name = 'end_session_endpoint' + synchronous = True + service_name = 'end_session' + response_body_type = 'html' + + def __init__(self, client_get, client_authn_factory=None, conf=None): + Service.__init__(self, client_get, + client_authn_factory=client_authn_factory, + conf=conf) + self.pre_construct = [self.get_id_token_hint, + self.add_post_logout_redirect_uri, + self.add_state] + + def get_id_token_hint(self, request_args=None, **kwargs): + """ + Add id_token_hint to request + + :param request_args: + :param kwargs: + :return: + """ + request_args = self.client_get("service_context").state.multiple_extend_request_args( + request_args, kwargs['state'], ['id_token'], + ['auth_response', 'token_response', 'refresh_token_response'], + orig=True + ) + + try: + request_args['id_token_hint'] = request_args['id_token'] + except KeyError: + pass + else: + del request_args['id_token'] + + return request_args, {} + + def add_post_logout_redirect_uri(self, request_args=None, **kwargs): + if 'post_logout_redirect_uri' not in request_args: + try: + request_args[ + 'post_logout_redirect_uri' + ] = self.client_get("service_context").register_args[ + 'post_logout_redirect_uris'][0] + except KeyError: + pass + + return request_args, {} + + def add_state(self, request_args=None, **kwargs): + if 'state' not in request_args: + request_args['state'] = rndstr(32) + + # As a side effect bind logout state to session state + self.client_get("service_context").state.store_logout_state2state(request_args['state'], + kwargs['state']) + + return request_args, {} diff --git a/src/oidcrp/oidc/provider_info_discovery.py b/src/oidcrp/oidc/provider_info_discovery.py new file mode 100644 index 0000000..230ce27 --- /dev/null +++ b/src/oidcrp/oidc/provider_info_discovery.py @@ -0,0 +1,176 @@ +import logging + +from oidcmsg import oidc +from oidcmsg.oauth2 import ResponseMessage + +from oidcrp.exception import ConfigurationError +from oidcrp.oauth2 import provider_info_discovery + +__author__ = 'Roland Hedberg' + +logger = logging.getLogger(__name__) + +PREFERENCE2PROVIDER = { + # "require_signed_request_object": "request_object_algs_supported", + "request_object_signing_alg": "request_object_signing_alg_values_supported", + "request_object_encryption_alg": + "request_object_encryption_alg_values_supported", + "request_object_encryption_enc": + "request_object_encryption_enc_values_supported", + "userinfo_signed_response_alg": "userinfo_signing_alg_values_supported", + "userinfo_encrypted_response_alg": + "userinfo_encryption_alg_values_supported", + "userinfo_encrypted_response_enc": + "userinfo_encryption_enc_values_supported", + "id_token_signed_response_alg": "id_token_signing_alg_values_supported", + "id_token_encrypted_response_alg": + "id_token_encryption_alg_values_supported", + "id_token_encrypted_response_enc": + "id_token_encryption_enc_values_supported", + "default_acr_values": "acr_values_supported", + "subject_type": "subject_types_supported", + "token_endpoint_auth_method": "token_endpoint_auth_methods_supported", + "token_endpoint_auth_signing_alg": + "token_endpoint_auth_signing_alg_values_supported", + "response_types": "response_types_supported", + 'grant_types': 'grant_types_supported', + 'scope': 'scopes_supported' +} + +PROVIDER2PREFERENCE = dict([(v, k) for k, v in PREFERENCE2PROVIDER.items()]) + +PROVIDER_DEFAULT = { + "token_endpoint_auth_method": "client_secret_basic", + "id_token_signed_response_alg": "RS256", +} + + +def add_redirect_uris(request_args, service=None, **kwargs): + """ + Add redirect_uris to the request arguments. + + :param request_args: Incomming request arguments + :param service: A link to the service + :param kwargs: Possible extra keyword arguments + :return: A possibly augmented set of request arguments. + """ + _context = service.client_get("service_context") + if "redirect_uris" not in request_args: + # Callbacks is a dictionary with callback type 'code', 'implicit', + # 'form_post' as keys. + _cbs = _context.callback + if _cbs: + # Filter out local additions. + _uris = [v for k, v in _cbs.items() if not k.startswith('__')] + request_args['redirect_uris'] = _uris + else: + request_args['redirect_uris'] = _context.redirect_uris + + return request_args, {} + + +class ProviderInfoDiscovery(provider_info_discovery.ProviderInfoDiscovery): + msg_type = oidc.Message + response_cls = oidc.ProviderConfigurationResponse + error_msg = ResponseMessage + + def __init__(self, client_get, client_authn_factory=None, conf=None): + provider_info_discovery.ProviderInfoDiscovery.__init__( + self, client_get, client_authn_factory=client_authn_factory, + conf=conf) + + def update_service_context(self, resp, **kwargs): + _context = self.client_get("service_context") + self._update_service_context(resp) + self.match_preferences(resp, _context.issuer) + if 'pre_load_keys' in self.conf and self.conf['pre_load_keys']: + _jwks = _context.keyjar.export_jwks_as_json( + issuer=resp['issuer']) + logger.info( + 'Preloaded keys for {}: {}'.format(resp['issuer'], _jwks)) + + def match_preferences(self, pcr=None, issuer=None): + """ + Match the clients preferences against what the provider can do. + This is to prepare for later client registration and or what + functionality the client actually will use. + In the client configuration the client preferences are expressed. + These are then compared with the Provider Configuration information. + If the Provider has left some claims out, defaults specified in the + standard will be used. + + :param pcr: Provider configuration response if available + :param issuer: The issuer identifier + """ + _context = self.client_get("service_context") + if not pcr: + pcr = _context.provider_info + + regreq = oidc.RegistrationRequest + + _behaviour = _context.behaviour + + for _pref, _prov in PREFERENCE2PROVIDER.items(): + try: + vals = _context.client_preferences[_pref] + except KeyError: + continue + + try: + _pvals = pcr[_prov] + except KeyError: + try: + # If the provider have not specified use what the + # standard says is mandatory if at all. + _pvals = PROVIDER_DEFAULT[_pref] + except KeyError: + logger.info( + 'No info from provider on {} and no default'.format( + _pref)) + _pvals = vals + + if isinstance(vals, str): + if vals in _pvals: + _behaviour[_pref] = vals + else: + try: + vtyp = regreq.c_param[_pref] + except KeyError: + # Allow non standard claims + if isinstance(vals, list): + _behaviour[_pref] = [v for v in vals if v in _pvals] + elif vals in _pvals: + _behaviour[_pref] = vals + else: + if isinstance(vtyp[0], list): + _behaviour[_pref] = [] + for val in vals: + if val in _pvals: + _behaviour[_pref].append( + val) + else: + for val in vals: + if val in _pvals: + _behaviour[_pref] = val + break + + if _pref not in _behaviour: + raise ConfigurationError("OP couldn't match preference:%s" % _pref, pcr) + + for key, val in _context.client_preferences.items(): + if key in _behaviour: + continue + + try: + vtyp = regreq.c_param[key] + if isinstance(vtyp[0], list): + pass + elif isinstance(val, list) and not isinstance(val, str): + val = val[0] + except KeyError: + pass + if key not in PREFERENCE2PROVIDER: + _behaviour[key] = val + + _context.behaviour= _behaviour + logger.debug('service_context behaviour: {}'.format(_behaviour)) diff --git a/src/oidcrp/oidc/read_registration.py b/src/oidcrp/oidc/read_registration.py new file mode 100644 index 0000000..f1c4aef --- /dev/null +++ b/src/oidcrp/oidc/read_registration.py @@ -0,0 +1,46 @@ +import logging + +from oidcmsg import oidc +from oidcmsg.message import Message +from oidcmsg.oauth2 import ResponseMessage + +from oidcrp.service import Service + +LOGGER = logging.getLogger(__name__) + + +class RegistrationRead(Service): + msg_type = Message + response_cls = oidc.RegistrationResponse + error_msg = ResponseMessage + synchronous = True + service_name = 'registration_read' + http_method = 'GET' + default_authn_method = 'client_secret_basic' + + def get_endpoint(self): + try: + return self.client_get("service_context").registration_response["registration_client_uri"] + except KeyError: + return '' + + def get_authn_header(self, request, authn_method, **kwargs): + """ + Construct an authorization specification to be sent in the + HTTP header. + + :param request: The service request + :param authn_method: Which authentication/authorization method to use + :param kwargs: Extra keyword arguments + :return: A set of keyword arguments to be sent in the HTTP header. + """ + headers = {} + + if authn_method == "client_secret_basic": + LOGGER.debug("Client authn method: %s", authn_method) + headers["Authorization"] = "Bearer {}".format( + self.client_get("service_context").registration_response[ + "registration_access_token"] + ) + + return headers \ No newline at end of file diff --git a/src/oidcrp/oidc/refresh_access_token.py b/src/oidcrp/oidc/refresh_access_token.py new file mode 100644 index 0000000..d6af5a3 --- /dev/null +++ b/src/oidcrp/oidc/refresh_access_token.py @@ -0,0 +1,15 @@ +from oidcmsg import oidc + +from oidcrp.oauth2 import refresh_access_token + + +class RefreshAccessToken(refresh_access_token.RefreshAccessToken): + msg_type = oidc.RefreshAccessTokenRequest + response_cls = oidc.AccessTokenResponse + error_msg = oidc.ResponseMessage + + def get_authn_method(self): + try: + return self.client_get("service_context").behaviour['token_endpoint_auth_method'] + except KeyError: + return self.default_authn_method diff --git a/src/oidcrp/oidc/registration.py b/src/oidcrp/oidc/registration.py new file mode 100644 index 0000000..925cc66 --- /dev/null +++ b/src/oidcrp/oidc/registration.py @@ -0,0 +1,171 @@ +import logging + +from oidcmsg import oidc +from oidcmsg.oauth2 import ResponseMessage + +from oidcrp.oidc.provider_info_discovery import add_redirect_uris +from oidcrp.service import Service + +__author__ = 'Roland Hedberg' + +logger = logging.getLogger(__name__) + +rt2gt = { + 'code': ['authorization_code'], + 'id_token': ['implicit'], + 'id_token token': ['implicit'], + 'code id_token': ['authorization_code', 'implicit'], + 'code token': ['authorization_code', 'implicit'], + 'code id_token token': ['authorization_code', 'implicit'] +} + + +def response_types_to_grant_types(response_types): + _res = set() + + for response_type in response_types: + _rt = response_type.split(' ') + _rt.sort() + try: + _gt = rt2gt[" ".join(_rt)] + except KeyError: + logger.warning( + 'No such response type combination: {}'.format(response_types)) + else: + _res.update(set(_gt)) + + return list(_res) + + +def add_request_uri(request_args=None, service=None, **kwargs): + _context = service.client_get("service_context") + if _context.requests_dir: + _pi = _context.provider_info + if _pi: + _req = _pi.get('require_request_uri_registration', False) + if _req is True: + request_args['request_uris'] = _context.generate_request_uris(_context.requests_dir) + + return request_args, {} + + +def add_post_logout_redirect_uris(request_args=None, service=None, **kwargs): + """ + + :param request_args: + :param service: pointer to the :py:class:`oidcrp.service.Service` + instance that is running this function + :param kwargs: parameters to the registration request + :return: + """ + + if "post_logout_redirect_uris" not in request_args: + _uris = service.client_get("service_context").register_args.get("post_logout_redirect_uris") + if _uris: + request_args["post_logout_redirect_uris"] = _uris + + return request_args, {} + + +def add_jwks_uri_or_jwks(request_args=None, service=None, **kwargs): + if 'jwks_uri' in request_args: + if 'jwks' in request_args: + del request_args['jwks'] # only one of jwks_uri and jwks allowed + return request_args, {} + elif 'jwks' in request_args: + return request_args, {} + + for attr in ['jwks_uri', 'jwks']: + _val = getattr(service.client_get("service_context"), attr, 0) + if _val: + request_args[attr] = _val + break + else: + try: + _val = service.client_get("service_context").config[attr] + except KeyError: + pass + else: + request_args[attr] = _val + break + + return request_args, {} + + +class Registration(Service): + msg_type = oidc.RegistrationRequest + response_cls = oidc.RegistrationResponse + error_msg = ResponseMessage + endpoint_name = 'registration_endpoint' + synchronous = True + service_name = 'registration' + request_body_type = 'json' + http_method = 'POST' + + def __init__(self, client_get, client_authn_factory=None, conf=None): + Service.__init__(self, client_get, + client_authn_factory=client_authn_factory, + conf=conf) + self.pre_construct = [self.add_client_behaviour_preference, + add_redirect_uris, add_request_uri, + add_post_logout_redirect_uris, + add_jwks_uri_or_jwks] + self.post_construct = [self.oidc_post_construct] + + def add_client_behaviour_preference(self, request_args=None, **kwargs): + _context = self.client_get("service_context") + for prop in self.msg_type.c_param.keys(): + if prop in request_args: + continue + + try: + request_args[prop] = _context.behaviour[prop] + except KeyError: + try: + request_args[ + prop] = _context.client_preferences[prop] + except KeyError: + pass + return request_args, {} + + def oidc_post_construct(self, request_args=None, **kwargs): + try: + request_args['grant_types'] = response_types_to_grant_types( + request_args['response_types']) + except KeyError: + pass + + # If a Client can use jwks_uri, it MUST NOT use jwks. + if 'jwks_uri' in request_args and 'jwks' in request_args: + del request_args['jwks'] + + return request_args + + def update_service_context(self, resp, key='', **kwargs): + if "token_endpoint_auth_method" not in resp: + resp["token_endpoint_auth_method"] = "client_secret_basic" + + _context = self.client_get("service_context") + _context.registration_response = resp + _client_id = resp.get("client_id") + if _client_id: + _context.client_id = _client_id + if _client_id not in _context.keyjar: + _context.keyjar.import_jwks( + _context.keyjar.export_jwks(True, ''), + issuer_id=_client_id + ) + _client_secret = resp.get("client_secret") + if _client_secret: + _context.client_secret = _client_secret + _context.keyjar.add_symmetric('', _client_secret) + _context.keyjar.add_symmetric(_client_id, _client_secret) + try: + _context.client_secret_expires_at = resp["client_secret_expires_at"] + except KeyError: + pass + + try: + _context.registration_access_token = resp["registration_access_token"] + except KeyError: + pass diff --git a/src/oidcrp/oidc/userinfo.py b/src/oidcrp/oidc/userinfo.py new file mode 100644 index 0000000..2b5852d --- /dev/null +++ b/src/oidcrp/oidc/userinfo.py @@ -0,0 +1,139 @@ +import logging + +from oidcmsg import oidc +from oidcmsg.exception import MissingSigningKey +from oidcmsg.message import Message + +from oidcrp.oauth2.utils import get_state_parameter +from oidcrp.service import Service + +logger = logging.getLogger(__name__) + +UI2REG = { + 'sigalg': 'userinfo_signed_response_alg', + 'encalg': 'userinfo_encrypted_response_alg', + 'encenc': 'userinfo_encrypted_response_enc' +} + + +def carry_state(request_args=None, **kwargs): + """ + Make sure post_construct_methods have access to state + + :param request_args: + :param kwargs: + :return: The value of the state parameter + """ + return request_args, {'state': get_state_parameter(request_args, kwargs)} + + +class UserInfo(Service): + msg_type = Message + response_cls = oidc.OpenIDSchema + error_msg = oidc.ResponseMessage + endpoint_name = 'userinfo_endpoint' + synchronous = True + service_name = 'userinfo' + default_authn_method = 'bearer_header' + http_method = 'GET' + + def __init__(self, client_get, client_authn_factory=None, conf=None): + Service.__init__(self, client_get, + client_authn_factory=client_authn_factory, + conf=conf) + self.pre_construct = [self.oidc_pre_construct, carry_state] + + def oidc_pre_construct(self, request_args=None, **kwargs): + if request_args is None: + request_args = {} + + if "access_token" in request_args: + pass + else: + request_args = self.client_get("service_context").state.multiple_extend_request_args( + request_args, kwargs['state'], ['access_token'], + ['auth_response', 'token_response', 'refresh_token_response'] + ) + + return request_args, {} + + def post_parse_response(self, response, **kwargs): + _context = self.client_get("service_context") + _state_interface = _context.state + _args = _state_interface.multiple_extend_request_args( + {}, kwargs['state'], ['id_token'], + ['auth_response', 'token_response', 'refresh_token_response'] + ) + + try: + _sub = _args['id_token']['sub'] + except KeyError: + logger.warning("Can not verify value on sub") + else: + if response['sub'] != _sub: + raise ValueError('Incorrect "sub" value') + + try: + _csrc = response["_claim_sources"] + except KeyError: + pass + else: + for csrc, spec in _csrc.items(): + if "JWT" in spec: + try: + aggregated_claims = Message().from_jwt( + spec["JWT"].encode("utf-8"), + keyjar=_context.keyjar) + except MissingSigningKey as err: + logger.warning( + 'Error encountered while unpacking aggregated ' + 'claims'.format(err)) + else: + claims = [value for value, src in + response["_claim_names"].items() if + src == csrc] + + for key in claims: + response[key] = aggregated_claims[key] + elif 'endpoint' in spec: + _info = { + "headers": self.get_authn_header( + {}, self.default_authn_method, + authn_endpoint=self.endpoint_name, + key=kwargs["state"] + ), + "url": spec["endpoint"] + } + + _state_interface.store_item(response, 'user_info', kwargs['state']) + return response + + def gather_verify_arguments(self): + """ + Need to add some information before running verify() + + :return: dictionary with arguments to the verify call + """ + _context = self.client_get("service_context") + kwargs = { + 'client_id': _context.client_id, + 'iss': _context.issuer, + 'keyjar': _context.keyjar, 'verify': True, + 'skew': _context.clock_skew + } + + _reg_resp = _context.registration_response + if _reg_resp: + for attr, param in UI2REG.items(): + try: + kwargs[attr] = _reg_resp[param] + except KeyError: + pass + + try: + kwargs['allow_missing_kid'] = _context.allow['missing_kid'] + except KeyError: + pass + + return kwargs + diff --git a/src/oidcrp/oidc/utils.py b/src/oidcrp/oidc/utils.py new file mode 100644 index 0000000..40b9b90 --- /dev/null +++ b/src/oidcrp/oidc/utils.py @@ -0,0 +1,87 @@ +import os + +from cryptojwt.jwe.jwe import JWE +from cryptojwt.jwe.utils import alg2keytype +from oidcmsg.exception import MissingRequiredAttribute + +from oidcrp.util import rndstr + + +def request_object_encryption(msg, service_context, **kwargs): + """ + Created an encrypted JSON Web token with *msg* as body. + + :param msg: The mesaqg + :param service_context: + :param kwargs: + :return: + """ + try: + encalg = kwargs["request_object_encryption_alg"] + except KeyError: + try: + encalg = service_context.behaviour[ + "request_object_encryption_alg"] + except KeyError: + return msg + + if not encalg: + return msg + + try: + encenc = kwargs["request_object_encryption_enc"] + except KeyError: + try: + encenc = service_context.behaviour["request_object_encryption_enc"] + except KeyError: + raise MissingRequiredAttribute( + "No request_object_encryption_enc specified") + + if not encenc: + raise MissingRequiredAttribute( + "No request_object_encryption_enc specified") + + _jwe = JWE(msg, alg=encalg, enc=encenc) + _kty = alg2keytype(encalg) + + try: + _kid = kwargs["enc_kid"] + except KeyError: + _kid = "" + + if "target" not in kwargs: + raise MissingRequiredAttribute("No target specified") + + if _kid: + _keys = service_context.keyjar.get_encrypt_key(_kty, + issuer_id=kwargs["target"], + kid=_kid) + _jwe["kid"] = _kid + else: + _keys = service_context.keyjar.get_encrypt_key(_kty, + issuer_id=kwargs["target"]) + + return _jwe.encrypt(_keys) + + +def construct_request_uri(local_dir, base_path, **kwargs): + """ + Constructs a special redirect_uri to be used when communicating with + one OP. Each OP should get their own redirect_uris. + + :param local_dir: Local directory in which to place the file + :param base_path: Base URL to start with + :param kwargs: + :return: 2-tuple with (filename, url) + """ + _filedir = local_dir + if not os.path.isdir(_filedir): + os.makedirs(_filedir) + _webpath = base_path + _name = rndstr(10) + ".jwt" + filename = os.path.join(_filedir, _name) + while os.path.exists(filename): + _name = rndstr(10) + filename = os.path.join(_filedir, _name) + _webname = "%s%s" % (_webpath, _name) + return filename, _webname diff --git a/src/oidcrp/oidc/webfinger.py b/src/oidcrp/oidc/webfinger.py new file mode 100644 index 0000000..a442a36 --- /dev/null +++ b/src/oidcrp/oidc/webfinger.py @@ -0,0 +1,167 @@ +import logging +from urllib.parse import urlsplit +from urllib.parse import urlunsplit + +from oidcmsg import oidc +from oidcmsg.exception import MissingRequiredAttribute +from oidcmsg.oauth2 import Message +from oidcmsg.oauth2 import ResponseMessage +from oidcmsg.oidc import JRD + +from oidcrp.oidc import OIC_ISSUER +from oidcrp.oidc import WF_URL +from oidcrp.service import Service + +__author__ = 'Roland Hedberg' + +logger = logging.getLogger(__name__) + +SCHEME = 0 +NETLOC = 1 +PATH = 2 +QUERY = 3 +FRAGMENT = 4 + + +class WebFinger(Service): + """ + Implements RFC 7033 + """ + msg_type = Message + response_cls = JRD + error_msg = ResponseMessage + synchronous = True + service_name = 'webfinger' + http_method = 'GET' + response_body_type = 'json' + + def __init__(self, client_get, client_authn_factory=None, + conf=None, rel='', **kwargs): + Service.__init__(self, client_get, + client_authn_factory=client_authn_factory, + conf=conf, **kwargs) + + self.rel = rel or OIC_ISSUER + + def update_service_context(self, resp, key='', **kwargs): + try: + links = resp['links'] + except KeyError: + raise MissingRequiredAttribute('links') + else: + for link in links: + if link['rel'] == self.rel: + _href = link['href'] + try: + _http_allowed = self.get_conf_attr( + 'allow', default={})['http_links'] + except KeyError: + _http_allowed = False + + if _href.startswith('http://') and not _http_allowed: + raise ValueError( + 'http link not allowed ({})'.format(_href)) + + self.client_get("service_context").issuer = link['href'] + break + return resp + + @staticmethod + def create_url(part, ignore): + res = [] + for a in range(0, 5): + if a in ignore: + res.append('') + else: + res.append(part[a]) + return urlunsplit(tuple(res)) + + def query(self, resource): + """ + Given a resource identifier find the domain specifier and then + construct the webfinger request. Implements + http://openid.net/specs/openid-connect-discovery-1_0.html#NormalizationSteps + + :param resource: + """ + if resource[0] in ['=', '@', '!']: # Have no process for handling these + raise ValueError('Not allowed resource identifier') + + try: + part = urlsplit(resource) + except Exception: + raise ValueError('Unparsable resource') + else: + if not part[SCHEME]: + if not part[NETLOC]: + _path = part[PATH] + if not part[QUERY] and not part[FRAGMENT]: + if '/' in _path or ':' in _path: + resource = "https://{}".format(resource) + part = urlsplit(resource) + authority = part[NETLOC] + else: + if '@' in _path: + authority = _path.split('@')[1] + else: + authority = _path + resource = 'acct:{}'.format(_path) + elif part[QUERY]: + resource = "https://{}?{}".format(_path, part[QUERY]) + parts = urlsplit(resource) + authority = parts[NETLOC] + else: + resource = "https://{}".format(_path) + part = urlsplit(resource) + authority = part[NETLOC] + else: + raise ValueError('Missing netloc') + else: + _scheme = part[SCHEME] + if _scheme not in ['http', 'https', 'acct']: + # assume it to be a hostname port combo, + # eg. example.com:8080 + resource = 'https://{}'.format(resource) + part = urlsplit(resource) + authority = part[NETLOC] + resource = self.create_url(part, [FRAGMENT]) + elif _scheme in ['http', 'https'] and not part[NETLOC]: + raise ValueError( + 'No authority part in the resource specification') + elif _scheme == 'acct': + _path = part[PATH] + for c in ['/', '?']: + _path = _path.split(c)[0] + + if '@' in _path: + authority = _path.split('@')[1] + else: + raise ValueError( + 'No authority part in the resource specification') + authority = authority.split('#')[0] + resource = self.create_url(part, [FRAGMENT]) + else: + authority = part[NETLOC] + resource = self.create_url(part, [FRAGMENT]) + + location = WF_URL.format(authority) + return oidc.WebFingerRequest( + resource=resource, rel=OIC_ISSUER).request(location) + + def get_request_parameters(self, request_args=None, **kwargs): + + if request_args is None: + request_args = {} + + try: + _resource = request_args['resource'] + except KeyError: + try: + _resource = kwargs['resource'] + except KeyError: + try: + _resource = self.client_get("service_context").config['resource'] + except KeyError: + raise MissingRequiredAttribute('resource') + + return {'url': self.query(_resource), 'method': 'GET'} diff --git a/src/oidcrp/provider/github.py b/src/oidcrp/provider/github.py index 820c327..b6eed08 100644 --- a/src/oidcrp/provider/github.py +++ b/src/oidcrp/provider/github.py @@ -3,8 +3,8 @@ from oidcmsg.message import SINGLE_OPTIONAL_STRING from oidcmsg.message import SINGLE_REQUIRED_STRING from oidcmsg.oauth2 import ResponseMessage -from oidcservice.oauth2 import access_token -from oidcservice.oidc import userinfo +from oidcrp.oauth2 import access_token +from oidcrp.oidc import userinfo class AccessTokenResponse(Message): diff --git a/src/oidcrp/provider/linkedin.py b/src/oidcrp/provider/linkedin.py index 0eb2aa6..f5c564d 100644 --- a/src/oidcrp/provider/linkedin.py +++ b/src/oidcrp/provider/linkedin.py @@ -5,8 +5,8 @@ from oidcmsg.message import SINGLE_REQUIRED_INT from oidcmsg.message import SINGLE_REQUIRED_STRING -from oidcservice.oauth2 import access_token -from oidcservice.oidc import userinfo +from oidcrp.oauth2 import access_token +from oidcrp.oidc import userinfo class AccessTokenResponse(Message): diff --git a/src/oidcrp/rp_handler.py b/src/oidcrp/rp_handler.py new file mode 100644 index 0000000..5dc6750 --- /dev/null +++ b/src/oidcrp/rp_handler.py @@ -0,0 +1,900 @@ +import hashlib +import logging +import sys +import traceback +from typing import Optional + +from cryptojwt import as_unicode +from cryptojwt.key_bundle import keybundle_from_local_file +from cryptojwt.key_jar import init_key_jar +from cryptojwt.utils import as_bytes +from oidcmsg import verified_claim_name +from oidcmsg.exception import MessageException +from oidcmsg.exception import NotForMe +from oidcmsg.oauth2 import ResponseMessage +from oidcmsg.oauth2 import is_error_message +from oidcmsg.oidc import AccessTokenResponse +from oidcmsg.oidc import AuthorizationRequest +from oidcmsg.oidc import AuthorizationResponse +from oidcmsg.oidc import Claims +from oidcmsg.oidc import OpenIDSchema +from oidcmsg.oidc.session import BackChannelLogoutRequest +from oidcmsg.time_util import time_sans_frac + +from . import oidc +from .defaults import DEFAULT_CLIENT_CONFIGS +from .defaults import DEFAULT_OIDC_SERVICES +from .defaults import DEFAULT_RP_KEY_DEFS +from .exception import OidcServiceError +from .oauth2 import Client +from .util import add_path +from .util import dynamic_provider_info_discovery +from .util import load_registration_response +from .util import rndstr + +logger = logging.getLogger(__name__) + + +class RPHandler(object): + def __init__(self, base_url, client_configs=None, services=None, keyjar=None, + hash_seed="", verify_ssl=True, client_authn_factory=None, + client_cls=None, state_db=None, http_lib=None, httpc_params=None, + **kwargs): + + self.base_url = base_url + if hash_seed: + self.hash_seed = as_bytes(hash_seed) + else: + self.hash_seed = as_bytes(rndstr(32)) + + _jwks_path = kwargs.get('jwks_path') + if keyjar is None: + self.keyjar = init_key_jar(**DEFAULT_RP_KEY_DEFS, issuer_id='') + self.keyjar.import_jwks_as_json(self.keyjar.export_jwks_as_json(True, ''), base_url) + if _jwks_path is None: + _jwks_path = DEFAULT_RP_KEY_DEFS['public_path'] + else: + self.keyjar = keyjar + + if _jwks_path: + self.jwks_uri = add_path(base_url, _jwks_path) + else: + self.jwks_uri = "" + if len(self.keyjar): + self.jwks = self.keyjar.export_jwks() + else: + self.jwks = {} + + if state_db: + self.state_db = state_db + else: + self.state_db = {} + + self.extra = kwargs + + self.client_cls = client_cls or oidc.RP + if services is None: + self.services = DEFAULT_OIDC_SERVICES + else: + self.services = services + + self.client_authn_factory = client_authn_factory + + if client_configs is None: + self.client_configs = DEFAULT_CLIENT_CONFIGS + else: + self.client_configs = client_configs + + # keep track on which RP instance that serves with OP + self.issuer2rp = {} + self.hash2issuer = {} + self.httplib = http_lib + + if not httpc_params: + self.httpc_params = {'verify': verify_ssl} + else: + self.httpc_params = httpc_params + + if not self.keyjar.httpc_params: + self.keyjar.httpc_params = self.httpc_params + + def state2issuer(self, state): + """ + Given the state value find the Issuer ID of the OP/AS that state value + was used against. + Will raise a KeyError if the state is unknown. + + :param state: The state value + :return: An Issuer ID + """ + for _rp in self.issuer2rp.values(): + try: + _iss = _rp.client_get("service_context").state.get_iss(state) + except KeyError: + continue + else: + if _iss: + return _iss + return None + + def pick_config(self, issuer): + """ + From the set of client configurations pick one based on the issuer ID. + Will raise a KeyError if issuer is unknown. + + :param issuer: Issuer ID + :return: A client configuration + """ + return self.client_configs[issuer] + + def get_session_information(self, key, client=None): + """ + This is the second of the methods users of this class should know about. + It will return the complete session information as an + :py:class:`oidcrp.state_interface.State` instance. + + :param key: The session key (state) + :return: A State instance + """ + if not client: + client = self.get_client_from_session_key(key) + + return client.client_get("service_context").state.get_state(key) + + def init_client(self, issuer): + """ + Initiate a Client instance. Specifically which Client class is used + is decided by configuration. + + :param issuer: An issuer ID + :return: A Client instance + """ + try: + _cnf = self.pick_config(issuer) + except KeyError: + _cnf = self.pick_config('') + _cnf['issuer'] = issuer + + try: + _services = _cnf['services'] + except KeyError: + _services = self.services + + try: + client = self.client_cls( + client_authn_factory=self.client_authn_factory, + services=_services, config=_cnf, httplib=self.httplib, + httpc_params=self.httpc_params) + except Exception as err: + logger.error('Failed initiating client: {}'.format(err)) + message = traceback.format_exception(*sys.exc_info()) + logger.error(message) + raise + + _context = client.client_get("service_context") + # If non persistent + _context.keyjar.load(self.keyjar.dump()) + # If persistent nothings has to be copied + + _context.base_url = self.base_url + _context.jwks_uri = self.jwks_uri + return client + + def do_provider_info(self, client=None, state=''): + """ + Either get the provider info from configuration or through dynamic + discovery. + + :param client: A Client instance + :param state: A key by which the state of the session can be + retrieved + :return: issuer ID + """ + + if not client: + if state: + client = self.get_client_from_session_key(state) + else: + raise ValueError('Missing state/session key') + + _context = client.client_get("service_context") + if not _context.get('provider_info'): + dynamic_provider_info_discovery(client) + return _context.get('provider_info')['issuer'] + else: + _pi = _context.get('provider_info') + for key, val in _pi.items(): + # All service endpoint parameters in the provider info has + # a name ending in '_endpoint' so I can look specifically + # for those + if key.endswith("_endpoint"): + for _srv in client.client_get("services").values(): + # Every service has an endpoint_name assigned + # when initiated. This name *MUST* match the + # endpoint names used in the provider info + if _srv.endpoint_name == key: + _srv.endpoint = val + + if 'keys' in _pi: + _kj = _context.keyjar + for typ, _spec in _pi['keys'].items(): + if typ == 'url': + for _iss, _url in _spec.items(): + _kj.add_url(_iss, _url) + elif typ == 'file': + for kty, _name in _spec.items(): + if kty == 'jwks': + _kj.import_jwks_from_file(_name, _context.get('issuer')) + elif kty == 'rsa': # PEM file + _kb = keybundle_from_local_file(_name, "der", ["sig"]) + _kj.add_kb(_context.get('issuer'), _kb) + else: + raise ValueError('Unknown provider JWKS type: {}'.format(typ)) + try: + return _context.get('provider_info')['issuer'] + except KeyError: + return _context.get('issuer') + + def do_client_registration(self, client=None, iss_id='', state=''): + """ + Prepare for and do client registration if configured to do so + + :param client: A Client instance + :param state: A key by which the state of the session can be + retrieved + """ + + if not client: + if state: + client = self.get_client_from_session_key(state) + else: + raise ValueError('Missing state/session key') + + _context = client.client_get("service_context") + _iss = _context.get('issuer') + if not _context.get('redirect_uris'): + # Create the necessary callback URLs + # as a side effect self.hash2issuer is set + callbacks = self.create_callbacks(_iss) + + _context.set('redirect_uris', [ + v for k, v in callbacks.items() if not k.startswith('__')]) + _context.set('callbacks', callbacks) + else: + self.hash2issuer[iss_id] = _iss + + # This should only be interesting if the client supports Single Log Out + if _context.post_logout_redirect_uris is None: + _context.post_logout_redirect_uris = [self.base_url] + + if not _context.client_id: + load_registration_response(client) + + def add_callbacks(self, service_context): + _callbacks = self.create_callbacks(service_context.get('provider_info')['issuer']) + service_context.set('redirect_uris', [ + v for k, v in _callbacks.items() if not k.startswith('__')]) + service_context.set('callbacks', _callbacks) + return _callbacks + + def client_setup(self, iss_id='', user=''): + """ + First if no issuer ID is given then the identifier for the user is + used by the webfinger service to try to find the issuer ID. + Once the method has an issuer ID if no client is bound to this issuer + one is created and initiated with + the necessary information for the client to be able to communicate + with the OP/AS that has the provided issuer ID. + + :param iss_id: The issuer ID + :param user: A user identifier + :return: A :py:class:`oidcrp.oidc.Client` instance + """ + + logger.info('client_setup: iss_id={}, user={}'.format(iss_id, user)) + + if not iss_id: + if not user: + raise ValueError('Need issuer or user') + + logger.debug("Connecting to previously unknown OP") + temporary_client = self.init_client('') + temporary_client.do_request('webfinger', resource=user) + else: + temporary_client = None + + try: + client = self.issuer2rp[iss_id] + except KeyError: + if temporary_client: + client = temporary_client + else: + logger.debug("Creating new client: %s", iss_id) + client = self.init_client(iss_id) + else: + return client + + logger.debug("Get provider info") + issuer = self.do_provider_info(client) + + logger.debug("Do client registration") + self.do_client_registration(client, iss_id) + + self.issuer2rp[issuer] = client + return client + + def create_callbacks(self, issuer): + """ + To mitigate some security issues the redirect_uris should be OP/AS + specific. This method creates a set of redirect_uris unique to the + OP/AS. + + :param issuer: Issuer ID + :return: A set of redirect_uris + """ + _hash = hashlib.sha256() + _hash.update(self.hash_seed) + _hash.update(as_bytes(issuer)) + _hex = _hash.hexdigest() + self.hash2issuer[_hex] = issuer + return { + 'code': "{}/authz_cb/{}".format(self.base_url, _hex), + 'implicit': "{}/authz_im_cb/{}".format(self.base_url, _hex), + 'form_post': "{}/authz_fp_cb/{}".format(self.base_url, _hex), + '__hex': _hex + } + + def init_authorization(self, client=None, state='', req_args=None): + """ + Constructs the URL that will redirect the user to the authorization + endpoint of the OP/AS. + + :param client: A Client instance + :param req_args: Non-default Request arguments + :return: A dictionary with 2 keys: **url** The authorization redirect + URL and **state** the key to the session information in the + state data store. + """ + if not client: + if state: + client = self.get_client_from_session_key(state) + else: + raise ValueError('Missing state/session key') + + _context = client.client_get("service_context") + + _nonce = rndstr(24) + request_args = { + 'redirect_uri': _context.get('redirect_uris')[0], + 'scope': _context.get('behaviour')['scope'], + 'response_type': _context.get('behaviour')['response_types'][0], + 'nonce': _nonce + } + + _req_args = _context.config.get("request_args") + if _req_args: + if 'claims' in _req_args: + _req_args["claims"] = Claims(**_req_args["claims"]) + request_args.update(_req_args) + + if req_args is not None: + request_args.update(req_args) + + # Need a new state for a new authorization request + _state = _context.state.create_state(_context.get('issuer')) + request_args['state'] = _state + _context.state.store_nonce2state(_nonce, _state) + + logger.debug('Authorization request args: {}'.format(request_args)) + + _srv = client.get_service('authorization') + _info = _srv.get_request_parameters(request_args=request_args) + logger.debug('Authorization info: {}'.format(_info)) + return {'url': _info['url'], 'state': _state} + + def begin(self, issuer_id='', user_id=''): + """ + This is the first of the 3 high level methods that most users of this + library should confine them self to use. + If will use client_setup to produce a Client instance ready to be used + against the OP/AS the user wants to use. + Once it has the client it will construct an Authorization + request. + + :param issuer_id: Issuer ID + :param user_id: A user identifier + :return: A dictionary containing **url** the URL that will redirect the + user to the OP/AS and **state** the session key which will + allow higher level code to access session information. + """ + + # Get the client instance that has been assigned to this issuer + client = self.client_setup(issuer_id, user_id) + + try: + res = self.init_authorization(client) + except Exception: + message = traceback.format_exception(*sys.exc_info()) + logger.error(message) + raise + else: + return res + + # ---------------------------------------------------------------------- + + def get_client_from_session_key(self, state): + return self.issuer2rp[self.state2issuer(state)] + + @staticmethod + def get_response_type(client): + """ + Return the response_type a specific client wants to use. + + :param client: A Client instance + :return: The response_type + """ + return client.service_context.get('behaviour')['response_types'][0] + + @staticmethod + def get_client_authn_method(client, endpoint): + """ + Return the client authentication method a client wants to use a + specific endpoint + + :param client: A Client instance + :param endpoint: The endpoint at which the client has to authenticate + :return: The client authentication method + """ + if endpoint == 'token_endpoint': + try: + am = client.client_get("service_context").get('behaviour')['token_endpoint_auth_method'] + except KeyError: + return '' + else: + if isinstance(am, str): + return am + else: # a list + return am[0] + + def get_access_token(self, state, client: Optional[Client] = None): + """ + Use the 'accesstoken' service to get an access token from the OP/AS. + + :param state: The state key (the state parameter in the + authorization request) + :param client: A Client instance + :return: A :py:class:`oidcmsg.oidc.AccessTokenResponse` or + :py:class:`oidcmsg.oauth2.AuthorizationResponse` + """ + logger.debug('get_accesstoken') + + if client is None: + client = self.get_client_from_session_key(state) + + _context = client.client_get("service_context") + authorization_response = _context.state.get_item(AuthorizationResponse, 'auth_response', + state) + authorization_request = _context.state.get_item(AuthorizationRequest, 'auth_request', state) + + req_args = { + 'code': authorization_response['code'], + 'state': state, + 'redirect_uri': authorization_request['redirect_uri'], + 'grant_type': 'authorization_code', + 'client_id': _context.get('client_id'), + 'client_secret': _context.get('client_secret') + } + logger.debug('request_args: {}'.format(req_args)) + try: + tokenresp = client.do_request( + 'accesstoken', request_args=req_args, + authn_method=self.get_client_authn_method(client, + "token_endpoint"), + state=state + ) + except Exception as err: + message = traceback.format_exception(*sys.exc_info()) + logger.error(message) + raise + else: + if is_error_message(tokenresp): + raise OidcServiceError(tokenresp['error']) + + return tokenresp + + def refresh_access_token(self, state, client=None, scope=''): + """ + Refresh an access token using a refresh_token. When asking for a new + access token the RP can ask for another scope for the new token. + + :param client: A Client instance + :param state: The state key (the state parameter in the + authorization request) + :param scope: What the returned token should be valid for. + :return: A :py:class:`oidcmsg.oidc.AccessTokenResponse` instance + """ + if scope: + req_args = {'scope': scope} + else: + req_args = {} + + if client is None: + client = self.get_client_from_session_key(state) + + try: + tokenresp = client.do_request( + 'refresh_token', + authn_method=self.get_client_authn_method(client, "token_endpoint"), + state=state, request_args=req_args + ) + except Exception as err: + message = traceback.format_exception(*sys.exc_info()) + logger.error(message) + raise + else: + if is_error_message(tokenresp): + raise OidcServiceError(tokenresp['error']) + + return tokenresp + + def get_user_info(self, state, client=None, access_token='', + **kwargs): + """ + use the access token previously acquired to get some userinfo + + :param client: A Client instance + :param state: The state value, this is the key into the session + data store + :param access_token: An access token + :param kwargs: Extra keyword arguments + :return: A :py:class:`oidcmsg.oidc.OpenIDSchema` instance + """ + if client is None: + client = self.get_client_from_session_key(state) + + if not access_token: + _arg = client.client_get("service_context").state.multiple_extend_request_args( + {}, state, ['access_token'], + ['auth_response', 'token_response', 'refresh_token_response']) + + request_args = {'access_token': access_token} + + resp = client.do_request('userinfo', state=state, + request_args=request_args, **kwargs) + if is_error_message(resp): + raise OidcServiceError(resp['error']) + + return resp + + @staticmethod + def userinfo_in_id_token(id_token): + """ + Given an verified ID token return all the claims that may been user + information. + + :param id_token: An :py:class:`oidcmsg.oidc.IDToken` instance + :return: A dictionary with user information + """ + res = dict([(k, id_token[k]) for k in OpenIDSchema.c_param.keys() if + k in id_token]) + res.update(id_token.extra()) + return res + + def finalize_auth(self, client, issuer, response): + """ + Given the response returned to the redirect_uri, parse and verify it. + + :param client: A Client instance + :param issuer: An Issuer ID + :param response: The authorization response as a dictionary + :return: An :py:class:`oidcmsg.oidc.AuthorizationResponse` or + :py:class:`oidcmsg.oauth2.AuthorizationResponse` instance. + """ + _srv = client.get_service('authorization') + try: + authorization_response = _srv.parse_response(response, + sformat='dict') + except Exception as err: + logger.error('Parsing authorization_response: {}'.format(err)) + message = traceback.format_exception(*sys.exc_info()) + logger.error(message) + raise + else: + logger.debug( + 'Authz response: {}'.format(authorization_response.to_dict())) + + if is_error_message(authorization_response): + return authorization_response + + _context = client.client_get("service_context") + try: + _iss = _context.state.get_iss(authorization_response['state']) + except KeyError: + raise KeyError('Unknown state value') + + if _iss != issuer: + logger.error('Issuer problem: {} != {}'.format(_iss, issuer)) + # got it from the wrong bloke + raise ValueError('Impersonator {}'.format(issuer)) + + _srv.update_service_context(authorization_response, key=authorization_response['state']) + _context.state.store_item(authorization_response, "auth_response", + authorization_response['state']) + return authorization_response + + def get_access_and_id_token(self, authorization_response=None, state='', + client=None): + """ + There are a number of services where access tokens and ID tokens can + occur in the response. This method goes through the possible places + based on the response_type the client uses. + + :param authorization_response: The Authorization response + :param state: The state key (the state parameter in the + authorization request) + :return: A dictionary with 2 keys: **access_token** with the access + token as value and **id_token** with a verified ID Token if one + was returned otherwise None. + """ + + if client is None: + client = self.get_client_from_session_key(state) + + _context = client.client_get("service_context") + + if authorization_response is None: + if state: + authorization_response = _context.state.get_item( + AuthorizationResponse, 'auth_response', state) + else: + raise ValueError( + 'One of authorization_response or state must be provided') + + if not state: + state = authorization_response['state'] + + authreq = _context.state.get_item( + AuthorizationRequest, 'auth_request', state) + _resp_type = set(authreq['response_type']) + + access_token = None + id_token = None + if _resp_type in [{'id_token'}, {'id_token', 'token'}, + {'code', 'id_token', 'token'}]: + id_token = authorization_response['__verified_id_token'] + + if _resp_type in [{'token'}, {'id_token', 'token'}, {'code', 'token'}, + {'code', 'id_token', 'token'}]: + access_token = authorization_response["access_token"] + elif _resp_type in [{'code'}, {'code', 'id_token'}]: + + # get the access token + token_resp = self.get_access_token(state, client=client) + if is_error_message(token_resp): + return False, "Invalid response %s." % token_resp["error"] + + access_token = token_resp["access_token"] + + try: + id_token = token_resp['__verified_id_token'] + except KeyError: + pass + + return {'access_token': access_token, 'id_token': id_token} + + # noinspection PyUnusedLocal + def finalize(self, issuer, response): + """ + The third of the high level methods that a user of this Class should + know about. + Once the consumer has redirected the user back to the + callback URL there might be a number of services that the client should + use. Which one those are are defined by the client configuration. + + :param issuer: Who sent the response + :param response: The Authorization response as a dictionary + :returns: A dictionary with two claims: + **state** The key under which the session information is + stored in the data store and + **error** and encountered error or + **userinfo** The collected user information + """ + + client = self.issuer2rp[issuer] + + authorization_response = self.finalize_auth(client, issuer, response) + if is_error_message(authorization_response): + return { + 'state': authorization_response['state'], + 'error': authorization_response['error'] + } + + _state = authorization_response['state'] + token = self.get_access_and_id_token(authorization_response, + state=_state, client=client) + + if client.client_get("service", "userinfo") and token['access_token']: + inforesp = self.get_user_info( + state=authorization_response['state'], client=client, + access_token=token['access_token']) + + if isinstance(inforesp, ResponseMessage) and 'error' in inforesp: + return { + 'error': "Invalid response %s." % inforesp["error"], + 'state': _state + } + + elif token['id_token']: # look for it in the ID Token + inforesp = self.userinfo_in_id_token(token['id_token']) + else: + inforesp = {} + + logger.debug("UserInfo: %s", inforesp) + + _context = client.client_get("service_context") + try: + _sid_support = _context.get('provider_info')[ + 'backchannel_logout_session_supported'] + except KeyError: + try: + _sid_support = _context.get('provider_info')[ + 'frontchannel_logout_session_supported'] + except: + _sid_support = False + + if _sid_support: + try: + sid = token['id_token']['sid'] + except KeyError: + pass + else: + _context.state.store_sid2state(sid, _state) + + _context.state.store_sub2state(token['id_token']['sub'], _state) + + return { + 'userinfo': inforesp, + 'state': authorization_response['state'], + 'token': token['access_token'], + 'id_token': token['id_token'] + } + + def has_active_authentication(self, state): + """ + Find out if the user has an active authentication + + :param state: + :return: True/False + """ + + client = self.get_client_from_session_key(state) + + # Look for Id Token in all the places where it can be + _arg = client.client_get("service_context").state.multiple_extend_request_args( + {}, state, ['__verified_id_token'], + ['auth_response', 'token_response', 'refresh_token_response']) + + if _arg: + _now = time_sans_frac() + exp = _arg['__verified_id_token']['exp'] + return _now < exp + else: + return False + + def get_valid_access_token(self, state): + """ + Find a valid access token. + + :param state: + :return: An access token if a valid one exists and when it + expires. Otherwise raise exception. + """ + + exp = 0 + token = None + indefinite = [] + now = time_sans_frac() + + client = self.get_client_from_session_key(state) + _context = client.client_get("service_context") + for cls, typ in [(AccessTokenResponse, 'refresh_token_response'), + (AccessTokenResponse, 'token_response'), + (AuthorizationResponse, 'auth_response')]: + try: + response = _context.state.get_item(cls, typ, state) + except KeyError: + pass + else: + if 'access_token' in response: + access_token = response["access_token"] + try: + _exp = response['__expires_at'] + except KeyError: # No expiry date, lives for ever + indefinite.append((access_token, 0)) + else: + if _exp > now and _exp > exp: # expires sometime in the future + exp = _exp + token = (access_token, _exp) + + if indefinite: + return indefinite[0] + else: + if token: + return token + else: + raise OidcServiceError('No valid access token') + + def logout(self, state, client=None, post_logout_redirect_uri=''): + """ + Does a RP initiated logout from an OP. After logout the user will be + redirect by the OP to a URL of choice (post_logout_redirect_uri). + + :param state: Key to an active session + :param client: Which client to use + :param post_logout_redirect_uri: If a special post_logout_redirect_uri + should be used + :return: A US + """ + if client is None: + client = self.get_client_from_session_key(state) + + try: + srv = client.client_get('service', 'end_session') + except KeyError: + raise OidcServiceError("Does not know how to logout") + + if post_logout_redirect_uri: + request_args = { + "post_logout_redirect_uri": post_logout_redirect_uri + } + else: + request_args = {} + + resp = srv.get_request_parameters(state=state, + request_args=request_args) + + return resp + + def clear_session(self, state): + client = self.get_client_from_session_key(state) + client.client_get("service_context").state.remove_state(state) + + +def backchannel_logout(client, request='', request_args=None): + """ + + :param request: URL encoded logout request + :return: + """ + if request: + req = BackChannelLogoutRequest().from_urlencoded(as_unicode(request)) + else: + req = BackChannelLogoutRequest(**request_args) + + _context = client.client_get("service_context") + kwargs = { + 'aud': _context.get('client_id'), + 'iss': _context.get('issuer'), + 'keyjar': _context.keyjar, + 'allowed_sign_alg': _context.get('registration_response').get( + "id_token_signed_response_alg", "RS256") + } + + try: + req.verify(**kwargs) + except (MessageException, ValueError, NotForMe) as err: + raise MessageException('Bogus logout request: {}'.format(err)) + + # Find the subject through 'sid' or 'sub' + sub = req[verified_claim_name('logout_token')].get('sub') + sid = None + if not sub: + sid = req[verified_claim_name('logout_token')].get('sid') + + if not sub and not sid: + raise MessageException('Neither "sid" nor "sub"') + elif sub: + _state = _context.state.get_state_by_sub(sub) + elif sid: + _state = _context.state.get_state_by_sid(sid) + return _state diff --git a/src/oidcrp/service.py b/src/oidcrp/service.py new file mode 100644 index 0000000..0e350e2 --- /dev/null +++ b/src/oidcrp/service.py @@ -0,0 +1,642 @@ +""" The basic Service class upon which all the specific services are built. """ +import logging +from typing import Callable +from typing import Optional +from typing import Union +from urllib.parse import urlparse + +from cryptojwt.jwt import JWT +from cryptojwt.utils import qualified_name +from oidcmsg.impexp import ImpExp +from oidcmsg.item import DLDict +from oidcmsg.message import Message +from oidcmsg.oauth2 import ResponseMessage +from oidcmsg.oauth2 import is_error_message + +from oidcrp import util +from oidcrp.client_auth import factory as ca_factory +from oidcrp.configure import Configuration +from oidcrp.exception import ResponseError +from oidcrp.util import JOSE_ENCODED +from oidcrp.util import JSON_ENCODED +from oidcrp.util import URL_ENCODED +from oidcrp.util import get_http_body +from oidcrp.util import get_http_url + +__author__ = 'Roland Hedberg' + +LOGGER = logging.getLogger(__name__) + +SUCCESSFUL = [200, 201, 202, 203, 204, 205, 206] + +SPECIAL_ARGS = ['authn_endpoint', 'algs'] + +REQUEST_INFO = 'Doing request with: URL:{}, method:{}, data:{}, https_args:{}' + + +class Service(ImpExp): + """The basic Service class.""" + msg_type = Message + response_cls = Message + error_msg = ResponseMessage + endpoint_name = '' + endpoint = '' + service_name = '' + synchronous = True + default_authn_method = '' + http_method = 'GET' + request_body_type = 'urlencoded' + response_body_type = 'json' + + parameter = { + 'default_authn_method': None, + 'endpoint': "", + 'error_msg': object, + 'http_method': None, + 'msg_type': object, + 'request_body_type': None, + 'response_body_type': None, + 'response_cls': object + } + + init_args = ["client_get"] + + def __init__(self, + client_get: Callable, + conf: Optional[Union[dict, Configuration]] = None, + client_authn_factory: Optional[Callable] = None, + **kwargs): + ImpExp.__init__(self) + if client_authn_factory is None: + self.client_authn_factory = ca_factory + else: + self.client_authn_factory = client_authn_factory + + self.client_get = client_get + self.default_request_args = {} + if conf: + self.conf = conf + for param in ['msg_type', 'response_cls', 'error_msg', + 'default_authn_method', 'http_method', + 'request_body_type', 'response_body_type']: + if param in conf: + setattr(self, param, conf[param]) + else: + self.conf = {} + + # pull in all the modifiers + self.pre_construct = [] + self.post_construct = [] + self.construct_extra_headers = [] + + def gather_request_args(self, **kwargs): + """ + Go through the attributes that the message class can contain and + add values if they are missing but exists in the client info or + when there are default values. + + :param kwargs: Initial set of attributes. + :return: Possibly augmented set of attributes + """ + ar_args = kwargs.copy() + + _context = self.client_get("service_context") + # Go through the list of claims defined for the message class + # there are a couple of places where informtation can be found + # access them in the order of priority + # 1. A keyword argument + # 2. configured set of default attribute values + # 3. default attribute values defined in the OIDC standard document + for prop in self.msg_type.c_param: + if prop in ar_args: + continue + + val = _context.get(prop) + if not val: + if "request_args" in self.conf: + val = self.conf['request_args'].get(prop) + if not val: + val = _context.register_args.get(prop) + if not val: + val = self.default_request_args.get(prop) + if not val: + val = _context.behaviour.get(prop) + + if val: + ar_args[prop] = val + + return ar_args + + def method_args(self, context, **kwargs): + """ + Collect the set of arguments that should be used by a set of methods + + :param context: Which service we're working for + :param kwargs: A set of keyword arguments that are added at run-time. + :return: A set of keyword arguments + """ + try: + _args = self.conf[context].copy() + except KeyError: + _args = kwargs + else: + _args.update(kwargs) + return _args + + def do_pre_construct(self, request_args, **kwargs): + """ + Will run the pre_construct methods one by one in the order given. + + :param request_args: Request arguments + :param kwargs: Extra key word arguments + :return: A tuple of request_args and post_args. post_args are to be + used by the post_construct methods. + """ + + _args = self.method_args('pre_construct', **kwargs) + post_args = {} + for meth in self.pre_construct: + request_args, _post_args = meth(request_args, service=self, post_args=post_args, + **_args) + # Not necessarily independent + # post_args.update(_post_args) + + return request_args, post_args + + def do_post_construct(self, request_args, **kwargs): + """ + Will run the post_construct methods one at the time in order. + + :param request_args: Request arguments + :param kwargs: Arguments used by the post_construct method + :return: Possible modified set of request arguments. + """ + _args = self.method_args('post_construct', **kwargs) + + for meth in self.post_construct: + request_args = meth(request_args, service=self, **_args) + + return request_args + + def update_service_context(self, resp, key='', **kwargs): + """ + A method run after the response has been parsed and verified. + + :param resp: The response as a :py:class:`oidcmsg.Message` instance + :param key: The key under which the response should be stored + :param kwargs: Extra key word arguments + """ + pass + + def construct(self, request_args=None, **kwargs): + """ + Instantiate the request as a message class instance with + attribute values gathered in a pre_construct method or in the + gather_request_args method. + + :param request_args: + :param kwargs: extra keyword arguments + :return: message class instance + """ + if request_args is None: + request_args = {} + + # run the pre_construct methods. Will return a possibly new + # set of request arguments but also a set of arguments to + # be used by the post_construct methods. + request_args, post_args = self.do_pre_construct(request_args, + **kwargs) + + # If 'state' appears among the keyword argument and is not + # expected to appear in the request, remove it. + if 'state' in self.msg_type.c_param and 'state' in kwargs: + # Don't overwrite something put there by the constructor + if 'state' not in request_args: + request_args['state'] = kwargs['state'] + + # logger.debug("request_args: %s" % sanitize(request_args)) + _args = self.gather_request_args(**request_args) + + # logger.debug("kwargs: %s" % sanitize(kwargs)) + # initiate the request as in an instance of the self.msg_type + # message type + request = self.msg_type(**_args) + + return self.do_post_construct(request, **post_args) + + def init_authentication_method(self, request, authn_method, + http_args=None, **kwargs): + """ + Will run the proper client authentication method. + Each such method will place the necessary information in the necessary + place. A method may modify the request. + + :param request: The request, a Message class instance + :param authn_method: Client authentication method + :param http_args: HTTP header arguments + :param kwargs: Extra keyword arguments + :return: Extended set of HTTP header arguments + """ + if http_args is None: + http_args = {} + + if authn_method: + LOGGER.debug('Client authn method: %s', authn_method) + return self.client_authn_factory(authn_method).construct( + request, self, http_args=http_args, **kwargs) + + return http_args + + def construct_request(self, request_args=None, **kwargs): + """ + The method where everything is setup for sending the request. + The request information is gathered and the where and how of sending the + request is decided. + + :param request_args: Initial request arguments as a dictionary + :param kwargs: Extra keyword arguments + :return: A dictionary with the keys 'url' and possibly 'body', 'kwargs', + 'request' and 'ht_args'. + """ + if request_args is None: + request_args = {} + + return self.construct(request_args, **kwargs) + + def get_endpoint(self): + """ + Find the service endpoint + + :return: The service endpoint (a URL) + """ + if self.endpoint: + return self.endpoint + + return self.client_get("service_context").provider_info[self.endpoint_name] + + def get_authn_header(self, + request: Union[dict, Message], + authn_method: Optional[str] = '', + **kwargs) -> dict: + """ + Construct an authorization specification to be sent in the + HTTP header. + + :param request: The service request + :param authn_method: Which authentication/authorization method to use + :param kwargs: Extra keyword arguments + :return: A set of keyword arguments to be sent in the HTTP header. + """ + headers = {} + # If I should deal with client authentication + if authn_method: + h_arg = self.init_authentication_method(request, authn_method, + **kwargs) + try: + headers = h_arg['headers'] + except KeyError: + pass + + return headers + + def get_authn_method(self) -> str: + """ + Find the method that the client should use to authenticate against a + service. + + :return: The authn/authz method + """ + return self.default_authn_method + + def get_headers(self, + request: Union[dict, Message], + http_method: str, + authn_method: Optional[str] = '', + **kwargs) -> dict: + """ + + :param request: + :param authn_method: + :param kwargs: + :return: + """ + if not authn_method: + authn_method = self.get_authn_method() + + _headers = self.get_authn_header(request, + authn_method=authn_method, + authn_endpoint=self.endpoint_name, + **kwargs) + + for meth in self.construct_extra_headers: + _headers = meth(self.client_get("service_context"), + headers=_headers, + request=request, + authn_method=authn_method, + service_endpoint=self.endpoint_name, + http_method=http_method, + **kwargs) + + return _headers + + def get_request_parameters(self, request_args=None, method="", + request_body_type="", authn_method='', **kwargs): + """ + Builds the request message and constructs the HTTP headers. + + This is the starting point for a pipeline that will: + + - construct the request message + - add/remove information to/from the request message in the way a + specific client authentication method requires. + - gather a set of HTTP headers like Content-type and Authorization. + - serialize the request message into the necessary format (JSON, + urlencoded, signed JWT) + + :param request_body_type: Which serialization to use for the HTTP body + :param method: HTTP method used. + :param authn_method: Client authentication method + :param request_args: Message arguments + :param kwargs: extra keyword arguments + :return: Dictionary with the necessary information for the HTTP + request + """ + if not method: + method = self.http_method + if not authn_method: + authn_method = self.get_authn_method() + if not request_body_type: + request_body_type = self.request_body_type + + request = self.construct_request(request_args=request_args, **kwargs) + + LOGGER.debug("Request: %s", request) + _info = {'method': method, "request": request} + + _args = kwargs.copy() + _context = self.client_get("service_context") + if _context.issuer: + _args['iss'] = _context.issuer + + # Client authentication by usage of the Authorization HTTP header + # or by modifying the request object + _headers = self.get_headers(request, http_method=method, + authn_method=authn_method, **_args) + + # Find out where to send this request + try: + endpoint_url = kwargs['endpoint'] + except KeyError: + endpoint_url = self.get_endpoint() + + _info['url'] = get_http_url(endpoint_url, request, method=method) + + # If there is to be a body part + if method == 'POST': + # How should it be serialized + if request_body_type == 'urlencoded': + content_type = URL_ENCODED + elif request_body_type in ['jws', 'jwe', 'jose']: + content_type = JOSE_ENCODED + else: # request_body_type == 'json' + content_type = JSON_ENCODED + + _info['body'] = get_http_body(request, content_type) + _headers.update({'Content-Type': content_type}) + + if _headers: + _info['headers'] = _headers + + return _info + + # ------------------ response handling ----------------------- + + @staticmethod + def get_urlinfo(info): + """ + Pick out the fragment or query part from a URL. + + :param info: A URL possibly containing a query or a fragment part + :return: the query/fragment part + """ + # If info is a whole URL pick out the query or fragment part + if '?' in info or '#' in info: + parts = urlparse(info) + # either query of fragment + if parts.query: + info = parts.query + else: + info = parts.fragment + return info + + def post_parse_response(self, response, **kwargs): + """ + This method does post processing of the service response. + Each service have their own version of this method. + + :param response: The service response + :param kwargs: A set of keyword arguments + :return: The possibly modified response + """ + return response + + def gather_verify_arguments(self): + """ + Need to add some information before running verify() + + :return: dictionary with arguments to the verify call + """ + + _context = self.client_get("service_context") + kwargs = { + 'iss': _context.issuer, + 'keyjar': _context.keyjar, + 'verify': True + } + + _client_id = _context.client_id + if _client_id: + kwargs['client_id'] = _client_id + + if self.service_name == "provider_info": + if _context.issuer.startswith("http://"): + kwargs["allow_http"] = True + + return kwargs + + def _do_jwt(self, info): + _context = self.client_get("service_context") + args = {'allowed_sign_algs': _context.get_sign_alg(self.service_name)} + enc_algs = _context.get_enc_alg_enc(self.service_name) + args['allowed_enc_algs'] = enc_algs['alg'] + args['allowed_enc_encs'] = enc_algs['enc'] + _jwt = JWT(key_jar=_context.keyjar, **args) + _jwt.iss = _context.client_id + return _jwt.unpack(info) + + def _do_response(self, info, sformat, **kwargs): + _context = self.client_get("service_context") + + try: + resp = self.response_cls().deserialize( + info, sformat, iss=_context.issuer, **kwargs) + except Exception as err: + resp = None + if sformat == 'json': + # Could be JWS or JWE but wrongly tagged + # Adding issuer is just a fail-safe. If one things was wrong + # then two can be. + try: + resp = self.response_cls().deserialize( + info, 'jwt', iss=_context.issuer, **kwargs) + except Exception: + pass + + if resp is None: + LOGGER.error('Error while deserializing: %s', err) + raise + return resp + + def parse_response(self, info, sformat="", state="", **kwargs): + """ + This the start of a pipeline that will: + + 1 Deserializes a response into it's response message class. + Or :py:class:`oidcmsg.oauth2.ErrorResponse` if it's an error + message + 2 verifies the correctness of the response by running the + verify method belonging to the message class used. + 3 runs the do_post_parse_response method iff the response was not + an error response. + + :param info: The response, can be either in a JSON or an urlencoded + format + :param sformat: Which serialization that was used + :param state: The state + :param kwargs: Extra key word arguments + :return: The parsed and to some extend verified response + """ + + if not sformat: + sformat = self.response_body_type + + LOGGER.debug('response format: %s', sformat) + + if sformat in ['jose', 'jws', 'jwe']: + resp = self.post_parse_response(info, state=state) + + if not resp: + LOGGER.error('Missing or faulty response') + raise ResponseError("Missing or faulty response") + + return resp + + # If format is urlencoded 'info' may be a URL + # in which case I have to get at the query/fragment part + if sformat == "urlencoded": + info = self.get_urlinfo(info) + + if sformat == 'jwt': + info = self._do_jwt(info) + sformat = "dict" + + LOGGER.debug('response_cls: %s', self.response_cls.__name__) + + resp = self._do_response(info, sformat, **kwargs) + + LOGGER.debug('Initial response parsing => "%s"', resp.to_dict()) + + # is this an error message + if is_error_message(resp): + LOGGER.debug('Error response: %s', resp) + else: + vargs = self.gather_verify_arguments() + LOGGER.debug("Verify response with %s", vargs) + try: + # verify the message. If something is wrong an exception is + # thrown + resp.verify(**vargs) + except Exception as err: + LOGGER.error( + 'Got exception while verifying response: %s', err) + raise + + resp = self.post_parse_response(resp, state=state) + + if not resp: + LOGGER.error('Missing or faulty response') + raise ResponseError("Missing or faulty response") + + return resp + + def get_conf_attr(self, attr, default=None): + """ + Get the value of a attribute in the configuration + + :param attr: The attribute + :param default: If the attribute doesn't appear in the configuration + return this value + :return: The value of attribute in the configuration or the default + value + """ + if attr in self.conf: + return self.conf[attr] + + return default + + +def gather_constructors(service_methods, construct): + """Loads the construct methods that are defined.""" + try: + _methods = service_methods + except KeyError: + pass + else: + for meth in _methods: + try: + func = meth['function'] + except KeyError: + pass + else: + construct.append(util.importer(func)) + + +def init_services(service_definitions, client_get, client_authn_factory=None): + """ + Initiates a set of services + + :param service_definitions: A dictionary containing service definitions + :param client_get: A function that returns different things from the base entity. + :param client_authn_factory: A list of methods the services can use to + authenticate the client to a service. + :return: A dictionary, with service name as key and the service instance as + value. + """ + service = DLDict() + for service_name, service_configuration in service_definitions.items(): + try: + kwargs = service_configuration['kwargs'] + except KeyError: + kwargs = {} + + kwargs.update({ + 'client_get': client_get, + 'client_authn_factory': client_authn_factory + }) + + if isinstance(service_configuration['class'], str): + _value_cls = service_configuration['class'] + _cls = util.importer(service_configuration['class']) + _srv = _cls(**kwargs) + else: + _value_cls = qualified_name(service_configuration['class']) + _srv = service_configuration['class'](**kwargs) + + if 'post_functions' in service_configuration: + gather_constructors(service_configuration['post_functions'], _srv.post_construct) + if 'pre_functions' in service_configuration: + gather_constructors(service_configuration['pre_functions'], _srv.pre_construct) + + service[_srv.service_name] = _srv + + return service diff --git a/src/oidcrp/service_context.py b/src/oidcrp/service_context.py new file mode 100644 index 0000000..0247162 --- /dev/null +++ b/src/oidcrp/service_context.py @@ -0,0 +1,299 @@ +""" +Implements a service context. A Service context is used to keep information that are +common between all the services that are used by OAuth2 client or OpenID Connect Relying Party. +""" +import copy +import hashlib +import os + +from cryptojwt.jwk.rsa import RSAKey, import_private_rsa_key_from_file +from cryptojwt.key_bundle import KeyBundle +from cryptojwt.utils import as_bytes +from oidcmsg.context import OidcContext +from oidcmsg.oidc import RegistrationRequest + +from oidcrp.state_interface import StateInterface + +CLI_REG_MAP = { + "userinfo": { + "sign": "userinfo_signed_response_alg", + "alg": "userinfo_encrypted_response_alg", + "enc": "userinfo_encrypted_response_enc" + }, + "id_token": { + "sign": "id_token_signed_response_alg", + "alg": "id_token_encrypted_response_alg", + "enc": "id_token_encrypted_response_enc" + }, + "request_object": { + "sign": "request_object_signing_alg", + "alg": "request_object_encryption_alg", + "enc": "request_object_encryption_enc" + } +} + +PROVIDER_INFO_MAP = { + "id_token": { + "sign": "id_token_signing_alg_values_supported", + "alg": "id_token_encryption_alg_values_supported", + "enc": "id_token_encryption_enc_values_supported" + }, + "userinfo": { + "sign": "userinfo_signing_alg_values_supported", + "alg": "userinfo_encryption_alg_values_supported", + "enc": "userinfo_encryption_enc_values_supported" + }, + "request_object": { + "sign": "request_object_signing_alg_values_supported", + "alg": "request_object_encryption_alg_values_supported", + "enc": "request_object_encryption_enc_values_supported" + }, + "token_enpoint_auth": { + "sign": "token_endpoint_auth_signing_alg_values_supported" + } +} + +DEFAULT_VALUE = { + 'client_secret': '', + 'client_id': '', + 'redirect_uris': [], + 'provider_info': {}, + 'behaviour': {}, + 'callback': {}, + 'issuer': '' +} + + +# def add_issuer(conf, issuer): +# res = {} +# for key, val in conf.items(): +# if key == 'abstract_storage_cls': +# res[key] = val +# else: +# _val = copy.deepcopy(val) +# _val['issuer'] = issuer +# res[key] = _val +# return res + + +class ServiceContext(OidcContext): + """ + This class keeps information that a client needs to be able to talk + to a server. Some of this information comes from configuration and some + from dynamic provider info discovery or client registration. + But information is also picked up during the conversation with a server. + """ + parameter = OidcContext.parameter.copy() + parameter.update({ + "add_on": None, + "allow": None, + "args": None, + "base_url": None, + 'behaviour': None, + 'callback': None, + 'client_id': None, + "client_preferences": None, + 'client_secret': None, + "client_secret_expires_at": 0, + 'clock_skew': None, + "config": None, + "httpc_params": None, + 'issuer': None, + "kid": None, + "post_logout_redirect_uris": [], + 'provider_info': None, + 'redirect_uris': None, + "requests_dir": None, + "register_args": None, + 'registration_response': None, + 'state': StateInterface, + 'verify_args': None + }) + + def __init__(self, base_url="", keyjar=None, config=None, state=None, **kwargs): + if config is None: + config = {} + self.config = config + + OidcContext.__init__(self, config, keyjar, entity_id=config.get('client_id', '')) + self.state = state or StateInterface() + + self.kid = {"sig": {}, "enc": {}} + + self.base_url = base_url + # Below so my IDE won't complain + self.allow = {} + self.client_preferences = {} + self.args = {} + self.add_on = {} + self.httpc_params = {} + self.issuer = "" + self.client_id = "" + self.client_secret = "" + self.client_secret_expires_at = 0 + self.behaviour = {} + self.provider_info = {} + self.post_logout_redirect_uris = [] + self.redirect_uris = [] + self.register_args = {} + self.registration_response = {} + self.requests_dir = '' + + _def_value = copy.deepcopy(DEFAULT_VALUE) + # Dynamic information + for param in ['client_secret', 'client_id', 'redirect_uris', 'provider_info', + 'behaviour', 'callback', 'issuer']: + _val = config.get(param, _def_value[param]) + self.set(param, _val) + if param == 'client_secret': + self.keyjar.add_symmetric('', _val) + + if not self.issuer: + self.issuer = self.provider_info.get("issuer", "") + + try: + self.clock_skew = config['clock_skew'] + except KeyError: + self.clock_skew = 15 + + for key, val in kwargs.items(): + setattr(self, key, val) + + for attr in ['base_url', 'requests_dir', 'allow', 'client_preferences', 'verify_args']: + try: + setattr(self, attr, config[attr]) + except KeyError: + pass + + for attr in RegistrationRequest.c_param: + try: + self.register_args[attr] = config[attr] + except KeyError: + pass + + if self.requests_dir: + # make sure the path exists. If not, then make it. + if not os.path.isdir(self.requests_dir): + os.makedirs(self.requests_dir) + + try: + self.import_keys(config['keys']) + except KeyError: + pass + + def __setitem__(self, key, value): + setattr(self, key, value) + + def filename_from_webname(self, webname): + """ + A 1<->1 map is maintained between a URL pointing to a file and + the name of the file in the file system. + + As an example if the base_url is 'https://example.com' and a jwks_uri + is 'https://example.com/jwks_uri.json' then the filename of the + corresponding file on the local filesystem would be 'jwks_uri'. + Relative to the directory from which the RP instance is run. + + :param webname: The published URL + :return: local filename + """ + if not webname.startswith(self.base_url): + raise ValueError("Webname doesn't match base_url") + + _name = webname[len(self.base_url):] + if _name.startswith('/'): + return _name[1:] + + return _name + + def generate_request_uris(self, path): + """ + Need to generate a redirect_uri path that is unique for a OP/RP combo + This is to counter the mix-up attack. + + :param path: Leading path + :return: A list of one unique URL + """ + _hash = hashlib.sha256() + try: + _hash.update(as_bytes(self.provider_info['issuer'])) + except KeyError: + _hash.update(as_bytes(self.issuer)) + _hash.update(as_bytes(self.base_url)) + + if not path.startswith('/'): + redirs = ['{}/{}/{}'.format(self.base_url, path, _hash.hexdigest())] + else: + redirs = ['{}{}/{}'.format(self.base_url, path, _hash.hexdigest())] + + self.set('redirect_uris', redirs) + return redirs + + def import_keys(self, keyspec): + """ + The client needs it's own set of keys. It can either dynamically + create them or load them from local storage. + This method can also fetch other entities keys provided the + URL points to a JWKS. + + :param keyspec: + """ + for where, spec in keyspec.items(): + if where == 'file': + for typ, files in spec.items(): + if typ == 'rsa': + for fil in files: + _key = RSAKey( + priv_key=import_private_rsa_key_from_file(fil), + use='sig') + _bundle = KeyBundle() + _bundle.append(_key) + self.keyjar.add_kb('', _bundle) + elif where == 'url': + for iss, url in spec.items(): + _bundle = KeyBundle(source=url) + self.keyjar.add_kb(iss, _bundle) + + def get_sign_alg(self, typ): + """ + + :param typ: ['id_token', 'userinfo', 'request_object'] + :return: + """ + + try: + return self.behaviour[CLI_REG_MAP[typ]['sign']] + except KeyError: + try: + return self.provider_info[PROVIDER_INFO_MAP[typ]['sign']] + except (KeyError, TypeError): + pass + + return None + + def get_enc_alg_enc(self, typ): + """ + + :param typ: + :return: + """ + + res = {} + for attr in ['enc', 'alg']: + try: + _alg = self.behaviour[CLI_REG_MAP[typ][attr]] + except KeyError: + try: + _alg = self.provider_info[PROVIDER_INFO_MAP[typ][attr]] + except KeyError: + _alg = None + + res[attr] = _alg + + return res + + def get(self, key, default=None): + return getattr(self, key, default) + + def set(self, key, value): + setattr(self, key, value) diff --git a/src/oidcrp/service_factory.py b/src/oidcrp/service_factory.py new file mode 100644 index 0000000..796823f --- /dev/null +++ b/src/oidcrp/service_factory.py @@ -0,0 +1,32 @@ +from glob import glob +import inspect +from os.path import basename +from os.path import dirname +from os.path import join +import sys + +from oidcrp.service import Service + + +def service_factory(req_name, module_dirs, **kwargs): + pwd = dirname(__file__) + if pwd not in sys.path: + sys.path.insert(0, pwd) + + for dir in module_dirs: + for x in glob(join(pwd, dir, '*.py')): + _mod = basename(x)[:-3] + if not _mod.startswith('__'): + if '/' in dir: + dir = dir.replace('/', '.') + _dir_mod = '{}.{}'.format(dir, basename(x)[:-3]) + if _dir_mod not in sys.modules: + __import__(_dir_mod, globals(), locals()) + + for name, obj in inspect.getmembers(sys.modules[_dir_mod]): + if inspect.isclass(obj) and issubclass(obj, Service): + try: + if obj.__name__ == req_name: + return obj(**kwargs) + except AttributeError: + pass diff --git a/src/oidcrp/state_interface.py b/src/oidcrp/state_interface.py new file mode 100644 index 0000000..a377da1 --- /dev/null +++ b/src/oidcrp/state_interface.py @@ -0,0 +1,388 @@ +"""A database interface for storing state information.""" +import json + +from oidcmsg.impexp import ImpExp +from oidcmsg.message import Message +from oidcmsg.message import SINGLE_OPTIONAL_JSON +from oidcmsg.message import SINGLE_REQUIRED_STRING +from oidcmsg.oidc import verified_claim_name + +from oidcrp.util import rndstr + + +class State(Message): + """A structure to keep information about previous events.""" + c_param = { + 'iss': SINGLE_REQUIRED_STRING, + 'auth_request': SINGLE_OPTIONAL_JSON, + 'auth_response': SINGLE_OPTIONAL_JSON, + 'token_response': SINGLE_OPTIONAL_JSON, + 'refresh_token_request': SINGLE_OPTIONAL_JSON, + 'refresh_token_response': SINGLE_OPTIONAL_JSON, + 'user_info': SINGLE_OPTIONAL_JSON + } + + +KEY_PATTERN = { + 'nonce': '__{}__', + 'logout state': '::{}::', + 'session id': '..{}..', + 'subject id': '=={}==' +} + + +class InMemoryStateDataBase: + """The simplest possible implementation of the state database.""" + + def __init__(self): + self._db = {} + + def set(self, key, value): + """Assign a value to a key.""" + self._db[key] = value + + def get(self, key): + """Return the value bound to a key.""" + try: + return self._db[key] + except KeyError: + return None + + def delete(self, key): + """Delete a key and its value.""" + try: + del self._db[key] + except KeyError: + pass + + def __setitem__(self, key, value): + """Assign a value to a key.""" + self._db[key] = value + + def __getitem__(self, key): + """Return the value bound to a key.""" + try: + return self._db[key] + except KeyError: + return None + + def __delitem__(self, key): + """Delete a key and its value.""" + try: + del self._db[key] + except KeyError: + pass + + +class StateInterface(ImpExp): + """A more powerful interface to a state DB.""" + + parameter = { + "_db": None + } + + def __init__(self): + ImpExp.__init__(self) + self._db = {} + + def get_state(self, key): + """ + Get the state connected to a given key. + + :param key: Key into the state database + :return: A :py:class:´oidcrp.state_interface.State` instance + """ + _data = self._db.get(key) + if not _data: + raise KeyError(key) + + return State().from_json(_data) + + def store_item(self, item, item_type, key): + """ + Store a service response. + + :param item: The item as a :py:class:`oidcmsg.message.Message` + subclass instance or a JSON document. + :param item_type: The type of request or response + :param key: The key under which the information should be stored in + the state database + """ + try: + _state = self.get_state(key) + except KeyError: + _state = State() + + try: + _state[item_type] = item.to_json() + except AttributeError: + _state[item_type] = item + + self._db[key] = _state.to_json() + + def get_iss(self, key): + """ + Get the Issuer ID + + :param key: Key to the information in the state database + :return: The issuer ID + """ + _state = self.get_state(key) + if not _state: + raise KeyError(key) + return _state['iss'] + + def get_item(self, item_cls, item_type, key): + """ + Get a piece of information (a request or a response) from the state + database. + + :param item_cls: The :py:class:`oidcmsg.message.Message` subclass + that described the item. + :param item_type: Which request/response that is wanted + :param key: The key to the information in the state database + :return: A :py:class:`oidcmsg.message.Message` instance + """ + _state = self.get_state(key) + try: + return item_cls(**_state[item_type]) + except TypeError: + return item_cls().from_json(_state[item_type]) + + def extend_request_args(self, args, item_cls, item_type, key, + parameters, orig=False): + """ + Add a set of parameters and their value to a set of request arguments. + + :param args: A dictionary + :param item_cls: The :py:class:`oidcmsg.message.Message` subclass + that describes the item + :param item_type: The type of item, this is one of the parameter + names in the :py:class:`oidcrp.state_interface.State` class. + :param key: The key to the information in the database + :param parameters: A list of parameters who's values this method + will return. + :param orig: Where the value of a claim is a signed JWT return + that. + :return: A dictionary with keys from the list of parameters and + values being the values of those parameters in the item. + If the parameter does not a appear in the item it will not appear + in the returned dictionary. + """ + try: + item = self.get_item(item_cls, item_type, key) + except KeyError: + pass + else: + for parameter in parameters: + if orig: + try: + args[parameter] = item[parameter] + except KeyError: + pass + else: + try: + args[parameter] = item[verified_claim_name(parameter)] + except KeyError: + try: + args[parameter] = item[parameter] + except KeyError: + pass + + return args + + def multiple_extend_request_args(self, args, key, parameters, item_types, + orig=False): + """ + Go through a set of items (by their type) and add the attribute-value + that match the list of parameters to the arguments + If the same parameter occurs in 2 different items then the value in + the later one will be the one used. + + :param args: Initial set of arguments + :param key: Key to the State information in the state database + :param parameters: A list of parameters that we're looking for + :param item_types: A list of item_type specifying which items we + are interested in. + :param orig: Where the value of a claim is a signed JWT return + that. + :return: A possibly augmented set of arguments. + """ + _state = self.get_state(key) + + for typ in item_types: + try: + _item = Message(**_state[typ]) + except KeyError: + continue + + for parameter in parameters: + if orig: + try: + args[parameter] = _item[parameter] + except KeyError: + pass + else: + try: + args[parameter] = _item[verified_claim_name(parameter)] + except KeyError: + try: + args[parameter] = _item[parameter] + except KeyError: + pass + + return args + + def store_x2state(self, value, state, xtyp): + """ + Store the connection between some value (x) and a state value. + This allows us later in the game to find the state if we have x. + + :param value: The value + :param state: The state value + :param xtyp: The type of value x is (e.g. nonce, ...) + """ + self._db[KEY_PATTERN[xtyp].format(value)] = state + try: + _val = self._db.get("ref{}ref".format(state)) + except KeyError: + _val = None + + if _val is None: + refs = {xtyp: value} + else: + refs = json.loads(_val) + refs[xtyp] = value + self._db["ref{}ref".format(state)] = json.dumps(refs) + + def get_state_by_x(self, value, xtyp): + """ + Find the state value by providing the x value. + Will raise an exception if the x value is absent from the state + data base. + + :param value: The value + :return: The state value + """ + _state = self._db.get(KEY_PATTERN[xtyp].format(value)) + if _state: + return _state + + raise KeyError('Unknown {}: "{}"'.format(xtyp, value)) + + def store_nonce2state(self, nonce, state): + """ + Store the connection between a nonce value and a state value. + This allows us later in the game to find the state if we have the nonce. + + :param nonce: The nonce value + :param state: The state value + """ + self.store_x2state(nonce, state, 'nonce') + + def get_state_by_nonce(self, nonce): + """ + Find the state value by providing the nonce value. + Will raise an exception if the nonce value is absent from the state + data base. + + :param nonce: The nonce value + :return: The state value + """ + return self.get_state_by_x(nonce, 'nonce') + + def store_logout_state2state(self, logout_state, state): + """ + Store the connection between a logout state value and a state value. + This allows us later in the game to find the state if we have the + logout state value. + + :param logout_state: The logout state value + :param state: The state value + """ + self.store_x2state(logout_state, state, 'logout state') + + def get_state_by_logout_state(self, logout_state): + """ + Find the state value by providing the logout state value. + Will raise an exception if the logout state value is absent from the + state data base. + + :param logout_state: The logout state value + :return: The state value + """ + return self.get_state_by_x(logout_state, 'logout state') + + def store_sid2state(self, sid, state): + """ + Store the connection between a session id (sid) value and a state value. + This allows us later in the game to find the state if we have the + sid value. + + :param sid: The session ID value + :param state: The state value + """ + self.store_x2state(sid, state, 'session id') + + def get_state_by_sid(self, sid): + """ + Find the state value by providing the logout state value. + Will raise an exception if the logout state value is absent from the + state data base. + + :param sid: The session ID value + :return: The state value + """ + return self.get_state_by_x(sid, 'session id') + + def store_sub2state(self, sub, state): + """ + Store the connection between a subject id (sub) value and a state value. + This allows us later in the game to find the state if we have the + sub value. + + :param sub: The Subject ID value + :param state: The state value + """ + self.store_x2state(sub, state, 'subject id') + + def get_state_by_sub(self, sub): + """ + Find the state value by providing the subject id value. + Will raise an exception if the subject id value is absent from the + state data base. + + :param sub: The Subject ID value + :return: The state value + """ + return self.get_state_by_x(sub, 'subject id') + + def create_state(self, iss, key=''): + """ + Create a State and assign some value to it. + + :param iss: The issuer + :param key: A key to use to access the state + """ + if not key: + key = rndstr(32) + else: + if key.startswith('__') and key.endswith('__'): + raise ValueError( + 'Invalid format. Leading and trailing "__" not allowed') + + _state = State(iss=iss) + self._db[key] = _state.to_json() + return key + + def remove_state(self, state): + """ + Remove a state. + + :param state: Key to the state + """ + del self._db[state] + refs = json.loads(self._db_db.get("ref{}ref".format(state))) + if refs: + for xtyp, _val in refs.items(): + del self._db[KEY_PATTERN[xtyp].format(_val)] diff --git a/src/oidcrp/util.py b/src/oidcrp/util.py index fa9d17f..d33bef2 100755 --- a/src/oidcrp/util.py +++ b/src/oidcrp/util.py @@ -1,25 +1,41 @@ +"""Utilities""" +from http.cookiejar import Cookie +from http.cookiejar import http2time import importlib import io import json import logging import os import ssl +import string import sys -from http.cookiejar import Cookie -from http.cookiejar import http2time +from urllib.parse import parse_qs +from urllib.parse import urlsplit +from urllib.parse import urlunsplit +from oidcmsg.exception import UnSupported +from oidcmsg.oauth2 import is_error_message import yaml -from oidcservice import sanitize -from oidcservice.exception import TimeFormatError -from oidcservice.exception import WrongContentType -from oidcservice.util import importer -logger = logging.getLogger(__name__) +# Since SystemRandom is not available on all systems +try: + import SystemRandom as rnd +except ImportError: + import random as rnd -__author__ = 'roland' +from oidcrp.defaults import BASECHR +from oidcrp.exception import ConfigurationError +from oidcrp.exception import OidcServiceError +from oidcrp.exception import TimeFormatError +from oidcrp.exception import WrongContentType URL_ENCODED = 'application/x-www-form-urlencoded' JSON_ENCODED = "application/json" +JOSE_ENCODED = "application/jose" + +logger = logging.getLogger(__name__) + +__author__ = 'roland' DEFAULT_POST_CONTENT_TYPE = URL_ENCODED @@ -50,6 +66,115 @@ } +def token_secret_key(sid): + return "token_secret_%s" % sid + + +def rndstr(size=16): + """ + Returns a string of random ascii characters or digits + + :param size: The length of the string + :return: string + """ + chars = string.ascii_letters + string.digits + return "".join(rnd.choice(chars) for i in range(size)) + + +def unreserved(size=64): + """ + Returns a string of random ascii characters, digits and unreserved + characters + + :param size: The length of the string + :return: string + """ + + return "".join([rnd.choice(BASECHR) for _ in range(size)]) + + +def sanitize(str): + return str + + +def get_http_url(url, req, method='GET'): + """ + Add a query part representing the request to a url that may already contain + a query part. Only done if the HTTP method used is 'GET' or 'DELETE'. + + :param url: The URL + :param req: The request as a :py:class:`oidcmsg.message.Message` instance + :param method: The HTTP method + :return: A possibly modified URL + """ + if method in ["GET", "DELETE"]: + if req.keys(): + _req = req.copy() + comp = urlsplit(str(url)) + if comp.query: + _req.update(parse_qs(comp.query)) + + _query = str(_req.to_urlencoded()) + return urlunsplit((comp.scheme, comp.netloc, comp.path, + _query, comp.fragment)) + + return url + + return url + + +def get_http_body(req, content_type=URL_ENCODED): + """ + Get the message into the format that should be places in the body part + of a HTTP request. + + :param req: The service request as a :py:class:`oidcmsg.message.Message` + instance + :param content_type: The format of the body part. + :return: The correctly formatet service request. + """ + if URL_ENCODED in content_type: + return req.to_urlencoded() + + if JSON_ENCODED in content_type: + return req.to_json() + + if JOSE_ENCODED in content_type: + return req # already packaged + + raise UnSupported( + "Unsupported content type: '%s'" % content_type) + + +def load_yaml_config(filename): + """Load a YAML configuration file.""" + with open(filename, "rt", encoding='utf-8') as file: + config_dict = yaml.safe_load(file) + return config_dict + + +def modsplit(name): + """Split importable""" + if ':' in name: + _part = name.split(':') + if len(_part) != 2: + raise ValueError("Syntax error: {s}") + return _part[0], _part[1] + + _part = name.split('.') + if len(_part) < 2: + raise ValueError("Syntax error: {s}") + + return '.'.join(_part[:-1]), _part[-1] + + +def importer(name): + """Import by name""" + _part = modsplit(name) + module = importlib.import_module(_part[0]) + return getattr(module, _part[1]) + + def match_to_(val, vlist): if isinstance(vlist, str): if vlist.startswith(val): @@ -267,12 +392,6 @@ def load_json(file_name): return js -def load_yaml_config(file_name): - with open(file_name) as fp: - c = yaml.safe_load(fp) - return c - - def yaml_to_py_stream(file_name): d = load_yaml_config(file_name) fstream = io.StringIO() @@ -402,3 +521,60 @@ def get_http_params(config): params['cert'] = _cert return params + + +def add_path(url, path): + if url.endswith('/'): + if path.startswith('/'): + return '{}{}'.format(url, path[1:]) + else: + return '{}{}'.format(url, path) + else: + if path.startswith('/'): + return '{}{}'.format(url, path) + else: + return '{}/{}'.format(url, path) + + +def load_registration_response(client): + """ + If the client has been statically registered that information + must be provided during the configuration. If expected to be + done dynamically. This method will do dynamic client registration. + + :param client: A :py:class:`oidcrp.oidc.Client` instance + """ + if not client.client_get("service_context").get('client_id'): + try: + response = client.do_request('registration') + except KeyError: + raise ConfigurationError('No registration info') + except Exception as err: + logger.error(err) + raise + else: + if 'error' in response: + raise OidcServiceError(response.to_json()) + + +def dynamic_provider_info_discovery(client): + """ + This is about performing dynamic Provider Info discovery + + :param client: A :py:class:`oidcrp.oidc.Client` instance + """ + try: + client.get_service('provider_info') + except KeyError: + raise ConfigurationError( + 'Can not do dynamic provider info discovery') + else: + _context = client.client_get("service_context") + try: + _context.set('issuer', _context.config['srv_discovery_url']) + except KeyError: + pass + + response = client.do_request('provider_info') + if is_error_message(response): + raise OidcServiceError(response['error']) diff --git a/tests/request123456.jwt b/tests/request123456.jwt new file mode 100644 index 0000000..f72cdd1 --- /dev/null +++ b/tests/request123456.jwt @@ -0,0 +1 @@ +eyJhbGciOiJSUzI1NiIsImtpZCI6IlNVc3dOaTFNUkZsRFQwWTJZalUxWjFSZlFsbzJTM2RFYTNGVFRrVjNMVGhGY25oRFRIRjVlbGsyVlEifQ.eyJyZXNwb25zZV90eXBlIjogImNvZGUiLCAic3RhdGUiOiAic3RhdGUiLCAicmVkaXJlY3RfdXJpIjogImh0dHBzOi8vZXhhbXBsZS5jb20vY2xpL2F1dGh6X2NiIiwgInNjb3BlIjogIm9wZW5pZCIsICJub25jZSI6ICJWNVFGVUxMWGpBS0lOS3V3Y2JlVU43OHRuWkZ2MGloYyIsICJjbGllbnRfaWQiOiAiY2xpZW50X2lkIiwgImlzcyI6ICJjbGllbnRfaWQiLCAiaWF0IjogMTYxNzg3MDI0NiwgImF1ZCI6IFsiaHR0cHM6Ly9leGFtcGxlLmNvbSJdfQ.itHjMHec6T2py2zDxaAS11tFAYdsTZ-SY3AV2j5SCTHPfk9CkXa6g6s_t0oVvcdzLVXaqQ1iU9WqOeirwj4UDxCTHulPzBOcBuC3WbCki2HT9EPI88Lov-kCuz_4juw97lIyU1BkaofaZJSjRcEW_fOY0KuP7BKIDmthTLTWxEGMoMXmXcs_QD13tz0IWjrkqjwbcKjbUxrvGHJbOnOBGwgHPPm46otDMO-hQrtvTOGz4PbdD5XZ4imDV0bJkx72ITpAfM8iODmd9sKrOWkEZEhsRG1ugXa8RgOPNLsLTLjTpzWiVeczJHOiE5H4EC-uwzXiRDiShq-q7VbTKocyYw \ No newline at end of file diff --git a/flask_rp/conf.yaml b/tests/rp_conf.yaml similarity index 77% rename from flask_rp/conf.yaml rename to tests/rp_conf.yaml index d83ac1a..d1b418c 100644 --- a/flask_rp/conf.yaml +++ b/tests/rp_conf.yaml @@ -1,33 +1,9 @@ -logging: - version: 1 - disable_existing_loggers: False - root: - handlers: - - console - - file - level: DEBUG - loggers: - idp: - level: DEBUG - handlers: - console: - class: logging.StreamHandler - stream: 'ext://sys.stdout' - formatter: default - file: - class: logging.FileHandler - filename: 'debug.log' - formatter: default - formatters: - default: - format: '%(asctime)s %(name)s %(levelname)s %(message)s' - port: &port 8090 domain: &domain 127.0.0.1 base_url: "https://{domain}:{port}" httpc_params: - # This is just for testing an local usage. In all other cases it MUST be True + # This is just for testing a local usage. In all other cases it MUST be True verify: false # Client side #client_cert: "certs/client.crt" @@ -134,16 +110,3 @@ clients: acr: essential: true - - -webserver: - port: *port - domain: *domain - # If BASE is https these has to be specified - server_cert: "certs/cert.pem" - server_key: "certs/key.pem" - # If you want the clients cert to be verified - # verify_user: optional - # The you also need - # ca_bundle: '' - debug: true diff --git a/tests/salesforce.key b/tests/salesforce.key new file mode 100644 index 0000000..d34432d --- /dev/null +++ b/tests/salesforce.key @@ -0,0 +1,15 @@ +-----BEGIN RSA PRIVATE KEY----- +MIICXAIBAAKBgQD6vqn19W/VB215DBADRakfPmCtFBf8/+YyhGqixWIwDiEl/L6L +w5HKZCUPVgrC0ADhJfvAbn4fte5MWBCTkqgepKL3BySMA0LMaBF12pbHlPSUbmQG +BJmTX4NNXuUel6TbPYJAU2Nh5Nan0Mb7Bmb8QpFvS0Hw7qZRW8y2eIttfwIDAQAB +AoGBAJVf9FxkRKUB8cOE3h006JWGUY2KROghgn9hxy0ErYO3RyQcN1+HuFh75GAI +gAyiYYO/XwS6TkSR2057wBRJ8ABzcL3+v5g+16Vbh0BjXVE+cv1WGdNGujyzl6ji +jlyF4cb6tXDyqWTLkMAtV20NfO/CGsfii6YEkZb2P90usthRAkEA/oG7a9EvQ7eR +gSEqppzW7KCwidPjnZTr/ROIZQU33nwkIJ0ElTjMNYKP8DerSuixR9skw2ZY8Q8I +1PTBnocHwwJBAPw3SAQYwxZwQMu1trVPMNOGIbSY4rQlMZGXrCZSu/TnozczFLA8 +qNM84g5veyJOzHKmYkIsMG1gwg5VNniG45UCQF6SlLOW0upl70K9sVyiUVcyywcc +Xqty6FJtjLSFQOKC3OXlkwtkRLXpo1UPSq6WUzIxY7LceFZzUMPZg41F/gMCQHNr +POqbBlPzZMOUUZthNP/nhu8lc8Fqr+dnmGElRVxK0JdHKfWInN2mI/DlNV064Dar +S5XqsPKs78EtX7MCT40CQFQZiry8m7ROubOU4+HDG9o1w9zcKXCkmbD9hBCGvTAj +BQNuGE0DtC6FEWTs8bXybLM5yBRq1XiKLdmi5N+3n4g= +-----END RSA PRIVATE KEY----- diff --git a/tests/test_01_base.py b/tests/test_01_base.py index 19a1299..64d058e 100644 --- a/tests/test_01_base.py +++ b/tests/test_01_base.py @@ -1,5 +1,5 @@ -from oidcrp import add_path -from oidcrp import load_registration_response +from oidcrp.util import add_path +from oidcrp.util import load_registration_response from oidcrp.oidc import RP diff --git a/tests/test_01_service_context.py b/tests/test_01_service_context.py new file mode 100644 index 0000000..88b35af --- /dev/null +++ b/tests/test_01_service_context.py @@ -0,0 +1,269 @@ +import os +from urllib.parse import urlsplit + +import pytest +import responses +from cryptojwt.key_jar import build_keyjar + +from oidcrp.service_context import ServiceContext + +BASE_URL = "https://entity.example.org" + +def test_client_info_init(): + config = { + 'client_id': 'client_id', 'issuer': 'issuer', + 'client_secret': 'client_secret_wordplay', + 'base_url': 'https://example.com', + 'requests_dir': 'requests' + } + ci = ServiceContext(BASE_URL, config=config) + + srvcnx = ServiceContext().load(ci.dump()) + + for attr in config.keys(): + try: + val = getattr(srvcnx, attr) + except AttributeError: + val = srvcnx.get(attr) + + assert val == config[attr] + + +def test_set_and_get_client_secret(): + service_context = ServiceContext() + service_context.client_secret = 'longenoughsupersecret' + assert service_context.client_secret == 'longenoughsupersecret' + + +def test_set_and_get_client_id(): + ci = ServiceContext() + ci.client_id = 'myself' + assert ci.client_id == 'myself' + + +def test_client_filename(): + config = { + 'client_id': 'client_id', 'issuer': 'issuer', + 'client_secret': 'longenoughsupersecret', 'base_url': 'https://example.com', + 'requests_dir': 'requests' + } + ci = ServiceContext(config=config) + fname = ci.filename_from_webname('https://example.com/rq12345') + assert fname == 'rq12345' + + +def verify_alg_support(service_context, alg, usage, typ): + """ + Verifies that the algorithm to be used are supported by the other side. + This will look at provider information either statically configured or + obtained through dynamic provider info discovery. + + :param alg: The algorithm specification + :param usage: In which context the 'alg' will be used. + The following contexts are supported: + - userinfo + - id_token + - request_object + - token_endpoint_auth + :param typ: Type of algorithm + - signing_alg + - encryption_alg + - encryption_enc + :return: True or False + """ + + supported = service_context.provider_info[ + "{}_{}_values_supported".format(usage, typ)] + + if alg in supported: + return True + else: + return False + + +class TestClientInfo(object): + @pytest.fixture(autouse=True) + def create_client_info_instance(self): + config = { + 'client_id': 'client_id', 'issuer': 'issuer', + 'client_secret': 'longenoughsupersecret', + 'base_url': 'https://example.com', + 'requests_dir': 'requests' + } + self.service_context = ServiceContext(config=config) + + def test_registration_userinfo_sign_enc_algs(self): + self.service_context.behaviour= { + "application_type": "web", + "redirect_uris": ["https://client.example.org/callback", + "https://client.example.org/callback2"], + "token_endpoint_auth_method": "client_secret_basic", + "jwks_uri": "https://client.example.org/my_public_keys.jwks", + "userinfo_encrypted_response_alg": "RSA1_5", + "userinfo_encrypted_response_enc": "A128CBC-HS256" + } + + assert self.service_context.get_sign_alg('userinfo') is None + assert self.service_context.get_enc_alg_enc('userinfo') == { + 'alg': 'RSA1_5', 'enc': 'A128CBC-HS256'} + + def test_registration_request_object_sign_enc_algs(self): + self.service_context.behaviour= { + "application_type": "web", + "redirect_uris": ["https://client.example.org/callback", + "https://client.example.org/callback2"], + "token_endpoint_auth_method": "client_secret_basic", + "jwks_uri": "https://client.example.org/my_public_keys.jwks", + "userinfo_encrypted_response_alg": "RSA1_5", + "userinfo_encrypted_response_enc": "A128CBC-HS256", + "request_object_signing_alg": "RS384" + } + + res = self.service_context.get_enc_alg_enc('userinfo') + # 'sign':'RS256' is an added default + assert res == {'alg': 'RSA1_5', 'enc': 'A128CBC-HS256'} + res = self.service_context.get_sign_alg('request_object') + assert res == 'RS384' + + def test_registration_id_token_sign_enc_algs(self): + self.service_context.behaviour= { + "application_type": "web", + "redirect_uris": ["https://client.example.org/callback", + "https://client.example.org/callback2"], + "token_endpoint_auth_method": "client_secret_basic", + "jwks_uri": "https://client.example.org/my_public_keys.jwks", + "userinfo_encrypted_response_alg": "RSA1_5", + "userinfo_encrypted_response_enc": "A128CBC-HS256", + "request_object_signing_alg": "RS384", + 'id_token_encrypted_response_alg': 'ECDH-ES', + 'id_token_encrypted_response_enc': "A128GCM", + 'id_token_signed_response_alg': "ES384", + } + + res = self.service_context.get_enc_alg_enc('userinfo') + # 'sign':'RS256' is an added default + assert res == {'alg': 'RSA1_5', 'enc': 'A128CBC-HS256'} + res = self.service_context.get_sign_alg('request_object') + assert res == 'RS384' + res = self.service_context.get_enc_alg_enc('id_token') + assert res == {'alg': 'ECDH-ES', 'enc': 'A128GCM'} + + def test_verify_alg_support(self): + self.service_context.provider_info= { + "version": "3.0", + "issuer": "https://server.example.com", + "authorization_endpoint": + "https://server.example.com/connect/authorize", + "token_endpoint": "https://server.example.com/connect/token", + "token_endpoint_auth_methods_supported": ["client_secret_basic", + "private_key_jwt"], + "token_endpoint_auth_signing_alg_values_supported": ["RS256", + "ES256"], + "userinfo_endpoint": "https://server.example.com/connect/userinfo", + "check_session_iframe": + "https://server.example.com/connect/check_session", + "end_session_endpoint": + "https://server.example.com/connect/end_session", + "jwks_uri": "https://server.example.com/jwks.json", + "registration_endpoint": + "https://server.example.com/connect/register", + "scopes_supported": ["openid", "profile", "email", "address", + "phone", "offline_access"], + "response_types_supported": ["code", "code id_token", "id_token", + "token id_token"], + "acr_values_supported": ["urn:mace:incommon:iap:silver", + "urn:mace:incommon:iap:bronze"], + "subject_types_supported": ["public", "pairwise"], + "userinfo_signing_alg_values_supported": ["RS256", "ES256", + "HS256"], + "userinfo_encryption_alg_values_supported": ["RSA1_5", "A128KW"], + "userinfo_encryption_enc_values_supported": ["A128CBC+HS256", + "A128GCM"], + "id_token_signing_alg_values_supported": ["RS256", "ES256", + "HS256"], + "id_token_encryption_alg_values_supported": ["RSA1_5", "A128KW"], + "id_token_encryption_enc_values_supported": ["A128CBC+HS256", + "A128GCM"], + "request_object_signing_alg_values_supported": ["none", "RS256", + "ES256"], + "display_values_supported": ["page", "popup"], + "claim_types_supported": ["normal", "distributed"], + "claims_supported": ["sub", "iss", "auth_time", "acr", "name", + "given_name", "family_name", "nickname", + "profile", + "picture", "website", "email", + "email_verified", + "locale", "zoneinfo", + "http://example.info/claims/groups"], + "claims_parameter_supported": True, + "service_documentation": + "http://server.example.com/connect/service_documentation.html", + "ui_locales_supported": ["en-US", "en-GB", "en-CA", "fr-FR", + "fr-CA"] + } + + assert verify_alg_support(self.service_context, 'RS256', 'id_token', + 'signing_alg') + assert verify_alg_support(self.service_context, + 'RS512', 'id_token', 'signing_alg') is False + + assert verify_alg_support(self.service_context, 'RSA1_5', 'userinfo', + 'encryption_alg') + + # token_endpoint_auth_signing_alg_values_supported + assert verify_alg_support(self.service_context, 'ES256', + 'token_endpoint_auth', 'signing_alg') + + def test_verify_requests_uri(self): + self.service_context.provider_info= {'issuer': 'https://example.com/'} + url_list = self.service_context.generate_request_uris('/leading') + sp = urlsplit(url_list[0]) + p = sp.path.split('/') + assert p[0] == '' + assert p[1] == 'leading' + assert len(p) == 3 + + # different for different OPs + self.service_context.provider_info= {'issuer': 'https://op.example.org/'} + url_list = self.service_context.generate_request_uris('/leading') + sp = urlsplit(url_list[0]) + np = sp.path.split('/') + assert np[0] == '' + assert np[1] == 'leading' + assert len(np) == 3 + + assert np[2] != p[2] + + def test_import_keys_file(self): + # Should only be one and that a symmetric key (client_secret) usable + # for signing and encryption + assert len(self.service_context.keyjar.get_issuer_keys('')) == 1 + + file_path = os.path.abspath( + os.path.join(os.path.dirname(__file__), 'salesforce.key')) + + keyspec = {'file': {'rsa': [file_path]}} + self.service_context.import_keys(keyspec) + + # Now there should be 2, the second a RSA key for signing + assert len(self.service_context.keyjar.get_issuer_keys('')) == 2 + + def test_import_keys_url(self): + assert len(self.service_context.keyjar.get_issuer_keys('')) == 1 + + # One EC key for signing + key_def = [{"type": "EC", "crv": "P-256", "use": ["sig"]}] + + keyjar = build_keyjar(key_def) + + with responses.RequestsMock() as rsps: + _jwks_url = 'https://foobar.com/jwks.json' + rsps.add("GET", _jwks_url, body=keyjar.export_jwks_as_json(), status=200, + adding_headers={"Content-Type": "application/json"}) + keyspec = {'url': {'https://foobar.com': _jwks_url}} + self.service_context.import_keys(keyspec) + self.service_context.keyjar.update() + + # Now there should be one belonging to https://example.com + assert len(self.service_context.keyjar.get_issuer_keys( + 'https://foobar.com')) == 1 diff --git a/tests/test_01_service_context_impexp.py b/tests/test_01_service_context_impexp.py new file mode 100644 index 0000000..09a42ac --- /dev/null +++ b/tests/test_01_service_context_impexp.py @@ -0,0 +1,302 @@ +import json +import os +from urllib.parse import urlsplit + +import pytest +import responses +from cryptojwt.key_jar import build_keyjar + +from oidcrp.service_context import ServiceContext + + +def test_client_info_init(): + config = { + 'client_id': 'client_id', 'issuer': 'issuer', + 'client_secret': 'client_secret_wordplay', + 'base_url': 'https://example.com', + 'requests_dir': 'requests' + } + ci = ServiceContext(config=config) + + srvcnx = ServiceContext().load(ci.dump()) + + for attr in config.keys(): + try: + val = getattr(srvcnx, attr) + except AttributeError: + val = srvcnx.get(attr) + + assert val == config[attr] + + +def test_set_and_get_client_secret(): + service_context = ServiceContext() + service_context.client_secret = 'longenoughsupersecret' + + srvcnx2 = ServiceContext().load(service_context.dump()) + + assert srvcnx2.client_secret == 'longenoughsupersecret' + + +def test_set_and_get_client_id(): + service_context = ServiceContext() + service_context.client_id = 'myself' + srvcnx2 = ServiceContext().load(service_context.dump()) + assert srvcnx2.client_id == 'myself' + + +def test_client_filename(): + config = { + 'client_id': 'client_id', 'issuer': 'issuer', + 'client_secret': 'longenoughsupersecret', 'base_url': 'https://example.com', + 'requests_dir': 'requests' + } + service_context = ServiceContext(config=config) + srvcnx2 = ServiceContext().load(service_context.dump()) + fname = srvcnx2.filename_from_webname('https://example.com/rq12345') + assert fname == 'rq12345' + + +def verify_alg_support(service_context, alg, usage, typ): + """ + Verifies that the algorithm to be used are supported by the other side. + This will look at provider information either statically configured or + obtained through dynamic provider info discovery. + + :param alg: The algorithm specification + :param usage: In which context the 'alg' will be used. + The following contexts are supported: + - userinfo + - id_token + - request_object + - token_endpoint_auth + :param typ: Type of algorithm + - signing_alg + - encryption_alg + - encryption_enc + :return: True or False + """ + + supported = service_context.provider_info[ + "{}_{}_values_supported".format(usage, typ)] + + if alg in supported: + return True + else: + return False + + +class TestClientInfo(object): + @pytest.fixture(autouse=True) + def create_client_info_instance(self): + config = { + 'client_id': 'client_id', 'issuer': 'issuer', + 'client_secret': 'longenoughsupersecret', + 'base_url': 'https://example.com', + 'requests_dir': 'requests' + } + self.service_context = ServiceContext(config=config) + + def test_registration_userinfo_sign_enc_algs(self): + self.service_context.set( + 'behaviour', { + "application_type": "web", + "redirect_uris": ["https://client.example.org/callback", + "https://client.example.org/callback2"], + "token_endpoint_auth_method": "client_secret_basic", + "jwks_uri": "https://client.example.org/my_public_keys.jwks", + "userinfo_encrypted_response_alg": "RSA1_5", + "userinfo_encrypted_response_enc": "A128CBC-HS256" + }) + + srvcntx = ServiceContext().load( + self.service_context.dump(exclude_attributes=["service_context"])) + assert srvcntx.get_sign_alg('userinfo') is None + assert srvcntx.get_enc_alg_enc('userinfo') == {'alg': 'RSA1_5', 'enc': 'A128CBC-HS256'} + + def test_registration_request_object_sign_enc_algs(self): + self.service_context.behaviour= { + "application_type": "web", + "redirect_uris": ["https://client.example.org/callback", + "https://client.example.org/callback2"], + "token_endpoint_auth_method": "client_secret_basic", + "jwks_uri": "https://client.example.org/my_public_keys.jwks", + "userinfo_encrypted_response_alg": "RSA1_5", + "userinfo_encrypted_response_enc": "A128CBC-HS256", + "request_object_signing_alg": "RS384" + } + + srvcntx = ServiceContext().load( + self.service_context.dump(exclude_attributes=["service_context"])) + res = srvcntx.get_enc_alg_enc('userinfo') + # 'sign':'RS256' is an added default + assert res == {'alg': 'RSA1_5', 'enc': 'A128CBC-HS256'} + assert srvcntx.get_sign_alg('request_object') == 'RS384' + + def test_registration_id_token_sign_enc_algs(self): + self.service_context.behaviour= { + "application_type": "web", + "redirect_uris": ["https://client.example.org/callback", + "https://client.example.org/callback2"], + "token_endpoint_auth_method": "client_secret_basic", + "jwks_uri": "https://client.example.org/my_public_keys.jwks", + "userinfo_encrypted_response_alg": "RSA1_5", + "userinfo_encrypted_response_enc": "A128CBC-HS256", + "request_object_signing_alg": "RS384", + 'id_token_encrypted_response_alg': 'ECDH-ES', + 'id_token_encrypted_response_enc': "A128GCM", + 'id_token_signed_response_alg': "ES384", + } + + srvcntx = ServiceContext().load( + self.service_context.dump(exclude_attributes=["service_context"])) + + # 'sign':'RS256' is an added default + assert srvcntx.get_enc_alg_enc('userinfo') == {'alg': 'RSA1_5', 'enc': 'A128CBC-HS256'} + assert srvcntx.get_sign_alg('request_object') == 'RS384' + assert srvcntx.get_enc_alg_enc('id_token') == {'alg': 'ECDH-ES', 'enc': 'A128GCM'} + + def test_verify_alg_support(self): + self.service_context.provider_info= { + "version": "3.0", + "issuer": "https://server.example.com", + "authorization_endpoint": + "https://server.example.com/connect/authorize", + "token_endpoint": "https://server.example.com/connect/token", + "token_endpoint_auth_methods_supported": ["client_secret_basic", + "private_key_jwt"], + "token_endpoint_auth_signing_alg_values_supported": ["RS256", + "ES256"], + "userinfo_endpoint": "https://server.example.com/connect/userinfo", + "check_session_iframe": + "https://server.example.com/connect/check_session", + "end_session_endpoint": + "https://server.example.com/connect/end_session", + "jwks_uri": "https://server.example.com/jwks.json", + "registration_endpoint": + "https://server.example.com/connect/register", + "scopes_supported": ["openid", "profile", "email", "address", + "phone", "offline_access"], + "response_types_supported": ["code", "code id_token", "id_token", + "token id_token"], + "acr_values_supported": ["urn:mace:incommon:iap:silver", + "urn:mace:incommon:iap:bronze"], + "subject_types_supported": ["public", "pairwise"], + "userinfo_signing_alg_values_supported": ["RS256", "ES256", + "HS256"], + "userinfo_encryption_alg_values_supported": ["RSA1_5", "A128KW"], + "userinfo_encryption_enc_values_supported": ["A128CBC+HS256", + "A128GCM"], + "id_token_signing_alg_values_supported": ["RS256", "ES256", + "HS256"], + "id_token_encryption_alg_values_supported": ["RSA1_5", "A128KW"], + "id_token_encryption_enc_values_supported": ["A128CBC+HS256", + "A128GCM"], + "request_object_signing_alg_values_supported": ["none", "RS256", + "ES256"], + "display_values_supported": ["page", "popup"], + "claim_types_supported": ["normal", "distributed"], + "claims_supported": ["sub", "iss", "auth_time", "acr", "name", + "given_name", "family_name", "nickname", + "profile", + "picture", "website", "email", + "email_verified", + "locale", "zoneinfo", + "http://example.info/claims/groups"], + "claims_parameter_supported": True, + "service_documentation": + "http://server.example.com/connect/service_documentation.html", + "ui_locales_supported": ["en-US", "en-GB", "en-CA", "fr-FR", + "fr-CA"] + } + + srvcntx = ServiceContext().load( + self.service_context.dump(exclude_attributes=["service_context"])) + + assert verify_alg_support(srvcntx, 'RS256', 'id_token', 'signing_alg') + assert verify_alg_support(srvcntx, 'RS512', 'id_token', 'signing_alg') is False + assert verify_alg_support(srvcntx, 'RSA1_5', 'userinfo', 'encryption_alg') + + # token_endpoint_auth_signing_alg_values_supported + assert verify_alg_support(srvcntx, 'ES256', 'token_endpoint_auth', 'signing_alg') + + def test_verify_requests_uri(self): + self.service_context.provider_info= {'issuer': 'https://example.com/'} + url_list = self.service_context.generate_request_uris('/leading') + sp = urlsplit(url_list[0]) + p = sp.path.split('/') + assert p[0] == '' + assert p[1] == 'leading' + assert len(p) == 3 + + srvcntx = ServiceContext().load( + self.service_context.dump(exclude_attributes=["service_context"])) + + # different for different OPs + srvcntx.provider_info= {'issuer': 'https://op.example.org/'} + url_list = srvcntx.generate_request_uris('/leading') + sp = urlsplit(url_list[0]) + np = sp.path.split('/') + assert np[0] == '' + assert np[1] == 'leading' + assert len(np) == 3 + + assert np[2] != p[2] + + def test_import_keys_file(self): + # Should only be one and that a symmetric key (client_secret) usable + # for signing and encryption + assert len(self.service_context.keyjar.get_issuer_keys('')) == 1 + + file_path = os.path.abspath( + os.path.join(os.path.dirname(__file__), 'salesforce.key')) + + keyspec = {'file': {'rsa': [file_path]}} + self.service_context.import_keys(keyspec) + + srvcntx = ServiceContext().load( + self.service_context.dump(exclude_attributes=["service_context"])) + + # Now there should be 2, the second a RSA key for signing + assert len(srvcntx.keyjar.get_issuer_keys('')) == 2 + + def test_import_keys_file_json(self): + # Should only be one and that a symmetric key (client_secret) usable + # for signing and encryption + assert len(self.service_context.keyjar.get_issuer_keys('')) == 1 + + file_path = os.path.abspath( + os.path.join(os.path.dirname(__file__), 'salesforce.key')) + + keyspec = {'file': {'rsa': [file_path]}} + self.service_context.import_keys(keyspec) + + _sc_state = self.service_context.dump(exclude_attributes=["service_context"]) + _jsc_state = json.dumps(_sc_state) + _o_state = json.loads(_jsc_state) + srvcntx = ServiceContext().load(_o_state) + + # Now there should be 2, the second a RSA key for signing + assert len(srvcntx.keyjar.get_issuer_keys('')) == 2 + + def test_import_keys_url(self): + assert len(self.service_context.keyjar.get_issuer_keys('')) == 1 + + # One EC key for signing + key_def = [{"type": "EC", "crv": "P-256", "use": ["sig"]}] + + keyjar = build_keyjar(key_def) + + with responses.RequestsMock() as rsps: + _jwks_url = 'https://foobar.com/jwks.json' + rsps.add("GET", _jwks_url, body=keyjar.export_jwks_as_json(), status=200, + adding_headers={"Content-Type": "application/json"}) + keyspec = {'url': {'https://foobar.com': _jwks_url}} + self.service_context.import_keys(keyspec) + self.service_context.keyjar.update() + + srvcntx = ServiceContext().load( + self.service_context.dump(exclude_attributes=["service_context"])) + + # Now there should be one belonging to https://example.com + assert len(srvcntx.keyjar.get_issuer_keys('https://foobar.com')) == 1 diff --git a/tests/test_02_cookie.py b/tests/test_02_cookie.py index 2bd8627..35e93f8 100644 --- a/tests/test_02_cookie.py +++ b/tests/test_02_cookie.py @@ -2,15 +2,15 @@ from http.cookies import SimpleCookie import pytest -from oidcrp.cookie import make_cookie -from oidcservice.exception import ImproperlyConfigured from oidcrp.cookie import CookieDealer from oidcrp.cookie import InvalidCookieSign from oidcrp.cookie import cookie_parts from oidcrp.cookie import cookie_signature +from oidcrp.cookie import make_cookie from oidcrp.cookie import parse_cookie from oidcrp.cookie import verify_cookie_signature +from oidcrp.exception import ImproperlyConfigured __author__ = 'roland' @@ -51,7 +51,8 @@ def test_cookie_dealer_improperly_configured(self): class BadServer(): def __init__(self): self.symkey = "" - with pytest.raises(ImproperlyConfigured) as err: + + with pytest.raises(ImproperlyConfigured): CookieDealer(BadServer()) def test_cookie_dealer_with_domain(self): diff --git a/tests/test_03_util.py b/tests/test_03_util.py index 838e366..a32bcec 100644 --- a/tests/test_03_util.py +++ b/tests/test_03_util.py @@ -1,15 +1,20 @@ -import pytest - from http.cookiejar import FileCookieJar from http.cookiejar import http2time from http.cookies import SimpleCookie +import json from urllib.parse import parse_qs from urllib.parse import urlparse +from urllib.parse import urlsplit -from oidcrp.util import get_deserialization_method, URL_ENCODED -from oidcservice.exception import WrongContentType +from oidcmsg.oauth2 import AccessTokenRequest +from oidcmsg.oauth2 import AuthorizationRequest +import pytest from oidcrp import util +from oidcrp.exception import WrongContentType +from oidcrp.util import JSON_ENCODED +from oidcrp.util import URL_ENCODED +from oidcrp.util import get_deserialization_method __author__ = 'Roland Hedberg' @@ -175,3 +180,88 @@ def test_verify_no_content_type(): resp = FakeResponse('text/html') del resp.headers['content-type'] assert util.verify_header(resp, 'txt') == 'txt' + + +def test_get_http_url(): + url = u'https://localhost:8092/authorization' + method = 'GET' + values = {'acr_values': 'PASSWORD', + 'state': 'urn:uuid:92d81fb3-72e8-4e6c-9173-c360b782148a', + 'redirect_uri': + 'https://localhost:8666/919D3F697FDAAF138124B83E09ECB0B7', + 'response_type': 'code', 'client_id': 'ok8tx7ulVlNV', + 'scope': 'openid profile email address phone'} + request = AuthorizationRequest(**values) + + _url = util.get_http_url(url, request, method) + _part = urlsplit(_url) + _req = parse_qs(_part.query) + assert set(_req.keys()) == {'acr_values', 'state', 'redirect_uri', + 'response_type', 'client_id', 'scope'} + + +def test_get_http_body_default_encoding(): + values = { + 'redirect_uri': + 'https://localhost:8666/919D3F697FDAAF138124B83E09ECB0B7', + 'code': 'Je1iKfPN1vCiN7L43GiXAuAWGAnm0mzA7QIjl' + '/YLBBZDB9wefNExQlLDUIIDM2rT' + '2t+gwuoRoapEXJyY2wrvg9cWTW2vxsZU+SuWzZlMDXc=', + 'grant_type': 'authorization_code'} + request = AccessTokenRequest(**values) + + body = util.get_http_body(request) + + _req = parse_qs(body) + assert set(_req.keys()) == {'code', 'grant_type', 'redirect_uri'} + + +def test_get_http_body_url_encoding(): + values = { + 'redirect_uri': + 'https://localhost:8666/919D3F697FDAAF138124B83E09ECB0B7', + 'code': 'Je1iKfPN1vCiN7L43GiXAuAWGAnm0mzA7QIjl' + '/YLBBZDB9wefNExQlLDUIIDM2rT' + '2t+gwuoRoapEXJyY2wrvg9cWTW2vxsZU+SuWzZlMDXc=', + 'grant_type': 'authorization_code'} + request = AccessTokenRequest(**values) + + body = util.get_http_body(request, URL_ENCODED) + + _req = parse_qs(body) + assert set(_req.keys()) == {'code', 'grant_type', 'redirect_uri'} + + +def test_get_http_body_json(): + values = { + 'redirect_uri': + 'https://localhost:8666/919D3F697FDAAF138124B83E09ECB0B7', + 'code': 'Je1iKfPN1vCiN7L43GiXAuAWGAnm0mzA7QIjl' + '/YLBBZDB9wefNExQlLDUIIDM2rT' + '2t+gwuoRoapEXJyY2wrvg9cWTW2vxsZU+SuWzZlMDXc=', + 'grant_type': 'authorization_code'} + request = AccessTokenRequest(**values) + + body = util.get_http_body(request, JSON_ENCODED) + + _req = json.loads(body) + assert set(_req.keys()) == {'code', 'grant_type', 'redirect_uri'} + + +def test_get_http_url_with_qp(): + url = u'https://localhost:8092/authorization?test=testslice' + method = 'GET' + values = {'acr_values': 'PASSWORD', + 'state': 'urn:uuid:92d81fb3-72e8-4e6c-9173-c360b782148a', + 'redirect_uri': + 'https://localhost:8666/919D3F697FDAAF138124B83E09ECB0B7', + 'response_type': 'code', 'client_id': 'ok8tx7ulVlNV', + 'scope': 'openid profile email address phone'} + request = AuthorizationRequest(**values) + + _url = util.get_http_url(url, request, method) + _part = urlsplit(_url) + _req = parse_qs(_part.query) + assert set(_req.keys()) == {'acr_values', 'state', 'redirect_uri', + 'response_type', 'client_id', 'scope', + 'test'} diff --git a/tests/test_07_service.py b/tests/test_07_service.py new file mode 100644 index 0000000..7740ae5 --- /dev/null +++ b/tests/test_07_service.py @@ -0,0 +1,93 @@ +from oidcmsg.oauth2 import Message +from oidcmsg.oauth2 import SINGLE_OPTIONAL_INT +from oidcmsg.oauth2 import SINGLE_OPTIONAL_STRING +from oidcmsg.oauth2 import SINGLE_REQUIRED_STRING +import pytest + +from oidcrp.entity import Entity +from oidcrp.service import Service +from oidcrp.service_context import ServiceContext +from oidcrp.state_interface import InMemoryStateDataBase +from oidcrp.state_interface import State + + +class DummyMessage(Message): + c_param = { + "req_str": SINGLE_REQUIRED_STRING, + "opt_str": SINGLE_OPTIONAL_STRING, + "opt_int": SINGLE_OPTIONAL_INT, + } + + +class Response(object): + def __init__(self, status_code, text, headers=None): + self.status_code = status_code + self.text = text + self.headers = headers or {"content-type": "text/plain"} + + +class DummyService(Service): + msg_type = DummyMessage + + +class TestDummyService(object): + @pytest.fixture(autouse=True) + def create_service(self): + config = { + "issuer": 'https://www.example.org/as', + 'client_id': 'client_id', + 'client_secret': 'a longesh password', + 'redirect_uris': ['https://example.com/cli/authz_cb'], + 'behaviour': {'response_types': ['code']} + } + service = { + "dummy": { + "class": DummyService + } + } + + entity = Entity(config=config, services=service) + self.service = DummyService(client_get=entity.client_get, conf={}) + + def test_construct(self): + req_args = {'foo': 'bar'} + _req = self.service.construct(request_args=req_args) + assert isinstance(_req, Message) + assert list(_req.keys()) == ['foo'] + + def test_construct_service_context(self): + req_args = {'foo': 'bar', 'req_str': 'some string'} + _req = self.service.construct(request_args=req_args) + assert isinstance(_req, Message) + assert set(_req.keys()) == {'foo', 'req_str'} + + def test_get_request_parameters(self): + req_args = {'foo': 'bar', 'req_str': 'some string'} + self.service.endpoint = 'https://example.com/authorize' + _info = self.service.get_request_parameters(request_args=req_args) + assert set(_info.keys()) == {'url', 'method', "request"} + msg = DummyMessage().from_urlencoded( + self.service.get_urlinfo(_info['url'])) + + def test_request_init(self): + req_args = {'foo': 'bar', 'req_str': 'some string'} + self.service.endpoint = 'https://example.com/authorize' + _info = self.service.get_request_parameters(request_args=req_args) + assert set(_info.keys()) == {'url', 'method', "request"} + msg = DummyMessage().from_urlencoded( + self.service.get_urlinfo(_info['url'])) + assert msg.to_dict() == {'foo': 'bar', 'req_str': 'some string'} + + +# class TestRequest(object): +# @pytest.fixture(autouse=True) +# def create_service(self): +# entity = Entity() +# service_context = entity.get_service_context() +# self.service = Service(service_context, client_authn_method=None) +# +# def test_construct(self): +# req_args = {'foo': 'bar'} +# _req = self.service.construct(request_args=req_args) +# assert isinstance(_req, Message) +# assert list(_req.keys()) == ['foo'] diff --git a/tests/test_08_webfinger.py b/tests/test_08_webfinger.py new file mode 100644 index 0000000..931431c --- /dev/null +++ b/tests/test_08_webfinger.py @@ -0,0 +1,314 @@ +import json +from urllib.parse import parse_qs, unquote_plus, urlsplit + +from oidcrp.entity import Entity +import pytest +from oidcmsg.exception import MissingRequiredAttribute +from oidcmsg.oidc import JRD, Link + +from oidcrp.oidc import OIC_ISSUER +from oidcrp.oidc.webfinger import WebFinger +from oidcrp.service_context import ServiceContext + +__author__ = 'Roland Hedberg' + +SERVICE_CONTEXT = ServiceContext() + + +ENTITY = Entity(config={}) + + +def test_query(): + rel = 'http%3A%2F%2Fopenid.net%2Fspecs%2Fconnect%2F1.0%2Fissuer' + pattern = 'https://{}/.well-known/webfinger?rel={}&resource={}' + example_oidc = { + 'example.com': ('example.com', rel, 'acct%3Aexample.com'), + 'joe@example.com': ('example.com', rel, 'acct%3Ajoe%40example.com'), + 'example.com/joe': ('example.com', rel, + 'https%3A%2F%2Fexample.com%2Fjoe'), + 'example.com:8080': ('example.com:8080', rel, + 'https%3A%2F%2Fexample.com%3A8080'), + 'Jane.Doe@example.com': ('example.com', rel, + 'acct%3AJane.Doe%40example.com'), + 'alice@example.com:8080': ('alice@example.com:8080', rel, + 'https%3A%2F%2Falice%40example.com%3A8080'), + 'https://example.com': ('example.com', rel, + 'https%3A%2F%2Fexample.com'), + 'https://example.com/joe': ( + 'example.com', rel, 'https%3A%2F%2Fexample.com%2Fjoe'), + 'https://joe@example.com:8080': ( + 'joe@example.com:8080', rel, + 'https%3A%2F%2Fjoe%40example.com%3A8080'), + 'acct:joe@example.com': ('example.com', rel, + 'acct%3Ajoe%40example.com') + } + + wf = WebFinger(ENTITY.client_get) + for key, args in example_oidc.items(): + _q = wf.query(key) + p = urlsplit(_q) + assert p.netloc == args[0] + qs = parse_qs(p.query) + assert qs['resource'][0] == unquote_plus(args[2]) + assert qs['rel'][0] == unquote_plus(args[1]) + + +def test_query_2(): + rel = 'http%3A%2F%2Fopenid.net%2Fspecs%2Fconnect%2F1.0%2Fissuer' + pattern = 'https://{}/.well-known/webfinger?rel={}&resource={}' + example_oidc = { + # below are identifiers that are slightly off + "example.com?query": ( + 'example.com', rel, 'https%3A%2F%2Fexample.com%3Fquery'), + "example.com#fragment": ( + 'example.com', rel, 'https%3A%2F%2Fexample.com'), + "example.com:8080/path?query#fragment": + ('example.com:8080', + rel, 'https%3A%2F%2Fexample.com%3A8080%2Fpath%3Fquery'), + "http://example.com/path": ( + 'example.com', rel, 'http%3A%2F%2Fexample.com%2Fpath'), + "http://example.com?query": ( + 'example.com', rel, 'http%3A%2F%2Fexample.com%3Fquery'), + "http://example.com#fragment": ( + 'example.com', rel, 'http%3A%2F%2Fexample.com'), + "http://example.com:8080/path?query#fragment": ( + 'example.com:8080', rel, + 'http%3A%2F%2Fexample.com%3A8080%2Fpath%3Fquery'), + "nov@example.com:8080": ( + "nov@example.com:8080", rel, + "https%3A%2F%2Fnov%40example.com%3A8080"), + "nov@example.com/path": ( + "nov@example.com", rel, + "https%3A%2F%2Fnov%40example.com%2Fpath"), + "nov@example.com?query": ( + "nov@example.com", rel, + "https%3A%2F%2Fnov%40example.com%3Fquery"), + "nov@example.com#fragment": ( + "nov@example.com", rel, + "https%3A%2F%2Fnov%40example.com"), + "nov@example.com:8080/path?query#fragment": ( + "nov@example.com:8080", rel, + "https%3A%2F%2Fnov%40example.com%3A8080%2Fpath%3Fquery"), + "acct:nov@example.com:8080": ( + "example.com:8080", rel, + "acct%3Anov%40example.com%3A8080" + ), + "acct:nov@example.com/path": ( + "example.com", rel, + "acct%3Anov%40example.com%2Fpath" + ), + "acct:nov@example.com?query": ( + "example.com", rel, + "acct%3Anov%40example.com%3Fquery" + ), + "acct:nov@example.com#fragment": ( + "example.com", rel, + "acct%3Anov%40example.com" + ), + "acct:nov@example.com:8080/path?query#fragment": ( + "example.com:8080", rel, + "acct%3Anov%40example.com%3A8080%2Fpath%3Fquery" + ) + } + + wf = WebFinger(ENTITY.client_get) + for key, args in example_oidc.items(): + _q = wf.query(key) + p = urlsplit(_q) + assert p.netloc == args[0] + qs = parse_qs(p.query) + assert qs['resource'][0] == unquote_plus(args[2]) + assert qs['rel'][0] == unquote_plus(args[1]) + + +def test_link1(): + link = Link( + rel="http://webfinger.net/rel/avatar", + type="image/jpeg", + href="http://www.example.com/~bob/bob.jpg" + ) + + assert set(link.keys()) == {'rel', 'type', 'href'} + assert link['rel'] == "http://webfinger.net/rel/avatar" + assert link['type'] == "image/jpeg" + assert link['href'] == "http://www.example.com/~bob/bob.jpg" + + +def test_link2(): + link = Link(rel="blog", type="text/html", + href="http://blogs.example.com/bob/", + titles={ + "en-us": "The Magical World of Bob", + "fr": "Le monde magique de Bob" + }) + + assert set(link.keys()) == {'rel', 'type', 'href', 'titles'} + assert link['rel'] == "blog" + assert link['type'] == "text/html" + assert link['href'] == "http://blogs.example.com/bob/" + assert set(link['titles'].keys()) == {'en-us', 'fr'} + + +def test_link3(): + link = Link(rel="http://webfinger.net/rel/profile-page", + href="http://www.example.com/~bob/") + + assert set(link.keys()) == {'rel', 'href'} + assert link['rel'] == "http://webfinger.net/rel/profile-page" + assert link['href'] == "http://www.example.com/~bob/" + + +def test_jrd(): + jrd = JRD( + subject="acct:bob@example.com", + aliases=[ + "http://www.example.com/~bob/" + ], + properties={ + "http://example.com/ns/role/": "employee" + }, + links=[ + Link( + rel="http://webfinger.net/rel/avatar", + type="image/jpeg", + href="http://www.example.com/~bob/bob.jpg" + ), + Link( + rel="http://webfinger.net/rel/profile-page", + href="http://www.example.com/~bob/" + )]) + + assert set(jrd.keys()) == {'subject', 'aliases', 'properties', 'links'} + + +def test_jrd2(): + ex0 = { + "subject": "acct:bob@example.com", + "aliases": [ + "http://www.example.com/~bob/" + ], + "properties": { + "http://example.com/ns/role/": "employee" + }, + "links": [ + { + "rel": "http://webfinger.net/rel/avatar", + "type": "image/jpeg", + "href": "http://www.example.com/~bob/bob.jpg" + }, + { + "rel": "http://webfinger.net/rel/profile-page", + "href": "http://www.example.com/~bob/" + }, + { + "rel": "blog", + "type": "text/html", + "href": "http://blogs.example.com/bob/", + "titles": { + "en-us": "The Magical World of Bob", + "fr": "Le monde magique de Bob" + } + }, + { + "rel": "vcard", + "href": "https://www.example.com/~bob/bob.vcf" + } + ] + } + + jrd0 = JRD().from_json(json.dumps(ex0)) + + for link in jrd0["links"]: + if link["rel"] == "blog": + assert link["href"] == "http://blogs.example.com/bob/" + break + + +def test_extra_member_response(): + ex = { + "subject": "acct:bob@example.com", + "aliases": [ + "http://www.example.com/~bob/" + ], + "properties": { + "http://example.com/ns/role/": "employee" + }, + 'dummy': 'foo', + "links": [ + { + "rel": "http://webfinger.net/rel/avatar", + "type": "image/jpeg", + "href": "http://www.example.com/~bob/bob.jpg" + }] + } + + _resp = JRD().from_json(json.dumps(ex)) + assert _resp['dummy'] == 'foo' + + +class TestWebFinger(object): + def test_query_device(self): + wf = WebFinger(ENTITY.client_get) + request_args = {'resource': "p1.example.com"} + _info = wf.get_request_parameters(request_args) + p = urlsplit(_info['url']) + assert p.netloc == request_args["resource"] + qs = parse_qs(p.query) + assert qs['resource'][0] == "acct:p1.example.com" + assert qs['rel'][0] == "http://openid.net/specs/connect/1.0/issuer" + + def test_query_rel(self): + wf = WebFinger(ENTITY.client_get) + request_args = {'resource': "acct:bob@example.com"} + _info = wf.get_request_parameters(request_args) + p = urlsplit(_info['url']) + assert p.netloc == "example.com" + qs = parse_qs(p.query) + assert qs['resource'][0] == "acct:bob@example.com" + assert qs['rel'][0] == "http://openid.net/specs/connect/1.0/issuer" + + def test_query_acct(self): + wf = WebFinger(ENTITY.client_get, rel=OIC_ISSUER) + request_args = {'resource': "acct:carol@example.com"} + _info = wf.get_request_parameters(request_args=request_args) + + p = urlsplit(_info['url']) + assert p.netloc == "example.com" + qs = parse_qs(p.query) + assert qs['resource'][0] == "acct:carol@example.com" + assert qs['rel'][0] == "http://openid.net/specs/connect/1.0/issuer" + + def test_query_acct_resource_kwargs(self): + wf = WebFinger(ENTITY.client_get, rel=OIC_ISSUER) + request_args = {} + _info = wf.get_request_parameters(request_args=request_args, + resource="acct:carol@example.com") + + p = urlsplit(_info['url']) + assert p.netloc == "example.com" + qs = parse_qs(p.query) + assert qs['resource'][0] == "acct:carol@example.com" + assert qs['rel'][0] == "http://openid.net/specs/connect/1.0/issuer" + + def test_query_acct_resource_config(self): + wf = WebFinger(ENTITY.client_get, rel=OIC_ISSUER) + wf.client_get("service_context").config['resource'] = "acct:carol@example.com" + request_args = {} + _info = wf.get_request_parameters(request_args=request_args) + + p = urlsplit(_info['url']) + assert p.netloc == "example.com" + qs = parse_qs(p.query) + assert qs['resource'][0] == "acct:carol@example.com" + assert qs['rel'][0] == "http://openid.net/specs/connect/1.0/issuer" + + def test_query_acct_no_resource(self): + wf = WebFinger(ENTITY.client_get, rel=OIC_ISSUER) + try: + del wf.client_get("service_context").config['resource'] + except KeyError: + pass + request_args = {} + + with pytest.raises(MissingRequiredAttribute): + wf.get_request_parameters(request_args=request_args) diff --git a/tests/test_09_client_auth.py b/tests/test_09_client_auth.py new file mode 100755 index 0000000..fe2b86f --- /dev/null +++ b/tests/test_09_client_auth.py @@ -0,0 +1,512 @@ +import base64 +import os +from urllib.parse import quote_plus + +from cryptojwt.exception import MissingKey +from cryptojwt.jwk.rsa import new_rsa_key +from cryptojwt.jws.jws import JWS +from cryptojwt.jws.jws import factory +from cryptojwt.jwt import JWT +from cryptojwt.key_bundle import KeyBundle +from cryptojwt.key_jar import KeyJar +from oidcmsg.message import Message +from oidcmsg.oauth2 import AccessTokenRequest +from oidcmsg.oauth2 import AccessTokenResponse +from oidcmsg.oauth2 import AuthorizationRequest +from oidcmsg.oauth2 import AuthorizationResponse +from oidcmsg.oauth2 import CCAccessTokenRequest +from oidcmsg.oauth2 import ResourceRequest +import pytest + +from oidcrp.client_auth import AuthnFailure +from oidcrp.client_auth import BearerBody +from oidcrp.client_auth import BearerHeader +from oidcrp.client_auth import ClientSecretBasic +from oidcrp.client_auth import ClientSecretJWT +from oidcrp.client_auth import ClientSecretPost +from oidcrp.client_auth import PrivateKeyJWT +from oidcrp.client_auth import assertion_jwt +from oidcrp.client_auth import bearer_auth +from oidcrp.client_auth import valid_service_context +from oidcrp.defaults import JWT_BEARER +from oidcrp.entity import Entity + +BASE_PATH = os.path.abspath(os.path.dirname(__file__)) +CLIENT_ID = "A" + +CLIENT_CONF = {'issuer': 'https://example.com/as', + 'redirect_uris': ['https://example.com/cli/authz_cb'], + 'client_secret': 'white boarding pass', + 'client_id': CLIENT_ID} + + +def _eq(l1, l2): + return set(l1) == set(l2) + + +@pytest.fixture +def entity(): + return Entity(config=CLIENT_CONF) + + +def test_quote(): + csb = ClientSecretBasic() + http_args = csb.construct( + Message(), + password='MKEM/A7Pkn7JuU0LAcxyHVKvwdczsugaPU0BieLb4CbQAgQj+ypcanFOCb0/FA5h', + user='796d8fae-a42f-4e4f-ab25-d6205b6d4fa2') + + assert http_args['headers'][ + 'Authorization'] == 'Basic ' \ + 'Nzk2ZDhmYWUtYTQyZi00ZTRmLWFiMjUtZDYyMDViNmQ0ZmEyOk1LRU0lMkZBN1BrbjdKdVUwTEFjeHlIVkt2d2RjenN1Z2FQVTBCaWVMYjRDYlFBZ1FqJTJCeXBjYW5GT0NiMCUyRkZBNWg=' + + +class TestClientSecretBasic(object): + def test_construct(self, entity): + _token_service = entity.client_get("service", "accesstoken") + request = _token_service.construct(redirect_uri="http://example.com", + state='ABCDE') + + csb = ClientSecretBasic() + http_args = csb.construct(request, _token_service) + + credentials = "{}:{}".format(quote_plus('A'), quote_plus('white boarding pass')) + + assert http_args == {"headers": {"Authorization": "Basic {}".format( + base64.urlsafe_b64encode(credentials.encode("utf-8")).decode( + "utf-8"))}} + + def test_does_not_remove_padding(self): + request = AccessTokenRequest(code="foo", + redirect_uri="http://example.com") + + csb = ClientSecretBasic() + http_args = csb.construct(request, user="ab", password="c") + + assert http_args["headers"]["Authorization"].endswith("==") + + def test_construct_cc(self): + """CC == Client Credentials, the 4th OAuth2 flow""" + request = CCAccessTokenRequest(grant_type="client_credentials") + + csb = ClientSecretBasic() + http_args = csb.construct(request, user="service1", password="secret") + + assert http_args["headers"]["Authorization"].startswith('Basic ') + + +class TestBearerHeader(object): + def test_construct(self, entity): + request = ResourceRequest(access_token="Sesame") + bh = BearerHeader() + http_args = bh.construct(request, + service=entity.client_get("service", "accesstoken")) + + assert http_args == {"headers": {"Authorization": "Bearer Sesame"}} + + def test_construct_with_http_args(self, entity): + request = ResourceRequest(access_token="Sesame") + bh = BearerHeader() + # Any HTTP args should just be passed on + http_args = bh.construct(request, + service=entity.client_get("service", "accesstoken"), + http_args={"foo": "bar"}) + + assert _eq(http_args.keys(), ["foo", "headers"]) + assert http_args["headers"] == {"Authorization": "Bearer Sesame"} + + def test_construct_with_headers_in_http_args(self, entity): + request = ResourceRequest(access_token="Sesame") + + bh = BearerHeader() + http_args = bh.construct(request, + service=entity.client_get("service", "accesstoken"), + http_args={"headers": {"x-foo": "bar"}}) + + assert _eq(http_args.keys(), ["headers"]) + assert _eq(http_args["headers"].keys(), ["Authorization", "x-foo"]) + assert http_args["headers"]["Authorization"] == "Bearer Sesame" + + def test_construct_with_resource_request(self, entity): + bh = BearerHeader() + request = ResourceRequest(access_token="Sesame") + + http_args = bh.construct(request, + service=entity.client_get("service", "accesstoken")) + + assert "access_token" not in request + assert http_args == {"headers": {"Authorization": "Bearer Sesame"}} + + def test_construct_with_token(self, entity): + authz_service = entity.client_get("service", 'authorization') + srv_cntx = authz_service.client_get("service_context") + _state = srv_cntx.state.create_state('Issuer') + req = AuthorizationRequest(state=_state, response_type='code', + redirect_uri='https://example.com', + scope=['openid']) + srv_cntx.state.store_item(req, 'auth_request', _state) + + # Add a state and bind a code to it + resp1 = AuthorizationResponse(code="auth_grant", state=_state) + response = authz_service.parse_response( + resp1.to_urlencoded(), "urlencoded") + authz_service.update_service_context(response, key=_state) + + # based on state find the code and then get an access token + resp2 = AccessTokenResponse(access_token="token1", + token_type="Bearer", expires_in=0, + state=_state) + _token_service = entity.client_get("service", 'accesstoken') + response = _token_service.parse_response( + resp2.to_urlencoded(), "urlencoded") + + _token_service.update_service_context(response, key=_state) + + # and finally use the access token, bound to a state, to + # construct the authorization header + http_args = BearerHeader().construct( + ResourceRequest(), _token_service, key=_state) + assert http_args == {"headers": {"Authorization": "Bearer token1"}} + + +class TestBearerBody(object): + def test_construct(self, entity): + _token_service = entity.client_get("service", 'accesstoken') + request = ResourceRequest(access_token="Sesame") + http_args = BearerBody().construct(request, service=_token_service) + + assert request["access_token"] == "Sesame" + assert http_args is None + + def test_construct_with_state(self, entity): + _auth_service = entity.client_get("service", 'authorization') + _cntx = _auth_service.client_get("service_context") + _key = _cntx.state.create_state(iss='Issuer') + + resp = AuthorizationResponse(code="code", state=_key) + _cntx.state.store_item(resp, 'auth_response', _key) + + atr = AccessTokenResponse(access_token="2YotnFZFEjr1zCsicMWpAA", + token_type="example", + refresh_token="tGzv3JOkF0XG5Qx2TlKWIA", + example_parameter="example_value", + scope=["inner", "outer"]) + _cntx.state.store_item(atr, 'token_response', _key) + + request = ResourceRequest() + http_args = BearerBody().construct(request, service=_auth_service, key=_key) + assert request["access_token"] == "2YotnFZFEjr1zCsicMWpAA" + assert http_args is None + + def test_construct_with_request(self, entity): + authz_service = entity.client_get("service", 'authorization') + _cntx = authz_service.client_get("service_context") + + _key = _cntx.state.create_state(iss='Issuer') + resp1 = AuthorizationResponse(code="auth_grant", state=_key) + response = authz_service.parse_response(resp1.to_urlencoded(), + "urlencoded") + authz_service.update_service_context(response, key=_key) + + resp2 = AccessTokenResponse(access_token="token1", + token_type="Bearer", expires_in=0, + state=_key) + _token_service = entity.client_get("service", 'accesstoken') + response = _token_service.parse_response(resp2.to_urlencoded(), "urlencoded") + _token_service.update_service_context(response, key=_key) + + request = ResourceRequest() + BearerBody().construct(request, service=authz_service, key=_key) + + assert "access_token" in request + assert request["access_token"] == "token1" + + +class TestClientSecretPost(object): + def test_construct(self, entity): + _token_service = entity.client_get("service", 'accesstoken') + request = _token_service.construct(redirect_uri="http://example.com", + state='ABCDE') + csp = ClientSecretPost() + http_args = csp.construct(request, service=_token_service) + + assert request["client_id"] == "A" + assert request["client_secret"] == "white boarding pass" + assert http_args is None + + request = AccessTokenRequest(code="foo", + redirect_uri="http://example.com") + http_args = csp.construct(request, service=_token_service, + client_secret="another") + assert request["client_id"] == "A" + assert request["client_secret"] == "another" + assert http_args is None + + def test_modify_1(self, entity): + token_service = entity.client_get("service", 'accesstoken') + request = token_service.construct(redirect_uri="http://example.com", + state='ABCDE') + csp = ClientSecretPost() + # client secret not in request or kwargs + del request["client_secret"] + http_args = csp.construct(request, service=token_service) + assert "client_secret" in request + + def test_modify_2(self, entity): + token_service = entity.client_get("service", 'accesstoken') + request = token_service.construct(redirect_uri="http://example.com", + state='ABCDE') + csp = ClientSecretPost() + # client secret not in request or kwargs + del request["client_secret"] + token_service.client_get("service_context").client_secret = "" + # this will fail + with pytest.raises(AuthnFailure): + http_args = csp.construct(request, service=token_service) + + +class TestPrivateKeyJWT(object): + def test_construct(self, entity): + token_service = entity.client_get("service", 'accesstoken') + kb_rsa = KeyBundle(source='file://{}'.format( + os.path.join(BASE_PATH, "data/keys/rsa.key")), fileformat='der') + + for key in kb_rsa: + key.add_kid() + + _context = token_service.client_get("service_context") + _context.keyjar.add_kb('', kb_rsa) + _context.provider_info = { + 'issuer': 'https://example.com/', + 'token_endpoint': "https://example.com/token"} + _context.registration_response = { + 'token_endpoint_auth_signing_alg': 'RS256'} + token_service.endpoint = "https://example.com/token" + + request = AccessTokenRequest() + pkj = PrivateKeyJWT() + http_args = pkj.construct(request, service=token_service, authn_endpoint='token_endpoint') + assert http_args == {} + cas = request["client_assertion"] + + _kj = KeyJar() + _kj.add_kb(_context.client_id, kb_rsa) + jso = JWT(key_jar=_kj).unpack(cas) + assert _eq(jso.keys(), ["aud", "iss", "sub", "jti", "exp", "iat"]) + # assert _jwt.headers == {'alg': 'RS256'} + assert jso['aud'] == [_context.provider_info['token_endpoint']] + + def test_construct_client_assertion(self, entity): + token_service = entity.client_get("service", 'accesstoken') + + kb_rsa = KeyBundle(source='file://{}'.format( + os.path.join(BASE_PATH, "data/keys/rsa.key")), fileformat='der') + + request = AccessTokenRequest() + pkj = PrivateKeyJWT() + _ca = assertion_jwt( + token_service.client_get("service_context").client_id, kb_rsa.get('RSA'), + "https://example.com/token", 'RS256') + http_args = pkj.construct(request, client_assertion=_ca) + assert http_args == {} + assert request['client_assertion'] == _ca + assert request['client_assertion_type'] == JWT_BEARER + + +class TestClientSecretJWT_TE(object): + def test_client_secret_jwt(self, entity): + _service_context = entity.client_get("service_context") + _service_context.token_endpoint = "https://example.com/token" + + _service_context.provider_info = { + 'issuer': 'https://example.com/', + 'token_endpoint': "https://example.com/token"} + + _service_context.registration_response = { + 'token_endpoint_auth_signing_alg': "HS256"} + + csj = ClientSecretJWT() + request = AccessTokenRequest() + + csj.construct(request, + service=entity.client_get("service", 'accesstoken'), + authn_endpoint='token_endpoint') + assert request["client_assertion_type"] == JWT_BEARER + assert "client_assertion" in request + cas = request["client_assertion"] + + _kj = KeyJar() + _kj.add_symmetric(_service_context.client_id, _service_context.client_secret, ['sig']) + jso = JWT(key_jar=_kj, sign_alg='HS256').unpack(cas) + assert _eq(jso.keys(), ["aud", "iss", "sub", "exp", "iat", 'jti']) + + _rj = JWS(alg='HS256') + info = _rj.verify_compact( + cas, _kj.get_signing_key(issuer_id=_service_context.client_id)) + + assert _eq(info.keys(), ["aud", "iss", "sub", "jti", "exp", "iat"]) + assert info['aud'] == [_service_context.provider_info['token_endpoint']] + + def test_get_key_by_kid(self, entity): + _service_context = entity.client_get("service_context") + _service_context.token_endpoint = "https://example.com/token" + + _service_context.provider_info = { + 'issuer': 'https://example.com/', + 'token_endpoint': "https://example.com/token"} + + _service_context.registration_response = { + 'token_endpoint_auth_signing_alg': "HS256"} + + csj = ClientSecretJWT() + request = AccessTokenRequest() + + # get a kid + _keys = _service_context.keyjar.get_issuer_keys("") + kid = _keys[0].kid + token_service = entity.client_get("service", 'accesstoken') + csj.construct(request, service=token_service, + authn_endpoint='token_endpoint', kid=kid) + assert "client_assertion" in request + + def test_get_key_by_kid_fail(self, entity): + token_service = entity.client_get("service", 'accesstoken') + _service_context = token_service.client_get("service_context") + _service_context.token_endpoint = "https://example.com/token" + + _service_context.provider_info = { + 'issuer': 'https://example.com/', + 'token_endpoint': "https://example.com/token"} + + _service_context.registration_response = { + 'token_endpoint_auth_signing_alg': "HS256"} + + csj = ClientSecretJWT() + request = AccessTokenRequest() + + # get a kid + kid = "abcdefgh" + with pytest.raises(MissingKey): + csj.construct(request, service=token_service, + authn_endpoint='token_endpoint', kid=kid) + + def test_get_audience_and_algorithm_default_alg(self, entity): + _service_context = entity.client_get("service_context") + _service_context.token_endpoint = "https://example.com/token" + + _service_context.provider_info = { + 'issuer': 'https://example.com/', + 'token_endpoint': "https://example.com/token"} + + _service_context.registration_response = { + 'token_endpoint_auth_signing_alg': "HS256"} + + csj = ClientSecretJWT() + request = AccessTokenRequest() + + _service_context.registration_response = {} + + token_service = entity.client_get("service", 'accesstoken') + + # Add a RSA key to be able to handle default + _kb = KeyBundle() + _rsa_key = new_rsa_key() + _kb.append(_rsa_key) + _service_context.keyjar.add_kb("", _kb) + # Since I have a RSA key this doesn't fail + csj.construct(request, service=token_service, authn_endpoint='token_endpoint') + + _jws = factory(request["client_assertion"]) + assert _jws.jwt.headers["alg"] == "RS256" + assert _jws.jwt.headers["kid"] == _rsa_key.kid + + # By client preferences + request = AccessTokenRequest() + _service_context.client_preferences = {"token_endpoint_auth_signing_alg": "RS512"} + csj.construct(request, service=token_service, authn_endpoint='token_endpoint') + + _jws = factory(request["client_assertion"]) + assert _jws.jwt.headers["alg"] == "RS512" + assert _jws.jwt.headers["kid"] == _rsa_key.kid + + # Use provider information is everything else fails + request = AccessTokenRequest() + _service_context.client_preferences = {} + _service_context.provider_info["token_endpoint_auth_signing_alg_values_supported"] = [ + "ES256", "RS256"] + csj.construct(request, service=token_service, authn_endpoint='token_endpoint') + + _jws = factory(request["client_assertion"]) + # Should be RS256 since I have no key for ES256 + assert _jws.jwt.headers["alg"] == "RS256" + assert _jws.jwt.headers["kid"] == _rsa_key.kid + + +class TestClientSecretJWT_UI(object): + def test_client_secret_jwt(self, entity): + access_token_service = entity.client_get("service", 'accesstoken') + + _service_context = access_token_service.client_get("service_context") + _service_context.token_endpoint = "https://example.com/token" + _service_context.provider_info = {'issuer': 'https://example.com/', + 'token_endpoint': "https://example.com/token"} + + csj = ClientSecretJWT() + request = AccessTokenRequest() + + csj.construct(request, service=access_token_service, + algorithm="HS256", authn_endpoint='userinfo') + assert request["client_assertion_type"] == JWT_BEARER + assert "client_assertion" in request + cas = request["client_assertion"] + + _kj = KeyJar() + _kj.add_symmetric(_service_context.client_id, + _service_context.client_secret, + usage=['sig']) + jso = JWT(key_jar=_kj, sign_alg='HS256').unpack(cas) + assert _eq(jso.keys(), ["aud", "iss", "sub", "jti", "exp", "iat"]) + + _rj = JWS(alg='HS256') + info = _rj.verify_compact( + cas, + _kj.get_signing_key(issuer_id=_service_context.client_id)) + + assert _eq(info.keys(), ["aud", "iss", "sub", "jti", "exp", "iat"]) + assert info['aud'] == [_service_context.provider_info['issuer']] + + +class TestValidClientInfo(object): + def test_valid_service_context(self, entity): + _service_context = entity.client_get("service_context") + + _now = 123456 # At some time + # Expiration time missing or 0, client_secret never expires + # service_context.client_secret_expires_at + assert valid_service_context(_service_context, _now) + assert valid_service_context(_service_context, _now) + # Expired secret + _service_context.client_secret_expires_at = 1 + assert valid_service_context(_service_context, _now) is not True + + _service_context.client_secret_expires_at = 123455 + assert valid_service_context(_service_context, _now) is not True + + # Valid secret + _service_context.client_secret_expires_at = 123460 + assert valid_service_context(_service_context, _now) + + +def test_bearer_auth(): + request = ResourceRequest(access_token="12345678") + authn = "" + assert bearer_auth(request, authn) == "12345678" + + request = ResourceRequest() + authn = "Bearer abcdefghijklm" + assert bearer_auth(request, authn) == "abcdefghijklm" + + request = ResourceRequest() + authn = "" + with pytest.raises(ValueError): + bearer_auth(request, authn) diff --git a/tests/test_10_oauth2_service.py b/tests/test_10_oauth2_service.py new file mode 100644 index 0000000..6e5baa7 --- /dev/null +++ b/tests/test_10_oauth2_service.py @@ -0,0 +1,264 @@ +from oidcmsg.oauth2 import AccessTokenRequest +from oidcmsg.oauth2 import AccessTokenResponse +from oidcmsg.oauth2 import AuthorizationRequest +from oidcmsg.oauth2 import AuthorizationResponse +from oidcmsg.oauth2 import Message +import pytest + +from oidcrp.entity import Entity + + +class Response(object): + def __init__(self, status_code, text, headers=None): + self.status_code = status_code + self.text = text + self.headers = headers or {"content-type": "text/plain"} + + +CLIENT_CONF = { + 'client_id': 'client_id', + 'client_secret': 'a longesh password', + 'redirect_uris': ['https://example.com/cli/authz_cb'], + 'behaviour': {'response_types': ['code']} +} + + +class TestAuthorization(object): + @pytest.fixture(autouse=True) + def create_service(self): + self.entity = Entity(config=CLIENT_CONF) + self.auth_service = self.entity.client_get("service",'authorization') + + def test_construct(self): + req_args = {'foo': 'bar'} + _req = self.auth_service.construct(request_args=req_args, state='state') + assert isinstance(_req, AuthorizationRequest) + assert set(_req.keys()) == {'client_id', 'redirect_uri', 'foo', 'state'} + _context = self.entity.client_get("service_context") + assert _context.state.get_state('state') + _item = _context.state.get_item(AuthorizationRequest, 'auth_request', 'state') + assert _item.to_dict() == { + 'foo': 'bar', 'redirect_uri': 'https://example.com/cli/authz_cb', + 'state': 'state', 'client_id': 'client_id' + } + + def test_get_request_parameters(self): + req_args = {'response_type': 'code'} + self.auth_service.endpoint = 'https://example.com/authorize' + _info = self.auth_service.get_request_parameters(request_args=req_args, + state='state') + assert set(_info.keys()) == {'url', 'method', 'request'} + msg = AuthorizationRequest().from_urlencoded( + self.auth_service.get_urlinfo(_info['url'])) + assert msg.to_dict() == { + 'client_id': 'client_id', + 'redirect_uri': 'https://example.com/cli/authz_cb', + 'response_type': 'code', 'state': 'state' + } + + def test_request_init(self): + req_args = {'response_type': 'code', 'state': "state"} + self.auth_service.endpoint = 'https://example.com/authorize' + _info = self.auth_service.get_request_parameters(request_args=req_args) + assert set(_info.keys()) == {'url', 'method', 'request'} + msg = AuthorizationRequest().from_urlencoded( + self.auth_service.get_urlinfo(_info['url'])) + assert msg.to_dict() == { + 'client_id': 'client_id', + 'redirect_uri': 'https://example.com/cli/authz_cb', + 'response_type': 'code', 'state': 'state' + } + + def test_response(self): + _state = "today" + req_args = {'response_type': 'code', 'state': _state} + self.auth_service.endpoint = 'https://example.com/authorize' + _info = self.auth_service.get_request_parameters(request_args=req_args) + assert set(_info.keys()) == {'url', 'method', 'request'} + msg = AuthorizationRequest().from_urlencoded( + self.auth_service.get_urlinfo(_info['url'])) + self.auth_service.client_get("service_context").state.store_item(msg, "auth_request", _state) + + resp1 = AuthorizationResponse(code="auth_grant", state=_state) + response = self.auth_service.parse_response( + resp1.to_urlencoded(), "urlencoded", state=_state) + self.auth_service.update_service_context(response, key=_state) + assert self.auth_service.client_get("service_context").state.get_state(_state) + + +class TestAccessTokenRequest(object): + @pytest.fixture(autouse=True) + def create_service(self): + client_config = { + 'client_id': 'client_id', + 'client_secret': 'a longesh password', + 'redirect_uris': ['https://example.com/cli/authz_cb'] + } + entity = Entity(config=client_config) + self.token_service = entity.client_get("service", "accesstoken") + auth_request = AuthorizationRequest( + redirect_uri='https://example.com/cli/authz_cb', + state='state' + ) + auth_response = AuthorizationResponse(code='access_code') + _state = self.token_service.client_get("service_context").state + _state.store_item(auth_request, 'auth_request', 'state') + _state.store_item(auth_response, 'auth_response', 'state') + + def test_construct(self): + req_args = {'foo': 'bar', 'state': 'state'} + + _req = self.token_service.construct(request_args=req_args) + assert isinstance(_req, AccessTokenRequest) + assert set(_req.keys()) == {'client_id', 'foo', 'grant_type', + 'client_secret', 'code', 'state', + 'redirect_uri'} + + def test_construct_2(self): + # Note that state as a argument means it will not end up in the + # request + req_args = {'foo': 'bar'} + + _req = self.token_service.construct(request_args=req_args, + state='state') + assert isinstance(_req, AccessTokenRequest) + assert set(_req.keys()) == {'client_id', 'foo', 'grant_type', + 'client_secret', 'code', 'state', + 'redirect_uri'} + + def test_get_request_parameters(self): + req_args = { + 'redirect_uri': 'https://example.com/cli/authz_cb', + 'code': 'access_code' + } + self.token_service.endpoint = 'https://example.com/authorize' + _info = self.token_service.get_request_parameters( + request_args=req_args, state='state', + authn_method='client_secret_basic') + assert set(_info.keys()) == {'headers', 'body', 'url', 'method', 'request'} + assert _info['url'] == 'https://example.com/authorize' + assert 'Authorization' in _info['headers'] + msg = AccessTokenRequest().from_urlencoded( + self.token_service.get_urlinfo(_info['body'])) + assert msg.to_dict() == { + 'client_id': 'client_id', 'code': 'access_code', + 'grant_type': 'authorization_code', 'state': 'state', + 'redirect_uri': 'https://example.com/cli/authz_cb' + } + assert 'client_secret' not in msg + + def test_request_init(self): + req_args = { + 'redirect_uri': 'https://example.com/cli/authz_cb', + 'code': 'access_code' + } + self.token_service.endpoint = 'https://example.com/authorize' + + _info = self.token_service.get_request_parameters(request_args=req_args, + state='state') + assert set(_info.keys()) == {'body', 'url', 'headers', 'method', 'request'} + assert _info['url'] == 'https://example.com/authorize' + msg = AccessTokenRequest().from_urlencoded( + self.token_service.get_urlinfo(_info['body'])) + assert msg.to_dict() == { + 'client_id': 'client_id', 'state': 'state', + 'code': 'access_code', 'grant_type': 'authorization_code', + 'redirect_uri': 'https://example.com/cli/authz_cb' + } + + +class TestProviderInfo(object): + @pytest.fixture(autouse=True) + def create_service(self): + self._iss = 'https://example.com/as' + + client_config = { + 'client_id': 'client_id', + 'client_secret': 'a longesh password', + "client_preferences": + { + "application_type": "web", + "application_name": "rphandler", + "contacts": ["ops@example.org"], + "response_types": ["code"], + "scope": ["openid", "profile", "email", "address", "phone"], + "token_endpoint_auth_method": "client_secret_basic", + }, + 'redirect_uris': ['https://example.com/cli/authz_cb'], + 'issuer': self._iss + } + entity = Entity(config=client_config) + self.auth_service = entity.client_get("service",'provider_info') + self.auth_service.endpoint = '{}/.well-known/openid-configuration'.format(self._iss) + + def test_construct(self): + _req = self.auth_service.construct() + assert isinstance(_req, Message) + assert len(_req) == 0 + + def test_get_request_parameters(self): + _info = self.auth_service.get_request_parameters() + assert set(_info.keys()) == {'url', 'method'} + assert _info['url'] == '{}/.well-known/openid-configuration'.format( + self._iss) + + +class TestRefreshAccessTokenRequest(object): + @pytest.fixture(autouse=True) + def create_service(self): + client_config = { + 'client_id': 'client_id', + 'client_secret': 'a longesh password', + 'redirect_uris': ['https://example.com/cli/authz_cb'] + } + entity = Entity(config=client_config) + self.refresh_service = entity.client_get("service",'refresh_token') + auth_response = AuthorizationResponse(code='access_code') + token_response = AccessTokenResponse(access_token='bearer_token', + refresh_token='refresh') + _state = self.refresh_service.client_get("service_context").state + _state.store_item(auth_response, 'auth_response', 'abcdef') + _state.store_item(token_response, 'token_response', 'abcdef') + self.refresh_service.endpoint = 'https://example.com/token' + + def test_construct(self): + _req = self.refresh_service.construct(state='abcdef') + assert isinstance(_req, Message) + assert len(_req) == 4 + assert set(_req.keys()) == {'client_id', 'client_secret', 'grant_type', + 'refresh_token'} + + def test_get_request_parameters(self): + _info = self.refresh_service.get_request_parameters(state='abcdef') + assert set(_info.keys()) == {'url', 'body', 'headers', 'method', 'request'} + + +def test_access_token_srv_conf(): + client_config = { + 'client_id': 'client_id', + 'client_secret': 'a longesh password', + 'redirect_uris': ['https://example.com/cli/authz_cb'] + } + entity = Entity(config=client_config) + token_service = entity.client_get("service",'accesstoken') + + _state_interface = token_service.client_get("service_context").state + _state_val = _state_interface.create_state(token_service.client_get("service_context").issuer) + auth_request = AuthorizationRequest(redirect_uri='https://example.com/cli/authz_cb', + state=_state_val) + + _state_interface.store_item(auth_request, "auth_request", _state_val) + auth_response = AuthorizationResponse(code='access_code') + _state_interface.store_item(auth_response, "auth_response", _state_val) + + req_args = { + 'redirect_uri': 'https://example.com/cli/authz_cb', + 'code': 'access_code' + } + token_service.endpoint = 'https://example.com/authorize' + _info = token_service.get_request_parameters(request_args=req_args, state=_state_val) + + assert _info + msg = AccessTokenRequest().from_urlencoded(_info['body']) + # client_secret_basic by default + assert 'client_secret' not in msg diff --git a/tests/test_11_oauth2.py b/tests/test_11_oauth2.py index 14d709a..d3d6b3a 100644 --- a/tests/test_11_oauth2.py +++ b/tests/test_11_oauth2.py @@ -2,7 +2,6 @@ import sys import time -import pytest from cryptojwt.jwk.rsa import import_private_rsa_key_from_file from cryptojwt.key_bundle import KeyBundle from oidcmsg.oauth2 import AccessTokenRequest @@ -13,9 +12,10 @@ from oidcmsg.oauth2 import ResponseMessage from oidcmsg.oidc import IdToken from oidcmsg.time_util import utc_time_sans_frac -from oidcservice.exception import OidcServiceError -from oidcservice.exception import ParseError +import pytest +from oidcrp.exception import OidcServiceError +from oidcrp.exception import ParseError from oidcrp.oauth2 import Client sys.path.insert(0, '.') @@ -33,7 +33,6 @@ iat=time.time()) - class MockResponse(): def __init__(self, status_code, text, headers=None): self.status_code = status_code @@ -60,9 +59,8 @@ def test_construct_authorization_request(self): 'response_type': ['code'] } - self.client.session_interface.create_state('issuer', key='ABCDE') - msg = self.client.service['authorization'].construct( - request_args=req_args) + self.client.client_get("service_context").state.create_state('issuer', key='ABCDE') + msg = self.client.client_get("service",'authorization').construct(request_args=req_args) assert isinstance(msg, AuthorizationRequest) assert msg['client_id'] == 'client_1' assert msg['redirect_uri'] == 'https://example.com/auth_cb' @@ -70,60 +68,56 @@ def test_construct_authorization_request(self): def test_construct_accesstoken_request(self): # Bind access code to state req_args = {} - - self.client.session_interface.create_state('issuer', 'ABCDE') + _context = self.client.client_get("service_context") + _context.state.create_state('issuer', 'ABCDE') auth_request = AuthorizationRequest( redirect_uri='https://example.com/cli/authz_cb', state='ABCDE' ) - self.client.session_interface.store_item(auth_request, 'auth_request', - 'ABCDE') + _context.state.store_item(auth_request, 'auth_request', 'ABCDE') auth_response = AuthorizationResponse(code='access_code') - self.client.session_interface.store_item(auth_response, + self.client.client_get("service_context").state.store_item(auth_response, 'auth_response', 'ABCDE') - msg = self.client.service['accesstoken'].construct( + msg = self.client.client_get("service",'accesstoken').construct( request_args=req_args, state='ABCDE') assert isinstance(msg, AccessTokenRequest) assert msg.to_dict() == { 'client_id': 'client_1', - 'code': 'access_code', 'client_secret': 'abcdefghijklmnop', 'grant_type': 'authorization_code', - 'redirect_uri': - 'https://example.com/cli/authz_cb', - 'state': 'ABCDE' + 'state': 'ABCDE', + 'code': 'access_code', + 'redirect_uri': 'https://example.com/cli/authz_cb' } def test_construct_refresh_token_request(self): - self.client.session_interface.create_state('issuer', 'ABCDE') + _context = self.client.client_get("service_context") + _context.state.create_state('issuer', 'ABCDE') auth_request = AuthorizationRequest( redirect_uri='https://example.com/cli/authz_cb', state='state' ) - self.client.session_interface.store_item(auth_request, 'auth_request', - 'ABCDE') + _context.state.store_item(auth_request, 'auth_request','ABCDE') auth_response = AuthorizationResponse(code='access_code') - self.client.session_interface.store_item(auth_response, - 'auth_response', 'ABCDE') + _context.state.store_item(auth_response,'auth_response', 'ABCDE') token_response = AccessTokenResponse(refresh_token="refresh_with_me", access_token="access") - self.client.session_interface.store_item(token_response, - 'token_response', 'ABCDE') + _context.state.store_item(token_response, 'token_response', 'ABCDE') req_args = {} - msg = self.client.service['refresh_token'].construct( + msg = self.client.client_get("service",'refresh_token').construct( request_args=req_args, state='ABCDE') assert isinstance(msg, RefreshAccessTokenRequest) assert msg.to_dict() == { @@ -137,7 +131,7 @@ def test_error_response(self): err = ResponseMessage(error='Illegal') http_resp = MockResponse(400, err.to_urlencoded()) resp = self.client.parse_request_response( - self.client.service['authorization'], http_resp) + self.client.client_get("service",'authorization'), http_resp) assert resp['error'] == 'Illegal' assert resp['status_code'] == 400 @@ -147,7 +141,7 @@ def test_error_response_500(self): http_resp = MockResponse(500, err.to_urlencoded()) with pytest.raises(ParseError): self.client.parse_request_response( - self.client.service['authorization'], http_resp) + self.client.client_get("service",'authorization'), http_resp) def test_error_response_2(self): err = ResponseMessage(error='Illegal') @@ -157,4 +151,4 @@ def test_error_response_2(self): with pytest.raises(OidcServiceError): self.client.parse_request_response( - self.client.service['authorization'], http_resp) + self.client.client_get("service",'authorization'), http_resp) diff --git a/tests/test_13_oidc_service.py b/tests/test_13_oidc_service.py new file mode 100644 index 0000000..c3a30dc --- /dev/null +++ b/tests/test_13_oidc_service.py @@ -0,0 +1,945 @@ +import json +import os + +from cryptojwt.exception import UnsupportedAlgorithm +from cryptojwt.jws import jws +from cryptojwt.jws.utils import left_hash +from cryptojwt.jwt import JWT +from cryptojwt.key_jar import build_keyjar +from cryptojwt.key_jar import init_key_jar +from oidcmsg.oidc import AccessTokenRequest +from oidcmsg.oidc import AccessTokenResponse +from oidcmsg.oidc import AuthorizationRequest +from oidcmsg.oidc import AuthorizationResponse +from oidcmsg.oidc import IdToken +from oidcmsg.oidc import Message +from oidcmsg.oidc import OpenIDSchema +from oidcmsg.oidc import RegistrationRequest +from oidcmsg.oidc import verified_claim_name +from oidcmsg.oidc.session import CheckIDRequest +from oidcmsg.oidc.session import CheckSessionRequest +from oidcmsg.oidc.session import EndSessionRequest +import pytest +import responses + +from oidcrp.defaults import DEFAULT_OIDC_SERVICES +from oidcrp.entity import Entity +from oidcrp.exception import ParameterError +from oidcrp.oidc.registration import add_jwks_uri_or_jwks +from oidcrp.oidc.registration import response_types_to_grant_types + + +class Response(object): + def __init__(self, status_code, text, headers=None): + self.status_code = status_code + self.text = text + self.headers = headers or {"content-type": "text/plain"} + + +KEYSPEC = [ + {"type": "RSA", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] + +_dirname = os.path.dirname(os.path.abspath(__file__)) + +ISS = 'https://example.com' + +CLI_KEY = init_key_jar(public_path='{}/pub_client.jwks'.format(_dirname), + private_path='{}/priv_client.jwks'.format(_dirname), + key_defs=KEYSPEC, issuer_id='client_id', read_only=False) + +ISS_KEY = init_key_jar(public_path='{}/pub_iss.jwks'.format(_dirname), + private_path='{}/priv_iss.jwks'.format(_dirname), + key_defs=KEYSPEC, issuer_id=ISS, read_only=False) + +ISS_KEY.import_jwks_as_json(open('{}/pub_client.jwks'.format(_dirname)).read(), + 'client_id') + +CLI_KEY.import_jwks_as_json(open('{}/pub_iss.jwks'.format(_dirname)).read(), ISS) + + +class TestAuthorization(object): + @pytest.fixture(autouse=True) + def create_request(self): + client_config = { + 'client_id': 'client_id', 'client_secret': 'a longesh password', + 'redirect_uris': ['https://example.com/cli/authz_cb'] + } + entity = Entity(services=DEFAULT_OIDC_SERVICES, keyjar=CLI_KEY, config=client_config) + entity.client_get("service_context").issuer = 'https://example.com' + self.service = entity.client_get("service",'authorization') + + def test_construct(self): + req_args = { + 'foo': 'bar', 'response_type': 'code', + 'state': 'state' + } + _req = self.service.construct(request_args=req_args) + assert isinstance(_req, AuthorizationRequest) + assert set(_req.keys()) == {'redirect_uri', 'foo', 'client_id', + 'response_type', 'scope', 'state', + 'nonce'} + + def test_construct_missing_openid_scope(self): + req_args = { + 'foo': 'bar', 'response_type': 'code', + 'state': 'state', 'scope': ['email'] + } + _req = self.service.construct(request_args=req_args) + assert isinstance(_req, AuthorizationRequest) + assert set(_req.keys()) == {'redirect_uri', 'foo', 'client_id', + 'response_type', 'scope', 'state', + 'nonce'} + assert _req['scope'] == ['email', 'openid'] + + def test_construct_token(self): + req_args = { + 'foo': 'bar', 'response_type': 'token', + 'state': 'state' + } + _req = self.service.construct(request_args=req_args) + assert isinstance(_req, AuthorizationRequest) + assert set(_req.keys()) == {'redirect_uri', 'foo', 'client_id', + 'response_type', 'scope', 'state'} + + def test_construct_token_nonce(self): + req_args = { + 'foo': 'bar', 'response_type': 'token', 'nonce': 'nonce', + 'state': 'state' + } + _req = self.service.construct(request_args=req_args) + assert isinstance(_req, AuthorizationRequest) + assert set(_req.keys()) == {'redirect_uri', 'foo', 'client_id', + 'response_type', 'scope', 'state', 'nonce'} + assert _req['nonce'] == 'nonce' + + def test_get_request_parameters(self): + req_args = {'response_type': 'code', 'state': 'state'} + self.service.endpoint = 'https://example.com/authorize' + _info = self.service.get_request_parameters(request_args=req_args) + assert set(_info.keys()) == {'url', 'method', 'request'} + msg = AuthorizationRequest().from_urlencoded( + self.service.get_urlinfo(_info['url'])) + assert set(msg.keys()) == {'response_type', 'state', 'client_id', + 'nonce', 'redirect_uri', 'scope'} + + def test_request_init(self): + req_args = {'response_type': 'code', 'state': 'state'} + self.service.endpoint = 'https://example.com/authorize' + _info = self.service.get_request_parameters(request_args=req_args) + assert set(_info.keys()) == {'url', 'method', 'request'} + msg = AuthorizationRequest().from_urlencoded( + self.service.get_urlinfo(_info['url'])) + assert set(msg.keys()) == {'client_id', 'scope', 'response_type', + 'state', 'redirect_uri', 'nonce'} + + def test_request_init_request_method(self): + req_args = {'response_type': 'code', 'state': 'state'} + self.service.endpoint = 'https://example.com/authorize' + _info = self.service.get_request_parameters(request_args=req_args, + request_method='value') + assert set(_info.keys()) == {'url', 'method', 'request'} + msg = AuthorizationRequest().from_urlencoded( + self.service.get_urlinfo(_info['url'])) + assert set(msg.to_dict()) == {'client_id', 'redirect_uri', 'request', + 'response_type', 'state', 'scope', 'nonce'} + _jws = jws.factory(msg['request']) + assert _jws + _resp = _jws.verify_compact( + msg['request'], + keys=ISS_KEY.get_signing_key(key_type='RSA', + issuer_id='client_id')) + assert _resp + assert set(_resp.keys()) == {'response_type', 'client_id', 'scope', + 'redirect_uri', 'state', 'nonce', 'iss', 'aud', 'iat'} + + def test_request_param(self): + req_args = {'response_type': 'code', 'state': 'state'} + self.service.endpoint = 'https://example.com/authorize' + + assert os.path.isfile(os.path.join(_dirname, 'request123456.jwt')) + + _context = self.service.client_get("service_context") + _context.registration_response = { + 'redirect_uris': ['https://example.com/cb'], + 'request_uris': ['https://example.com/request123456.jwt'] + } + _context.base_url = 'https://example.com/' + _info = self.service.get_request_parameters(request_args=req_args, + request_method='reference') + + assert set(_info.keys()) == {'url', 'method', 'request'} + + def test_update_service_context_no_idtoken(self): + req_args = {'response_type': 'code', 'state': 'state'} + self.service.endpoint = 'https://example.com/authorize' + _info = self.service.get_request_parameters(request_args=req_args) + resp = AuthorizationResponse(state='state', code='code') + self.service.update_service_context(resp, 'state') + + def test_update_service_context_with_idtoken(self): + req_args = {'response_type': 'code', 'state': 'state', 'nonce': 'nonce'} + self.service.endpoint = 'https://example.com/authorize' + _info = self.service.get_request_parameters(request_args=req_args) + # Build an ID Token + idt = JWT(key_jar=ISS_KEY, iss=ISS, lifetime=3600) + payload = {'sub': '123456789', 'aud': ['client_id'], 'nonce': 'nonce'} + # have to calculate c_hash + alg = 'RS256' + halg = "HS%s" % alg[-3:] + payload["c_hash"] = left_hash('code', halg) + + _idt = idt.pack(payload) + resp = AuthorizationResponse(state='state', code='code', id_token=_idt) + resp = self.service.parse_response(resp.to_urlencoded()) + self.service.update_service_context(resp, 'state') + + def test_update_service_context_with_idtoken_wrong_nonce(self): + req_args = {'response_type': 'code', 'state': 'state', 'nonce': 'nonce'} + self.service.endpoint = 'https://example.com/authorize' + _info = self.service.get_request_parameters(request_args=req_args) + # Build an ID Token + idt = JWT(ISS_KEY, iss=ISS, lifetime=3600) + payload = { + 'sub': '123456789', 'aud': ['client_id'], + 'nonce': 'nonce' + } + # have to calculate c_hash + alg = 'RS256' + halg = "HS%s" % alg[-3:] + payload["c_hash"] = left_hash('code', halg) + + _idt = idt.pack(payload) + resp = AuthorizationResponse(state='state', code='code', id_token=_idt) + resp = self.service.parse_response(resp.to_urlencoded()) + with pytest.raises(ParameterError): + self.service.update_service_context(resp, 'state2') + + def test_update_service_context_with_idtoken_missing_nonce(self): + req_args = {'response_type': 'code', 'state': 'state', 'nonce': 'nonce'} + self.service.endpoint = 'https://example.com/authorize' + self.service.get_request_parameters(request_args=req_args) + # Build an ID Token + idt = JWT(ISS_KEY, iss=ISS, lifetime=3600) + payload = {'sub': '123456789', 'aud': ['client_id']} + # have to calculate c_hash + alg = 'RS256' + halg = "HS%s" % alg[-3:] + payload["c_hash"] = left_hash('code', halg) + + _idt = idt.pack(payload) + resp = AuthorizationResponse(state='state', code='code', id_token=_idt) + resp = self.service.parse_response(resp.to_urlencoded()) + with pytest.raises(ValueError): + self.service.update_service_context(resp, 'state') + + @pytest.mark.parametrize("allow_sign_alg_none", [True, False]) + def test_allow_unsigned_idtoken(self, allow_sign_alg_none): + req_args = {'response_type': 'code', 'state': 'state', 'nonce': 'nonce'} + self.service.endpoint = 'https://example.com/authorize' + self.service.get_request_parameters(request_args=req_args) + # Build an ID Token + idt = JWT(ISS_KEY, iss=ISS, lifetime=3600, sign_alg='none') + payload = {'sub': '123456789', 'aud': ['client_id']} + _idt = idt.pack(payload) + self.service.client_get("service_context").behaviour["verify_args"] = { + "allow_sign_alg_none": allow_sign_alg_none + } + resp = AuthorizationResponse(state='state', code='code', id_token=_idt) + if allow_sign_alg_none: + resp = self.service.parse_response(resp.to_urlencoded()) + else: + with pytest.raises(UnsupportedAlgorithm): + self.service.parse_response(resp.to_urlencoded()) + + +class TestAuthorizationCallback(object): + @pytest.fixture(autouse=True) + def create_request(self): + client_config = { + 'client_id': 'client_id', 'client_secret': 'a longesh password', + 'callback': { + 'code': 'https://example.com/cli/authz_cb', + 'implicit': 'https://example.com/cli/authz_im_cb', + 'form_post': 'https://example.com/cli/authz_fp_cb' + } + } + entity = Entity(keyjar=CLI_KEY, config=client_config, services=DEFAULT_OIDC_SERVICES) + entity.client_get("service_context").issuer = 'https://example.com' + self.service = entity.client_get("service",'authorization') + + def test_construct_code(self): + req_args = { + 'foo': 'bar', 'response_type': 'code', + 'state': 'state' + } + _req = self.service.construct(request_args=req_args) + assert isinstance(_req, AuthorizationRequest) + assert set(_req.keys()) == {'redirect_uri', 'foo', 'client_id', + 'response_type', 'scope', 'state', + 'nonce'} + assert _req['redirect_uri'] == 'https://example.com/cli/authz_cb' + + def test_construct_implicit(self): + req_args = { + 'foo': 'bar', + 'response_type': 'id_token token', + 'state': 'state', + 'nonce': "nonce" + } + _req = self.service.construct(request_args=req_args) + assert isinstance(_req, AuthorizationRequest) + assert set(_req.keys()) == {'redirect_uri', 'foo', 'client_id', + 'response_type', 'scope', 'state', + 'nonce'} + assert _req['redirect_uri'] == 'https://example.com/cli/authz_im_cb' + + def test_construct_form_post(self): + req_args = { + 'foo': 'bar', 'response_type': 'code id_token token', + 'state': 'state', 'response_mode': 'form_post', + 'nonce': "nonce" + } + _req = self.service.construct(request_args=req_args) + assert isinstance(_req, AuthorizationRequest) + assert set(_req.keys()) == {'redirect_uri', 'foo', 'client_id', + 'response_type', 'scope', 'state', + 'nonce', 'response_mode'} + assert _req['redirect_uri'] == 'https://example.com/cli/authz_fp_cb' + + +class TestAccessTokenRequest(object): + @pytest.fixture(autouse=True) + def create_request(self): + client_config = { + 'client_id': 'client_id', 'client_secret': 'a longesh password', + 'redirect_uris': ['https://example.com/cli/authz_cb'] + } + entity = Entity(keyjar=CLI_KEY, config=client_config, services=DEFAULT_OIDC_SERVICES) + entity.client_get("service_context").issuer = 'https://example.com' + self.service = entity.client_get("service",'accesstoken') + + # add some history + auth_request = AuthorizationRequest( + redirect_uri='https://example.com/cli/authz_cb', + state='state', response_type='code').to_json() + + _stat_interface = entity.client_get("service_context").state + _stat_interface.store_item(auth_request, "auth_request", 'state') + + auth_response = AuthorizationResponse(code='access_code').to_json() + _stat_interface.store_item(auth_response, "auth_response", 'state') + + def test_construct(self): + req_args = {'foo': 'bar'} + + _req = self.service.construct(request_args=req_args, state='state') + assert isinstance(_req, AccessTokenRequest) + assert set(_req.keys()) == {'client_id', 'foo', 'grant_type', + 'client_secret', 'code', 'state', + 'redirect_uri'} + + def test_get_request_parameters(self): + req_args = { + 'redirect_uri': 'https://example.com/cli/authz_cb', + 'code': 'access_code' + } + self.service.endpoint = 'https://example.com/authorize' + _info = self.service.get_request_parameters(request_args=req_args, + state='state', + authn_method='client_secret_basic') + assert set(_info.keys()) == {'body', 'url', 'headers', 'method', 'request'} + assert _info['url'] == 'https://example.com/authorize' + msg = AccessTokenRequest().from_urlencoded( + self.service.get_urlinfo(_info['body'])) + assert msg.to_dict() == { + 'client_id': 'client_id', 'code': 'access_code', + 'grant_type': 'authorization_code', 'state': 'state', + 'redirect_uri': 'https://example.com/cli/authz_cb' + } + + def test_request_init(self): + req_args = { + 'redirect_uri': 'https://example.com/cli/authz_cb', + 'code': 'access_code' + } + self.service.endpoint = 'https://example.com/authorize' + + _info = self.service.get_request_parameters(request_args=req_args, + state='state') + assert set(_info.keys()) == {'body', 'url', 'headers', 'method', 'request'} + assert _info['url'] == 'https://example.com/authorize' + msg = AccessTokenRequest().from_urlencoded( + self.service.get_urlinfo(_info['body'])) + assert msg.to_dict() == { + 'client_id': 'client_id', 'code': 'access_code', + 'grant_type': 'authorization_code', 'state': 'state', + 'redirect_uri': 'https://example.com/cli/authz_cb' + } + + def test_id_token_nonce_match(self): + _state_interface = self.service.client_get("service_context").state + _state_interface.store_nonce2state('nonce', 'state') + resp = AccessTokenResponse() + resp[verified_claim_name('id_token')] = {'nonce': 'nonce'} + _state_interface.store_nonce2state('nonce2', 'state2') + with pytest.raises(ParameterError): + self.service.update_service_context(resp, key='state2') + + +class TestProviderInfo(object): + @pytest.fixture(autouse=True) + def create_service(self): + self._iss = ISS + client_config = { + 'client_id': 'client_id', 'client_secret': 'a longesh password', + 'redirect_uris': ['https://example.com/cli/authz_cb'], + 'issuer': self._iss, + 'client_preferences': + { + "application_type": "web", + "application_name": "rphandler", + "contacts": ["ops@example.org"], + "response_types": ["code"], + "scope": ["openid", "profile", "email", + "address", "phone"], + "token_endpoint_auth_method": "client_secret_basic", + } + } + entity = Entity(keyjar=CLI_KEY, config=client_config, services=DEFAULT_OIDC_SERVICES) + entity.client_get("service_context").issuer = 'https://example.com' + self.service = entity.client_get("service",'provider_info') + + def test_construct(self): + _req = self.service.construct() + assert isinstance(_req, Message) + assert len(_req) == 0 + + def test_get_request_parameters(self): + _info = self.service.get_request_parameters() + assert set(_info.keys()) == {'url', 'method'} + assert _info['url'] == '{}/.well-known/openid-configuration'.format( + self._iss) + + def test_post_parse(self): + OP_BASEURL = ISS + + provider_info_response = { + "version": "3.0", + "token_endpoint_auth_methods_supported": [ + "client_secret_post", "client_secret_basic", + "client_secret_jwt", "private_key_jwt"], + "claims_parameter_supported": True, + "request_parameter_supported": True, + "request_uri_parameter_supported": True, + "require_request_uri_registration": True, + "grant_types_supported": ["authorization_code", + "implicit", + "urn:ietf:params:oauth:grant-type:jwt-bearer", + "refresh_token"], + "response_types_supported": ["code", "id_token", + "id_token token", + "code id_token", + "code token", + "code id_token token"], + "response_modes_supported": ["query", "fragment", + "form_post"], + "subject_types_supported": ["public", "pairwise"], + "claim_types_supported": ["normal", "aggregated", + "distributed"], + "claims_supported": ["birthdate", "address", + "nickname", "picture", "website", + "email", "gender", "sub", + "phone_number_verified", + "given_name", "profile", + "phone_number", "updated_at", + "middle_name", "name", "locale", + "email_verified", + "preferred_username", "zoneinfo", + "family_name"], + "scopes_supported": ["openid", "profile", "email", + "address", "phone", + "offline_access", "openid"], + "userinfo_signing_alg_values_supported": [ + "RS256", "RS384", "RS512", + "ES256", "ES384", "ES512", + "HS256", "HS384", "HS512", + "PS256", "PS384", "PS512", "none"], + "id_token_signing_alg_values_supported": [ + "RS256", "RS384", "RS512", + "ES256", "ES384", "ES512", + "HS256", "HS384", "HS512", + "PS256", "PS384", "PS512", "none"], + "request_object_signing_alg_values_supported": [ + "RS256", "RS384", "RS512", "ES256", "ES384", + "ES512", "HS256", "HS384", "HS512", "PS256", + "PS384", "PS512", "none"], + "token_endpoint_auth_signing_alg_values_supported": [ + "RS256", "RS384", "RS512", "ES256", "ES384", + "ES512", "HS256", "HS384", "HS512", "PS256", + "PS384", "PS512"], + "userinfo_encryption_alg_values_supported": [ + "RSA1_5", "RSA-OAEP", "RSA-OAEP-256", + "A128KW", "A192KW", "A256KW", + "ECDH-ES", "ECDH-ES+A128KW", "ECDH-ES+A192KW", + "ECDH-ES+A256KW"], + "id_token_encryption_alg_values_supported": [ + "RSA1_5", "RSA-OAEP", "RSA-OAEP-256", + "A128KW", "A192KW", "A256KW", + "ECDH-ES", "ECDH-ES+A128KW", "ECDH-ES+A192KW", + "ECDH-ES+A256KW"], + "request_object_encryption_alg_values_supported": [ + "RSA1_5", "RSA-OAEP", "RSA-OAEP-256", "A128KW", + "A192KW", "A256KW", "ECDH-ES", "ECDH-ES+A128KW", + "ECDH-ES+A192KW", "ECDH-ES+A256KW"], + "userinfo_encryption_enc_values_supported": [ + "A128CBC-HS256", "A192CBC-HS384", "A256CBC-HS512", + "A128GCM", "A192GCM", "A256GCM"], + "id_token_encryption_enc_values_supported": [ + "A128CBC-HS256", "A192CBC-HS384", "A256CBC-HS512", + "A128GCM", "A192GCM", "A256GCM"], + "request_object_encryption_enc_values_supported": [ + "A128CBC-HS256", "A192CBC-HS384", "A256CBC-HS512", + "A128GCM", "A192GCM", "A256GCM"], + "acr_values_supported": ["PASSWORD"], + "issuer": OP_BASEURL, + "jwks_uri": "{}/static/jwks_tE2iLbOAqXhe8bqh.json".format(OP_BASEURL), + "authorization_endpoint": "{}/authorization".format(OP_BASEURL), + "token_endpoint": "{}/token".format(OP_BASEURL), + "userinfo_endpoint": "{}/userinfo".format(OP_BASEURL), + "registration_endpoint": "{}/registration".format(OP_BASEURL), + "end_session_endpoint": "{}/end_session".format(OP_BASEURL) + } + assert self.service.client_get("service_context").behaviour == {} + resp = self.service.post_parse_response(provider_info_response) + + iss_jwks = ISS_KEY.export_jwks_as_json(issuer_id=ISS) + with responses.RequestsMock() as rsps: + rsps.add("GET", resp["jwks_uri"], + body=iss_jwks, status=200) + + self.service.update_service_context(resp) + + assert self.service.client_get("service_context").behaviour == { + 'token_endpoint_auth_method': 'client_secret_basic', + 'response_types': ['code'], + 'application_type': 'web', + 'application_name': 'rphandler', + 'contacts': ['ops@example.org'], + 'scope': ['openid', 'profile', 'email', 'address', 'phone'] + } + + def test_post_parse_2(self): + OP_BASEURL = ISS + + provider_info_response = { + "version": "3.0", + "token_endpoint_auth_methods_supported": [ + "client_secret_post", "client_secret_basic", + "client_secret_jwt", "private_key_jwt"], + "issuer": OP_BASEURL, + "jwks_uri": "{}/static/jwks_tE2iLbOAqXhe8bqh.json".format(OP_BASEURL), + "authorization_endpoint": "{}/authorization".format(OP_BASEURL), + "token_endpoint": "{}/token".format(OP_BASEURL), + "userinfo_endpoint": "{}/userinfo".format(OP_BASEURL), + "registration_endpoint": "{}/registration".format(OP_BASEURL), + "end_session_endpoint": "{}/end_session".format(OP_BASEURL) + } + assert self.service.client_get("service_context").behaviour == {} + resp = self.service.post_parse_response(provider_info_response) + + iss_jwks = ISS_KEY.export_jwks_as_json(issuer_id=ISS) + with responses.RequestsMock() as rsps: + rsps.add("GET", resp["jwks_uri"], + body=iss_jwks, status=200) + + self.service.update_service_context(resp) + + assert self.service.client_get("service_context").behaviour == { + 'token_endpoint_auth_method': 'client_secret_basic', + 'response_types': ['code'], + 'application_type': 'web', + 'application_name': 'rphandler', + 'contacts': ['ops@example.org'], + 'scope': ['openid', 'profile', 'email', 'address', 'phone'] + } + + +def test_response_types_to_grant_types(): + req_args = ['code'] + assert set( + response_types_to_grant_types(req_args)) == {'authorization_code'} + req_args = ['code', 'code id_token'] + assert set( + response_types_to_grant_types(req_args)) == {'authorization_code', + 'implicit'} + req_args = ['code', 'id_token code', 'code token id_token'] + assert set( + response_types_to_grant_types(req_args)) == {'authorization_code', + 'implicit'} + + +def create_jws(val): + lifetime = 3600 + + idts = IdToken(**val) + + return idts.to_jwt(key=ISS_KEY.get_signing_key('ec', issuer_id=ISS), + algorithm="ES256", lifetime=lifetime) + + +class TestRegistration(object): + @pytest.fixture(autouse=True) + def create_request(self): + self._iss = ISS + client_config = { + 'client_id': 'client_id', 'client_secret': 'a longesh password', + 'redirect_uris': ['https://example.com/cli/authz_cb'], + 'issuer': self._iss, 'requests_dir': 'requests', + 'base_url': 'https://example.com/cli/' + } + entity = Entity(keyjar=CLI_KEY, config=client_config, services=DEFAULT_OIDC_SERVICES) + entity.client_get("service_context").issuer = 'https://example.com' + self.service = entity.client_get("service",'registration') + + def test_construct(self): + _req = self.service.construct() + assert isinstance(_req, RegistrationRequest) + assert len(_req) == 4 + + def test_config_with_post_logout(self): + self.service.client_get("service_context").register_args[ + 'post_logout_redirect_uris'] = ['https://example.com/post_logout'] + _req = self.service.construct() + assert isinstance(_req, RegistrationRequest) + assert len(_req) == 5 + assert 'post_logout_redirect_uris' in _req + + def test_config_with_required_request_uri(self): + _pi = self.service.client_get("service_context").provider_info + _pi['require_request_uri_registration'] = True + self.service.client_get("service_context").provider_info = _pi + _req = self.service.construct() + assert isinstance(_req, RegistrationRequest) + assert len(_req) == 5 + assert 'request_uris' in _req + + +class TestUserInfo(object): + @pytest.fixture(autouse=True) + def create_request(self): + self._iss = ISS + client_config = { + 'client_id': 'client_id', 'client_secret': 'a longesh password', + 'redirect_uris': ['https://example.com/cli/authz_cb'], + 'issuer': self._iss, 'requests_dir': 'requests', + 'base_url': 'https://example.com/cli/' + } + entity = Entity(keyjar=CLI_KEY, config=client_config, services=DEFAULT_OIDC_SERVICES) + entity.client_get("service_context").issuer = 'https://example.com' + self.service = entity.client_get("service",'userinfo') + + entity.client_get("service_context").behaviour = { + 'userinfo_signed_response_alg': 'RS256', + "userinfo_encrypted_response_alg": "RSA-OAEP", + "userinfo_encrypted_response_enc": "A256GCM" + } + + _state_interface = self.service.client_get("service_context").state + # Add history + auth_response = AuthorizationResponse(code='access_code').to_json() + _state_interface.store_item(auth_response, 'auth_response', 'abcde') + + idtval = { + 'nonce': 'KUEYfRM2VzKDaaKD', 'sub': 'diana', + 'iss': ISS, 'aud': 'client_id' + } + idt = create_jws(idtval) + + ver_idt = IdToken().from_jwt(idt, CLI_KEY) + + token_response = AccessTokenResponse( + access_token='access_token', id_token=idt, + __verified_id_token=ver_idt).to_json() + _state_interface.store_item(token_response, 'token_response', 'abcde') + + def test_construct(self): + _req = self.service.construct(state='abcde') + assert isinstance(_req, Message) + assert len(_req) == 1 + assert 'access_token' in _req + + def test_unpack_simple_response(self): + resp = OpenIDSchema(sub='diana', given_name='Diana', + family_name='krall') + _resp = self.service.parse_response(resp.to_json(), + state='abcde') + assert _resp + + def test_unpack_aggregated_response(self): + claims = { + "address": { + "street_address": "1234 Hollywood Blvd.", + "locality": "Los Angeles", + "region": "CA", + "postal_code": "90210", + "country": "US" + }, + "phone_number": "+1 (555) 123-4567" + } + + srv = JWT(ISS_KEY, iss=ISS, sign_alg='ES256') + _jwt = srv.pack(payload=claims) + + resp = OpenIDSchema(sub='diana', given_name='Diana', + family_name='krall', + _claim_names={ + 'address': 'src1', + 'phone_number': 'src1' + }, + _claim_sources={'src1': {'JWT': _jwt}}) + + _resp = self.service.parse_response(resp.to_json(), state='abcde') + _resp = self.service.post_parse_response(_resp, state='abcde') + assert set(_resp.keys()) == {'sub', 'given_name', 'family_name', + '_claim_names', '_claim_sources', + 'address', 'phone_number'} + + def test_unpack_aggregated_response_missing_keys(self): + claims = { + "address": { + "street_address": "1234 Hollywood Blvd.", + "locality": "Los Angeles", + "region": "CA", + "postal_code": "90210", + "country": "US" + }, + "phone_number": "+1 (555) 123-4567" + } + + _keyjar = build_keyjar(KEYSPEC) + + srv = JWT(_keyjar, iss=ISS, sign_alg='ES256') + _jwt = srv.pack(payload=claims) + + resp = OpenIDSchema(sub='diana', given_name='Diana', + family_name='krall', + _claim_names={ + 'address': 'src1', + 'phone_number': 'src1' + }, + _claim_sources={'src1': {'JWT': _jwt}}) + + _resp = self.service.parse_response(resp.to_json(), state='abcde') + assert _resp + + def test_unpack_signed_response(self): + resp = OpenIDSchema(sub='diana', given_name='Diana', + family_name='krall', iss=ISS) + sk = ISS_KEY.get_signing_key('rsa', issuer_id=ISS) + alg = self.service.client_get("service_context").get_sign_alg('userinfo') + _resp = self.service.parse_response(resp.to_jwt(sk, algorithm=alg), + state='abcde', sformat='jwt') + assert _resp + + def test_unpack_encrypted_response(self): + # Add encryption key + _kj = build_keyjar([{"type": "RSA", "use": ["enc"]}], issuer_id='') + # Own key jar gets the private key + self.service.client_get("service_context").keyjar.import_jwks( + _kj.export_jwks(private=True), issuer_id='client_id') + # opponent gets the public key + ISS_KEY.import_jwks(_kj.export_jwks(), issuer_id='client_id') + + resp = OpenIDSchema(sub='diana', given_name='Diana', + family_name='krall', iss=ISS, aud='client_id') + enckey = ISS_KEY.get_encrypt_key('rsa', issuer_id='client_id') + algspec = self.service.client_get("service_context").get_enc_alg_enc( + self.service.service_name) + + enc_resp = resp.to_jwe(enckey, **algspec) + _resp = self.service.parse_response(enc_resp, state='abcde', + sformat='jwt') + assert _resp + + +class TestCheckSession(object): + @pytest.fixture(autouse=True) + def create_request(self): + self._iss = ISS + client_config = { + 'client_id': 'client_id', 'client_secret': 'a longesh password', + 'redirect_uris': ['https://example.com/cli/authz_cb'], + 'issuer': self._iss, 'requests_dir': 'requests', + 'base_url': 'https://example.com/cli/' + } + services = { + "checksession": { + 'class': 'oidcrp.oidc.check_session.CheckSession' + }} + entity = Entity(keyjar=CLI_KEY, config=client_config, services=services) + entity.client_get("service_context").issuer = 'https://example.com' + self.service = entity.client_get("service",'check_session') + + def test_construct(self): + _state_interface = self.service.client_get("service_context").state + _state_interface.store_item(json.dumps({'id_token': 'a.signed.jwt'}), + 'token_response', 'abcde') + _req = self.service.construct(state='abcde') + assert isinstance(_req, CheckSessionRequest) + assert len(_req) == 1 + assert "id_token" in _req + assert _req["id_token"] == 'a.signed.jwt' + + +class TestCheckID(object): + @pytest.fixture(autouse=True) + def create_request(self): + self._iss = ISS + client_config = { + 'client_id': 'client_id', 'client_secret': 'a longesh password', + 'redirect_uris': ['https://example.com/cli/authz_cb'], + 'issuer': self._iss, 'requests_dir': 'requests', + 'base_url': 'https://example.com/cli/' + } + services = { + "checksession": { + 'class': 'oidcrp.oidc.check_id.CheckID' + }} + entity = Entity(keyjar=CLI_KEY, config=client_config, services=services) + entity.client_get("service_context").issuer = 'https://example.com' + self.service = entity.client_get("service",'check_id') + + def test_construct(self): + _state_interface = self.service.client_get("service_context").state + _state_interface.store_item(json.dumps({'id_token': 'a.signed.jwt'}), + 'token_response', 'abcde') + _req = self.service.construct(state='abcde') + assert isinstance(_req, CheckIDRequest) + assert len(_req) == 1 + assert "id_token" in _req + assert _req["id_token"] == 'a.signed.jwt' + + +class TestEndSession(object): + @pytest.fixture(autouse=True) + def create_request(self): + self._iss = ISS + client_config = { + 'client_id': 'client_id', 'client_secret': 'a longesh password', + 'redirect_uris': ['https://example.com/cli/authz_cb'], + 'issuer': self._iss, 'requests_dir': 'requests', + 'base_url': 'https://example.com/cli/', + 'post_logout_redirect_uris': ['https://example.com/post_logout'] + } + services = { + "checksession": { + 'class': 'oidcrp.oidc.end_session.EndSession' + }} + entity = Entity(keyjar=CLI_KEY, config=client_config, services=services) + entity.client_get("service_context").issuer = 'https://example.com' + self.service = entity.client_get("service",'end_session') + + def test_construct(self): + self.service.client_get("service_context").state.store_item( + json.dumps({'id_token': 'a.signed.jwt'}), + 'token_response', 'abcde') + _req = self.service.construct(state='abcde') + assert isinstance(_req, EndSessionRequest) + assert len(_req) == 3 + assert set(_req.keys()) == {'state', 'id_token_hint', + 'post_logout_redirect_uri'} + + +def test_authz_service_conf(): + client_config = { + 'client_id': 'client_id', + 'client_secret': 'a longesh password', + 'redirect_uris': ['https://example.com/cli/authz_cb'], + 'behaviour': {'response_types': ['code']} + } + + services = { + "authz": { + 'class': 'oidcrp.oidc.authorization.Authorization', + 'kwargs': { + 'conf': { + 'request_args': { + 'claims': { + "id_token": + { + "auth_time": {"essential": True}, + "acr": {"values": ["urn:mace:incommon:iap:silver"]} + } + } + } + } + } + } + } + entity = Entity(keyjar=CLI_KEY, config=client_config, services=services) + entity.client_get("service_context").issuer = 'https://example.com' + service = entity.client_get("service",'authorization') + + req = service.construct() + assert 'claims' in req + assert set(req['claims'].keys()) == {'id_token'} + + +def test_add_jwks_uri_or_jwks_0(): + client_config = { + 'client_id': 'client_id', 'client_secret': 'a longesh password', + 'redirect_uris': ['https://example.com/cli/authz_cb'], + 'jwks_uri': 'https://example.com/jwks/jwks.json', + 'issuer': ISS, + 'client_preferences': { + 'id_token_signed_response_alg': 'RS384', + 'userinfo_signed_response_alg': 'RS384' + } + } + entity = Entity(keyjar=CLI_KEY, config=client_config, services=DEFAULT_OIDC_SERVICES) + entity.client_get("service_context").issuer = 'https://example.com' + service = entity.client_get("service",'registration') + + req_args, post_args = add_jwks_uri_or_jwks({}, service) + assert req_args['jwks_uri'] == 'https://example.com/jwks/jwks.json' + + +def test_add_jwks_uri_or_jwks_1(): + client_config = { + 'client_id': 'client_id', 'client_secret': 'a longesh password', + 'redirect_uris': ['https://example.com/cli/authz_cb'], + 'jwks_uri': 'https://example.com/jwks/jwks.json', + 'jwks': {"keys": []}, + 'issuer': ISS, + 'client_preferences': { + 'id_token_signed_response_alg': 'RS384', + 'userinfo_signed_response_alg': 'RS384' + } + } + entity = Entity(keyjar=CLI_KEY, config=client_config, services=DEFAULT_OIDC_SERVICES) + service = entity.client_get("service",'registration') + + req_args, post_args = add_jwks_uri_or_jwks({}, service) + assert req_args['jwks_uri'] == 'https://example.com/jwks/jwks.json' + assert set(req_args.keys()) == {'jwks_uri'} + + +def test_add_jwks_uri_or_jwks_2(): + client_config = { + 'client_id': 'client_id', 'client_secret': 'a longesh password', + 'redirect_uris': ['https://example.com/cli/authz_cb'], + 'issuer': ISS, + 'client_preferences': { + 'id_token_signed_response_alg': 'RS384', + 'userinfo_signed_response_alg': 'RS384' + } + } + entity = Entity(keyjar=CLI_KEY, config=client_config, + jwks_uri='https://example.com/jwks/jwks.json', services=DEFAULT_OIDC_SERVICES) + service = entity.client_get("service",'registration') + + req_args, post_args = add_jwks_uri_or_jwks({}, service) + assert req_args['jwks_uri'] == 'https://example.com/jwks/jwks.json' + assert set(req_args.keys()) == {'jwks_uri'} diff --git a/tests/test_14_oidc.py b/tests/test_14_oidc.py index 747c138..5ebc41c 100755 --- a/tests/test_14_oidc.py +++ b/tests/test_14_oidc.py @@ -3,8 +3,6 @@ import sys import time -import pytest -import responses from cryptojwt.jwk.rsa import import_private_rsa_key_from_file from cryptojwt.key_bundle import KeyBundle from oidcmsg.oauth2 import AccessTokenRequest @@ -15,6 +13,8 @@ from oidcmsg.oidc import IdToken from oidcmsg.oidc import OpenIDSchema from oidcmsg.time_util import utc_time_sans_frac +import pytest +import responses from oidcrp.oidc import RP @@ -38,7 +38,6 @@ def access_token_callback(endpoint): return 'access_token' - class TestClient(object): @pytest.fixture(autouse=True) def create_client(self): @@ -54,65 +53,66 @@ def test_construct_authorization_request(self): req_args = { 'state': 'ABCDE', 'redirect_uri': 'https://example.com/auth_cb', - 'response_type': ['code'] + 'response_type': ['code'], + 'nonce': 'nonce' } - self.client.session_interface.create_state('issuer', 'ABCDE') + self.client.client_get("service_context").state.create_state('issuer', 'ABCDE') - msg = self.client.service['authorization'].construct( + msg = self.client.client_get("service",'authorization').construct( request_args=req_args) assert isinstance(msg, AuthorizationRequest) assert msg['redirect_uri'] == 'https://example.com/auth_cb' def test_construct_accesstoken_request(self): - auth_request = AuthorizationRequest( - redirect_uri='https://example.com/cli/authz_cb', - state='ABCDE' - ) + _context = self.client.client_get("service_context") + auth_request = AuthorizationRequest(redirect_uri='https://example.com/cli/authz_cb') - self.client.session_interface.store_item(auth_request, - 'auth_request', 'ABCDE') + _state = _context.state.create_state('issuer') + auth_request["state"] = _state + + _context.state.store_item(auth_request, 'auth_request', _state) auth_response = AuthorizationResponse(code='access_code') - self.client.session_interface.store_item(auth_response, - 'auth_response', 'ABCDE') + _context.state.store_item(auth_response, 'auth_response', _state) # Bind access code to state req_args = {} - msg = self.client.service['accesstoken'].construct( - request_args=req_args, state='ABCDE') + msg = self.client.client_get("service",'accesstoken').construct(request_args=req_args, state=_state) assert isinstance(msg, AccessTokenRequest) assert msg.to_dict() == { - 'client_id': 'client_1', 'code': 'access_code', + 'client_id': 'client_1', 'client_secret': 'abcdefghijklmnop', 'grant_type': 'authorization_code', - 'redirect_uri': 'https://example.com/cli/authz_cb', - 'state': 'ABCDE' + 'state': _state, + 'code': 'access_code', + 'redirect_uri': 'https://example.com/cli/authz_cb' } def test_construct_refresh_token_request(self): - self.client.session_interface.create_state('issuer', 'ABCDE') + _context = self.client.client_get("service_context") + _context.state.create_state('issuer', 'ABCDE') auth_request = AuthorizationRequest( redirect_uri='https://example.com/cli/authz_cb', state='state' ) - self.client.session_interface.store_item(auth_request, + _context.state.store_item(auth_request, 'auth_request', 'ABCDE') auth_response = AuthorizationResponse(code='access_code') - self.client.session_interface.store_item(auth_response, + _context.state.store_item(auth_response, 'auth_response', 'ABCDE') token_response = AccessTokenResponse(refresh_token="refresh_with_me", access_token="access") - self.client.session_interface.store_item(token_response, + _context.state.store_item(token_response, 'token_response', 'ABCDE') req_args = {} - msg = self.client.service['refresh_token'].construct( + msg = self.client.client_get("service",'refresh_token').construct( request_args=req_args, state='ABCDE') assert isinstance(msg, RefreshAccessTokenRequest) assert msg.to_dict() == { @@ -123,26 +123,27 @@ def test_construct_refresh_token_request(self): } def test_do_userinfo_request_init(self): - self.client.session_interface.create_state('issuer', 'ABCDE') + _context = self.client.client_get("service_context") + _context.state.create_state('issuer', 'ABCDE') auth_request = AuthorizationRequest( redirect_uri='https://example.com/cli/authz_cb', state='state' ) - self.client.session_interface.store_item(auth_request, + _context.state.store_item(auth_request, 'auth_request', 'ABCDE') auth_response = AuthorizationResponse(code='access_code') - self.client.session_interface.store_item(auth_response, + _context.state.store_item(auth_response, 'auth_response', 'ABCDE') token_response = AccessTokenResponse(refresh_token="refresh_with_me", access_token="access") - self.client.session_interface.store_item(token_response, + _context.state.store_item(token_response, 'token_response', 'ABCDE') - _srv = self.client.service['userinfo'] + _srv = self.client.client_get("service",'userinfo') _srv.endpoint = "https://example.com/userinfo" _info = _srv.get_request_parameters(state='ABCDE') assert _info diff --git a/tests/test_14_pkce.py b/tests/test_14_pkce.py new file mode 100644 index 0000000..77eaae5 --- /dev/null +++ b/tests/test_14_pkce.py @@ -0,0 +1,128 @@ +import os + +from cryptojwt.key_jar import init_key_jar +from oidcmsg.message import Message +from oidcmsg.message import SINGLE_REQUIRED_STRING +from oidcmsg.oauth2 import AuthorizationResponse +import pytest + +from oidcrp.entity import Entity +from oidcrp.oauth2 import DEFAULT_OAUTH2_SERVICES +from oidcrp.oauth2.add_on import do_add_ons +from oidcrp.oauth2.add_on.pkce import add_code_challenge +from oidcrp.oauth2.add_on.pkce import add_code_verifier +from oidcrp.service import Service + + +class DummyMessage(Message): + c_param = { + "req_str": SINGLE_REQUIRED_STRING, + } + + +class DummyService(Service): + msg_type = DummyMessage + + +_dirname = os.path.dirname(os.path.abspath(__file__)) + +ISS = 'https://example.com' + +KEYSPEC = [ + {"type": "RSA", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] + +CLI_KEY = init_key_jar(public_path='{}/pub_client.jwks'.format(_dirname), + private_path='{}/priv_client.jwks'.format(_dirname), + key_defs=KEYSPEC, issuer_id='client_id') + + +class TestPKCE256: + @pytest.fixture(autouse=True) + def create_client(self): + config = { + 'client_id': 'client_id', + 'client_secret': 'a longesh password', + 'redirect_uris': ['https://example.com/cli/authz_cb'], + 'behaviour': {'response_types': ['code']}, + 'add_ons': { + "pkce": { + "function": "oidcrp.oauth2.add_on.pkce.add_support", + "kwargs": { + "code_challenge_length": 64, + "code_challenge_method": "S256" + } + } + } + } + self.entity = Entity(keyjar=CLI_KEY, config=config, services=DEFAULT_OAUTH2_SERVICES) + + if 'add_ons' in config: + do_add_ons(config['add_ons'], self.entity.client_get("services")) + + def test_add_code_challenge_default_values(self): + auth_serv = self.entity.client_get("service","authorization") + _state_key = self.entity.client_get("service_context").state.create_state(iss="Issuer") + request_args, _ = add_code_challenge({'state': _state_key}, auth_serv) + + # default values are length:64 method:S256 + assert set(request_args.keys()) == {'code_challenge', 'code_challenge_method', + 'state'} + assert request_args['code_challenge_method'] == 'S256' + + request_args = add_code_verifier({}, auth_serv, state=_state_key) + assert len(request_args['code_verifier']) == 64 + + def test_authorization_and_pkce(self): + auth_serv = self.entity.client_get("service","authorization") + _state = self.entity.client_get("service_context").state.create_state(iss='Issuer') + + request = auth_serv.construct_request({"state": _state, "response_type": "code"}) + assert set(request.keys()) == {'client_id', 'code_challenge', + 'code_challenge_method', 'state', + 'redirect_uri', 'response_type'} + + def test_access_token_and_pkce(self): + authz_service = self.entity.client_get("service","authorization") + request = authz_service.construct_request({"state": 'state', "response_type": "code"}) + _state = request['state'] + auth_response = AuthorizationResponse(code='access code') + self.entity.client_get("service_context").state.store_item(auth_response, 'auth_response', _state) + + token_service = self.entity.client_get("service","accesstoken") + request = token_service.construct_request(state=_state) + assert set(request.keys()) == {'client_id', 'redirect_uri', 'grant_type', + 'client_secret', 'code_verifier', 'code', + 'state'} + + +class TestPKCE384: + @pytest.fixture(autouse=True) + def create_client(self): + config = { + 'client_id': 'client_id', 'client_secret': 'a longesh password', + 'redirect_uris': ['https://example.com/cli/authz_cb'], + 'add_ons': { + "pkce": { + "function": "oidcrp.oauth2.add_on.pkce.add_support", + "kwargs": { + "code_challenge_length": 128, + "code_challenge_method": "S384" + } + } + } + } + self.entity = Entity(keyjar=CLI_KEY, config=config, services=DEFAULT_OAUTH2_SERVICES) + if 'add_ons' in config: + do_add_ons(config['add_ons'], self.entity.client_get("services")) + + def test_add_code_challenge_spec_values(self): + auth_serv = self.entity.client_get("service","authorization") + request_args, _ = add_code_challenge({'state': 'state'}, auth_serv) + assert set(request_args.keys()) == {'code_challenge', 'code_challenge_method', + 'state'} + assert request_args['code_challenge_method'] == 'S384' + + request_args = add_code_verifier({}, auth_serv, state='state') + assert len(request_args['code_verifier']) == 128 diff --git a/tests/test_15_oic_utils.py b/tests/test_15_oic_utils.py new file mode 100644 index 0000000..bbb3bf4 --- /dev/null +++ b/tests/test_15_oic_utils.py @@ -0,0 +1,49 @@ +from cryptojwt.jwe.jwe import factory +from cryptojwt.key_jar import build_keyjar +from oidcmsg.oidc import AuthorizationRequest + +from oidcrp.oidc.utils import construct_request_uri +from oidcrp.oidc.utils import request_object_encryption +from oidcrp.service_context import ServiceContext + +KEYSPEC = [ + {"type": "RSA", "use": ["enc"]}, + {"type": "EC", "crv": "P-256", "use": ["enc"]}, +] + +RECEIVER = 'https://example.org/op' + +KEYJAR = build_keyjar(KEYSPEC, issuer_id=RECEIVER) + + +def test_request_object_encryption(): + msg = AuthorizationRequest(state='ABCDE', + redirect_uri='https://example.com/cb', + response_type='code') + + conf = { + 'redirect_uris': ['https://example.com/cli/authz_cb'], + 'client_id': 'client_1', + 'client_secret': 'abcdefghijklmnop', + } + service_context = ServiceContext(keyjar=KEYJAR, config=conf) + _behav = service_context.behaviour + _behav["request_object_encryption_alg"] = 'RSA1_5' + _behav["request_object_encryption_enc"] = "A128CBC-HS256" + service_context.behaviour = _behav + + _jwe = request_object_encryption(msg.to_json(), service_context, target=RECEIVER) + assert _jwe + + _decryptor = factory(_jwe) + + assert _decryptor.jwt.verify_headers(alg='RSA1_5', enc='A128CBC-HS256') + + +def test_construct_request_uri(): + local_dir = 'home' + base_path = 'https://example.com/' + a, b = construct_request_uri(local_dir, base_path) + assert a.startswith('home') and a.endswith('.jwt') + d, f = a.split('/') + assert b == '{}{}'.format(base_path, f) diff --git a/tests/test_16_cc_oauth2_service.py b/tests/test_16_cc_oauth2_service.py new file mode 100644 index 0000000..ec5861b --- /dev/null +++ b/tests/test_16_cc_oauth2_service.py @@ -0,0 +1,163 @@ +from oidcmsg.oauth2 import AccessTokenResponse +import pytest + +from oidcrp.entity import Entity +from oidcrp.util import rndstr + +KEYDEF = [{"type": "EC", "crv": "P-256", "use": ["sig"]}] + + +class TestRP(): + @pytest.fixture(autouse=True) + def create_service(self): + client_config = { + 'client_id': 'client_id', + 'client_secret': 'another password' + } + services = { + 'token': { + 'class': 'oidcrp.oauth2.client_credentials.cc_access_token.CCAccessToken' + }, + 'refresh_token': { + 'class': 'oidcrp.oauth2.client_credentials.cc_refresh_access_token' + '.CCRefreshAccessToken' + } + } + + self.entity = Entity(config=client_config, services=services) + + self.entity.client_get("service",'accesstoken').endpoint = 'https://example.com/token' + self.entity.client_get("service",'refresh_token').endpoint = 'https://example.com/token' + + def test_token_get_request(self): + request_args = {'grant_type': 'client_credentials'} + _srv = self.entity.client_get("service",'accesstoken') + _info = _srv.get_request_parameters(request_args=request_args) + assert _info['method'] == 'POST' + assert _info['url'] == 'https://example.com/token' + assert _info['body'] == 'grant_type=client_credentials' + assert _info['headers'] == { + 'Authorization': 'Basic Y2xpZW50X2lkOmFub3RoZXIrcGFzc3dvcmQ=', + 'Content-Type': 'application/x-www-form-urlencoded' + } + + def test_token_parse_response(self): + request_args = {'grant_type': 'client_credentials'} + _srv = self.entity.client_get("service",'accesstoken') + _request_info = _srv.get_request_parameters(request_args=request_args) + + response = AccessTokenResponse(**{ + "access_token": "2YotnFZFEjr1zCsicMWpAA", + "token_type": "example", + "expires_in": 3600, + "refresh_token": "tGzv3JOkF0XG5Qx2TlKWIA", + "example_parameter": "example_value" + }) + + _response = _srv.parse_response(response.to_json(), sformat="json") + # since no state attribute is involved, a key is minted + _key = rndstr(16) + _srv.update_service_context(_response, key=_key) + info = _srv.client_get("service_context").state.get_item(AccessTokenResponse, 'token_response', _key) + assert '__expires_at' in info + + def test_refresh_token_get_request(self): + _srv = self.entity.client_get("service",'accesstoken') + _srv.update_service_context({ + "access_token": "2YotnFZFEjr1zCsicMWpAA", + "token_type": "example", + "expires_in": 3600, + "refresh_token": "tGzv3JOkF0XG5Qx2TlKWIA", + "example_parameter": "example_value" + }) + _srv = self.entity.client_get("service",'refresh_token') + _id = rndstr(16) + _info = _srv.get_request_parameters(state_id=_id) + assert _info['method'] == 'POST' + assert _info['url'] == 'https://example.com/token' + assert _info[ + 'body'] == 'grant_type=refresh_token' + assert _info['headers'] == { + 'Authorization': 'Bearer tGzv3JOkF0XG5Qx2TlKWIA', + 'Content-Type': 'application/x-www-form-urlencoded' + } + + def test_refresh_token_parse_response(self): + request_args = {'grant_type': 'client_credentials'} + _srv = self.entity.client_get("service",'accesstoken') + _request_info = _srv.get_request_parameters(request_args=request_args) + + response = AccessTokenResponse(**{ + "access_token": "2YotnFZFEjr1zCsicMWpAA", + "token_type": "example", + "expires_in": 3600, + "refresh_token": "tGzv3JOkF0XG5Qx2TlKWIA", + "example_parameter": "example_value" + }) + + _response = _srv.parse_response(response.to_json(), sformat="json") + # since no state attribute is involved, a key is minted + _key = rndstr(16) + _srv.update_service_context(_response, key=_key) + info = _srv.client_get("service_context").state.get_item(AccessTokenResponse, 'token_response', _key) + assert '__expires_at' in info + + # Move from token to refresh token service + + _srv = self.entity.client_get("service",'refresh_token') + _request_info = _srv.get_request_parameters(request_args=request_args, state=_key) + + refresh_response = AccessTokenResponse(**{ + "access_token": 'wy4R01DmMoB5xkI65nNkVv1l', + "token_type": "example", + "expires_in": 3600, + "refresh_token": 'lhNX9LSG8w1QuD6tSgc6CPfJ', + }) + + _response = _srv.parse_response(refresh_response.to_json(), sformat="json") + _srv.update_service_context(_response, key=_key) + info = _srv.client_get("service_context").state.get_item(AccessTokenResponse, 'token_response', _key) + assert '__expires_at' in info + + def test_2nd_refresh_token_parse_response(self): + request_args = {'grant_type': 'client_credentials'} + _srv = self.entity.client_get("service",'accesstoken') + _request_info = _srv.get_request_parameters(request_args=request_args) + + response = AccessTokenResponse(**{ + "access_token": "2YotnFZFEjr1zCsicMWpAA", + "token_type": "example", + "expires_in": 3600, + "refresh_token": "tGzv3JOkF0XG5Qx2TlKWIA", + "example_parameter": "example_value" + }) + + _response = _srv.parse_response(response.to_json(), sformat="json") + # since no state attribute is involved, a key is minted + _key = rndstr(16) + _srv.update_service_context(_response, key=_key) + info = _srv.client_get("service_context").state.get_item(AccessTokenResponse, 'token_response', _key) + assert '__expires_at' in info + + # Move from token to refresh token service + + _srv = self.entity.client_get("service",'refresh_token') + _request_info = _srv.get_request_parameters(request_args=request_args, state=_key) + + refresh_response = AccessTokenResponse(**{ + "access_token": 'wy4R01DmMoB5xkI65nNkVv1l', + "token_type": "example", + "expires_in": 3600, + "refresh_token": 'lhNX9LSG8w1QuD6tSgc6CPfJ', + }) + + _response = _srv.parse_response(refresh_response.to_json(), sformat="json") + _srv.update_service_context(_response, key=_key) + info = _srv.client_get("service_context").state.get_item(AccessTokenResponse, 'token_response', _key) + assert '__expires_at' in info + + _request_info = _srv.get_request_parameters(request_args=request_args, state=_key) + assert _request_info['headers'] == { + 'Authorization': 'Bearer {}'.format(refresh_response["refresh_token"]), + 'Content-Type': 'application/x-www-form-urlencoded' + } diff --git a/tests/test_17_read_registration.py b/tests/test_17_read_registration.py new file mode 100644 index 0000000..ca01f0e --- /dev/null +++ b/tests/test_17_read_registration.py @@ -0,0 +1,98 @@ +import json +import time + +import pytest +import responses +from cryptojwt.utils import as_bytes +from oidcmsg.oidc import RegistrationResponse + +from oidcrp.entity import Entity +import requests +from oidcrp.service_context import ServiceContext +from oidcrp.service_factory import service_factory + +ISS = "https://example.com" +RP_BASEURL = "https://example.com/rp" + + +class TestRegistrationRead(object): + @pytest.fixture(autouse=True) + def create_request(self): + self._iss = ISS + client_config = { + "redirect_uris": ["https://example.com/cli/authz_cb"], + "issuer": self._iss, "requests_dir": "requests", + "base_url": "https://example.com/cli/", + "client_preferences": { + "application_type": "web", + "response_types": ["code"], + "contacts": ["ops@example.org"], + "jwks_uri": "https://example.com/rp/static/jwks.json", + "redirect_uris": ["{}/authz_cb".format(RP_BASEURL)], + "token_endpoint_auth_method": "client_secret_basic", + "grant_types": ["authorization_code"] + } + } + services = { + 'registration': { + 'class': 'oidcrp.oidc.registration.Registration' + }, + 'read_registration': { + 'class': 'oidcrp.oidc.read_registration.RegistrationRead' + } + } + + self.entity = Entity(config=client_config, services=services) + + self.reg_service = self.entity.client_get("service",'registration') + self.read_service = self.entity.client_get("service",'registration_read') + + def test_construct(self): + self.reg_service.endpoint = "{}/registration".format(ISS) + + _param = self.reg_service.get_request_parameters() + + now = int(time.time()) + + _client_registration_response = json.dumps({ + "client_id": "zls2qhN1jO6A", + "client_secret": "c8434f28cf9375d9a7", + "registration_access_token": "NdGrGR7LCuzNtixvBFnDphGXv7wRcONn", + "registration_client_uri": "{}/registration_api?client_id=zls2qhN1jO6A".format(ISS), + "client_secret_expires_at": now + 3600, + "client_id_issued_at": now, + "application_type": "web", + "response_types": ["code"], + "contacts": ["ops@example.com"], + "redirect_uris": ["{}/authz_cb".format(RP_BASEURL)], + "token_endpoint_auth_method": "client_secret_basic", + "grant_types": ["authorization_code"] + }) + + with responses.RequestsMock() as rsps: + rsps.add(_param["method"], _param["url"], body=_client_registration_response, status=200) + _resp = requests.request( + _param["method"], _param["url"], + data=as_bytes(_param["body"]), + headers=_param["headers"], + verify=False + ) + + resp = self.reg_service.parse_response(_resp.text) + self.reg_service.update_service_context(resp) + + assert resp + + _read_param = self.read_service.get_request_parameters() + with responses.RequestsMock() as rsps: + rsps.add(_param["method"], _param["url"], body=_client_registration_response, + adding_headers={"Content-Type": "application/json"}, status=200) + _resp = requests.request( + _param["method"], + _param["url"], + headers=_param["headers"], + verify=False + ) + + read_resp = self.reg_service.parse_response(_resp.text) + assert isinstance(read_resp, RegistrationResponse) diff --git a/tests/test_20_conversation.py b/tests/test_20_conversation.py new file mode 100644 index 0000000..9659aa4 --- /dev/null +++ b/tests/test_20_conversation.py @@ -0,0 +1,427 @@ +#!/usr/bin/env python3 +import json +import time +from urllib.parse import parse_qs +from urllib.parse import urlparse + +from cryptojwt.jwt import JWT +from cryptojwt.key_jar import KeyJar +from oidcmsg.oidc import AccessTokenResponse +from oidcmsg.oidc import AuthorizationResponse +from oidcmsg.oidc import JRD +from oidcmsg.oidc import Link +from oidcmsg.oidc import OpenIDSchema +from oidcmsg.oidc import ProviderConfigurationResponse + +from oidcrp.defaults import DEFAULT_OIDC_SERVICES +from oidcrp.entity import Entity +from oidcrp.oidc.webfinger import WebFinger + +# ================== SETUP =========================== + +KEYSPEC = [ + {"type": "RSA", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] + +JWKS_OP = { + 'keys': [{ + 'd': 'mcAW1xeNsjzyV1M7F7_cUHz0MIR' + '-tcnKFJnbbo5UXxMRUPu17qwRHr8ttep1Ie64r2L9QlphcT9BjYd0KQ8ll3flIzLtiJv__MNPQVjk5bsYzb_erQRzSwLJU-aCcNFB8dIyQECzu-p44UVEPQUGzykImsSShvMQhcvrKiqqg7NlijJuEKHaKynV9voPsjwKYSqk6lH8kMloCaVS-dOkK-r7bZtbODUxx9GJWnxhX0JWXcdrPZRb29y9cdthrMcEaCXG23AxnMEfp-enDqarLHYTQrCBJXs_b-9k2d8v9zLm7E-Pf-0YGmaoJtX89lwQkO_SmFF3sXsnI2cFreqU3Q', + 'e': 'AQAB', + 'kid': 'c19uYlBJXzVfNjNZeGVnYmxncHZwUzZTZDVwUFdxdVJLU3AxQXdwaFdfbw', + 'kty': 'RSA', + 'n': '3ZblhNL2CjRktLM9vyDn8jnA4G1B1HCpPh' + '-gv2AK4m9qDBZPYZGOGqzeW3vanvLTBlqnPm0GHg4rOrfMEwwLrfMcgmg1y4GD0vVU8G9HP1' + '-oUPtKUqaKOp313tFKzFh9_OHGQ6EmhxG7gegPR9kQXduTDXqBFi81MzRplIQ8DHLM3-n2CyDW1V-dhRVh' + '-AM0ZcJyzR_DvZ3mhG44DysPdHQOSeWnpdn1d81' + '-PriqZfhAF9tn1ihgtjXd5swf1HTSjLd7xv1hitGf2245Xmr' + '-V2pQFzeMukLM3JKbTYbElsB7Zm0wZx49hZMtgx35XMoO04bifdbO3yLtTA5ovXN3fQ', + 'p': '88aNu59aBn0elksaVznzoVKkdbT5B4euhOIEqJoFvFbEocw9mC4k' + '-yozIAQSV5FEakoSPOl8lrymCoM3Q1fVHfaM9Rbb9RCRlsV1JOeVVZOE05HUdz8zOIqLBDEGM_oQqDwF_kp' + '-4nDTZ1-dtnGdTo4Cf7QRuApzE_dwVabUCTc', + 'q': + '6LOHuM7H_0kDrMTwUEX7Aubzr792GoJ6EgTKIQY25SAFTZpYwuC3NnqlAdy8foIa3d7eGU2yICRbBG0S_ITcooDFrOa7nZ6enMUclMTxW8FwwvBXeIHo9cIsrKYtOThGplz43Cvl73MK5M58ZRmuhaNYa6Mk4PL4UokARfEiDus', + 'use': 'sig' + }, + { + 'crv': 'P-256', + 'd': 'N2dg0-DAROBF8owQA4-uY5s0Ab-Fep_42kEFQG4BNVQ', + 'kid': 'UnpYbi0tWC1HaEtyRFMtSmkyZDVHUHZVNDF0d21KTVk1dzEwYmhpNlVtQQ', + 'kty': 'EC', + 'use': 'sig', + 'x': 'Ls8SqX8Ti5QAKtw3rdGr5K537-tqQCIbhyebeE_2C38', + 'y': 'S-BrbPQkh8HVFLWg5Wid_5OAk4ewn5skHlHtG08ShaA' + } + ] +} + +OP_KEYJAR = KeyJar() +OP_KEYJAR.import_jwks(JWKS_OP, '') +OP_PUBLIC_JWKS = OP_KEYJAR.export_jwks() +OP_BASEURL = "https://example.org/op" + +RP_JWKS = { + "keys": [{ + "kty": "RSA", "use": "sig", + "kid": "Mk0yN2w0N3BZLWtyOEpQWGFmNDZvQi1hbDl2azR3ai1WNElGdGZQSFd6MA", + "e": "AQAB", + "n": "yPrOADZtGoa9jxFCmDsJ1nAYmzgznUxCtUlb_ty33" + "-AFNEqzW_pSLr5g6RQAPGsvVQqbsb9AB18QNgz" + "-eG7cnvKIIR7JXWCuGv_Q9MwoRD0-zaYGRbRvFoTZokZMB6euBfMo6kijJ" + "-gdKuSaxIE84X_Fcf1ESAKJ0EX6Cxdm8hKkBelGIDPMW5z7EHQ8OuLCQtTJnDvbjEOk9sKzkKqVj53XFs5vjd4WUhxS6xIDcWE-lTafUpm0BsobklLePidHxyAMGOunL_Pt3RCLZGlWeWOO9fZhLtydiDWiZlcNR0FQEX_mfV1kCOHHBFN1VKOY2pyJpjp9djdtHxPZ9fP35w", + "d": + "aRBTqGDLYFaXuba4LYSPe_5Vnq8erFg1dzfGU9Fmfi5KCjAS2z5cv_reBnpiNTODJt3Izn7AJhpYCyl3zdWGl8EJ0OabNalY2txoi9A-LI4nyrHEDaRpfkgszVwaWtYZbxrShMc8I5x_wvCGx7sX7Hoy6YgQreRFzw8Fy86MDncpmcUwQTnXVUMLgioeYz5gW6rwXkqj_NVyuHPiheykJG026cXFNBWplCk4ET1bvf_6ZB9QmLwO16Pu2O-dtu1HHDOqI7y6-YgKIC6mcLrQrF9-FO7NkilcOB7zODNiYzhDBQ2YJAbcdn_3M_lkhaFwR-n4WB7vCM0vNqz7lEg6QQ", + "p": + "_STNoJFkX9_uw8whytVmTrHP5K7vcZBIH9nuCTvj137lC48ZpR1UARx4qShxHLfK7DrufHd7TYnJkEMNUHFmdKvkaVQMY0_BsBSvCrUl10gzxsI08hg53L17E1Pe73iZp3f5nA4eB-1YB-km1Cc-Xs10OPWedJHf9brlCPDLAb8", + "q": + "yz9T0rPEc0ZPjSi45gsYiQL2KJ3UsPHmLrgOHq0D4UvsB6UFtUtOWh7A1UpQdmBuHjIJz" + "-Iq7VH4kzlI6VxoXhwE69oxBXr4I7fBudZRvlLuIJS9M2wvsTVouj0DBYSR6ZlAQHCCou89P2P6zQCEaqu7bWXNcpyTixbbvOU1w9k" + }, + { + "kty": "EC", "use": "sig", + "kid": "ME9NV3VQV292OTA4T1pNLXZoVjd2TldVSjNrNEkycjU2ZjkycldQOTcyUQ", + "crv": "P-256", + "x": "WWoO_Exim-LOD1k8QPi_CdU8M_VUSF7DkJCKR7PFWhQ", + "y": "EpxHNZp6ykyeLiS6r7l9ly2in1Zju7hnLk7RFraklxE", + "d": "pepDloEcTyHnoEuqFirZ8hpt861piMDgiuvHIhhRSpM" + }] +} + +RP_KEYJAR = KeyJar() +RP_KEYJAR.import_jwks(RP_JWKS, '') +RP_KEYJAR.import_jwks(OP_PUBLIC_JWKS, OP_BASEURL) +RP_BASEURL = "https://example.com/rp" + +SERVICE_PUBLIC_JWKS = RP_KEYJAR.export_jwks('') +OP_KEYJAR.import_jwks(SERVICE_PUBLIC_JWKS, RP_BASEURL) + + +# --------------------------------------------------- + +def test_conversation(): + config = { + "client_preferences": + { + "application_type": "web", + "application_name": "rphandler", + "contacts": ["ops@example.org"], + "response_types": ["code"], + "scope": ["openid", "profile", "email", "address", "phone"], + "token_endpoint_auth_method": "client_secret_basic", + }, + "redirect_uris": ["{}/authz_cb".format(RP_BASEURL)], + "jwks_uri": "{}/static/jwks.json".format(RP_BASEURL) + } + + service_spec = DEFAULT_OIDC_SERVICES.copy() + service_spec['WebFinger'] = {'class': WebFinger} + + entity = Entity(config=config, services=service_spec, keyjar=RP_KEYJAR) + + assert set(entity.client_get("services").keys()) == {'accesstoken', 'authorization', + 'webfinger', + 'registration', 'refresh_token', + 'userinfo', + 'provider_info'} + service_context = entity.client_get("service_context") + + # ======================== WebFinger ======================== + + webfinger_service = entity.client_get("service",'webfinger') + info = webfinger_service.get_request_parameters( + request_args={'resource': 'foobar@example.org'}) + + assert info[ + 'url'] == 'https://example.org/.well-known/webfinger?rel=http' \ + '%3A%2F' \ + '%2Fopenid.net%2Fspecs%2Fconnect%2F1.0%2Fissuer' \ + '&resource' \ + '=acct%3Afoobar%40example.org' + + webfinger_response = json.dumps({ + "subject": "acct:foobar@example.org", + "links": [{ + "rel": "http://openid.net/specs/connect/1.0/issuer", + "href": "https://example.org/op" + }], + "expires": "2018-02-04T11:08:41Z" + }) + + response = webfinger_service.parse_response(webfinger_response) + + assert isinstance(response, JRD) + assert set(response.keys()) == {'subject', 'links', 'expires'} + assert response['links'] == [ + Link(rel='http://openid.net/specs/connect/1.0/issuer', + href='https://example.org/op')] + + webfinger_service.update_service_context(resp=response) + entity.client_get("service_context").issuer = OP_BASEURL + + # =================== Provider info discovery ==================== + provider_info_service = entity.client_get("service",'provider_info') + info = provider_info_service.get_request_parameters() + + assert info[ + 'url'] == 'https://example.org/op/.well-known/openid' \ + '-configuration' + + provider_info_response = json.dumps({ + "version": "3.0", + "token_endpoint_auth_methods_supported": [ + "client_secret_post", "client_secret_basic", + "client_secret_jwt", "private_key_jwt"], + "claims_parameter_supported": True, + "request_parameter_supported": True, + "request_uri_parameter_supported": True, + "require_request_uri_registration": True, + "grant_types_supported": ["authorization_code", + "implicit", + "urn:ietf:params:oauth:grant-type:jwt-bearer", + "refresh_token"], + "response_types_supported": ["code", "id_token", + "id_token token", + "code id_token", + "code token", + "code id_token token"], + "response_modes_supported": ["query", "fragment", + "form_post"], + "subject_types_supported": ["public", "pairwise"], + "claim_types_supported": ["normal", "aggregated", + "distributed"], + "claims_supported": ["birthdate", "address", + "nickname", "picture", "website", + "email", "gender", "sub", + "phone_number_verified", + "given_name", "profile", + "phone_number", "updated_at", + "middle_name", "name", "locale", + "email_verified", + "preferred_username", "zoneinfo", + "family_name"], + "scopes_supported": ["openid", "profile", "email", + "address", "phone", + "offline_access", "openid"], + "userinfo_signing_alg_values_supported": [ + "RS256", "RS384", "RS512", + "ES256", "ES384", "ES512", + "HS256", "HS384", "HS512", + "PS256", "PS384", "PS512", "none"], + "id_token_signing_alg_values_supported": [ + "RS256", "RS384", "RS512", + "ES256", "ES384", "ES512", + "HS256", "HS384", "HS512", + "PS256", "PS384", "PS512", "none"], + "request_object_signing_alg_values_supported": [ + "RS256", "RS384", "RS512", "ES256", "ES384", + "ES512", "HS256", "HS384", "HS512", "PS256", + "PS384", "PS512", "none"], + "token_endpoint_auth_signing_alg_values_supported": [ + "RS256", "RS384", "RS512", "ES256", "ES384", + "ES512", "HS256", "HS384", "HS512", "PS256", + "PS384", "PS512"], + "userinfo_encryption_alg_values_supported": [ + "RSA1_5", "RSA-OAEP", "RSA-OAEP-256", + "A128KW", "A192KW", "A256KW", + "ECDH-ES", "ECDH-ES+A128KW", "ECDH-ES+A192KW", "ECDH-ES+A256KW"], + "id_token_encryption_alg_values_supported": [ + "RSA1_5", "RSA-OAEP", "RSA-OAEP-256", + "A128KW", "A192KW", "A256KW", + "ECDH-ES", "ECDH-ES+A128KW", "ECDH-ES+A192KW", "ECDH-ES+A256KW"], + "request_object_encryption_alg_values_supported": [ + "RSA1_5", "RSA-OAEP", "RSA-OAEP-256", "A128KW", + "A192KW", "A256KW", "ECDH-ES", "ECDH-ES+A128KW", + "ECDH-ES+A192KW", "ECDH-ES+A256KW"], + "userinfo_encryption_enc_values_supported": [ + "A128CBC-HS256", "A192CBC-HS384", "A256CBC-HS512", + "A128GCM", "A192GCM", "A256GCM"], + "id_token_encryption_enc_values_supported": [ + "A128CBC-HS256", "A192CBC-HS384", "A256CBC-HS512", + "A128GCM", "A192GCM", "A256GCM"], + "request_object_encryption_enc_values_supported": [ + "A128CBC-HS256", "A192CBC-HS384", "A256CBC-HS512", + "A128GCM", "A192GCM", "A256GCM"], + "acr_values_supported": ["PASSWORD"], + "issuer": OP_BASEURL, + "jwks_uri": "{}/static/jwks_tE2iLbOAqXhe8bqh.json".format(OP_BASEURL), + "authorization_endpoint": "{}/authorization".format(OP_BASEURL), + "token_endpoint": "{}/token".format(OP_BASEURL), + "userinfo_endpoint": "{}/userinfo".format(OP_BASEURL), + "registration_endpoint": "{}/registration".format(OP_BASEURL), + "end_session_endpoint": "{}/end_session".format(OP_BASEURL) + }) + + resp = provider_info_service.parse_response(provider_info_response) + + assert isinstance(resp, ProviderConfigurationResponse) + provider_info_service.update_service_context(resp) + + _pi = entity.client_get("service_context").provider_info + assert _pi['issuer'] == OP_BASEURL + assert _pi['authorization_endpoint'] == 'https://example.org/op/authorization' + assert _pi['registration_endpoint'] == 'https://example.org/op/registration' + + # =================== Client registration ==================== + registration_service = entity.client_get("service",'registration') + info = registration_service.get_request_parameters() + + assert info['url'] == 'https://example.org/op/registration' + _body = json.loads(info['body']) + assert _body == { + "application_type": "web", + "response_types": ["code"], + "contacts": ["ops@example.org"], + "jwks_uri": "https://example.com/rp/static/jwks.json", + "redirect_uris": ["{}/authz_cb".format(RP_BASEURL)], + 'token_endpoint_auth_method': 'client_secret_basic', + "grant_types": ["authorization_code"] + } + assert info['headers'] == {'Content-Type': 'application/json'} + + now = int(time.time()) + + op_client_registration_response = json.dumps({ + "client_id": "zls2qhN1jO6A", + "client_secret": "c8434f28cf9375d9a7", + "registration_access_token": "NdGrGR7LCuzNtixvBFnDphGXv7wRcONn", + "registration_client_uri": "{}/registration?client_id=zls2qhN1jO6A".format( + RP_BASEURL), + "client_secret_expires_at": now + 3600, + "client_id_issued_at": now, + "application_type": "web", + "response_types": ["code"], + "contacts": ["ops@example.com"], + "redirect_uris": ["{}/authz_cb".format(RP_BASEURL)], + "token_endpoint_auth_method": "client_secret_basic", + "grant_types": ["authorization_code"] + }) + + response = registration_service.parse_response(op_client_registration_response) + + registration_service.update_service_context(response) + + assert service_context.client_id == 'zls2qhN1jO6A' + assert service_context.client_secret == 'c8434f28cf9375d9a7' + assert set(service_context.registration_response.keys()) == { + 'client_secret_expires_at', 'contacts', 'client_id', + 'token_endpoint_auth_method', 'redirect_uris', 'response_types', + 'client_id_issued_at', 'client_secret', 'application_type', + 'registration_client_uri', 'registration_access_token', + 'grant_types'} + + # =================== Authorization ==================== + + STATE = 'Oh3w3gKlvoM2ehFqlxI3HIK5' + NONCE = 'UvudLKz287YByZdsY3AJoPAlEXQkJ0dK' + + auth_service = entity.client_get("service",'authorization') + _state_interface = service_context.state + + info = auth_service.get_request_parameters(request_args={'state': STATE, 'nonce': NONCE}) + + p = urlparse(info['url']) + _query = parse_qs(p.query) + assert set(_query.keys()) == {'state', 'nonce', 'response_type', 'scope', + 'client_id', 'redirect_uri'} + assert _query['scope'] == ['openid profile email address phone'] + assert _query['nonce'] == [NONCE] + assert _query['state'] == [STATE] + + op_authz_resp = { + 'state': STATE, + 'scope': 'openid', + 'code': 'Z0FBQUFBQmFkdFFjUVpFWE81SHU5N1N4N01', + 'iss': OP_BASEURL, + 'client_id': 'zls2qhN1jO6A' + } + + _authz_rep = AuthorizationResponse(**op_authz_resp) + + _resp = auth_service.parse_response(_authz_rep.to_urlencoded()) + auth_service.update_service_context(_resp, key=STATE) + _item = _state_interface.get_item(AuthorizationResponse, 'auth_response', STATE) + assert _item['code'] == 'Z0FBQUFBQmFkdFFjUVpFWE81SHU5N1N4N01' + + # =================== Access token ==================== + + token_service = entity.client_get("service",'accesstoken') + request_args = { + 'state': STATE, + 'redirect_uri': service_context.redirect_uris[0] + } + + info = token_service.get_request_parameters(request_args=request_args) + + assert info['url'] == 'https://example.org/op/token' + _qp = parse_qs(info['body']) + assert _qp == { + 'grant_type': ['authorization_code'], + 'redirect_uri': ['https://example.com/rp/authz_cb'], + 'client_id': ['zls2qhN1jO6A'], + 'state': ['Oh3w3gKlvoM2ehFqlxI3HIK5'], + 'code': ['Z0FBQUFBQmFkdFFjUVpFWE81SHU5N1N4N01'] + } + assert info['headers'] == { + 'Authorization': 'Basic ' + 'emxzMnFoTjFqTzZBOmM4NDM0ZjI4Y2Y5Mzc1ZDlhNw==', + 'Content-Type': 'application/x-www-form-urlencoded' + } + + # create the IdToken + _jwt = JWT(OP_KEYJAR, OP_BASEURL, lifetime=3600, sign=True, + sign_alg='RS256') + payload = { + 'sub': '1b2fc9341a16ae4e30082965d537', 'acr': 'PASSWORD', + 'auth_time': 1517736988, 'nonce': NONCE + } + _jws = _jwt.pack(payload=payload, recv='zls2qhN1jO6A') + + _resp = { + "state": "Oh3w3gKlvoM2ehFqlxI3HIK5", + "scope": "openid", + "access_token": "Z0FBQUFBQmFkdFF", + "token_type": "Bearer", + 'expires_in': 600, + "id_token": _jws + } + + service_context.issuer = OP_BASEURL + _resp = token_service.parse_response(json.dumps(_resp), state=STATE) + + assert isinstance(_resp, AccessTokenResponse) + assert set(_resp['__verified_id_token'].keys()) == { + 'iss', 'nonce', 'acr', 'auth_time', 'aud', 'iat', 'exp', 'sub'} + + token_service.update_service_context(_resp, key=STATE) + + _item = _state_interface.get_item(AccessTokenResponse, + 'token_response', STATE) + + assert set(_item.keys()) == {'state', 'scope', 'access_token', + 'token_type', 'id_token', + '__verified_id_token', + 'expires_in', '__expires_at'} + + assert _item['token_type'] == 'Bearer' + assert _item['access_token'] == 'Z0FBQUFBQmFkdFF' + + # =================== User info ==================== + + userinfo_service = entity.client_get("service",'userinfo') + info = userinfo_service.get_request_parameters(state=STATE) + + assert info['url'] == 'https://example.org/op/userinfo' + assert info['headers'] == {'Authorization': 'Bearer Z0FBQUFBQmFkdFF'} + + op_resp = {"sub": "1b2fc9341a16ae4e30082965d537"} + + _resp = userinfo_service.parse_response(json.dumps(op_resp), state=STATE) + userinfo_service.update_service_context(_resp, key=STATE) + + assert isinstance(_resp, OpenIDSchema) + assert _resp.to_dict() == {'sub': '1b2fc9341a16ae4e30082965d537'} + + _item = _state_interface.get_item(OpenIDSchema, + 'user_info', STATE) + assert _item.to_dict() == {'sub': '1b2fc9341a16ae4e30082965d537'} diff --git a/tests/test_20_rp_handler.py b/tests/test_20_rp_handler_oidc.py similarity index 90% rename from tests/test_20_rp_handler.py rename to tests/test_20_rp_handler_oidc.py index 55cdfdc..b74b588 100644 --- a/tests/test_20_rp_handler.py +++ b/tests/test_20_rp_handler_oidc.py @@ -4,8 +4,6 @@ from urllib.parse import urlparse from urllib.parse import urlsplit -import pytest -import responses from cryptojwt.key_jar import KeyJar from cryptojwt.key_jar import init_key_jar from oidcmsg.oidc import AccessTokenResponse @@ -15,10 +13,11 @@ from oidcmsg.oidc import Link from oidcmsg.oidc import OpenIDSchema from oidcmsg.oidc import ProviderConfigurationResponse -from oidcservice.service import init_services -from oidcservice.service_context import ServiceContext +import pytest +import responses -from oidcrp import RPHandler +from oidcrp.entity import Entity +from oidcrp.rp_handler import RPHandler BASE_URL = 'https://example.com/rp' @@ -39,27 +38,27 @@ "redirect_uris": None, "services": { 'web_finger': { - 'class': 'oidcservice.oidc.webfinger.WebFinger' + 'class': 'oidcrp.oidc.webfinger.WebFinger' }, "discovery": { - 'class': 'oidcservice.oidc.provider_info_discovery' + 'class': 'oidcrp.oidc.provider_info_discovery' '.ProviderInfoDiscovery' }, 'registration': { - 'class': 'oidcservice.oidc.registration.Registration' + 'class': 'oidcrp.oidc.registration.Registration' }, 'authorization': { - 'class': 'oidcservice.oidc.authorization.Authorization' + 'class': 'oidcrp.oidc.authorization.Authorization' }, 'access_token': { - 'class': 'oidcservice.oidc.access_token.AccessToken' + 'class': 'oidcrp.oidc.access_token.AccessToken' }, 'refresh_access_token': { - 'class': 'oidcservice.oidc.refresh_access_token' + 'class': 'oidcrp.oidc.refresh_access_token' '.RefreshAccessToken' }, 'userinfo': { - 'class': 'oidcservice.oidc.userinfo.UserInfo' + 'class': 'oidcrp.oidc.userinfo.UserInfo' } } }, @@ -83,7 +82,7 @@ "userinfo_request_method": "GET", 'services': { 'authorization': { - 'class': 'oidcservice.oidc.authorization.Authorization' + 'class': 'oidcrp.oidc.authorization.Authorization' }, 'access_token': { 'class': 'oidcrp.provider.linkedin.AccessToken' @@ -113,14 +112,14 @@ }, 'services': { 'authorization': { - 'class': 'oidcservice.oidc.authorization.Authorization' + 'class': 'oidcrp.oidc.authorization.Authorization' }, 'access_token': { - 'class': 'oidcservice.oidc.access_token.AccessToken', + 'class': 'oidcrp.oidc.access_token.AccessToken', 'kwargs': {'conf': {'default_authn_method': ''}} }, 'userinfo': { - 'class': 'oidcservice.oidc.userinfo.UserInfo', + 'class': 'oidcrp.oidc.userinfo.UserInfo', 'kwargs': {'conf': {'default_authn_method': ''}} } } @@ -146,17 +145,17 @@ }, 'services': { 'authorization': { - 'class': 'oidcservice.oidc.authorization.Authorization' + 'class': 'oidcrp.oidc.authorization.Authorization' }, 'access_token': { - 'class': 'oidcservice.oidc.access_token.AccessToken' + 'class': 'oidcrp.oidc.access_token.AccessToken' }, 'userinfo': { - 'class': 'oidcservice.oidc.userinfo.UserInfo', + 'class': 'oidcrp.oidc.userinfo.UserInfo', 'kwargs': {'conf': {'default_authn_method': ''}} }, 'refresh_access_token': { - 'class': 'oidcservice.oidc.refresh_access_token' + 'class': 'oidcrp.oidc.refresh_access_token' '.RefreshAccessToken' } } @@ -223,10 +222,10 @@ def test_pick_config(self): def test_init_client(self): client = self.rph.init_client('github') - assert set(client.service.keys()) == {'authorization', 'accesstoken', - 'userinfo', 'refresh_token'} + assert set(client.client_get("services").keys()) == {'authorization', 'accesstoken', + 'userinfo', 'refresh_token'} - _context = client.service_context + _context = client.client_get("service_context") assert _context.get('client_id') == 'eeeeeeeee' assert _context.get('client_secret') == 'aaaaaaaaaaaaaaaaaaaa' @@ -266,8 +265,8 @@ def test_do_provider_info(self): # Make sure the service endpoints are set for service_type in ['authorization', 'accesstoken', 'userinfo']: - _srv = client.service[service_type] - _endp = client.service_context.get('provider_info')[_srv.endpoint_name] + _srv = client.client_get("service",service_type) + _endp = client.client_get("service_context").get('provider_info')[_srv.endpoint_name] assert _srv.endpoint == _endp def test_do_client_registration(self): @@ -278,12 +277,12 @@ def test_do_client_registration(self): # only 2 things should have happened assert self.rph.hash2issuer['github'] == issuer - assert client.service_context.get('post_logout_redirect_uris') is None + assert client.client_get("service_context").post_logout_redirect_uris == [] def test_do_client_setup(self): client = self.rph.client_setup('github') _github_id = iss_id('github') - _context = client.service_context + _context = client.client_get("service_context") assert _context.get('client_id') == 'eeeeeeeee' assert _context.get('client_secret') == 'aaaaaaaaaaaaaaaaaaaa' @@ -297,8 +296,8 @@ def test_do_client_setup(self): assert len(keys) == 2 for service_type in ['authorization', 'accesstoken', 'userinfo']: - _srv = client.service[service_type] - _endp = client.service_context.get('provider_info')[_srv.endpoint_name] + _srv = client.client_get("service",service_type) + _endp = _srv.client_get("service_context").get('provider_info')[_srv.endpoint_name] assert _srv.endpoint == _endp assert self.rph.hash2issuer['github'] == _context.get('issuer') @@ -324,7 +323,7 @@ def test_begin(self): client = self.rph.issuer2rp[_github_id] - assert client.service_context.get('issuer') == _github_id + assert client.client_get("service_context").issuer == _github_id part = urlsplit(res['url']) assert part.scheme == 'https' @@ -356,7 +355,7 @@ def test_get_client_from_session_key(self): # redo self.rph.do_provider_info(state=res['state']) # get new redirect_uris - cli2.service_context.redirect_uris = [] + cli2.client_get("service_context").redirect_uris = [] self.rph.do_client_registration(state=res['state']) def test_finalize_auth(self): @@ -368,8 +367,8 @@ def test_finalize_auth(self): state=res['state']) resp = self.rph.finalize_auth(client, _session['iss'], auth_response.to_dict()) assert set(resp.keys()) == {'state', 'code'} - aresp = client.service['authorization'].get_item(AuthorizationResponse, - 'auth_response', res['state']) + aresp = client.client_get("service_context").state.get_item(AuthorizationResponse, 'auth_response', + res['state']) assert set(aresp.keys()) == {'state', 'code'} def test_get_client_authn_method(self): @@ -392,7 +391,7 @@ def test_get_access_token(self): client = self.rph.issuer2rp[_session['iss']] _github_id = iss_id('github') - client.service_context.keyjar.import_jwks( + client.client_get("service_context").keyjar.import_jwks( GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) _nonce = _session['auth_request']['nonce'] @@ -418,7 +417,7 @@ def test_get_access_token(self): with responses.RequestsMock() as rsps: rsps.add("POST", _url, body=at.to_json(), adding_headers={"Content-Type": "application/json"}, status=200) - client.service['accesstoken'].endpoint = _url + client.client_get("service",'accesstoken').endpoint = _url auth_response = AuthorizationResponse(code='access_code', state=res['state']) @@ -430,9 +429,8 @@ def test_get_access_token(self): 'token_type', '__verified_id_token', '__expires_at'} - atresp = client.service['accesstoken'].get_item(AccessTokenResponse, - 'token_response', - res['state']) + atresp = client.client_get("service_context").state.get_item( + AccessTokenResponse, 'token_response', res['state']) assert set(atresp.keys()) == {'access_token', 'expires_in', 'id_token', 'token_type', '__verified_id_token', '__expires_at'} @@ -450,7 +448,7 @@ def test_access_and_id_token(self): } _github_id = iss_id('github') - client.service_context.keyjar.import_jwks( + client.client_get("service_context").keyjar.import_jwks( GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) idts = IdToken(**idval) @@ -468,7 +466,7 @@ def test_access_and_id_token(self): with responses.RequestsMock() as rsps: rsps.add("POST", _url, body=at.to_json(), adding_headers={"Content-Type": "application/json"}, status=200) - client.service['accesstoken'].endpoint = _url + client.client_get("service",'accesstoken').endpoint = _url _response = AuthorizationResponse(code='access_code', state=res['state']) @@ -491,7 +489,7 @@ def test_access_and_id_token_by_reference(self): } _github_id = iss_id('github') - client.service_context.keyjar.import_jwks( + client.client_get("service_context").keyjar.import_jwks( GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) idts = IdToken(**idval) @@ -509,7 +507,7 @@ def test_access_and_id_token_by_reference(self): with responses.RequestsMock() as rsps: rsps.add("POST", _url, body=at.to_json(), adding_headers={"Content-Type": "application/json"}, status=200) - client.service['accesstoken'].endpoint = _url + client.client_get("service",'accesstoken').endpoint = _url _response = AuthorizationResponse(code='access_code', state=res['state']) @@ -532,7 +530,7 @@ def test_get_user_info(self): } _github_id = iss_id('github') - client.service_context.keyjar.import_jwks( + client.client_get("service_context").keyjar.import_jwks( GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) idts = IdToken(**idval) @@ -550,7 +548,7 @@ def test_get_user_info(self): with responses.RequestsMock() as rsps: rsps.add("POST", _url, body=at.to_json(), adding_headers={"Content-Type": "application/json"}, status=200) - client.service['accesstoken'].endpoint = _url + client.client_get("service",'accesstoken').endpoint = _url _response = AuthorizationResponse(code='access_code', state=res['state']) @@ -564,7 +562,7 @@ def test_get_user_info(self): with responses.RequestsMock() as rsps: rsps.add("GET", _url, body='{"sub":"EndUserSubject"}', adding_headers={"Content-Type": "application/json"}, status=200) - client.service['userinfo'].endpoint = _url + client.client_get("service",'userinfo').endpoint = _url userinfo_resp = self.rph.get_user_info(res['state'], client, token_resp['access_token']) @@ -590,16 +588,14 @@ def test_userinfo_in_id_token(self): 'occupation'} - def test_get_provider_specific_service(): - service_context = ServiceContext() srv_desc = { 'access_token': { 'class': 'oidcrp.provider.github.AccessToken' } } - _srv = init_services(srv_desc, service_context) - assert _srv['accesstoken'].response_body_type == 'urlencoded' + entity = Entity(services=srv_desc) + assert entity.client_get("service",'accesstoken').response_body_type == 'urlencoded' class TestRPHandlerTier2(object): @@ -618,7 +614,7 @@ def rphandler_setup(self): } _github_id = iss_id('github') - client.service_context.keyjar.import_jwks( + client.client_get("service_context").keyjar.import_jwks( GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) idts = IdToken(**idval) @@ -638,7 +634,7 @@ def rphandler_setup(self): rsps.add("POST", _url, body=at.to_json(), adding_headers={"Content-Type": "application/json"}, status=200) - client.service['accesstoken'].endpoint = _url + client.client_get("service",'accesstoken').endpoint = _url _response = AuthorizationResponse(code='access_code', state=res['state']) @@ -653,7 +649,7 @@ def rphandler_setup(self): rsps.add("GET", _url, body='{"sub":"EndUserSubject"}', adding_headers={"Content-Type": "application/json"}, status=200) - client.service['userinfo'].endpoint = _url + client.client_get("service",'userinfo').endpoint = _url self.rph.get_user_info(res['state'], client, token_resp['access_token']) self.state = res['state'] @@ -681,7 +677,7 @@ def test_refresh_access_token(self): rsps.add("POST", _url, body=at.to_json(), adding_headers={"Content-Type": "application/json"}, status=200) - client.service['refresh_token'].endpoint = _url + client.client_get("service",'refresh_token').endpoint = _url res = self.rph.refresh_access_token(self.state, client, 'openid email') assert res['access_token'] == '2nd_accessTok' @@ -693,7 +689,7 @@ def test_get_user_info(self): with responses.RequestsMock() as rsps: rsps.add("GET", _url, body='{"sub":"EndUserSubject", "mail":"foo@example.com"}', adding_headers={"Content-Type": "application/json"}, status=200) - client.service['userinfo'].endpoint = _url + client.client_get("service",'userinfo').endpoint = _url resp = self.rph.get_user_info(self.state, client) assert set(resp.keys()) == {'sub', 'mail'} @@ -842,7 +838,7 @@ def test_finalize(self): p.path, _info.to_json(), 200, {'content-type': "application/json"}) _github_id = iss_id('github') - client.service_context.keyjar.import_jwks(GITHUB_KEY.export_jwks( + client.client_get("service_context").keyjar.import_jwks(GITHUB_KEY.export_jwks( issuer_id=_github_id), _github_id) # do the rest (= get access token and user info) diff --git a/tests/test_21_pushed_auth.py b/tests/test_21_pushed_auth.py new file mode 100644 index 0000000..c3fb20b --- /dev/null +++ b/tests/test_21_pushed_auth.py @@ -0,0 +1,66 @@ +import json +import os + +from cryptojwt.key_jar import init_key_jar +import pytest +import responses + +from oidcrp.oauth2 import Client +from oidcrp.oauth2 import DEFAULT_OAUTH2_SERVICES + +_dirname = os.path.dirname(os.path.abspath(__file__)) + +ISS = 'https://example.com' + +KEYSPEC = [ + {"type": "RSA", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] + +CLI_KEY = init_key_jar(public_path='{}/pub_client.jwks'.format(_dirname), + private_path='{}/priv_client.jwks'.format(_dirname), + key_defs=KEYSPEC, issuer_id='') + + +class TestPushedAuth: + @pytest.fixture(autouse=True) + def create_client(self): + config = { + 'client_id': 'client_id', 'client_secret': 'a longesh password', + 'redirect_uris': ['https://example.com/cli/authz_cb'], + 'behaviour': {'response_types': ['code']}, + 'add_ons': { + "pushed_authorization": { + "function": + "oidcrp.oauth2.add_on.pushed_authorization.add_support", + "kwargs": { + "body_format": "jws", + "signing_algorithm": "RS256", + "http_client": None, + "merge_rule": "lax" + } + } + } + } + self.entity = Client(keyjar=CLI_KEY, config=config, services=DEFAULT_OAUTH2_SERVICES) + + self.entity.client_get("service_context").provider_info = { + "pushed_authorization_request_endpoint": "https://as.example.com/push" + } + + def test_authorization(self): + auth_service = self.entity.client_get("service","authorization") + req_args = {'foo': 'bar', "response_type": "code"} + with responses.RequestsMock() as rsps: + _resp = { + "request_uri": "urn:example:bwc4JK-ESC0w8acc191e-Y1LTC2", + "expires_in": 3600 + } + rsps.add("GET", + auth_service.client_get("service_context").provider_info[ + "pushed_authorization_request_endpoint"], + body=json.dumps(_resp), status=200) + + _req = auth_service.construct(request_args=req_args, state='state') + + assert set(_req.keys()) == {"request_uri", "response_type", "client_id"} diff --git a/tests/test_21_rph_defaults.py b/tests/test_21_rph_defaults.py index 1515fac..ac07e72 100644 --- a/tests/test_21_rph_defaults.py +++ b/tests/test_21_rph_defaults.py @@ -7,8 +7,8 @@ from oidcmsg.oidc import ProviderConfigurationResponse from oidcmsg.oidc import RegistrationResponse -from oidcrp import DEFAULT_KEY_DEFS -from oidcrp import RPHandler +from oidcrp.defaults import DEFAULT_KEY_DEFS +from oidcrp.rp_handler import RPHandler BASE_URL = "https://example.com" @@ -24,11 +24,11 @@ def test_pick_config(self): def test_init_client(self): client = self.rph.init_client('') - assert set(client.service.keys()) == { + assert set(client.client_get("services").keys()) == { 'registration', 'provider_info', 'webfinger', 'authorization', 'accesstoken', 'userinfo', 'refresh_token'} - _context = client.service_context + _context = client.client_get("service_context") assert _context.config['client_preferences'] == { 'application_type': 'web', @@ -69,12 +69,14 @@ def test_begin(self): issuer = self.rph.do_provider_info(client) + _context = client.client_get("service_context") + # Calculating request so I can build a reasonable response - self.rph.add_callbacks(client.service_context) - _req = client.service['registration'].construct_request() + self.rph.add_callbacks(_context) + _req = client.client_get("service",'registration').construct_request() with responses.RequestsMock() as rsps: - request_uri = client.service_context.get('provider_info')["registration_endpoint"] + request_uri = _context.get('provider_info')["registration_endpoint"] _jws = RegistrationResponse( client_id="client uno", client_secret="VerySecretAndLongEnough", **_req.to_dict() ).to_json() @@ -83,12 +85,12 @@ def test_begin(self): self.rph.issuer2rp[issuer] = client - assert set(client.service_context.get('behaviour').keys()) == { + assert set(_context.get('behaviour').keys()) == { 'token_endpoint_auth_method', 'response_types', 'scope', 'application_type', 'application_name'} - assert client.service_context.get('client_id') == "client uno" - assert client.service_context.get('client_secret') == "VerySecretAndLongEnough" - assert client.service_context.get('issuer') == ISS_ID + assert _context.get('client_id') == "client uno" + assert _context.get('client_secret') == "VerySecretAndLongEnough" + assert _context.get('issuer') == ISS_ID res = self.rph.init_authorization(client) assert set(res.keys()) == {'url', 'state'} @@ -125,20 +127,21 @@ def test_begin_2(self): issuer = self.rph.do_provider_info(client) + _context = client.client_get("service_context") # Calculating request so I can build a reasonable response - self.rph.add_callbacks(client.service_context) + self.rph.add_callbacks(_context) # Publishing a JWKS instead of a JWKS_URI - client.service_context.jwks_uri = '' - client.service_context.jwks = client.service_context.keyjar.export_jwks() + _context.jwks_uri = '' + _context.jwks = _context.keyjar.export_jwks() - _req = client.service['registration'].construct_request() + _req = client.client_get("service",'registration').construct_request() with responses.RequestsMock() as rsps: - request_uri = client.service_context.get('provider_info')["registration_endpoint"] + request_uri = _context.get('provider_info')["registration_endpoint"] _jws = RegistrationResponse( client_id="client uno", client_secret="VerySecretAndLongEnough", **_req.to_dict() ).to_json() rsps.add("POST", request_uri, body=_jws, status=200) self.rph.do_client_registration(client, ISS_ID) - assert 'jwks' in client.service_context.get('registration_response') \ No newline at end of file + assert 'jwks' in _context.get('registration_response') \ No newline at end of file diff --git a/tests/test_22_config.py b/tests/test_22_config.py index 9839f8a..f9dc91f 100644 --- a/tests/test_22_config.py +++ b/tests/test_22_config.py @@ -1,21 +1,36 @@ import os from oidcrp.configure import Configuration +from oidcrp.configure import RPConfiguration +from oidcrp.configure import create_from_config_file _dirname = os.path.dirname(os.path.abspath(__file__)) def test_yaml_config(): - c = Configuration.create_from_config_file(os.path.join(_dirname, 'conf.yaml')) + c = create_from_config_file(Configuration, + entity_conf=[{"class": RPConfiguration, "attr": "rp"}], + filename=os.path.join(_dirname, 'conf.yaml'), + base_path=_dirname) assert c - assert c.base_url == "https://127.0.0.1:8090" - assert c.domain == "127.0.0.1" - assert c.httpc_params == {"verify": False} - assert c.port == 8090 - assert set(c.services.keys()) == {'discovery', 'registration', 'authorization', 'accesstoken', - 'userinfo', 'end_session'} - assert c.web_conf == { - 'port': 8090, 'domain': '127.0.0.1', 'server_cert': 'certs/cert.pem', - 'server_key': 'certs/key.pem', 'debug': True - } - assert set(c.clients.keys()) == {'', 'bobcat', 'flop'} \ No newline at end of file + assert set(c.web_conf.keys()) == {'port', 'domain', 'server_cert', 'server_key', 'debug'} + + rp_config = c.rp + assert rp_config.base_url == "https://127.0.0.1:8090" + assert rp_config.httpc_params == {"verify": False} + assert set(rp_config.services.keys()) == {'discovery', 'registration', 'authorization', + 'accesstoken', 'userinfo', 'end_session'} + assert set(rp_config.clients.keys()) == {'', 'bobcat', 'flop'} + + +def test_dict(): + configuration = create_from_config_file(RPConfiguration, + filename=os.path.join(_dirname, 'rp_conf.yaml'), + base_path=_dirname) + assert configuration + + assert configuration.base_url == "https://127.0.0.1:8090" + assert configuration.httpc_params == {"verify": False} + assert set(configuration.services.keys()) == {'discovery', 'registration', 'authorization', + 'accesstoken', 'userinfo', 'end_session'} + assert set(configuration.clients.keys()) == {'', 'bobcat', 'flop'} diff --git a/tests/test_31_oauth2_persistent.py b/tests/test_31_oauth2_persistent.py index 61a977c..d7531e9 100644 --- a/tests/test_31_oauth2_persistent.py +++ b/tests/test_31_oauth2_persistent.py @@ -13,8 +13,8 @@ from oidcmsg.oauth2 import ResponseMessage from oidcmsg.oidc import IdToken from oidcmsg.time_util import utc_time_sans_frac -from oidcservice.exception import OidcServiceError -from oidcservice.exception import ParseError +from oidcrp.exception import OidcServiceError +from oidcrp.exception import ParseError from oidcrp.oauth2 import Client @@ -32,6 +32,12 @@ nonce="N0nce", iat=time.time()) +CONF = { + 'issuer': 'https://op.example.com', + 'redirect_uris': ['https://example.com/cli/authz_cb'], + 'client_id': CLIENT_ID, + 'client_secret': 'abcdefghijklmnop' +} class MockResponse(): @@ -43,67 +49,30 @@ def __init__(self, status_code, text, headers=None): class TestClient(object): - @pytest.fixture(autouse=True) - def create_client(self): - self.redirect_uri = "http://example.com/redirect" - conf = { - 'issuer': 'https://op.example.com', - 'redirect_uris': ['https://example.com/cli/authz_cb'], - 'client_id': 'client_1', - 'client_secret': 'abcdefghijklmnop', - 'db_conf': { - 'keyjar': { - 'handler': 'oidcmsg.storage.abfile.LabeledAbstractFileSystem', - 'fdir': 'db/keyjar', - 'key_conv': 'oidcmsg.storage.converter.QPKey', - 'value_conv': 'cryptojwt.serialize.item.KeyIssuer', - 'label': 'keyjar' - }, - 'default': { - 'handler': 'oidcmsg.storage.abfile.AbstractFileSystem', - 'fdir': 'db', - 'key_conv': 'oidcmsg.storage.converter.QPKey', - 'value_conv': 'oidcmsg.storage.converter.JSON' - } - } - } - self.client = Client(config=conf) - - def test_construct_authorization_request(self): - req_args = { - 'state': 'ABCDE', - 'redirect_uri': 'https://example.com/auth_cb', - 'response_type': ['code'] - } - - self.client.session_interface.create_state('issuer', key='ABCDE') - msg = self.client.service['authorization'].construct( - request_args=req_args) - assert isinstance(msg, AuthorizationRequest) - assert msg['client_id'] == 'client_1' - assert msg['redirect_uri'] == 'https://example.com/auth_cb' - def test_construct_accesstoken_request(self): - # Bind access code to state - req_args = {} - - self.client.session_interface.create_state('issuer', 'ABCDE') + # Client 1 starts the chain of event + client_1 = Client(config=CONF) + _context_1 = client_1.client_get("service_context") + _state = _context_1.state.create_state('issuer') auth_request = AuthorizationRequest( redirect_uri='https://example.com/cli/authz_cb', - state='ABCDE' + state=_state ) - self.client.session_interface.store_item(auth_request, 'auth_request', - 'ABCDE') + _context_1.state.store_item(auth_request, 'auth_request', _state) - auth_response = AuthorizationResponse(code='access_code') + # Client 2 carries on + client_2 = Client(config=CONF) + _state_dump = _context_1.dump() - self.client.session_interface.store_item(auth_response, - 'auth_response', 'ABCDE') + _context2 = client_2.client_get("service_context") + _context2.load(_state_dump) + + auth_response = AuthorizationResponse(code='access_code') + _context2.state.store_item(auth_response,'auth_response', _state) - msg = self.client.service['accesstoken'].construct( - request_args=req_args, state='ABCDE') + msg = client_2.client_get("service",'accesstoken').construct(request_args={}, state=_state) assert isinstance(msg, AccessTokenRequest) assert msg.to_dict() == { @@ -113,34 +82,40 @@ def test_construct_accesstoken_request(self): 'grant_type': 'authorization_code', 'redirect_uri': 'https://example.com/cli/authz_cb', - 'state': 'ABCDE' + 'state': _state } def test_construct_refresh_token_request(self): - self.client.session_interface.create_state('issuer', 'ABCDE') + # Client 1 starts the chain event + client_1 = Client(config=CONF) + _state = client_1.client_get("service_context").state.create_state('issuer') auth_request = AuthorizationRequest( redirect_uri='https://example.com/cli/authz_cb', - state='state' + state=_state ) - self.client.session_interface.store_item(auth_request, 'auth_request', - 'ABCDE') + client_1.client_get("service_context").state.store_item(auth_request, 'auth_request', _state) - auth_response = AuthorizationResponse(code='access_code') + # Client 2 carries on + client_2 = Client(config=CONF) + _state_dump = client_1.client_get("service_context").dump() + client_2.client_get("service_context").load(_state_dump) - self.client.session_interface.store_item(auth_response, - 'auth_response', 'ABCDE') + auth_response = AuthorizationResponse(code='access_code') + client_2.client_get("service_context").state.store_item(auth_response, 'auth_response', _state) token_response = AccessTokenResponse(refresh_token="refresh_with_me", access_token="access") - self.client.session_interface.store_item(token_response, - 'token_response', 'ABCDE') + client_2.client_get("service_context").state.store_item(token_response, 'token_response', _state) + + # Next up is Client 1 + _state_dump = client_2.client_get("service_context").dump() + client_1.client_get("service_context").load(_state_dump) req_args = {} - msg = self.client.service['refresh_token'].construct( - request_args=req_args, state='ABCDE') + msg = client_1.client_get("service",'refresh_token').construct(request_args=req_args, state=_state) assert isinstance(msg, RefreshAccessTokenRequest) assert msg.to_dict() == { 'client_id': 'client_1', @@ -148,29 +123,3 @@ def test_construct_refresh_token_request(self): 'grant_type': 'refresh_token', 'refresh_token': 'refresh_with_me' } - - def test_error_response(self): - err = ResponseMessage(error='Illegal') - http_resp = MockResponse(400, err.to_urlencoded()) - resp = self.client.parse_request_response( - self.client.service['authorization'], http_resp) - - assert resp['error'] == 'Illegal' - assert resp['status_code'] == 400 - - def test_error_response_500(self): - err = ResponseMessage(error='Illegal') - http_resp = MockResponse(500, err.to_urlencoded()) - with pytest.raises(ParseError): - self.client.parse_request_response( - self.client.service['authorization'], http_resp) - - def test_error_response_2(self): - err = ResponseMessage(error='Illegal') - http_resp = MockResponse( - 400, err.to_json(), - headers={'content-type': 'application/x-www-form-urlencoded'}) - - with pytest.raises(OidcServiceError): - self.client.parse_request_response( - self.client.service['authorization'], http_resp) diff --git a/tests/test_32_oidc_persistent.py b/tests/test_32_oidc_persistent.py index aa7cf2a..174b4aa 100755 --- a/tests/test_32_oidc_persistent.py +++ b/tests/test_32_oidc_persistent.py @@ -1,11 +1,7 @@ -import json import os -import shutil import sys import time -import pytest -import responses from cryptojwt.jwk.rsa import import_private_rsa_key_from_file from cryptojwt.key_bundle import KeyBundle from oidcmsg.oauth2 import AccessTokenRequest @@ -14,7 +10,6 @@ from oidcmsg.oauth2 import AuthorizationResponse from oidcmsg.oauth2 import RefreshAccessTokenRequest from oidcmsg.oidc import IdToken -from oidcmsg.oidc import OpenIDSchema from oidcmsg.time_util import utc_time_sans_frac from oidcrp.oidc import RP @@ -28,114 +23,89 @@ KC_RSA = KeyBundle({"priv_key": _key, "kty": "RSA", "use": "sig"}) CLIENT_ID = "client_1" -IDTOKEN = IdToken(iss="http://oidc.example.org/", sub="sub", +ISSUER = "http://op.example.com" + +IDTOKEN = IdToken(iss=ISSUER, sub="sub", aud=CLIENT_ID, exp=utc_time_sans_frac() + 86400, nonce="N0nce", iat=time.time()) +CONF = { + 'issuer': ISSUER, + 'redirect_uris': ['https://example.com/cli/authz_cb'], + 'client_id': CLIENT_ID, + 'client_secret': 'abcdefghijklmnop' +} + def access_token_callback(endpoint): if endpoint: return 'access_token' - class TestClient(object): - @pytest.fixture(autouse=True) - def create_client(self): - try: - shutil.rmtree('db') - except FileNotFoundError: - pass - - self.redirect_uri = "http://example.com/redirect" - conf = { - 'issuer': 'https://op.example.com', - 'redirect_uris': ['https://example.com/cli/authz_cb'], - 'client_id': 'client_1', - 'client_secret': 'abcdefghijklmnop', - 'db_conf': { - 'keyjar': { - 'handler': 'oidcmsg.storage.abfile.LabeledAbstractFileSystem', - 'fdir': 'db/keyjar', - 'key_conv': 'oidcmsg.storage.converter.QPKey', - 'value_conv': 'cryptojwt.serialize.item.KeyIssuer', - 'label': 'keyjar' - }, - 'default': { - 'handler': 'oidcmsg.storage.abfile.AbstractFileSystem', - 'fdir': 'db', - 'key_conv': 'oidcmsg.storage.converter.QPKey', - 'value_conv': 'oidcmsg.storage.converter.JSON' - } - } - } - self.client = RP(config=conf) - - def test_construct_authorization_request(self): - req_args = { - 'state': 'ABCDE', - 'redirect_uri': 'https://example.com/auth_cb', - 'response_type': ['code'] - } - - self.client.session_interface.create_state('issuer', 'ABCDE') - - msg = self.client.service['authorization'].construct( - request_args=req_args) - assert isinstance(msg, AuthorizationRequest) - assert msg['redirect_uri'] == 'https://example.com/auth_cb' - def test_construct_accesstoken_request(self): + # Client 1 starts + client_1 = RP(config=CONF) + _state = client_1.client_get("service_context").state.create_state(ISSUER) auth_request = AuthorizationRequest( redirect_uri='https://example.com/cli/authz_cb', - state='ABCDE' + state=_state ) + client_1.client_get("service_context").state.store_item(auth_request, 'auth_request', _state) - self.client.session_interface.store_item(auth_request, - 'auth_request', 'ABCDE') + # Client 2 carries on + client_2 = RP(config=CONF) + _state_dump = client_1.client_get("service_context").dump() + client_2.client_get("service_context").load(_state_dump) auth_response = AuthorizationResponse(code='access_code') - - self.client.session_interface.store_item(auth_response, - 'auth_response', 'ABCDE') + client_2.client_get("service_context").state.store_item(auth_response, 'auth_response', _state) # Bind access code to state req_args = {} - msg = self.client.service['accesstoken'].construct( - request_args=req_args, state='ABCDE') + msg = client_2.client_get("service",'accesstoken').construct( + request_args=req_args, state=_state) assert isinstance(msg, AccessTokenRequest) assert msg.to_dict() == { 'client_id': 'client_1', 'code': 'access_code', 'client_secret': 'abcdefghijklmnop', 'grant_type': 'authorization_code', 'redirect_uri': 'https://example.com/cli/authz_cb', - 'state': 'ABCDE' + 'state': _state } def test_construct_refresh_token_request(self): - self.client.session_interface.create_state('issuer', 'ABCDE') + # Client 1 starts + client_1 = RP(config=CONF) + _state = client_1.client_get("service_context").state.create_state(ISSUER) auth_request = AuthorizationRequest( redirect_uri='https://example.com/cli/authz_cb', - state='state' + state=_state ) - self.client.session_interface.store_item(auth_request, - 'auth_request', 'ABCDE') + client_1.client_get("service_context").state.store_item(auth_request, 'auth_request', _state) + + # Client 2 carries on + client_2 = RP(config=CONF) + _state_dump = client_1.client_get("service_context").dump() + client_2.client_get("service_context").load(_state_dump) auth_response = AuthorizationResponse(code='access_code') - self.client.session_interface.store_item(auth_response, - 'auth_response', 'ABCDE') + client_2.client_get("service_context").state.store_item(auth_response, 'auth_response', _state) token_response = AccessTokenResponse(refresh_token="refresh_with_me", access_token="access") - self.client.session_interface.store_item(token_response, - 'token_response', 'ABCDE') + client_2.client_get("service_context").state.store_item(token_response, + 'token_response', _state) + + # Back to Client 1 + _state_dump = client_2.client_get("service_context").dump() + client_1.client_get("service_context").load(_state_dump) req_args = {} - msg = self.client.service['refresh_token'].construct( - request_args=req_args, state='ABCDE') + msg = client_1.client_get("service",'refresh_token').construct(request_args=req_args, state=_state) assert isinstance(msg, RefreshAccessTokenRequest) assert msg.to_dict() == { 'client_id': 'client_1', @@ -145,139 +115,34 @@ def test_construct_refresh_token_request(self): } def test_do_userinfo_request_init(self): - self.client.session_interface.create_state('issuer', 'ABCDE') + # Client 1 starts + client_1 = RP(config=CONF) + _state = client_1.client_get("service_context").state.create_state(ISSUER) auth_request = AuthorizationRequest( redirect_uri='https://example.com/cli/authz_cb', state='state' ) - self.client.session_interface.store_item(auth_request, - 'auth_request', 'ABCDE') + # Client 2 carries on + client_2 = RP(config=CONF) + _state_dump = client_1.client_get("service_context").dump() + client_2.client_get("service_context").load(_state_dump) auth_response = AuthorizationResponse(code='access_code') - self.client.session_interface.store_item(auth_response, - 'auth_response', 'ABCDE') + client_2.client_get("service_context").state.store_item(auth_response, 'auth_response', _state) token_response = AccessTokenResponse(refresh_token="refresh_with_me", access_token="access") - self.client.session_interface.store_item(token_response, - 'token_response', 'ABCDE') + client_2.client_get("service_context").state.store_item(token_response, 'token_response', _state) - _srv = self.client.service['userinfo'] + # Back to Client 1 + _state_dump = client_2.client_get("service_context").dump() + client_1.client_get("service_context").load(_state_dump) + + _srv = client_1.client_get("service",'userinfo') _srv.endpoint = "https://example.com/userinfo" - _info = _srv.get_request_parameters(state='ABCDE') + _info = _srv.get_request_parameters(state=_state) assert _info assert _info['headers'] == {'Authorization': 'Bearer access'} assert _info['url'] == 'https://example.com/userinfo' - - def test_fetch_distributed_claims_1(self): - _url = "https://example.com/claims.json" - # split the example in 5.6.2.2 into two - uinfo = OpenIDSchema(**{ - "sub": 'jane_doe', - "name": "Jane Doe", - "given_name": "Jane", - "family_name": "Doe", - "email": "janedoe@example.com", - "birthdate": "0000-03-22", - "eye_color": "blue", - "_claim_names": { - "payment_info": "src1", - "shipping_address": "src1", - }, - "_claim_sources": { - "src1": { - "endpoint": _url - } - } - }) - - # Wrong set of claims. Actually extra claim - _info = { - "shipping_address": { - "street_address": "1234 Hollywood Blvd.", - "locality": "Los Angeles", - "region": "CA", - "postal_code": "90210", - "country": "US" - }, - "payment_info": "Some_Card 1234 5678 9012 3456", - "phone_number": "+1 (310) 123-4567" - } - - with responses.RequestsMock() as rsps: - rsps.add("GET", _url, body=json.dumps(_info), - adding_headers={"Content-Type": "application/json"}, status=200) - - res = self.client.fetch_distributed_claims(uinfo) - - assert 'payment_info' in res - assert 'shipping_address' in res - assert 'phone_number' not in res - - def test_fetch_distributed_claims_2(self): - _url = "https://example.com/claims.json" - - uinfo = OpenIDSchema(**{ - "sub": 'jane_doe', - "name": "Jane Doe", - "given_name": "Jane", - "family_name": "Doe", - "email": "janedoe@example.com", - "birthdate": "0000-03-22", - "eye_color": "blue", - "_claim_names": { - "credit_score": "src2" - }, - "_claim_sources": { - "src2": { - "endpoint": _url, - "access_token": "ksj3n283dke" - } - } - }) - - _claims = { - "credit_score": 650 - } - - with responses.RequestsMock() as rsps: - rsps.add("GET", _url, body=json.dumps(_claims), - adding_headers={"Content-Type": "application/json"}, status=200) - - res = self.client.fetch_distributed_claims(uinfo) - - assert 'credit_score' in res - - def test_fetch_distributed_claims_3(self, httpserver): - _url = "https://example.com/claims.json" - - uinfo = OpenIDSchema(**{ - "sub": 'jane_doe', - "name": "Jane Doe", - "given_name": "Jane", - "family_name": "Doe", - "email": "janedoe@example.com", - "birthdate": "0000-03-22", - "eye_color": "blue", - "_claim_names": { - "credit_score": "src2" - }, - "_claim_sources": { - "src2": { - "endpoint": _url, - } - } - }) - - _claims = {"credit_score": 650} - - with responses.RequestsMock() as rsps: - rsps.add("GET", _url, body=json.dumps(_claims), - adding_headers={"Content-Type": "application/json"}, status=200) - - res = self.client.fetch_distributed_claims( - uinfo, callback=access_token_callback) - - assert 'credit_score' in res diff --git a/tests/test_40_dpop.py b/tests/test_40_dpop.py new file mode 100644 index 0000000..490182a --- /dev/null +++ b/tests/test_40_dpop.py @@ -0,0 +1,73 @@ +import os + +import pytest +from cryptojwt.jws.jws import factory +from cryptojwt.key_jar import init_key_jar + +from oidcrp.client_auth import factory as ca_factory +from oidcrp.oauth2 import Client +from oidcrp.oauth2 import DEFAULT_OAUTH2_SERVICES +from oidcrp.oauth2.add_on import do_add_ons +from oidcrp.service import init_services +from oidcrp.service_context import ServiceContext + +_dirname = os.path.dirname(os.path.abspath(__file__)) + +KEYSPEC = [ + {"type": "RSA", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, +] + +CLI_KEY = init_key_jar(public_path='{}/pub_client.jwks'.format(_dirname), + private_path='{}/priv_client.jwks'.format(_dirname), + key_defs=KEYSPEC, issuer_id='client_id') + + +class TestDPoP: + @pytest.fixture(autouse=True) + def create_client(self): + config = { + 'client_id': 'client_id', + 'client_secret': 'a longesh password', + 'redirect_uris': ['https://example.com/cli/authz_cb'], + 'behaviour': {'response_types': ['code']}, + 'add_ons': { + "dpop": { + "function": "oidcrp.oauth2.add_on.dpop.add_support", + "kwargs": { + "signing_algorithms": ["ES256", "ES512"] + } + } + } + } + + self.client = Client(keyjar=CLI_KEY, config=config, services=DEFAULT_OAUTH2_SERVICES) + + self.client.client_get("service_context").provider_info= { + "authorization_endpoint": "https://example.com/auth", + "token_endpoint": "https://example.com/token", + "dpop_signing_alg_values_supported": ["RS256", "ES256"] + } + + def test_add_header(self): + token_serv = self.client.client_get("service","accesstoken") + req_args = { + "grant_type": "authorization_code", + "code": "SplxlOBeZQQYbYS6WxSbIA", + "redirect_uri": "https://client/example.com/cb" + } + headers = token_serv.get_headers(request=req_args, http_method="POST") + assert headers + assert "dpop" in headers + + # Now for the content of the DPoP proof + _jws = factory(headers["dpop"]) + _payload = _jws.jwt.payload() + assert _payload["htu"] == "https://example.com/token" + assert _payload["htm"] == "POST" + _header = _jws.jwt.headers + assert "jwk" in _header + assert _header["typ"] == "dpop+jwt" + assert _header["alg"] == "ES256" + assert _header["jwk"]["kty"] == "EC" + assert _header["jwk"]["crv"] == "P-256" diff --git a/tests/test_40_rp_handler_persistent.py b/tests/test_40_rp_handler_persistent.py index e81cc61..b642511 100644 --- a/tests/test_40_rp_handler_persistent.py +++ b/tests/test_40_rp_handler_persistent.py @@ -1,25 +1,14 @@ -import json import os -import shutil from urllib.parse import parse_qs -from urllib.parse import urlparse from urllib.parse import urlsplit -import pytest -import responses -from cryptojwt.key_jar import KeyJar from cryptojwt.key_jar import init_key_jar from oidcmsg.oidc import AccessTokenResponse from oidcmsg.oidc import AuthorizationResponse from oidcmsg.oidc import IdToken -from oidcmsg.oidc import JRD -from oidcmsg.oidc import Link -from oidcmsg.oidc import OpenIDSchema -from oidcmsg.oidc import ProviderConfigurationResponse -from oidcservice.service import init_services -from oidcservice.service_context import ServiceContext +import responses -from oidcrp import RPHandler +from oidcrp.rp_handler import RPHandler BASE_URL = 'https://example.com/rp' @@ -34,7 +23,6 @@ "verify_args": {"allow_sign_alg_none": True} } - DB_CONF = { 'keyjar': { 'handler': 'oidcmsg.storage.abfile.LabeledAbstractFileSystem', @@ -51,34 +39,33 @@ } } - CLIENT_CONFIG = { "": { "client_preferences": CLIENT_PREFS, "redirect_uris": None, "services": { 'web_finger': { - 'class': 'oidcservice.oidc.webfinger.WebFinger' + 'class': 'oidcrp.oidc.webfinger.WebFinger' }, "discovery": { - 'class': 'oidcservice.oidc.provider_info_discovery' + 'class': 'oidcrp.oidc.provider_info_discovery' '.ProviderInfoDiscovery' }, 'registration': { - 'class': 'oidcservice.oidc.registration.Registration' + 'class': 'oidcrp.oidc.registration.Registration' }, 'authorization': { - 'class': 'oidcservice.oidc.authorization.Authorization' + 'class': 'oidcrp.oidc.authorization.Authorization' }, 'access_token': { - 'class': 'oidcservice.oidc.access_token.AccessToken' + 'class': 'oidcrp.oidc.access_token.AccessToken' }, 'refresh_access_token': { - 'class': 'oidcservice.oidc.refresh_access_token' + 'class': 'oidcrp.oidc.refresh_access_token' '.RefreshAccessToken' }, 'userinfo': { - 'class': 'oidcservice.oidc.userinfo.UserInfo' + 'class': 'oidcrp.oidc.userinfo.UserInfo' } } }, @@ -102,7 +89,7 @@ "userinfo_request_method": "GET", 'services': { 'authorization': { - 'class': 'oidcservice.oidc.authorization.Authorization' + 'class': 'oidcrp.oidc.authorization.Authorization' }, 'access_token': { 'class': 'oidcrp.provider.linkedin.AccessToken' @@ -133,14 +120,14 @@ }, 'services': { 'authorization': { - 'class': 'oidcservice.oidc.authorization.Authorization' + 'class': 'oidcrp.oidc.authorization.Authorization' }, 'access_token': { - 'class': 'oidcservice.oidc.access_token.AccessToken', + 'class': 'oidcrp.oidc.access_token.AccessToken', 'kwargs': {'conf': {'default_authn_method': ''}} }, 'userinfo': { - 'class': 'oidcservice.oidc.userinfo.UserInfo', + 'class': 'oidcrp.oidc.userinfo.UserInfo', 'kwargs': {'conf': {'default_authn_method': ''}} } }, @@ -167,17 +154,17 @@ }, 'services': { 'authorization': { - 'class': 'oidcservice.oidc.authorization.Authorization' + 'class': 'oidcrp.oidc.authorization.Authorization' }, 'access_token': { - 'class': 'oidcservice.oidc.access_token.AccessToken' + 'class': 'oidcrp.oidc.access_token.AccessToken' }, 'userinfo': { - 'class': 'oidcservice.oidc.userinfo.UserInfo', + 'class': 'oidcrp.oidc.userinfo.UserInfo', 'kwargs': {'conf': {'default_authn_method': ''}} }, 'refresh_access_token': { - 'class': 'oidcservice.oidc.refresh_access_token' + 'class': 'oidcrp.oidc.refresh_access_token' '.RefreshAccessToken' } }, @@ -225,92 +212,69 @@ def iss_id(iss): class TestRPHandler(object): - @pytest.fixture(autouse=True) - def rphandler_setup(self): - try: - shutil.rmtree('db') - except FileNotFoundError: - pass - - self.rph = RPHandler(BASE_URL, client_configs=CLIENT_CONFIG, - keyjar=CLI_KEY, module_dirs=['oidc']) - def test_pick_config(self): - cnf = self.rph.pick_config('facebook') + rph_1 = RPHandler(BASE_URL, client_configs=CLIENT_CONFIG, + keyjar=CLI_KEY, module_dirs=['oidc']) + cnf = rph_1.pick_config('facebook') assert cnf['issuer'] == "https://www.facebook.com/v2.11/dialog/oauth" - cnf = self.rph.pick_config('linkedin') + cnf = rph_1.pick_config('linkedin') assert cnf['issuer'] == "https://www.linkedin.com/oauth/v2/" - cnf = self.rph.pick_config('github') + cnf = rph_1.pick_config('github') assert cnf['issuer'] == "https://github.com/login/oauth/authorize" - cnf = self.rph.pick_config('') + cnf = rph_1.pick_config('') assert 'issuer' not in cnf - def test_init_client(self): - client = self.rph.init_client('github') - assert set(client.service.keys()) == {'authorization', 'accesstoken', - 'userinfo', 'refresh_token'} - - _context = client.service_context - - assert _context.get('client_id') == 'eeeeeeeee' - assert _context.get('client_secret') == 'aaaaaaaaaaaaaaaaaaaa' - assert _context.get('issuer') == "https://github.com/login/oauth/authorize" - - assert _context.get('provider_info') is not None - assert set(_context.get('provider_info').keys()) == { - 'authorization_endpoint', 'token_endpoint', 'userinfo_endpoint' - } - - assert _context.get('behaviour') == { - "response_types": ["code"], - "scope": ["user", "public_repo"], - "token_endpoint_auth_method": '', - 'verify_args': {'allow_sign_alg_none': True} - } - - _github_id = iss_id('github') - _context.keyjar.import_jwks(GITHUB_KEY.export_jwks(issuer_id=_github_id), - _github_id) + def test_do_provider_info(self): + rph_1 = RPHandler(BASE_URL, client_configs=CLIENT_CONFIG, + keyjar=CLI_KEY, module_dirs=['oidc']) - # The key jar should only contain a symmetric key that is the clients - # secret. 2 because one is marked for encryption and the other signing - # usage. + client_1 = rph_1.init_client('github') + issuer = rph_1.do_provider_info(client_1) + assert issuer == iss_id('github') - assert list(_context.keyjar.owners()) == ['', _github_id] - keys = _context.keyjar.get_issuer_keys('') - assert len(keys) == 2 + # Make sure the service endpoints are set - assert _context.base_url == BASE_URL + rph_2 = RPHandler(BASE_URL, client_configs=CLIENT_CONFIG, + keyjar=CLI_KEY, module_dirs=['oidc']) - def test_do_provider_info(self): - client = self.rph.init_client('github') - issuer = self.rph.do_provider_info(client) - assert issuer == iss_id('github') + client_2 = rph_2.init_client('github') - # Make sure the service endpoints are set + _context_dump = client_1.client_get("service_context").dump() + client_2.client_get("service_context").load(_context_dump) + _service_dump = client_1.client_get("services").dump() + client_2.client_get("services").load(_service_dump, + init_args={ + "client_get": client_2.client_get + }) for service_type in ['authorization', 'accesstoken', 'userinfo']: - _srv = client.service[service_type] - _endp = client.service_context.get('provider_info')[_srv.endpoint_name] + _srv = client_2.client_get("service",service_type) + _endp = client_2.client_get("service_context").provider_info[_srv.endpoint_name] assert _srv.endpoint == _endp def test_do_client_registration(self): - client = self.rph.init_client('github') - issuer = self.rph.do_provider_info(client) - self.rph.do_client_registration(client, 'github') + rph_1 = RPHandler(BASE_URL, client_configs=CLIENT_CONFIG, + keyjar=CLI_KEY, module_dirs=['oidc']) + + client = rph_1.init_client('github') + issuer = rph_1.do_provider_info(client) + rph_1.do_client_registration(client, 'github') # only 2 things should have happened - assert self.rph.hash2issuer['github'] == issuer - assert client.service_context.get('post_logout_redirect_uris') is None + assert rph_1.hash2issuer['github'] == issuer + assert not client.client_get("service_context").post_logout_redirect_uris def test_do_client_setup(self): - client = self.rph.client_setup('github') + rph_1 = RPHandler(BASE_URL, client_configs=CLIENT_CONFIG, + keyjar=CLI_KEY, module_dirs=['oidc']) + + client = rph_1.client_setup('github') _github_id = iss_id('github') - _context = client.service_context + _context = client.client_get("service_context") assert _context.get('client_id') == 'eeeeeeeee' assert _context.get('client_secret') == 'aaaaaaaaaaaaaaaaaaaa' @@ -324,14 +288,17 @@ def test_do_client_setup(self): assert len(keys) == 2 for service_type in ['authorization', 'accesstoken', 'userinfo']: - _srv = client.service[service_type] - _endp = client.service_context.get('provider_info')[_srv.endpoint_name] + _srv = client.client_get("service",service_type) + _endp = client.client_get("service_context").get('provider_info')[_srv.endpoint_name] assert _srv.endpoint == _endp - assert self.rph.hash2issuer['github'] == _context.get('issuer') + assert rph_1.hash2issuer['github'] == _context.get('issuer') def test_create_callbacks(self): - cb = self.rph.create_callbacks('https://op.example.com/') + rph_1 = RPHandler(BASE_URL, client_configs=CLIENT_CONFIG, + keyjar=CLI_KEY, module_dirs=['oidc']) + + cb = rph_1.create_callbacks('https://op.example.com/') assert set(cb.keys()) == {'code', 'implicit', 'form_post', '__hex'} _hash = cb['__hex'] @@ -340,18 +307,21 @@ def test_create_callbacks(self): assert cb['implicit'] == 'https://example.com/rp/authz_im_cb/{}'.format(_hash) assert cb['form_post'] == 'https://example.com/rp/authz_fp_cb/{}'.format(_hash) - assert list(self.rph.hash2issuer.keys()) == [_hash] + assert list(rph_1.hash2issuer.keys()) == [_hash] - assert self.rph.hash2issuer[_hash] == 'https://op.example.com/' + assert rph_1.hash2issuer[_hash] == 'https://op.example.com/' def test_begin(self): - res = self.rph.begin(issuer_id='github') + rph_1 = RPHandler(BASE_URL, client_configs=CLIENT_CONFIG, + keyjar=CLI_KEY, module_dirs=['oidc']) + + res = rph_1.begin(issuer_id='github') assert set(res.keys()) == {'url', 'state'} _github_id = iss_id('github') - client = self.rph.issuer2rp[_github_id] + client = rph_1.issuer2rp[_github_id] - assert client.service_context.get('issuer') == _github_id + assert client.client_get("service_context").get('issuer') == _github_id part = urlsplit(res['url']) assert part.scheme == 'https' @@ -370,56 +340,71 @@ def test_begin(self): assert query['scope'] == ['user public_repo openid'] def test_get_session_information(self): - res = self.rph.begin(issuer_id='github') - _session = self.rph.get_session_information(res['state']) - assert self.rph.client_configs['github']['issuer'] == _session['iss'] + rph_1 = RPHandler(BASE_URL, client_configs=CLIENT_CONFIG, + keyjar=CLI_KEY, module_dirs=['oidc']) + + res = rph_1.begin(issuer_id='github') + _session = rph_1.get_session_information(res['state']) + assert rph_1.client_configs['github']['issuer'] == _session['iss'] def test_get_client_from_session_key(self): - res = self.rph.begin(issuer_id='linkedin') - cli1 = self.rph.get_client_from_session_key(state=res['state']) - _session = self.rph.get_session_information(res['state']) - cli2 = self.rph.issuer2rp[_session['iss']] + rph_1 = RPHandler(BASE_URL, client_configs=CLIENT_CONFIG, + keyjar=CLI_KEY, module_dirs=['oidc']) + + res = rph_1.begin(issuer_id='linkedin') + cli1 = rph_1.get_client_from_session_key(state=res['state']) + _session = rph_1.get_session_information(res['state']) + cli2 = rph_1.issuer2rp[_session['iss']] assert cli1 == cli2 # redo - self.rph.do_provider_info(state=res['state']) + rph_1.do_provider_info(state=res['state']) # get new redirect_uris - cli2.service_context.redirect_uris = [] - self.rph.do_client_registration(state=res['state']) + cli2.client_get("service_context").redirect_uris = [] + rph_1.do_client_registration(state=res['state']) def test_finalize_auth(self): - res = self.rph.begin(issuer_id='linkedin') - _session = self.rph.get_session_information(res['state']) - client = self.rph.issuer2rp[_session['iss']] + rph_1 = RPHandler(BASE_URL, client_configs=CLIENT_CONFIG, + keyjar=CLI_KEY, module_dirs=['oidc']) + + res = rph_1.begin(issuer_id='linkedin') + _session = rph_1.get_session_information(res['state']) + client = rph_1.issuer2rp[_session['iss']] auth_response = AuthorizationResponse(code='access_code', state=res['state']) - resp = self.rph.finalize_auth(client, _session['iss'], auth_response.to_dict()) + resp = rph_1.finalize_auth(client, _session['iss'], auth_response.to_dict()) assert set(resp.keys()) == {'state', 'code'} - aresp = client.service['authorization'].get_item(AuthorizationResponse, - 'auth_response', res['state']) + aresp = client.client_get("service",'authorization').client_get("service_context").state.get_item( + AuthorizationResponse, 'auth_response', res['state']) assert set(aresp.keys()) == {'state', 'code'} def test_get_client_authn_method(self): - res = self.rph.begin(issuer_id='github') - _session = self.rph.get_session_information(res['state']) - client = self.rph.issuer2rp[_session['iss']] - authn_method = self.rph.get_client_authn_method(client, 'token_endpoint') + rph_1 = RPHandler(BASE_URL, client_configs=CLIENT_CONFIG, + keyjar=CLI_KEY, module_dirs=['oidc']) + + res = rph_1.begin(issuer_id='github') + _session = rph_1.get_session_information(res['state']) + client = rph_1.issuer2rp[_session['iss']] + authn_method = rph_1.get_client_authn_method(client, 'token_endpoint') assert authn_method == '' - res = self.rph.begin(issuer_id='linkedin') - _session = self.rph.get_session_information(res['state']) - client = self.rph.issuer2rp[_session['iss']] - authn_method = self.rph.get_client_authn_method(client, - 'token_endpoint') + res = rph_1.begin(issuer_id='linkedin') + _session = rph_1.get_session_information(res['state']) + client = rph_1.issuer2rp[_session['iss']] + authn_method = rph_1.get_client_authn_method(client, + 'token_endpoint') assert authn_method == 'client_secret_post' def test_get_access_token(self): - res = self.rph.begin(issuer_id='github') - _session = self.rph.get_session_information(res['state']) - client = self.rph.issuer2rp[_session['iss']] + rph_1 = RPHandler(BASE_URL, client_configs=CLIENT_CONFIG, + keyjar=CLI_KEY, module_dirs=['oidc']) + + res = rph_1.begin(issuer_id='github') + _session = rph_1.get_session_information(res['state']) + client = rph_1.issuer2rp[_session['iss']] _github_id = iss_id('github') - client.service_context.keyjar.import_jwks( + client.client_get("service_context").keyjar.import_jwks( GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) _nonce = _session['auth_request']['nonce'] @@ -445,29 +430,31 @@ def test_get_access_token(self): with responses.RequestsMock() as rsps: rsps.add("POST", _url, body=at.to_json(), adding_headers={"Content-Type": "application/json"}, status=200) - client.service['accesstoken'].endpoint = _url + client.client_get("service",'accesstoken').endpoint = _url auth_response = AuthorizationResponse(code='access_code', state=res['state']) - resp = self.rph.finalize_auth(client, _session['iss'], - auth_response.to_dict()) + resp = rph_1.finalize_auth(client, _session['iss'], + auth_response.to_dict()) - resp = self.rph.get_access_token(res['state'], client) + resp = rph_1.get_access_token(res['state'], client) assert set(resp.keys()) == {'access_token', 'expires_in', 'id_token', 'token_type', '__verified_id_token', '__expires_at'} - atresp = client.service['accesstoken'].get_item(AccessTokenResponse, - 'token_response', - res['state']) + atresp = client.client_get("service",'accesstoken').client_get("service_context").state.get_item( + AccessTokenResponse, 'token_response', res['state']) assert set(atresp.keys()) == {'access_token', 'expires_in', 'id_token', 'token_type', '__verified_id_token', '__expires_at'} def test_access_and_id_token(self): - res = self.rph.begin(issuer_id='github') - _session = self.rph.get_session_information(res['state']) - client = self.rph.issuer2rp[_session['iss']] + rph_1 = RPHandler(BASE_URL, client_configs=CLIENT_CONFIG, + keyjar=CLI_KEY, module_dirs=['oidc']) + + res = rph_1.begin(issuer_id='github') + _session = rph_1.get_session_information(res['state']) + client = rph_1.issuer2rp[_session['iss']] _nonce = _session['auth_request']['nonce'] _iss = _session['iss'] _aud = client.client_id @@ -477,7 +464,7 @@ def test_access_and_id_token(self): } _github_id = iss_id('github') - client.service_context.keyjar.import_jwks( + client.client_get("service_context").keyjar.import_jwks( GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) idts = IdToken(**idval) @@ -495,20 +482,23 @@ def test_access_and_id_token(self): with responses.RequestsMock() as rsps: rsps.add("POST", _url, body=at.to_json(), adding_headers={"Content-Type": "application/json"}, status=200) - client.service['accesstoken'].endpoint = _url + client.client_get("service",'accesstoken').endpoint = _url _response = AuthorizationResponse(code='access_code', state=res['state']) - auth_response = self.rph.finalize_auth(client, _session['iss'], - _response.to_dict()) - resp = self.rph.get_access_and_id_token(auth_response, client=client) + auth_response = rph_1.finalize_auth(client, _session['iss'], + _response.to_dict()) + resp = rph_1.get_access_and_id_token(auth_response, client=client) assert resp['access_token'] == 'accessTok' assert isinstance(resp['id_token'], IdToken) def test_access_and_id_token_by_reference(self): - res = self.rph.begin(issuer_id='github') - _session = self.rph.get_session_information(res['state']) - client = self.rph.issuer2rp[_session['iss']] + rph_1 = RPHandler(BASE_URL, client_configs=CLIENT_CONFIG, + keyjar=CLI_KEY, module_dirs=['oidc']) + + res = rph_1.begin(issuer_id='github') + _session = rph_1.get_session_information(res['state']) + client = rph_1.issuer2rp[_session['iss']] _nonce = _session['auth_request']['nonce'] _iss = _session['iss'] _aud = client.client_id @@ -518,7 +508,7 @@ def test_access_and_id_token_by_reference(self): } _github_id = iss_id('github') - client.service_context.keyjar.import_jwks( + client.client_get("service_context").keyjar.import_jwks( GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) idts = IdToken(**idval) @@ -536,20 +526,23 @@ def test_access_and_id_token_by_reference(self): with responses.RequestsMock() as rsps: rsps.add("POST", _url, body=at.to_json(), adding_headers={"Content-Type": "application/json"}, status=200) - client.service['accesstoken'].endpoint = _url + client.client_get("service",'accesstoken').endpoint = _url _response = AuthorizationResponse(code='access_code', state=res['state']) - _ = self.rph.finalize_auth(client, _session['iss'], - _response.to_dict()) - resp = self.rph.get_access_and_id_token(state=res['state']) + _ = rph_1.finalize_auth(client, _session['iss'], + _response.to_dict()) + resp = rph_1.get_access_and_id_token(state=res['state']) assert resp['access_token'] == 'accessTok' assert isinstance(resp['id_token'], IdToken) def test_get_user_info(self): - res = self.rph.begin(issuer_id='github') - _session = self.rph.get_session_information(res['state']) - client = self.rph.issuer2rp[_session['iss']] + rph_1 = RPHandler(BASE_URL, client_configs=CLIENT_CONFIG, + keyjar=CLI_KEY, module_dirs=['oidc']) + + res = rph_1.begin(issuer_id='github') + _session = rph_1.get_session_information(res['state']) + client = rph_1.issuer2rp[_session['iss']] _nonce = _session['auth_request']['nonce'] _iss = _session['iss'] _aud = client.client_id @@ -559,7 +552,7 @@ def test_get_user_info(self): } _github_id = iss_id('github') - client.service_context.keyjar.import_jwks( + client.client_get("service_context").keyjar.import_jwks( GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) idts = IdToken(**idval) @@ -577,30 +570,33 @@ def test_get_user_info(self): with responses.RequestsMock() as rsps: rsps.add("POST", _url, body=at.to_json(), adding_headers={"Content-Type": "application/json"}, status=200) - client.service['accesstoken'].endpoint = _url + client.client_get("service",'accesstoken').endpoint = _url _response = AuthorizationResponse(code='access_code', state=res['state']) - auth_response = self.rph.finalize_auth(client, _session['iss'], - _response.to_dict()) + auth_response = rph_1.finalize_auth(client, _session['iss'], + _response.to_dict()) - token_resp = self.rph.get_access_and_id_token(auth_response, - client=client) + token_resp = rph_1.get_access_and_id_token(auth_response, + client=client) _url = "https://github.com/user_info" with responses.RequestsMock() as rsps: rsps.add("GET", _url, body='{"sub":"EndUserSubject"}', adding_headers={"Content-Type": "application/json"}, status=200) - client.service['userinfo'].endpoint = _url + client.client_get("service",'userinfo').endpoint = _url - userinfo_resp = self.rph.get_user_info(res['state'], client, - token_resp['access_token']) + userinfo_resp = rph_1.get_user_info(res['state'], client, + token_resp['access_token']) assert userinfo_resp def test_userinfo_in_id_token(self): - res = self.rph.begin(issuer_id='github') - _session = self.rph.get_session_information(res['state']) - client = self.rph.issuer2rp[_session['iss']] + rph_1 = RPHandler(BASE_URL, client_configs=CLIENT_CONFIG, + keyjar=CLI_KEY, module_dirs=['oidc']) + + res = rph_1.begin(issuer_id='github') + _session = rph_1.get_session_information(res['state']) + client = rph_1.issuer2rp[_session['iss']] _nonce = _session['auth_request']['nonce'] _iss = _session['iss'] _aud = client.client_id @@ -612,326 +608,325 @@ def test_userinfo_in_id_token(self): idts = IdToken(**idval) - userinfo = self.rph.userinfo_in_id_token(idts) + userinfo = rph_1.userinfo_in_id_token(idts) assert set(userinfo.keys()) == {'sub', 'family_name', 'given_name', 'occupation'} - -def test_get_provider_specific_service(): - service_context = ServiceContext() - srv_desc = { - 'access_token': { - 'class': 'oidcrp.provider.github.AccessToken' - } - } - _srv = init_services(srv_desc, service_context) - assert _srv['accesstoken'].response_body_type == 'urlencoded' - - -class TestRPHandlerTier2(object): - @pytest.fixture(autouse=True) - def rphandler_setup(self): - self.rph = RPHandler(BASE_URL, CLIENT_CONFIG, keyjar=CLI_KEY) - res = self.rph.begin(issuer_id='github') - _session = self.rph.get_session_information(res['state']) - client = self.rph.issuer2rp[_session['iss']] - _nonce = _session['auth_request']['nonce'] - _iss = _session['iss'] - _aud = client.client_id - idval = { - 'nonce': _nonce, 'sub': 'EndUserSubject', 'iss': _iss, - 'aud': _aud - } - - _github_id = iss_id('github') - client.service_context.keyjar.import_jwks( - GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) - - idts = IdToken(**idval) - _signed_jwt = idts.to_jwt( - key=GITHUB_KEY.get_signing_key('rsa', issuer_id=_github_id), - algorithm="RS256", lifetime=300) - - _info = { - "access_token": "accessTok", "id_token": _signed_jwt, - "token_type": "Bearer", "expires_in": 3600, - 'refresh_token': 'refreshing' - } - - at = AccessTokenResponse(**_info) - _url = "https://github.com/token" - with responses.RequestsMock() as rsps: - rsps.add("POST", _url, body=at.to_json(), - adding_headers={"Content-Type": "application/json"}, status=200) - - client.service['accesstoken'].endpoint = _url - - _response = AuthorizationResponse(code='access_code', - state=res['state']) - auth_response = self.rph.finalize_auth(client, _session['iss'], - _response.to_dict()) - - token_resp = self.rph.get_access_and_id_token(auth_response, - client=client) - - _url = "https://github.com/token" - with responses.RequestsMock() as rsps: - rsps.add("GET", _url, body='{"sub":"EndUserSubject"}', - adding_headers={"Content-Type": "application/json"}, status=200) - - client.service['userinfo'].endpoint = _url - self.rph.get_user_info(res['state'], client, - token_resp['access_token']) - self.state = res['state'] - - def test_init_authorization(self): - _session = self.rph.get_session_information(self.state) - client = self.rph.issuer2rp[_session['iss']] - res = self.rph.init_authorization( - client, req_args={'scope': ['openid', 'email']}) - part = urlsplit(res['url']) - _qp = parse_qs(part.query) - assert _qp['scope'] == ['openid email'] - - def test_refresh_access_token(self): - _session = self.rph.get_session_information(self.state) - client = self.rph.issuer2rp[_session['iss']] - - _info = { - "access_token": "2nd_accessTok", - "token_type": "Bearer", "expires_in": 3600 - } - at = AccessTokenResponse(**_info) - _url = "https://github.com/token" - with responses.RequestsMock() as rsps: - rsps.add("POST", _url, body=at.to_json(), - adding_headers={"Content-Type": "application/json"}, status=200) - - client.service['refresh_token'].endpoint = _url - res = self.rph.refresh_access_token(self.state, client, 'openid email') - assert res['access_token'] == '2nd_accessTok' - - def test_get_user_info(self): - _session = self.rph.get_session_information(self.state) - client = self.rph.issuer2rp[_session['iss']] - - _url = "https://github.com/userinfo" - with responses.RequestsMock() as rsps: - rsps.add("GET", _url, body='{"sub":"EndUserSubject", "mail":"foo@example.com"}', - adding_headers={"Content-Type": "application/json"}, status=200) - client.service['userinfo'].endpoint = _url - - resp = self.rph.get_user_info(self.state, client) - assert set(resp.keys()) == {'sub', 'mail'} - assert resp['mail'] == 'foo@example.com' - - def test_has_active_authentication(self): - assert self.rph.has_active_authentication(self.state) - - def test_get_valid_access_token(self): - (token, expires_at) = self.rph.get_valid_access_token(self.state) - assert token == 'accessTok' - assert expires_at > 0 - - -class MockResponse(): - def __init__(self, status_code, text, headers=None): - self.status_code = status_code - self.text = text - self.headers = headers or {} - - -class MockOP(object): - def __init__(self, issuer, keyjar=None): - self.keyjar = keyjar - self.issuer = issuer - self.state = '' - self.nonce = '' - self.get_response = {} - self.register_get_response('default', 'OK', 200) - self.post_response = {} - self.register_post_response('default', 'OK', 200) - - def register_get_response(self, path, data, status_code=200, - headers=None): - _headers = headers or {} - self.get_response[path] = MockResponse(status_code, data, _headers) - - def register_post_response(self, path, data, status_code=200, headers=None): - _headers = headers or {} - self.post_response[path] = MockResponse(status_code, data, _headers) - - def __call__(self, url, method="GET", data=None, headers=None, **kwargs): - if method == 'GET': - p = urlparse(url) - try: - _resp = self.get_response[p.path] - except KeyError: - _resp = self.get_response['default'] - - if callable(_resp.text): - _data = _resp.text(data) - _resp = MockResponse(_resp.status_code, _data, _resp.headers) - - return _resp - elif method == 'POST': - p = urlparse(url) - try: - _resp = self.post_response[p.path] - except KeyError: - _resp = self.post_response['default'] - - if callable(_resp.text): - _data = _resp.text(data) - _resp = MockResponse(_resp.status_code, _data, _resp.headers) - - return _resp - - -def construct_access_token_response(nonce, issuer, client_id, key_jar): - _aud = client_id - - idval = { - 'nonce': nonce, 'sub': 'EndUserSubject', 'iss': issuer, - 'aud': _aud - } - - idts = IdToken(**idval) - _signed_jwt = idts.to_jwt( - key=key_jar.get_signing_key('rsa', issuer_id=issuer), - algorithm="RS256", lifetime=300) - - _info = { - "access_token": "accessTok", "id_token": _signed_jwt, - "token_type": "Bearer", "expires_in": 3600 - } - - return AccessTokenResponse(**_info) - - -def registration_callback(data): - _req = json.loads(data) - # add client_id and client_secret - _req['client_id'] = 'client1' - _req['client_secret'] = "ClientSecretString" - return json.dumps(_req) - - -class TestRPHandlerWithMockOP(object): - @pytest.fixture(autouse=True) - def rphandler_setup(self): - self.issuer = 'https://github.com/login/oauth/authorize' - self.mock_op = MockOP(issuer=self.issuer) - self.rph = RPHandler(BASE_URL, client_configs=CLIENT_CONFIG, - http_lib=self.mock_op, keyjar=KeyJar()) - - def test_finalize(self): - auth_query = self.rph.begin(issuer_id='github') - # The authorization query is sent and after successful authentication - client = self.rph.get_client_from_session_key( - state=auth_query['state']) - # register a response - p = urlparse( - CLIENT_CONFIG['github']['provider_info']['authorization_endpoint']) - self.mock_op.register_get_response(p.path, 'Redirect', 302) - - _ = client.http(auth_query['url']) - - # the user is redirected back to the RP with a positive response - auth_response = AuthorizationResponse(code='access_code', - state=auth_query['state']) - - # need session information and the client instance - _session = self.rph.get_session_information(auth_response['state']) - client = self.rph.get_client_from_session_key( - state=auth_response['state']) - - # Faking - resp = construct_access_token_response( - _session['auth_request']['nonce'], issuer=self.issuer, - client_id=CLIENT_CONFIG['github']['client_id'], - key_jar=GITHUB_KEY) - - p = urlparse( - CLIENT_CONFIG['github']['provider_info']['token_endpoint']) - self.mock_op.register_post_response( - p.path, resp.to_json(), 200, {'content-type': "application/json"} - ) - - _info = OpenIDSchema(sub='EndUserSubject', - given_name='Diana', - family_name='Krall', - occupation='Jazz pianist') - p = urlparse( - CLIENT_CONFIG['github']['provider_info']['userinfo_endpoint']) - self.mock_op.register_get_response( - p.path, _info.to_json(), 200, {'content-type': "application/json"}) - - _github_id = iss_id('github') - client.service_context.keyjar.import_jwks(GITHUB_KEY.export_jwks( - issuer_id=_github_id), _github_id) - - # do the rest (= get access token and user info) - # assume code flow - resp = self.rph.finalize(_session['iss'], auth_response.to_dict()) - - assert set(resp.keys()) == {'userinfo', 'state', 'token', 'id_token'} - - def test_dynamic_setup(self): - user_id = 'acct:foobar@example.com' - _link = Link(rel="http://openid.net/specs/connect/1.0/issuer", - href="https://server.example.com") - webfinger_response = JRD(subject=user_id, - links=[_link]) - self.mock_op.register_get_response( - '/.well-known/webfinger', webfinger_response.to_json(), 200, - {'content-type': "application/json"}) - - resp = { - "authorization_endpoint": - "https://server.example.com/connect/authorize", - "issuer": "https://server.example.com", - "subject_types_supported": ['public'], - "token_endpoint": "https://server.example.com/connect/token", - "token_endpoint_auth_methods_supported": ["client_secret_basic", - "private_key_jwt"], - "userinfo_endpoint": "https://server.example.com/connect/user", - "check_id_endpoint": "https://server.example.com/connect/check_id", - "refresh_session_endpoint": - "https://server.example.com/connect/refresh_session", - "end_session_endpoint": - "https://server.example.com/connect/end_session", - "jwks_uri": "https://server.example.com/jwk.json", - "registration_endpoint": - "https://server.example.com/connect/register", - "scopes_supported": ["openid", "profile", "email", "address", - "phone"], - "response_types_supported": ["code", "code id_token", - "token id_token"], - "acrs_supported": ["1", "2", - "http://id.incommon.org/assurance/bronze"], - "user_id_types_supported": ["public", "pairwise"], - "userinfo_algs_supported": ["HS256", "RS256", "A128CBC", "A128KW", - "RSA1_5"], - "id_token_signing_alg_values_supported": ["HS256", "RS256", - "A128CBC", "A128KW", - "RSA1_5"], - "request_object_algs_supported": ["HS256", "RS256", "A128CBC", - "A128KW", - "RSA1_5"] - } - - pcr = ProviderConfigurationResponse(**resp) - self.mock_op.register_get_response( - '/.well-known/openid-configuration', pcr.to_json(), 200, - {'content-type': "application/json"}) - - self.mock_op.register_post_response( - '/connect/register', registration_callback, 200, - {'content-type': "application/json"}) - - auth_query = self.rph.begin(user_id=user_id) - assert auth_query - client = self.rph.issuer2rp["https://server.example.com"] - assert len(client.service_context.keyjar.owners()) == 3 - assert 'client1' in client.service_context.keyjar +# def test_get_provider_specific_service(): +# service_context = ServiceContext() +# srv_desc = { +# 'access_token': { +# 'class': 'oidcrp.provider.github.AccessToken' +# } +# } +# _srv = init_services(srv_desc, service_context) +# assert _srv['accesstoken'].response_body_type == 'urlencoded' +# +# +# class TestRPHandlerTier2(object): +# @pytest.fixture(autouse=True) +# def rphandler_setup(self): +# rph_1 = RPHandler(BASE_URL, CLIENT_CONFIG, keyjar=CLI_KEY) +# res = rph_1.begin(issuer_id='github') +# _session = rph_1.get_session_information(res['state']) +# client = rph_1.issuer2rp[_session['iss']] +# _nonce = _session['auth_request']['nonce'] +# _iss = _session['iss'] +# _aud = client.client_id +# idval = { +# 'nonce': _nonce, 'sub': 'EndUserSubject', 'iss': _iss, +# 'aud': _aud +# } +# +# _github_id = iss_id('github') +# client.client_get("service_context").keyjar.import_jwks( +# GITHUB_KEY.export_jwks(issuer_id=_github_id), _github_id) +# +# idts = IdToken(**idval) +# _signed_jwt = idts.to_jwt( +# key=GITHUB_KEY.get_signing_key('rsa', issuer_id=_github_id), +# algorithm="RS256", lifetime=300) +# +# _info = { +# "access_token": "accessTok", "id_token": _signed_jwt, +# "token_type": "Bearer", "expires_in": 3600, +# 'refresh_token': 'refreshing' +# } +# +# at = AccessTokenResponse(**_info) +# _url = "https://github.com/token" +# with responses.RequestsMock() as rsps: +# rsps.add("POST", _url, body=at.to_json(), +# adding_headers={"Content-Type": "application/json"}, status=200) +# +# client.service['accesstoken'].endpoint = _url +# +# _response = AuthorizationResponse(code='access_code', +# state=res['state']) +# auth_response = rph_1.finalize_auth(client, _session['iss'], +# _response.to_dict()) +# +# token_resp = rph_1.get_access_and_id_token(auth_response, +# client=client) +# +# _url = "https://github.com/token" +# with responses.RequestsMock() as rsps: +# rsps.add("GET", _url, body='{"sub":"EndUserSubject"}', +# adding_headers={"Content-Type": "application/json"}, status=200) +# +# client.service['userinfo'].endpoint = _url +# rph_1.get_user_info(res['state'], client, +# token_resp['access_token']) +# self.state = res['state'] +# +# def test_init_authorization(self): +# _session = rph_1.get_session_information(self.state) +# client = rph_1.issuer2rp[_session['iss']] +# res = rph_1.init_authorization( +# client, req_args={'scope': ['openid', 'email']}) +# part = urlsplit(res['url']) +# _qp = parse_qs(part.query) +# assert _qp['scope'] == ['openid email'] +# +# def test_refresh_access_token(self): +# _session = rph_1.get_session_information(self.state) +# client = rph_1.issuer2rp[_session['iss']] +# +# _info = { +# "access_token": "2nd_accessTok", +# "token_type": "Bearer", "expires_in": 3600 +# } +# at = AccessTokenResponse(**_info) +# _url = "https://github.com/token" +# with responses.RequestsMock() as rsps: +# rsps.add("POST", _url, body=at.to_json(), +# adding_headers={"Content-Type": "application/json"}, status=200) +# +# client.service['refresh_token'].endpoint = _url +# res = rph_1.refresh_access_token(self.state, client, 'openid email') +# assert res['access_token'] == '2nd_accessTok' +# +# def test_get_user_info(self): +# _session = rph_1.get_session_information(self.state) +# client = rph_1.issuer2rp[_session['iss']] +# +# _url = "https://github.com/userinfo" +# with responses.RequestsMock() as rsps: +# rsps.add("GET", _url, body='{"sub":"EndUserSubject", "mail":"foo@example.com"}', +# adding_headers={"Content-Type": "application/json"}, status=200) +# client.service['userinfo'].endpoint = _url +# +# resp = rph_1.get_user_info(self.state, client) +# assert set(resp.keys()) == {'sub', 'mail'} +# assert resp['mail'] == 'foo@example.com' +# +# def test_has_active_authentication(self): +# assert rph_1.has_active_authentication(self.state) +# +# def test_get_valid_access_token(self): +# (token, expires_at) = rph_1.get_valid_access_token(self.state) +# assert token == 'accessTok' +# assert expires_at > 0 +# +# +# class MockResponse(): +# def __init__(self, status_code, text, headers=None): +# self.status_code = status_code +# self.text = text +# self.headers = headers or {} +# +# +# class MockOP(object): +# def __init__(self, issuer, keyjar=None): +# self.keyjar = keyjar +# self.issuer = issuer +# self.state = '' +# self.nonce = '' +# self.get_response = {} +# self.register_get_response('default', 'OK', 200) +# self.post_response = {} +# self.register_post_response('default', 'OK', 200) +# +# def register_get_response(self, path, data, status_code=200, +# headers=None): +# _headers = headers or {} +# self.get_response[path] = MockResponse(status_code, data, _headers) +# +# def register_post_response(self, path, data, status_code=200, headers=None): +# _headers = headers or {} +# self.post_response[path] = MockResponse(status_code, data, _headers) +# +# def __call__(self, url, method="GET", data=None, headers=None, **kwargs): +# if method == 'GET': +# p = urlparse(url) +# try: +# _resp = self.get_response[p.path] +# except KeyError: +# _resp = self.get_response['default'] +# +# if callable(_resp.text): +# _data = _resp.text(data) +# _resp = MockResponse(_resp.status_code, _data, _resp.headers) +# +# return _resp +# elif method == 'POST': +# p = urlparse(url) +# try: +# _resp = self.post_response[p.path] +# except KeyError: +# _resp = self.post_response['default'] +# +# if callable(_resp.text): +# _data = _resp.text(data) +# _resp = MockResponse(_resp.status_code, _data, _resp.headers) +# +# return _resp +# +# +# def construct_access_token_response(nonce, issuer, client_id, key_jar): +# _aud = client_id +# +# idval = { +# 'nonce': nonce, 'sub': 'EndUserSubject', 'iss': issuer, +# 'aud': _aud +# } +# +# idts = IdToken(**idval) +# _signed_jwt = idts.to_jwt( +# key=key_jar.get_signing_key('rsa', issuer_id=issuer), +# algorithm="RS256", lifetime=300) +# +# _info = { +# "access_token": "accessTok", "id_token": _signed_jwt, +# "token_type": "Bearer", "expires_in": 3600 +# } +# +# return AccessTokenResponse(**_info) +# +# +# def registration_callback(data): +# _req = json.loads(data) +# # add client_id and client_secret +# _req['client_id'] = 'client1' +# _req['client_secret'] = "ClientSecretString" +# return json.dumps(_req) +# +# +# class TestRPHandlerWithMockOP(object): +# @pytest.fixture(autouse=True) +# def rphandler_setup(self): +# self.issuer = 'https://github.com/login/oauth/authorize' +# self.mock_op = MockOP(issuer=self.issuer) +# rph_1 = RPHandler(BASE_URL, client_configs=CLIENT_CONFIG, +# http_lib=self.mock_op, keyjar=KeyJar()) +# +# def test_finalize(self): +# auth_query = rph_1.begin(issuer_id='github') +# # The authorization query is sent and after successful authentication +# client = rph_1.get_client_from_session_key( +# state=auth_query['state']) +# # register a response +# p = urlparse( +# CLIENT_CONFIG['github']['provider_info']['authorization_endpoint']) +# self.mock_op.register_get_response(p.path, 'Redirect', 302) +# +# _ = client.http(auth_query['url']) +# +# # the user is redirected back to the RP with a positive response +# auth_response = AuthorizationResponse(code='access_code', +# state=auth_query['state']) +# +# # need session information and the client instance +# _session = rph_1.get_session_information(auth_response['state']) +# client = rph_1.get_client_from_session_key( +# state=auth_response['state']) +# +# # Faking +# resp = construct_access_token_response( +# _session['auth_request']['nonce'], issuer=self.issuer, +# client_id=CLIENT_CONFIG['github']['client_id'], +# key_jar=GITHUB_KEY) +# +# p = urlparse( +# CLIENT_CONFIG['github']['provider_info']['token_endpoint']) +# self.mock_op.register_post_response( +# p.path, resp.to_json(), 200, {'content-type': "application/json"} +# ) +# +# _info = OpenIDSchema(sub='EndUserSubject', +# given_name='Diana', +# family_name='Krall', +# occupation='Jazz pianist') +# p = urlparse( +# CLIENT_CONFIG['github']['provider_info']['userinfo_endpoint']) +# self.mock_op.register_get_response( +# p.path, _info.to_json(), 200, {'content-type': "application/json"}) +# +# _github_id = iss_id('github') +# client.client_get("service_context").keyjar.import_jwks(GITHUB_KEY.export_jwks( +# issuer_id=_github_id), _github_id) +# +# # do the rest (= get access token and user info) +# # assume code flow +# resp = rph_1.finalize(_session['iss'], auth_response.to_dict()) +# +# assert set(resp.keys()) == {'userinfo', 'state', 'token', 'id_token'} +# +# def test_dynamic_setup(self): +# user_id = 'acct:foobar@example.com' +# _link = Link(rel="http://openid.net/specs/connect/1.0/issuer", +# href="https://server.example.com") +# webfinger_response = JRD(subject=user_id, +# links=[_link]) +# self.mock_op.register_get_response( +# '/.well-known/webfinger', webfinger_response.to_json(), 200, +# {'content-type': "application/json"}) +# +# resp = { +# "authorization_endpoint": +# "https://server.example.com/connect/authorize", +# "issuer": "https://server.example.com", +# "subject_types_supported": ['public'], +# "token_endpoint": "https://server.example.com/connect/token", +# "token_endpoint_auth_methods_supported": ["client_secret_basic", +# "private_key_jwt"], +# "userinfo_endpoint": "https://server.example.com/connect/user", +# "check_id_endpoint": "https://server.example.com/connect/check_id", +# "refresh_session_endpoint": +# "https://server.example.com/connect/refresh_session", +# "end_session_endpoint": +# "https://server.example.com/connect/end_session", +# "jwks_uri": "https://server.example.com/jwk.json", +# "registration_endpoint": +# "https://server.example.com/connect/register", +# "scopes_supported": ["openid", "profile", "email", "address", +# "phone"], +# "response_types_supported": ["code", "code id_token", +# "token id_token"], +# "acrs_supported": ["1", "2", +# "http://id.incommon.org/assurance/bronze"], +# "user_id_types_supported": ["public", "pairwise"], +# "userinfo_algs_supported": ["HS256", "RS256", "A128CBC", "A128KW", +# "RSA1_5"], +# "id_token_signing_alg_values_supported": ["HS256", "RS256", +# "A128CBC", "A128KW", +# "RSA1_5"], +# "request_object_algs_supported": ["HS256", "RS256", "A128CBC", +# "A128KW", +# "RSA1_5"] +# } +# +# pcr = ProviderConfigurationResponse(**resp) +# self.mock_op.register_get_response( +# '/.well-known/openid-configuration', pcr.to_json(), 200, +# {'content-type': "application/json"}) +# +# self.mock_op.register_post_response( +# '/connect/register', registration_callback, 200, +# {'content-type': "application/json"}) +# +# auth_query = rph_1.begin(user_id=user_id) +# assert auth_query +# client = rph_1.issuer2rp["https://server.example.com"] +# assert len(client.client_get("service_context").keyjar.owners()) == 3 +# assert 'client1' in client.client_get("service_context").keyjar diff --git a/chrp/README.txt b/unsupported/chrp/README.txt similarity index 100% rename from chrp/README.txt rename to unsupported/chrp/README.txt diff --git a/flask_rp/certs/cert.pem b/unsupported/chrp/certs/cert.pem similarity index 100% rename from flask_rp/certs/cert.pem rename to unsupported/chrp/certs/cert.pem diff --git a/flask_rp/certs/key.pem b/unsupported/chrp/certs/key.pem similarity index 100% rename from flask_rp/certs/key.pem rename to unsupported/chrp/certs/key.pem diff --git a/chrp/conf.py b/unsupported/chrp/conf.py similarity index 78% rename from chrp/conf.py rename to unsupported/chrp/conf.py index 9f24470..e2bfbdb 100644 --- a/chrp/conf.py +++ b/unsupported/chrp/conf.py @@ -26,19 +26,19 @@ SERVICES = ['ProviderInfoDiscovery', 'Registration', 'Authorization', 'AccessToken', 'RefreshAccessToken', 'UserInfo'] -SERVICES_DICT = {'accesstoken': {'class': 'oidcservice.oidc.access_token.AccessToken', +SERVICES_DICT = {'accesstoken': {'class': 'oidcrp.oidc.access_token.AccessToken', 'kwargs': {}}, - 'authorization': {'class': 'oidcservice.oidc.authorization.Authorization', + 'authorization': {'class': 'oidcrp.oidc.authorization.Authorization', 'kwargs': {}}, - 'discovery': {'class': 'oidcservice.oidc.provider_info_discovery.ProviderInfoDiscovery', + 'discovery': {'class': 'oidcrp.oidc.provider_info_discovery.ProviderInfoDiscovery', 'kwargs': {}}, - 'end_session': {'class': 'oidcservice.oidc.end_session.EndSession', + 'end_session': {'class': 'oidcrp.oidc.end_session.EndSession', 'kwargs': {}}, - 'refresh_accesstoken': {'class': 'oidcservice.oidc.refresh_access_token.RefreshAccessToken', + 'refresh_accesstoken': {'class': 'oidcrp.oidc.refresh_access_token.RefreshAccessToken', 'kwargs': {}}, - 'registration': {'class': 'oidcservice.oidc.registration.Registration', + 'registration': {'class': 'oidcrp.oidc.registration.Registration', 'kwargs': {}}, - 'userinfo': {'class': 'oidcservice.oidc.userinfo.UserInfo', 'kwargs': {}}} + 'userinfo': {'class': 'oidcrp.oidc.userinfo.UserInfo', 'kwargs': {}}} CLIENT_PREFS = { "application_type": "web", diff --git a/chrp/config.py b/unsupported/chrp/config.py similarity index 100% rename from chrp/config.py rename to unsupported/chrp/config.py diff --git a/chrp/cprp.py b/unsupported/chrp/cprp.py similarity index 100% rename from chrp/cprp.py rename to unsupported/chrp/cprp.py diff --git a/chrp/example_conf.py b/unsupported/chrp/example_conf.py similarity index 100% rename from chrp/example_conf.py rename to unsupported/chrp/example_conf.py diff --git a/chrp/html/opbyuid.html b/unsupported/chrp/html/opbyuid.html similarity index 100% rename from chrp/html/opbyuid.html rename to unsupported/chrp/html/opbyuid.html diff --git a/chrp/html/opresult.html b/unsupported/chrp/html/opresult.html similarity index 100% rename from chrp/html/opresult.html rename to unsupported/chrp/html/opresult.html diff --git a/chrp/html/repost_fragment.html b/unsupported/chrp/html/repost_fragment.html similarity index 100% rename from chrp/html/repost_fragment.html rename to unsupported/chrp/html/repost_fragment.html diff --git a/chrp/jwks_dir/jwks.json b/unsupported/chrp/jwks_dir/jwks.json similarity index 100% rename from chrp/jwks_dir/jwks.json rename to unsupported/chrp/jwks_dir/jwks.json diff --git a/chrp/make_opbyuid_html.py b/unsupported/chrp/make_opbyuid_html.py similarity index 100% rename from chrp/make_opbyuid_html.py rename to unsupported/chrp/make_opbyuid_html.py diff --git a/chrp/rp.py b/unsupported/chrp/rp.py similarity index 100% rename from chrp/rp.py rename to unsupported/chrp/rp.py diff --git a/chrp/static/jwks.json b/unsupported/chrp/static/jwks.json similarity index 100% rename from chrp/static/jwks.json rename to unsupported/chrp/static/jwks.json diff --git a/unsupported/chrp/utils.py b/unsupported/chrp/utils.py new file mode 100644 index 0000000..e69de29