diff --git a/doc/source/add_on/pkce.rst b/doc/source/add_on/pkce.rst index 36e85c2..0400d9a 100644 --- a/doc/source/add_on/pkce.rst +++ b/doc/source/add_on/pkce.rst @@ -8,8 +8,6 @@ Proof Key for Code Exchange Introduction ------------ - - OAuth 2.0 public clients utilizing the Authorization Code Grant are susceptible to the authorization code interception attack. `RFC7636`_ describes the attack as well as a technique to mitigate diff --git a/example/flask_rp/wsgi.py b/example/flask_rp/wsgi.py index be25906..5377046 100755 --- a/example/flask_rp/wsgi.py +++ b/example/flask_rp/wsgi.py @@ -3,9 +3,10 @@ import os import sys +from oidcmsg.configure import create_from_config_file + from oidcrp.configure import Configuration from oidcrp.configure import RPConfiguration -from oidcrp.configure import create_from_config_file from oidcrp.util import create_context try: diff --git a/setup.py b/setup.py index 17cc3e1..ba39a70 100755 --- a/setup.py +++ b/setup.py @@ -74,7 +74,7 @@ def run_tests(self): "Programming Language :: Python :: 3.9", "Topic :: Software Development :: Libraries :: Python Modules"], install_requires=[ - 'oidcmsg>=1.5.3', + 'oidcmsg==1.5.4', 'pyyaml>=5.1.2', 'responses' ], diff --git a/src/oidcrp/__init__.py b/src/oidcrp/__init__.py index 1f6bf90..0a60c74 100644 --- a/src/oidcrp/__init__.py +++ b/src/oidcrp/__init__.py @@ -1,7 +1,7 @@ import logging __author__ = 'Roland Hedberg' -__version__ = '2.1.1' +__version__ = '2.1.2' logger = logging.getLogger(__name__) diff --git a/src/oidcrp/configure.py b/src/oidcrp/configure.py index 84ac367..800d24f 100755 --- a/src/oidcrp/configure.py +++ b/src/oidcrp/configure.py @@ -7,6 +7,8 @@ from typing import List from typing import Optional +from oidcmsg.configure import Base + from oidcrp.logging import configure_logging from oidcrp.util import load_yaml_config from oidcrp.util import lower_or_upper @@ -16,89 +18,6 @@ except ImportError: from cryptojwt import rndstr as rnd_token -DEFAULT_FILE_ATTRIBUTE_NAMES = ['server_key', 'server_cert', 'filename', 'template_dir', - 'private_path', 'public_path', 'db_file'] - - -def add_base_path(conf: dict, base_path: str, file_attributes: List[str]): - for key, val in conf.items(): - if key in file_attributes: - if val.startswith("/"): - continue - elif val == "": - conf[key] = "./" + val - else: - conf[key] = os.path.join(base_path, val) - if isinstance(val, dict): - conf[key] = add_base_path(val, base_path, file_attributes) - - return conf - - -def set_domain_and_port(conf: dict, uris: List[str], domain: str, port: int): - for key, val in conf.items(): - if key in uris: - if not val: - continue - - if isinstance(val, list): - _new = [v.format(domain=domain, port=port) for v in val] - else: - _new = val.format(domain=domain, port=port) - conf[key] = _new - elif isinstance(val, dict): - conf[key] = set_domain_and_port(val, uris, domain, port) - return conf - - -class Base: - """ Configuration base class """ - - def __init__(self, - conf: Dict, - base_path: str = '', - file_attributes: Optional[List[str]] = None, - ): - - if file_attributes is None: - file_attributes = DEFAULT_FILE_ATTRIBUTE_NAMES - - if base_path and file_attributes: - # this adds a base path to all paths in the configuration - add_base_path(conf, base_path, file_attributes) - - def __getitem__(self, item): - if item in self.__dict__: - return self.__dict__[item] - else: - raise KeyError - - def get(self, item, default=None): - return getattr(self, item, default) - - def __contains__(self, item): - return item in self.__dict__ - - def items(self): - for key in self.__dict__: - if key.startswith('__') and key.endswith('__'): - continue - yield key, getattr(self, key) - - def extend(self, entity_conf, conf, base_path, file_attributes, domain, port): - for econf in entity_conf: - _path = econf.get("path") - _cnf = conf - if _path: - for step in _path: - _cnf = _cnf[step] - _attr = econf["attr"] - _cls = econf["class"] - setattr(self, _attr, - _cls(_cnf, base_path=base_path, file_attributes=file_attributes, - domain=domain, port=port)) - - URIS = [ "redirect_uris", 'post_logout_redirect_uris', 'frontchannel_logout_uri', 'backchannel_logout_uri', 'issuer', 'base_url'] @@ -112,23 +31,17 @@ def __init__(self, domain: Optional[str] = "127.0.0.1", port: Optional[int] = 80, file_attributes: Optional[List[str]] = None, + dir_attributes: Optional[List[str]] = None, ): - Base.__init__(self, conf, base_path=base_path, file_attributes=file_attributes) - - _keys_conf = lower_or_upper(conf, 'rp_keys') - if _keys_conf is None: - _keys_conf = lower_or_upper(conf, 'oidc_keys') # legacy - - self.keys = _keys_conf + Base.__init__(self, conf, + base_path=base_path, + domain=domain, + port=port, + file_attributes=file_attributes, + dir_attributes=dir_attributes) - if not domain: - domain = conf.get("domain", "127.0.0.1") - - if not port: - port = conf.get("port", 80) - - conf = set_domain_and_port(conf, URIS, domain, port) + self.key_conf = lower_or_upper(conf, 'rp_keys') or lower_or_upper(conf, 'oidc_keys') self.clients = lower_or_upper(conf, "clients") hash_seed = lower_or_upper(conf, 'hash_seed') @@ -155,8 +68,10 @@ def __init__(self, file_attributes: Optional[List[str]] = None, domain: Optional[str] = "", port: Optional[int] = 0, + dir_attributes: Optional[List[str]] = None, ): - Base.__init__(self, conf, base_path=base_path, file_attributes=file_attributes) + Base.__init__(self, conf, base_path=base_path, file_attributes=file_attributes, + dir_attributes=dir_attributes) log_conf = conf.get('logging') if log_conf: @@ -166,40 +81,35 @@ def __init__(self, self.web_conf = lower_or_upper(conf, "webserver") - # entity info - if not domain: - domain = conf.get("domain", "127.0.0.1") - - if not port: - port = conf.get("port", 80) - if entity_conf: self.extend(entity_conf=entity_conf, conf=conf, base_path=base_path, - file_attributes=file_attributes, domain=domain, port=port) - - -def create_from_config_file(cls, - filename: str, - base_path: Optional[str] = '', - entity_conf: Optional[List[dict]] = None, - file_attributes: Optional[List[str]] = None, - domain: Optional[str] = "", - port: Optional[int] = 0): - if filename.endswith(".yaml"): - """Load configuration as YAML""" - _cnf = load_yaml_config(filename) - elif filename.endswith(".json"): - _str = open(filename).read() - _cnf = json.loads(_str) - elif filename.endswith(".py"): - head, tail = os.path.split(filename) - tail = tail[:-3] - module = importlib.import_module(tail) - _cnf = getattr(module, "CONFIG") - else: - raise ValueError("Unknown file type") - - return cls(_cnf, - entity_conf=entity_conf, - base_path=base_path, file_attributes=file_attributes, - domain=domain, port=port) + file_attributes=file_attributes, domain=domain, port=port, + dir_attributes=dir_attributes) + + +# def create_from_config_file(cls, +# filename: str, +# base_path: Optional[str] = '', +# entity_conf: Optional[List[dict]] = None, +# file_attributes: Optional[List[str]] = None, +# dir_attributes: Optional[List[str]] = None, +# domain: Optional[str] = "", +# port: Optional[int] = 0): +# if filename.endswith(".yaml"): +# """Load configuration as YAML""" +# _cnf = load_yaml_config(filename) +# elif filename.endswith(".json"): +# _str = open(filename).read() +# _cnf = json.loads(_str) +# elif filename.endswith(".py"): +# head, tail = os.path.split(filename) +# tail = tail[:-3] +# module = importlib.import_module(tail) +# _cnf = getattr(module, "CONFIG") +# else: +# raise ValueError("Unknown file type") +# +# return cls(_cnf, +# entity_conf=entity_conf, +# base_path=base_path, file_attributes=file_attributes, +# domain=domain, port=port, dir_attributes=dir_attributes) diff --git a/src/oidcrp/rp_handler.py b/src/oidcrp/rp_handler.py index 957ab2d..63356ae 100644 --- a/src/oidcrp/rp_handler.py +++ b/src/oidcrp/rp_handler.py @@ -21,7 +21,7 @@ from oidcmsg.oidc import OpenIDSchema from oidcmsg.oidc import RegistrationRequest from oidcmsg.oidc.session import BackChannelLogoutRequest -from oidcmsg.time_util import time_sans_frac +from oidcmsg.time_util import utc_time_sans_frac from . import oidc from .defaults import DEFAULT_CLIENT_CONFIGS @@ -836,7 +836,7 @@ def has_active_authentication(self, state): ['auth_response', 'token_response', 'refresh_token_response']) if _arg: - _now = time_sans_frac() + _now = utc_time_sans_frac() exp = _arg['__verified_id_token']['exp'] return _now < exp else: @@ -854,7 +854,7 @@ def get_valid_access_token(self, state): exp = 0 token = None indefinite = [] - now = time_sans_frac() + now = utc_time_sans_frac() client = self.get_client_from_session_key(state) _context = client.client_get("service_context") diff --git a/tests/pub_client.jwks b/tests/pub_client.jwks index d16a636..a57e904 100644 --- a/tests/pub_client.jwks +++ b/tests/pub_client.jwks @@ -1 +1 @@ -{"keys": [{"kty": "RSA", "use": "sig", "kid": "SUswNi1MRFlDT0Y2YjU1Z1RfQlo2S3dEa3FTTkV3LThFcnhDTHF5elk2VQ", "e": "AQAB", "n": "0UkUx2ewKyc-XJ1o0ToyGjws_JybAMZj2oYjsPyyvQ_T5dhZ2VmRRRkhsaVJ2xE_GGc7mSG0IjmGFyXp5y0w4mJBcsAEE5-8eBTvQdYIryjW74r3jt6Fi4Hlm1yFMTie3apv8mw79BUj-jT0kh3_m-FiKKUvLsq45DcLtTJ4cx7Ize37dl1sFSpQcoYMk7eiUEM8fiNboiVwvBYNAWVMkUM-LnVUPm3UjvKp0LihYEkZFWOxmuQmj2x25SFUkjus38ERrRqJQBZduxdBHFrWtWg8yOA53BkMU0FFg_r0H3ctl-5GaKw-BWlogU4qXnsq85xy0EoenRk7FPV8g_ulJw"}, {"kty": "EC", "use": "sig", "kid": "NC1pdGRQN002bWM3bk1xX2R0SktscElqbFdtN29ITDV2WVd2b0hOYzREVQ", "crv": "P-256", "x": "kK7Qp1woSerI7rUOAwW_4sU6ZmwV3wwXKX3VU-v2fMI", "y": "iPWd_Pjq6EjxYy08KNFZ3PxhEwgWHgAQTTknlKMKJA0"}]} \ No newline at end of file +{"keys": [{"kty": "RSA", "use": "sig", "kid": "SUswNi1MRFlDT0Y2YjU1Z1RfQlo2S3dEa3FTTkV3LThFcnhDTHF5elk2VQ", "n": "0UkUx2ewKyc-XJ1o0ToyGjws_JybAMZj2oYjsPyyvQ_T5dhZ2VmRRRkhsaVJ2xE_GGc7mSG0IjmGFyXp5y0w4mJBcsAEE5-8eBTvQdYIryjW74r3jt6Fi4Hlm1yFMTie3apv8mw79BUj-jT0kh3_m-FiKKUvLsq45DcLtTJ4cx7Ize37dl1sFSpQcoYMk7eiUEM8fiNboiVwvBYNAWVMkUM-LnVUPm3UjvKp0LihYEkZFWOxmuQmj2x25SFUkjus38ERrRqJQBZduxdBHFrWtWg8yOA53BkMU0FFg_r0H3ctl-5GaKw-BWlogU4qXnsq85xy0EoenRk7FPV8g_ulJw", "e": "AQAB"}, {"kty": "EC", "use": "sig", "kid": "NC1pdGRQN002bWM3bk1xX2R0SktscElqbFdtN29ITDV2WVd2b0hOYzREVQ", "crv": "P-256", "x": "kK7Qp1woSerI7rUOAwW_4sU6ZmwV3wwXKX3VU-v2fMI", "y": "iPWd_Pjq6EjxYy08KNFZ3PxhEwgWHgAQTTknlKMKJA0"}]} \ No newline at end of file diff --git a/tests/test_11_oauth2.py b/tests/test_11_oauth2.py index d3d6b3a..fca1d93 100644 --- a/tests/test_11_oauth2.py +++ b/tests/test_11_oauth2.py @@ -4,6 +4,7 @@ from cryptojwt.jwk.rsa import import_private_rsa_key_from_file from cryptojwt.key_bundle import KeyBundle +from oidcmsg.configure import create_from_config_file from oidcmsg.oauth2 import AccessTokenRequest from oidcmsg.oauth2 import AccessTokenResponse from oidcmsg.oauth2 import AuthorizationRequest @@ -14,6 +15,7 @@ from oidcmsg.time_util import utc_time_sans_frac import pytest +from oidcrp.configure import RPConfiguration from oidcrp.exception import OidcServiceError from oidcrp.exception import ParseError from oidcrp.oauth2 import Client @@ -60,7 +62,7 @@ def test_construct_authorization_request(self): } self.client.client_get("service_context").state.create_state('issuer', key='ABCDE') - msg = self.client.client_get("service",'authorization').construct(request_args=req_args) + msg = self.client.client_get("service", 'authorization').construct(request_args=req_args) assert isinstance(msg, AuthorizationRequest) assert msg['client_id'] == 'client_1' assert msg['redirect_uri'] == 'https://example.com/auth_cb' @@ -81,9 +83,9 @@ def test_construct_accesstoken_request(self): auth_response = AuthorizationResponse(code='access_code') self.client.client_get("service_context").state.store_item(auth_response, - 'auth_response', 'ABCDE') + 'auth_response', 'ABCDE') - msg = self.client.client_get("service",'accesstoken').construct( + msg = self.client.client_get("service", 'accesstoken').construct( request_args=req_args, state='ABCDE') assert isinstance(msg, AccessTokenRequest) @@ -105,11 +107,11 @@ def test_construct_refresh_token_request(self): state='state' ) - _context.state.store_item(auth_request, 'auth_request','ABCDE') + _context.state.store_item(auth_request, 'auth_request', 'ABCDE') auth_response = AuthorizationResponse(code='access_code') - _context.state.store_item(auth_response,'auth_response', 'ABCDE') + _context.state.store_item(auth_response, 'auth_response', 'ABCDE') token_response = AccessTokenResponse(refresh_token="refresh_with_me", access_token="access") @@ -117,7 +119,7 @@ def test_construct_refresh_token_request(self): _context.state.store_item(token_response, 'token_response', 'ABCDE') req_args = {} - msg = self.client.client_get("service",'refresh_token').construct( + msg = self.client.client_get("service", 'refresh_token').construct( request_args=req_args, state='ABCDE') assert isinstance(msg, RefreshAccessTokenRequest) assert msg.to_dict() == { @@ -131,7 +133,7 @@ def test_error_response(self): err = ResponseMessage(error='Illegal') http_resp = MockResponse(400, err.to_urlencoded()) resp = self.client.parse_request_response( - self.client.client_get("service",'authorization'), http_resp) + self.client.client_get("service", 'authorization'), http_resp) assert resp['error'] == 'Illegal' assert resp['status_code'] == 400 @@ -141,7 +143,7 @@ def test_error_response_500(self): http_resp = MockResponse(500, err.to_urlencoded()) with pytest.raises(ParseError): self.client.parse_request_response( - self.client.client_get("service",'authorization'), http_resp) + self.client.client_get("service", 'authorization'), http_resp) def test_error_response_2(self): err = ResponseMessage(error='Illegal') @@ -151,4 +153,42 @@ def test_error_response_2(self): with pytest.raises(OidcServiceError): self.client.parse_request_response( - self.client.client_get("service",'authorization'), http_resp) + self.client.client_get("service", 'authorization'), http_resp) + + +class TestClient2(object): + @pytest.fixture(autouse=True) + def create_client(self): + self.redirect_uri = "http://example.com/redirect" + KEYSPEC = [ + {"type": "RSA", "use": ["sig"]}, + {"type": "EC", "crv": "P-256", "use": ["sig"]}, + ] + + conf = { + 'redirect_uris': ['https://example.com/cli/authz_cb'], + 'client_id': 'client_1', + 'client_secret': 'abcdefghijklmnop', + 'rp_keys': { + 'private_path': 'private/jwks.json', + 'key_defs': KEYSPEC, + 'public_path': 'static/jwks.json', + # this will create the jwks files if they are absent + 'read_only': False + } + } + rp_conf = RPConfiguration(conf) + self.client = Client(config=rp_conf) + assert self.client + + def test_keyjar(self): + req_args = { + 'state': 'ABCDE', + 'redirect_uri': 'https://example.com/auth_cb', + 'response_type': ['code'] + } + + _context = self.client.client_get("service_context") + assert len(_context.keyjar) == 1 # one issuer + assert len(_context.keyjar[""]) == 2 + assert len(_context.keyjar.get("sig")) == 2 \ No newline at end of file diff --git a/tests/test_17_read_registration.py b/tests/test_17_read_registration.py index ca01f0e..3e6bb84 100644 --- a/tests/test_17_read_registration.py +++ b/tests/test_17_read_registration.py @@ -1,15 +1,13 @@ import json import time -import pytest -import responses from cryptojwt.utils import as_bytes from oidcmsg.oidc import RegistrationResponse +import pytest +import responses from oidcrp.entity import Entity import requests -from oidcrp.service_context import ServiceContext -from oidcrp.service_factory import service_factory ISS = "https://example.com" RP_BASEURL = "https://example.com/rp" @@ -44,8 +42,8 @@ def create_request(self): self.entity = Entity(config=client_config, services=services) - self.reg_service = self.entity.client_get("service",'registration') - self.read_service = self.entity.client_get("service",'registration_read') + self.reg_service = self.entity.client_get("service", 'registration') + self.read_service = self.entity.client_get("service", 'registration_read') def test_construct(self): self.reg_service.endpoint = "{}/registration".format(ISS) @@ -70,7 +68,8 @@ def test_construct(self): }) with responses.RequestsMock() as rsps: - rsps.add(_param["method"], _param["url"], body=_client_registration_response, status=200) + rsps.add(_param["method"], _param["url"], body=_client_registration_response, + status=200) _resp = requests.request( _param["method"], _param["url"], data=as_bytes(_param["body"]), diff --git a/tests/test_20_rp_handler_oidc.py b/tests/test_20_rp_handler_oidc.py index 1191cd9..8c4dde1 100644 --- a/tests/test_20_rp_handler_oidc.py +++ b/tests/test_20_rp_handler_oidc.py @@ -4,7 +4,6 @@ from urllib.parse import urlparse from urllib.parse import urlsplit -from cryptojwt.key_jar import KeyJar from cryptojwt.key_jar import init_key_jar from oidcmsg.oidc import AccessTokenResponse from oidcmsg.oidc import AuthorizationResponse @@ -312,7 +311,8 @@ def test_do_client_registration(self): # only 2 things should have happened assert self.rph.hash2issuer['github'] == issuer - assert client.client_get("service_context").callback.get("post_logout_redirect_uris") is None + assert client.client_get("service_context").callback.get( + "post_logout_redirect_uris") is None def test_do_client_setup(self): client = self.rph.client_setup('github') diff --git a/tests/test_22_config.py b/tests/test_22_config.py index f9dc91f..bc29a52 100644 --- a/tests/test_22_config.py +++ b/tests/test_22_config.py @@ -1,8 +1,9 @@ import os +from oidcmsg.configure import create_from_config_file + from oidcrp.configure import Configuration from oidcrp.configure import RPConfiguration -from oidcrp.configure import create_from_config_file _dirname = os.path.dirname(os.path.abspath(__file__))