diff --git a/.clang-format b/.clang-format
new file mode 100644
index 00000000..b2bb93db
--- /dev/null
+++ b/.clang-format
@@ -0,0 +1,17 @@
+# A clang-format style that approximates Python's PEP 7
+BasedOnStyle: Google
+AlwaysBreakAfterReturnType: All
+AllowShortIfStatementsOnASingleLine: false
+AlignAfterOpenBracket: Align
+BreakBeforeBraces: Stroustrup
+ColumnLimit: 95
+DerivePointerAlignment: false
+IndentWidth: 4
+Language: Cpp
+PointerAlignment: Right
+ReflowComments: true
+SpaceBeforeParens: ControlStatements
+SpacesInParentheses: false
+TabWidth: 4
+UseTab: Never
+SortIncludes: false
diff --git a/.clangd b/.clangd
new file mode 100644
index 00000000..6c88d686
--- /dev/null
+++ b/.clangd
@@ -0,0 +1,4 @@
+Diagnostics:
+ Includes:
+ IgnoreHeader:
+ - "pythoncapi_compat.*\\.h"
diff --git a/.coveragerc b/.coveragerc
deleted file mode 100644
index 081835d3..00000000
--- a/.coveragerc
+++ /dev/null
@@ -1,12 +0,0 @@
-[run]
-branch = True
-plugins = Cython.Coverage
-source =
- asyncpg/
- tests/
-omit =
- *.pxd
-
-[paths]
-source =
- asyncpg
diff --git a/.flake8 b/.flake8
index 9697fc96..d4e76b7a 100644
--- a/.flake8
+++ b/.flake8
@@ -1,3 +1,5 @@
[flake8]
+select = C90,E,F,W,Y0
ignore = E402,E731,W503,W504,E252
-exclude = .git,__pycache__,build,dist,.eggs,.github,.local
+exclude = .git,__pycache__,build,dist,.eggs,.github,.local,.venv*,.tox
+per-file-ignores = *.pyi: F401,F403,F405,F811,E127,E128,E203,E266,E301,E302,E305,E501,E701,E704,E741,B303,W503,W504
diff --git a/.github/release_log.py b/.github/release_log.py
index 0e3ee7f4..717cd6f6 100755
--- a/.github/release_log.py
+++ b/.github/release_log.py
@@ -45,10 +45,7 @@ def main():
print(f'* {first_line}')
print(f' (by {username} in {sha}', end='')
- if issue_num:
- print(f' for #{issue_num})')
- else:
- print(')')
+ print(')')
print()
diff --git a/.github/workflows/install-krb5.sh b/.github/workflows/install-krb5.sh
new file mode 100755
index 00000000..bdb5744d
--- /dev/null
+++ b/.github/workflows/install-krb5.sh
@@ -0,0 +1,42 @@
+#!/bin/bash
+
+set -Eexuo pipefail
+shopt -s nullglob
+
+if [[ $OSTYPE == linux* ]]; then
+ if [ "$(id -u)" = "0" ]; then
+ SUDO=
+ else
+ SUDO=sudo
+ fi
+
+ if [ -e /etc/os-release ]; then
+ source /etc/os-release
+ elif [ -e /etc/centos-release ]; then
+ ID="centos"
+ VERSION_ID=$(cat /etc/centos-release | cut -f3 -d' ' | cut -f1 -d.)
+ else
+ echo "install-krb5.sh: cannot determine which Linux distro this is" >&2
+ exit 1
+ fi
+
+ if [ "${ID}" = "debian" -o "${ID}" = "ubuntu" ]; then
+ export DEBIAN_FRONTEND=noninteractive
+
+ $SUDO apt-get update
+ $SUDO apt-get install -y --no-install-recommends \
+ libkrb5-dev krb5-user krb5-kdc krb5-admin-server
+ elif [ "${ID}" = "almalinux" ]; then
+ $SUDO dnf install -y krb5-server krb5-workstation krb5-libs krb5-devel
+ elif [ "${ID}" = "centos" ]; then
+ $SUDO yum install -y krb5-server krb5-workstation krb5-libs krb5-devel
+ elif [ "${ID}" = "alpine" ]; then
+ $SUDO apk add krb5 krb5-server krb5-dev
+ else
+ echo "install-krb5.sh: Unsupported linux distro: ${distro}" >&2
+ exit 1
+ fi
+else
+ echo "install-krb5.sh: unsupported OS: ${OSTYPE}" >&2
+ exit 1
+fi
diff --git a/.github/workflows/install-postgres.sh b/.github/workflows/install-postgres.sh
index 70d42f60..733c7033 100755
--- a/.github/workflows/install-postgres.sh
+++ b/.github/workflows/install-postgres.sh
@@ -3,42 +3,60 @@
set -Eexuo pipefail
shopt -s nullglob
-PGVERSION=${PGVERSION:-12}
+if [[ $OSTYPE == linux* ]]; then
+ PGVERSION=${PGVERSION:-12}
-if [ -e /etc/os-release ]; then
- source /etc/os-release
-elif [ -e /etc/centos-release ]; then
- ID="centos"
- VERSION_ID=$(cat /etc/centos-release | cut -f3 -d' ' | cut -f1 -d.)
-else
- echo "install-postgres.sh: cannot determine which Linux distro this is" >&2
- exit 1
-fi
+ if [ -e /etc/os-release ]; then
+ source /etc/os-release
+ elif [ -e /etc/centos-release ]; then
+ ID="centos"
+ VERSION_ID=$(cat /etc/centos-release | cut -f3 -d' ' | cut -f1 -d.)
+ else
+ echo "install-postgres.sh: cannot determine which Linux distro this is" >&2
+ exit 1
+ fi
-if [ "${ID}" = "debian" -o "${ID}" = "ubuntu" ]; then
- export DEBIAN_FRONTEND=noninteractive
-
- apt-get install -y --no-install-recommends curl gnupg ca-certificates
- curl https://www.postgresql.org/media/keys/ACCC4CF8.asc | apt-key add -
- mkdir -p /etc/apt/sources.list.d/
- echo "deb https://apt.postgresql.org/pub/repos/apt/ ${VERSION_CODENAME}-pgdg main" \
- >> /etc/apt/sources.list.d/pgdg.list
- apt-get update
- apt-get install -y --no-install-recommends \
- "postgresql-${PGVERSION}" \
- "postgresql-contrib-${PGVERSION}"
-elif [ "${ID}" = "centos" ]; then
- el="EL-${VERSION_ID}-$(arch)"
- baseurl="https://download.postgresql.org/pub/repos/yum/reporpms"
- yum install -y "${baseurl}/${el}/pgdg-redhat-repo-latest.noarch.rpm"
- if [ ${VERSION_ID} -ge 8 ]; then
- dnf -qy module disable postgresql
+ if [ "${ID}" = "debian" -o "${ID}" = "ubuntu" ]; then
+ export DEBIAN_FRONTEND=noninteractive
+
+ apt-get install -y --no-install-recommends curl gnupg ca-certificates
+ curl https://www.postgresql.org/media/keys/ACCC4CF8.asc | apt-key add -
+ mkdir -p /etc/apt/sources.list.d/
+ echo "deb https://apt.postgresql.org/pub/repos/apt/ ${VERSION_CODENAME}-pgdg main" \
+ >> /etc/apt/sources.list.d/pgdg.list
+ apt-get update
+ apt-get install -y --no-install-recommends \
+ "postgresql-${PGVERSION}" \
+ "postgresql-contrib-${PGVERSION}"
+ elif [ "${ID}" = "almalinux" ]; then
+ yum install -y \
+ "postgresql-server" \
+ "postgresql-devel" \
+ "postgresql-contrib"
+ elif [ "${ID}" = "centos" ]; then
+ el="EL-${VERSION_ID%.*}-$(arch)"
+ baseurl="https://download.postgresql.org/pub/repos/yum/reporpms"
+ yum install -y "${baseurl}/${el}/pgdg-redhat-repo-latest.noarch.rpm"
+ if [ ${VERSION_ID%.*} -ge 8 ]; then
+ dnf -qy module disable postgresql
+ fi
+ yum install -y \
+ "postgresql${PGVERSION}-server" \
+ "postgresql${PGVERSION}-contrib"
+ ln -s "/usr/pgsql-${PGVERSION}/bin/pg_config" "/usr/local/bin/pg_config"
+ elif [ "${ID}" = "alpine" ]; then
+ apk add shadow postgresql postgresql-dev postgresql-contrib
+ else
+ echo "install-postgres.sh: unsupported Linux distro: ${distro}" >&2
+ exit 1
fi
- yum install -y \
- "postgresql${PGVERSION}-server" \
- "postgresql${PGVERSION}-contrib"
- ln -s "/usr/pgsql-${PGVERSION}/bin/pg_config" "/usr/local/bin/pg_config"
+
+ useradd -m -s /bin/bash apgtest
+
+elif [[ $OSTYPE == darwin* ]]; then
+ brew install postgresql
+
else
- echo "install-postgres.sh: Unsupported distro: ${distro}" >&2
+ echo "install-postgres.sh: unsupported OS: ${OSTYPE}" >&2
exit 1
fi
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index e388e7bb..353ed824 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -22,7 +22,7 @@ jobs:
github_token: ${{ secrets.RELEASE_BOT_GITHUB_TOKEN }}
version_file: asyncpg/_version.py
version_line_pattern: |
- __version__\s*=\s*(?:['"])([[:PEP440:]])(?:['"])
+ __version__(?:\s*:\s*typing\.Final)?\s*=\s*(?:['"])([[:PEP440:]])(?:['"])
- name: Stop if not approved
if: steps.checkver.outputs.approved != 'true'
@@ -37,10 +37,10 @@ jobs:
mkdir -p dist/
echo "${VERSION}" > dist/VERSION
- - uses: actions/upload-artifact@v2
+ - uses: actions/upload-artifact@v4
with:
- name: dist
- path: dist/
+ name: dist-version
+ path: dist/VERSION
build-sdist:
needs: validate-release-request
@@ -50,32 +50,60 @@ jobs:
PIP_DISABLE_PIP_VERSION_CHECK: 1
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v5
with:
fetch-depth: 50
submodules: true
+ persist-credentials: false
- name: Set up Python
- uses: actions/setup-python@v2
+ uses: actions/setup-python@v6
+ with:
+ python-version: "3.x"
- name: Build source distribution
run: |
pip install -U setuptools wheel pip
python setup.py sdist
- - uses: actions/upload-artifact@v2
+ - uses: actions/upload-artifact@v4
with:
- name: dist
+ name: dist-sdist
path: dist/*.tar.*
- build-wheels:
+ build-wheels-matrix:
needs: validate-release-request
+ runs-on: ubuntu-latest
+ outputs:
+ include: ${{ steps.set-matrix.outputs.include }}
+ steps:
+ - uses: actions/checkout@v5
+ with:
+ persist-credentials: false
+ - uses: actions/setup-python@v6
+ with:
+ python-version: "3.x"
+ - run: pip install cibuildwheel==3.3.0
+ - id: set-matrix
+ run: |
+ MATRIX_INCLUDE=$(
+ {
+ cibuildwheel --print-build-identifiers --platform linux --archs x86_64,aarch64 | grep cp | jq -nRc '{"only": inputs, "os": "ubuntu-latest"}' \
+ && cibuildwheel --print-build-identifiers --platform macos --archs x86_64,arm64 | grep cp | jq -nRc '{"only": inputs, "os": "macos-latest"}' \
+ && cibuildwheel --print-build-identifiers --platform windows --archs x86,AMD64 | grep cp | jq -nRc '{"only": inputs, "os": "windows-latest"}'
+ } | jq -sc
+ )
+ echo "include=$MATRIX_INCLUDE" >> $GITHUB_OUTPUT
+
+ build-wheels:
+ needs: build-wheels-matrix
runs-on: ${{ matrix.os }}
+ name: Build ${{ matrix.only }}
+
strategy:
+ fail-fast: false
matrix:
- os: [ubuntu-latest, macos-latest, windows-latest]
- cibw_python: ["cp37-*", "cp38-*", "cp39-*", "cp310-*"]
- cibw_arch: ["auto64"]
+ include: ${{ fromJson(needs.build-wheels-matrix.outputs.include) }}
defaults:
run:
@@ -85,37 +113,39 @@ jobs:
PIP_DISABLE_PIP_VERSION_CHECK: 1
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v5
with:
fetch-depth: 50
submodules: true
+ persist-credentials: false
- - uses: pypa/cibuildwheel@v2.1.1
+ - name: Set up QEMU
+ if: runner.os == 'Linux'
+ uses: docker/setup-qemu-action@29109295f81e9208d7d86ff1c6c12d2833863392 # v3.6.0
+
+ - uses: pypa/cibuildwheel@63fd63b352a9a8bdcc24791c9dbee952ee9a8abc # v3.3.0
+ with:
+ only: ${{ matrix.only }}
env:
CIBW_BUILD_VERBOSITY: 1
- CIBW_BUILD: ${{ matrix.cibw_python }}
- CIBW_ARCHS: ${{ matrix.cibw_arch }}
- CIBW_BEFORE_ALL_LINUX: >
- yum -y install libffi-devel
- && env PGVERSION=12 .github/workflows/install-postgres.sh
- && useradd -m -s /bin/bash apgtest
- CIBW_TEST_EXTRAS: "test"
- CIBW_TEST_COMMAND: >
- python {project}/tests/__init__.py
- CIBW_TEST_COMMAND_WINDOWS: >
- python {project}\tests\__init__.py
- CIBW_TEST_COMMAND_LINUX: >
- PY=`which python`
- && chmod -R go+rX "$(dirname $(dirname $(dirname $PY)))"
- && su -p -l apgtest -c "$PY {project}/tests/__init__.py"
-
- - uses: actions/upload-artifact@v2
+
+ - uses: actions/upload-artifact@v4
with:
- name: dist
+ name: dist-wheels-${{ matrix.only }}
path: wheelhouse/*.whl
+ merge-artifacts:
+ runs-on: ubuntu-latest
+ needs: [build-sdist, build-wheels]
+ steps:
+ - name: Merge Artifacts
+ uses: actions/upload-artifact/merge@v4
+ with:
+ name: dist
+ delete-merged: true
+
publish-docs:
- needs: validate-release-request
+ needs: [build-sdist, build-wheels]
runs-on: ubuntu-latest
env:
@@ -123,27 +153,30 @@ jobs:
steps:
- name: Checkout source
- uses: actions/checkout@v2
+ uses: actions/checkout@v5
with:
fetch-depth: 5
submodules: true
+ persist-credentials: false
- name: Set up Python
- uses: actions/setup-python@v2
+ uses: actions/setup-python@v6
with:
- python-version: 3.8
+ python-version: "3.x"
- name: Build docs
run: |
- pip install -e .[dev]
+ pip install --group docs
+ pip install -e .
make htmldocs
- name: Checkout gh-pages
- uses: actions/checkout@v2
+ uses: actions/checkout@v5
with:
fetch-depth: 5
ref: gh-pages
path: docs/gh-pages
+ persist-credentials: false
- name: Sync docs
run: |
@@ -164,13 +197,23 @@ jobs:
needs: [build-sdist, build-wheels, publish-docs]
runs-on: ubuntu-latest
+ environment:
+ name: pypi
+ url: https://pypi.org/p/asyncpg
+ permissions:
+ id-token: write
+ attestations: write
+ contents: write
+ deployments: write
+
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v5
with:
fetch-depth: 5
submodules: false
+ persist-credentials: false
- - uses: actions/download-artifact@v2
+ - uses: actions/download-artifact@v4
with:
name: dist
path: dist/
@@ -179,7 +222,7 @@ jobs:
id: relver
run: |
set -e
- echo ::set-output name=version::$(cat dist/VERSION)
+ echo "version=$(cat dist/VERSION)" >> $GITHUB_OUTPUT
rm dist/VERSION
- name: Merge and tag the PR
@@ -205,9 +248,6 @@ jobs:
ls -al dist/
- name: Upload to PyPI
- uses: pypa/gh-action-pypi-publish@master
+ uses: pypa/gh-action-pypi-publish@release/v1
with:
- user: __token__
- password: ${{ secrets.PYPI_TOKEN }}
- # password: ${{ secrets.TEST_PYPI_TOKEN }}
- # repository_url: https://test.pypi.org/legacy/
+ attestations: true
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index defe9d7a..77e63738 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -17,19 +17,18 @@ jobs:
# job.
strategy:
matrix:
- python-version: [3.6, 3.7, 3.8, 3.9, 3.10.0-rc.1]
+ python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14", "3.14t"]
os: [ubuntu-latest, macos-latest, windows-latest]
loop: [asyncio, uvloop]
exclude:
- # uvloop does not support Python 3.6
- - loop: uvloop
- python-version: 3.6
# uvloop does not support windows
- loop: uvloop
os: windows-latest
runs-on: ${{ matrix.os }}
+ permissions: {}
+
defaults:
run:
shell: bash
@@ -38,10 +37,11 @@ jobs:
PIP_DISABLE_PIP_VERSION_CHECK: 1
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v5
with:
fetch-depth: 50
submodules: true
+ persist-credentials: false
- name: Check if release PR.
uses: edgedb/action-release/validate-pr@master
@@ -51,46 +51,58 @@ jobs:
missing_version_ok: yes
version_file: asyncpg/_version.py
version_line_pattern: |
- __version__\s*=\s*(?:['"])([[:PEP440:]])(?:['"])
+ __version__(?:\s*:\s*typing\.Final)?\s*=\s*(?:['"])([[:PEP440:]])(?:['"])
+
+ - name: Setup PostgreSQL
+ if: "!steps.release.outputs.is_release && matrix.os == 'macos-latest'"
+ run: |
+ POSTGRES_FORMULA="postgresql@18"
+ brew install "$POSTGRES_FORMULA"
+ echo "$(brew --prefix "$POSTGRES_FORMULA")/bin" >> $GITHUB_PATH
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- if: steps.release.outputs.version == 0
+ uses: actions/setup-python@v6
+ if: "!steps.release.outputs.is_release"
with:
python-version: ${{ matrix.python-version }}
- name: Install Python Deps
- if: steps.release.outputs.version == 0
+ if: "!steps.release.outputs.is_release"
run: |
+ [ "$RUNNER_OS" = "Linux" ] && .github/workflows/install-krb5.sh
python -m pip install -U pip setuptools wheel
- python -m pip install -e .[test]
+ python -m pip install --group test
+ python -m pip install -e .
- name: Test
- if: steps.release.outputs.version == 0
+ if: "!steps.release.outputs.is_release"
env:
LOOP_IMPL: ${{ matrix.loop }}
run: |
if [ "${LOOP_IMPL}" = "uvloop" ]; then
- env USE_UVLOOP=1 python setup.py test
+ env USE_UVLOOP=1 python -m unittest -v tests.suite
else
- python setup.py test
+ python -m unittest -v tests.suite
fi
test-postgres:
strategy:
matrix:
- postgres-version: [9.5, 9.6, 10, 11, 12, 13]
+ postgres-version: ["9.5", "9.6", "10", "11", "12", "13", "14", "15", "16", "17", "18"]
runs-on: ubuntu-latest
+ permissions: {}
+
env:
PIP_DISABLE_PIP_VERSION_CHECK: 1
steps:
- - uses: actions/checkout@v2
+ - uses: actions/checkout@v5
with:
fetch-depth: 50
submodules: true
+ persist-credentials: false
- name: Check if release PR.
uses: edgedb/action-release/validate-pr@master
@@ -100,10 +112,10 @@ jobs:
missing_version_ok: yes
version_file: asyncpg/_version.py
version_line_pattern: |
- __version__\s*=\s*(?:['"])([[:PEP440:]])(?:['"])
+ __version__(?:\s*:\s*typing\.Final)?\s*=\s*(?:['"])([[:PEP440:]])(?:['"])
- name: Set up PostgreSQL
- if: steps.release.outputs.version == 0
+ if: "!steps.release.outputs.is_release"
env:
PGVERSION: ${{ matrix.postgres-version }}
DISTRO_NAME: focal
@@ -114,23 +126,25 @@ jobs:
>> "${GITHUB_ENV}"
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v2
- if: steps.release.outputs.version == 0
+ uses: actions/setup-python@v6
+ if: "!steps.release.outputs.is_release"
with:
- python-version: ${{ matrix.python-version }}
+ python-version: "3.x"
- name: Install Python Deps
- if: steps.release.outputs.version == 0
+ if: "!steps.release.outputs.is_release"
run: |
- python -m pip install -U pip setuptools
- pip install -e .[test]
+ [ "$RUNNER_OS" = "Linux" ] && .github/workflows/install-krb5.sh
+ python -m pip install -U pip setuptools wheel
+ python -m pip install --group test
+ python -m pip install -e .
- name: Test
- if: steps.release.outputs.version == 0
+ if: "!steps.release.outputs.is_release"
env:
PGVERSION: ${{ matrix.postgres-version }}
run: |
- python setup.py test
+ python -m unittest -v tests.suite
# This job exists solely to act as the test job aggregate to be
# targeted by branch policies.
@@ -138,6 +152,7 @@ jobs:
name: "Regression Tests"
needs: [test-platforms, test-postgres]
runs-on: ubuntu-latest
+ permissions: {}
steps:
- run: echo OK
diff --git a/.gitignore b/.gitignore
index 21286094..ec9c96ac 100644
--- a/.gitignore
+++ b/.gitignore
@@ -33,4 +33,8 @@ docs/_build
/.pytest_cache/
/.eggs
/.vscode
+/.zed
/.mypy_cache
+/.venv*
+/.tox
+/compile_commands.json
diff --git a/MANIFEST.in b/MANIFEST.in
index 2389f6fa..3eac0565 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1,5 +1,5 @@
recursive-include docs *.py *.rst Makefile *.css
recursive-include examples *.py
recursive-include tests *.py *.pem
-recursive-include asyncpg *.pyx *.pxd *.pxi *.py *.c *.h
+recursive-include asyncpg *.pyx *.pxd *.pxi *.py *.pyi *.c *.h
include LICENSE README.rst Makefile performance.png .flake8
diff --git a/Makefile b/Makefile
index 9ad5d2e7..67417a3f 100644
--- a/Makefile
+++ b/Makefile
@@ -12,6 +12,7 @@ clean:
rm -fr dist/ doc/_build/
rm -fr asyncpg/pgproto/*.c asyncpg/pgproto/*.html
rm -fr asyncpg/pgproto/codecs/*.html
+ rm -fr asyncpg/pgproto/*.so
rm -fr asyncpg/protocol/*.c asyncpg/protocol/*.html
rm -fr asyncpg/protocol/*.so build *.egg-info
rm -fr asyncpg/protocol/codecs/*.html
@@ -19,17 +20,16 @@ clean:
compile:
- $(PYTHON) setup.py build_ext --inplace --cython-always
+ env ASYNCPG_BUILD_CYTHON_ALWAYS=1 $(PYTHON) -m pip install -e .
debug:
- ASYNCPG_DEBUG=1 $(PYTHON) setup.py build_ext --inplace
-
+ env ASYNCPG_DEBUG=1 $(PYTHON) -m pip install -e .
test:
- PYTHONASYNCIODEBUG=1 $(PYTHON) setup.py test
- $(PYTHON) setup.py test
- USE_UVLOOP=1 $(PYTHON) setup.py test
+ PYTHONASYNCIODEBUG=1 $(PYTHON) -m unittest -v tests.suite
+ $(PYTHON) -m unittest -v tests.suite
+ USE_UVLOOP=1 $(PYTHON) -m unittest -v tests.suite
testinstalled:
@@ -37,9 +37,9 @@ testinstalled:
quicktest:
- $(PYTHON) setup.py test
+ $(PYTHON) -m unittest -v tests.suite
htmldocs:
- $(PYTHON) setup.py build_ext --inplace
+ $(PYTHON) -m pip install -e .[docs]
$(MAKE) -C docs html
diff --git a/README.rst b/README.rst
index 2f5da7a4..1a37296d 100644
--- a/README.rst
+++ b/README.rst
@@ -13,9 +13,10 @@ of PostgreSQL server binary protocol for use with Python's ``asyncio``
framework. You can read more about asyncpg in an introductory
`blog post `_.
-asyncpg requires Python 3.6 or later and is supported for PostgreSQL
-versions 9.5 to 13. Older PostgreSQL versions or other databases implementing
-the PostgreSQL protocol *may* work, but are not being actively tested.
+asyncpg requires Python 3.9 or later and is supported for PostgreSQL
+versions 9.5 to 18. Other PostgreSQL versions or other databases
+implementing the PostgreSQL protocol *may* work, but are not being
+actively tested.
Documentation
@@ -28,15 +29,14 @@ The project documentation can be found
Performance
-----------
-In our testing asyncpg is, on average, **3x** faster than psycopg2
-(and its asyncio variant -- aiopg).
+In our testing asyncpg is, on average, **5x** faster than psycopg3.
-.. image:: https://raw.githubusercontent.com/MagicStack/asyncpg/master/performance.png
- :target: https://gistpreview.github.io/?b8eac294ac85da177ff82f784ff2cb60
+.. image:: https://raw.githubusercontent.com/MagicStack/asyncpg/master/performance.png?fddca40ab0
+ :target: https://gistpreview.github.io/?0ed296e93523831ea0918d42dd1258c2
The above results are a geometric mean of benchmarks obtained with PostgreSQL
`client driver benchmarking toolbench `_
-in November 2020 (click on the chart to see full details).
+in June 2023 (click on the chart to see full details).
Features
@@ -59,11 +59,18 @@ This enables asyncpg to have easy-to-use support for:
Installation
------------
-asyncpg is available on PyPI and has no dependencies.
-Use pip to install::
+asyncpg is available on PyPI. When not using GSSAPI/SSPI authentication it
+has no dependencies. Use pip to install::
$ pip install asyncpg
+If you need GSSAPI/SSPI authentication, use::
+
+ $ pip install 'asyncpg[gssauth]'
+
+For more details, please `see the documentation
+`_.
+
Basic Usage
-----------
@@ -82,8 +89,7 @@ Basic Usage
)
await conn.close()
- loop = asyncio.get_event_loop()
- loop.run_until_complete(run())
+ asyncio.run(run())
License
diff --git a/asyncpg/__init__.py b/asyncpg/__init__.py
index 01af7904..e8811a9d 100644
--- a/asyncpg/__init__.py
+++ b/asyncpg/__init__.py
@@ -4,6 +4,7 @@
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
+from __future__ import annotations
from .connection import connect, Connection # NOQA
from .exceptions import * # NOQA
@@ -14,8 +15,10 @@
from ._version import __version__ # NOQA
+from . import exceptions
-__all__ = (
- ('connect', 'create_pool', 'Pool', 'Record', 'Connection')
- + exceptions.__all__ # NOQA
+
+__all__: tuple[str, ...] = (
+ 'connect', 'create_pool', 'Pool', 'Record', 'Connection'
)
+__all__ += exceptions.__all__ # NOQA
diff --git a/asyncpg/_asyncio_compat.py b/asyncpg/_asyncio_compat.py
new file mode 100644
index 00000000..a211d0a9
--- /dev/null
+++ b/asyncpg/_asyncio_compat.py
@@ -0,0 +1,94 @@
+# Backports from Python/Lib/asyncio for older Pythons
+#
+# Copyright (c) 2001-2023 Python Software Foundation; All Rights Reserved
+#
+# SPDX-License-Identifier: PSF-2.0
+
+from __future__ import annotations
+
+import asyncio
+import functools
+import sys
+import typing
+
+if typing.TYPE_CHECKING:
+ from . import compat
+
+if sys.version_info < (3, 11):
+ from async_timeout import timeout as timeout_ctx
+else:
+ from asyncio import timeout as timeout_ctx
+
+_T = typing.TypeVar('_T')
+
+
+async def wait_for(fut: compat.Awaitable[_T], timeout: float | None) -> _T:
+ """Wait for the single Future or coroutine to complete, with timeout.
+
+ Coroutine will be wrapped in Task.
+
+ Returns result of the Future or coroutine. When a timeout occurs,
+ it cancels the task and raises TimeoutError. To avoid the task
+ cancellation, wrap it in shield().
+
+ If the wait is cancelled, the task is also cancelled.
+
+ If the task supresses the cancellation and returns a value instead,
+ that value is returned.
+
+ This function is a coroutine.
+ """
+ # The special case for timeout <= 0 is for the following case:
+ #
+ # async def test_waitfor():
+ # func_started = False
+ #
+ # async def func():
+ # nonlocal func_started
+ # func_started = True
+ #
+ # try:
+ # await asyncio.wait_for(func(), 0)
+ # except asyncio.TimeoutError:
+ # assert not func_started
+ # else:
+ # assert False
+ #
+ # asyncio.run(test_waitfor())
+
+ if timeout is not None and timeout <= 0:
+ fut = asyncio.ensure_future(fut)
+
+ if fut.done():
+ return fut.result()
+
+ await _cancel_and_wait(fut)
+ try:
+ return fut.result()
+ except asyncio.CancelledError as exc:
+ raise TimeoutError from exc
+
+ async with timeout_ctx(timeout):
+ return await fut
+
+
+async def _cancel_and_wait(fut: asyncio.Future[_T]) -> None:
+ """Cancel the *fut* future or task and wait until it completes."""
+
+ loop = asyncio.get_running_loop()
+ waiter = loop.create_future()
+ cb = functools.partial(_release_waiter, waiter)
+ fut.add_done_callback(cb)
+
+ try:
+ fut.cancel()
+ # We cannot wait on *fut* directly to make
+ # sure _cancel_and_wait itself is reliably cancellable.
+ await waiter
+ finally:
+ fut.remove_done_callback(cb)
+
+
+def _release_waiter(waiter: asyncio.Future[typing.Any], *args: object) -> None:
+ if not waiter.done():
+ waiter.set_result(None)
diff --git a/asyncpg/_testbase/__init__.py b/asyncpg/_testbase/__init__.py
index 9944b20f..95775e11 100644
--- a/asyncpg/_testbase/__init__.py
+++ b/asyncpg/_testbase/__init__.py
@@ -117,10 +117,22 @@ def setUp(self):
self.__unhandled_exceptions = []
def tearDown(self):
- if self.__unhandled_exceptions:
+ excs = []
+ for exc in self.__unhandled_exceptions:
+ if isinstance(exc, ConnectionResetError):
+ texc = traceback.TracebackException.from_exception(
+ exc, lookup_lines=False)
+ if texc.stack[-1].name == "_call_connection_lost":
+ # On Windows calling socket.shutdown may raise
+ # ConnectionResetError, which happens in the
+ # finally block of _call_connection_lost.
+ continue
+ excs.append(exc)
+
+ if excs:
formatted = []
- for i, context in enumerate(self.__unhandled_exceptions):
+ for i, context in enumerate(excs):
formatted.append(self._format_loop_exception(context, i + 1))
self.fail(
@@ -214,13 +226,6 @@ def _init_cluster(ClusterCls, cluster_kwargs, initdb_options=None):
return cluster
-def _start_cluster(ClusterCls, cluster_kwargs, server_settings,
- initdb_options=None):
- cluster = _init_cluster(ClusterCls, cluster_kwargs, initdb_options)
- cluster.start(port='dynamic', server_settings=server_settings)
- return cluster
-
-
def _get_initdb_options(initdb_options=None):
if not initdb_options:
initdb_options = {}
@@ -244,8 +249,12 @@ def _init_default_cluster(initdb_options=None):
_default_cluster = pg_cluster.RunningCluster()
else:
_default_cluster = _init_cluster(
- pg_cluster.TempCluster, cluster_kwargs={},
- initdb_options=_get_initdb_options(initdb_options))
+ pg_cluster.TempCluster,
+ cluster_kwargs={
+ "data_dir_suffix": ".apgtest",
+ },
+ initdb_options=_get_initdb_options(initdb_options),
+ )
return _default_cluster
@@ -262,6 +271,7 @@ def create_pool(dsn=None, *,
max_size=10,
max_queries=50000,
max_inactive_connection_lifetime=60.0,
+ connect=None,
setup=None,
init=None,
loop=None,
@@ -271,12 +281,18 @@ def create_pool(dsn=None, *,
**connect_kwargs):
return pool_class(
dsn,
- min_size=min_size, max_size=max_size,
- max_queries=max_queries, loop=loop, setup=setup, init=init,
+ min_size=min_size,
+ max_size=max_size,
+ max_queries=max_queries,
+ loop=loop,
+ connect=connect,
+ setup=setup,
+ init=init,
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
connection_class=connection_class,
record_class=record_class,
- **connect_kwargs)
+ **connect_kwargs,
+ )
class ClusterTestCase(TestCase):
@@ -435,3 +451,93 @@ def tearDown(self):
self.con = None
finally:
super().tearDown()
+
+
+class HotStandbyTestCase(ClusterTestCase):
+
+ @classmethod
+ def setup_cluster(cls):
+ cls.master_cluster = cls.new_cluster(pg_cluster.TempCluster)
+ cls.start_cluster(
+ cls.master_cluster,
+ server_settings={
+ 'max_wal_senders': 10,
+ 'wal_level': 'hot_standby'
+ }
+ )
+
+ con = None
+
+ try:
+ con = cls.loop.run_until_complete(
+ cls.master_cluster.connect(
+ database='postgres', user='postgres', loop=cls.loop))
+
+ cls.loop.run_until_complete(
+ con.execute('''
+ CREATE ROLE replication WITH LOGIN REPLICATION
+ '''))
+
+ cls.master_cluster.trust_local_replication_by('replication')
+
+ conn_spec = cls.master_cluster.get_connection_spec()
+
+ cls.standby_cluster = cls.new_cluster(
+ pg_cluster.HotStandbyCluster,
+ cluster_kwargs={
+ 'master': conn_spec,
+ 'replication_user': 'replication'
+ }
+ )
+ cls.start_cluster(
+ cls.standby_cluster,
+ server_settings={
+ 'hot_standby': True
+ }
+ )
+
+ finally:
+ if con is not None:
+ cls.loop.run_until_complete(con.close())
+
+ @classmethod
+ def get_cluster_connection_spec(cls, cluster, kwargs={}):
+ conn_spec = cluster.get_connection_spec()
+ if kwargs.get('dsn'):
+ conn_spec.pop('host')
+ conn_spec.update(kwargs)
+ if not os.environ.get('PGHOST') and not kwargs.get('dsn'):
+ if 'database' not in conn_spec:
+ conn_spec['database'] = 'postgres'
+ if 'user' not in conn_spec:
+ conn_spec['user'] = 'postgres'
+ return conn_spec
+
+ @classmethod
+ def get_connection_spec(cls, kwargs={}):
+ primary_spec = cls.get_cluster_connection_spec(
+ cls.master_cluster, kwargs
+ )
+ standby_spec = cls.get_cluster_connection_spec(
+ cls.standby_cluster, kwargs
+ )
+ return {
+ 'host': [primary_spec['host'], standby_spec['host']],
+ 'port': [primary_spec['port'], standby_spec['port']],
+ 'database': primary_spec['database'],
+ 'user': primary_spec['user'],
+ **kwargs
+ }
+
+ @classmethod
+ def connect_primary(cls, **kwargs):
+ conn_spec = cls.get_cluster_connection_spec(cls.master_cluster, kwargs)
+ return pg_connection.connect(**conn_spec, loop=cls.loop)
+
+ @classmethod
+ def connect_standby(cls, **kwargs):
+ conn_spec = cls.get_cluster_connection_spec(
+ cls.standby_cluster,
+ kwargs
+ )
+ return pg_connection.connect(**conn_spec, loop=cls.loop)
diff --git a/asyncpg/_testbase/fuzzer.py b/asyncpg/_testbase/fuzzer.py
index 5c0b870c..88745646 100644
--- a/asyncpg/_testbase/fuzzer.py
+++ b/asyncpg/_testbase/fuzzer.py
@@ -191,6 +191,12 @@ async def handle(self):
return_when=asyncio.FIRST_COMPLETED)
finally:
+ if self.proxy_to_backend_task is not None:
+ self.proxy_to_backend_task.cancel()
+
+ if self.proxy_from_backend_task is not None:
+ self.proxy_from_backend_task.cancel()
+
# Asyncio fails to properly remove the readers and writers
# when the task doing recv() or send() is cancelled, so
# we must remove the readers and writers manually before
diff --git a/asyncpg/_version.py b/asyncpg/_version.py
index eab825c7..738da168 100644
--- a/asyncpg/_version.py
+++ b/asyncpg/_version.py
@@ -10,4 +10,8 @@
# supported platforms, publish the packages on PyPI, merge the PR
# to the target branch, create a Git tag pointing to the commit.
-__version__ = '0.24.0.dev0'
+from __future__ import annotations
+
+import typing
+
+__version__: typing.Final = '0.32.0.dev0'
diff --git a/asyncpg/cluster.py b/asyncpg/cluster.py
index 0999e41c..606c2eae 100644
--- a/asyncpg/cluster.py
+++ b/asyncpg/cluster.py
@@ -9,9 +9,11 @@
import os
import os.path
import platform
+import random
import re
import shutil
import socket
+import string
import subprocess
import sys
import tempfile
@@ -45,6 +47,29 @@ def find_available_port():
sock.close()
+def _world_readable_mkdtemp(suffix=None, prefix=None, dir=None):
+ name = "".join(random.choices(string.ascii_lowercase, k=8))
+ if dir is None:
+ dir = tempfile.gettempdir()
+ if prefix is None:
+ prefix = tempfile.gettempprefix()
+ if suffix is None:
+ suffix = ""
+ fn = os.path.join(dir, prefix + name + suffix)
+ os.mkdir(fn, 0o755)
+ return fn
+
+
+def _mkdtemp(suffix=None, prefix=None, dir=None):
+ if _system == 'Windows' and os.environ.get("GITHUB_ACTIONS"):
+ # Due to mitigations introduced in python/cpython#118486
+ # when Python runs in a session created via an SSH connection
+ # tempfile.mkdtemp creates directories that are not accessible.
+ return _world_readable_mkdtemp(suffix, prefix, dir)
+ else:
+ return tempfile.mkdtemp(suffix, prefix, dir)
+
+
class ClusterError(Exception):
pass
@@ -122,9 +147,13 @@ def init(self, **settings):
else:
extra_args = []
+ os.makedirs(self._data_dir, exist_ok=True)
process = subprocess.run(
[self._pg_ctl, 'init', '-D', self._data_dir] + extra_args,
- stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ cwd=self._data_dir,
+ )
output = process.stdout
@@ -199,7 +228,10 @@ def start(self, wait=60, *, server_settings={}, **opts):
process = subprocess.run(
[self._pg_ctl, 'start', '-D', self._data_dir,
'-o', ' '.join(extra_args)],
- stdout=stdout, stderr=subprocess.STDOUT)
+ stdout=stdout,
+ stderr=subprocess.STDOUT,
+ cwd=self._data_dir,
+ )
if process.returncode != 0:
if process.stderr:
@@ -218,7 +250,10 @@ def start(self, wait=60, *, server_settings={}, **opts):
self._daemon_process = \
subprocess.Popen(
[self._postgres, '-D', self._data_dir, *extra_args],
- stdout=stdout, stderr=subprocess.STDOUT)
+ stdout=stdout,
+ stderr=subprocess.STDOUT,
+ cwd=self._data_dir,
+ )
self._daemon_pid = self._daemon_process.pid
@@ -232,7 +267,10 @@ def reload(self):
process = subprocess.run(
[self._pg_ctl, 'reload', '-D', self._data_dir],
- stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ cwd=self._data_dir,
+ )
stderr = process.stderr
@@ -245,7 +283,10 @@ def stop(self, wait=60):
process = subprocess.run(
[self._pg_ctl, 'stop', '-D', self._data_dir, '-t', str(wait),
'-m', 'fast'],
- stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ cwd=self._data_dir,
+ )
stderr = process.stderr
@@ -583,9 +624,9 @@ class TempCluster(Cluster):
def __init__(self, *,
data_dir_suffix=None, data_dir_prefix=None,
data_dir_parent=None, pg_config_path=None):
- self._data_dir = tempfile.mkdtemp(suffix=data_dir_suffix,
- prefix=data_dir_prefix,
- dir=data_dir_parent)
+ self._data_dir = _mkdtemp(suffix=data_dir_suffix,
+ prefix=data_dir_prefix,
+ dir=data_dir_parent)
super().__init__(self._data_dir, pg_config_path=pg_config_path)
@@ -626,7 +667,7 @@ def init(self, **settings):
'pg_basebackup init exited with status {:d}:\n{}'.format(
process.returncode, output.decode()))
- if self._pg_version <= (11, 0):
+ if self._pg_version < (12, 0):
with open(os.path.join(self._data_dir, 'recovery.conf'), 'w') as f:
f.write(textwrap.dedent("""\
standby_mode = 'on'
diff --git a/asyncpg/compat.py b/asyncpg/compat.py
index 348b8caa..57eec650 100644
--- a/asyncpg/compat.py
+++ b/asyncpg/compat.py
@@ -4,23 +4,26 @@
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
+from __future__ import annotations
-import asyncio
+import enum
import pathlib
import platform
+import typing
import sys
+if typing.TYPE_CHECKING:
+ import asyncio
-PY_37 = sys.version_info >= (3, 7)
-SYSTEM = platform.uname().system
+SYSTEM: typing.Final = platform.uname().system
-if SYSTEM == 'Windows':
+if sys.platform == 'win32':
import ctypes.wintypes
- CSIDL_APPDATA = 0x001a
+ CSIDL_APPDATA: typing.Final = 0x001a
- def get_pg_home_directory() -> pathlib.Path:
+ def get_pg_home_directory() -> pathlib.Path | None:
# We cannot simply use expanduser() as that returns the user's
# home directory, whereas Postgres stores its config in
# %AppData% on Windows.
@@ -32,19 +35,14 @@ def get_pg_home_directory() -> pathlib.Path:
return pathlib.Path(buf.value) / 'postgresql'
else:
- def get_pg_home_directory() -> pathlib.Path:
- return pathlib.Path.home()
-
-
-if PY_37:
- def current_asyncio_task(loop):
- return asyncio.current_task(loop)
-else:
- def current_asyncio_task(loop):
- return asyncio.Task.current_task(loop)
+ def get_pg_home_directory() -> pathlib.Path | None:
+ try:
+ return pathlib.Path.home()
+ except (RuntimeError, KeyError):
+ return None
-async def wait_closed(stream):
+async def wait_closed(stream: asyncio.StreamWriter) -> None:
# Not all asyncio versions have StreamWriter.wait_closed().
if hasattr(stream, 'wait_closed'):
try:
@@ -55,17 +53,36 @@ async def wait_closed(stream):
pass
-# Workaround for https://bugs.python.org/issue37658
-async def wait_for(fut, timeout):
- if timeout is None:
- return await fut
+if sys.version_info < (3, 12):
+ def markcoroutinefunction(c): # type: ignore
+ pass
+else:
+ from inspect import markcoroutinefunction # noqa: F401
- fut = asyncio.ensure_future(fut)
- try:
- return await asyncio.wait_for(fut, timeout)
- except asyncio.CancelledError:
- if fut.done():
- return fut.result()
- else:
- raise
+if sys.version_info < (3, 12):
+ from ._asyncio_compat import wait_for as wait_for # noqa: F401
+else:
+ from asyncio import wait_for as wait_for # noqa: F401
+
+
+if sys.version_info < (3, 11):
+ from ._asyncio_compat import timeout_ctx as timeout # noqa: F401
+else:
+ from asyncio import timeout as timeout # noqa: F401
+
+if sys.version_info < (3, 9):
+ from typing import ( # noqa: F401
+ Awaitable as Awaitable,
+ )
+else:
+ from collections.abc import ( # noqa: F401
+ Awaitable as Awaitable,
+ )
+
+if sys.version_info < (3, 11):
+ class StrEnum(str, enum.Enum):
+ __str__ = str.__str__
+ __repr__ = enum.Enum.__repr__
+else:
+ from enum import StrEnum as StrEnum # noqa: F401
diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py
index cd94b834..07c4fdde 100644
--- a/asyncpg/connect_utils.py
+++ b/asyncpg/connect_utils.py
@@ -4,21 +4,25 @@
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
+from __future__ import annotations
import asyncio
+import configparser
import collections
+from collections.abc import Callable
import enum
import functools
import getpass
import os
import pathlib
import platform
+import random
import re
import socket
import ssl as ssl_module
import stat
import struct
-import time
+import sys
import typing
import urllib.parse
import warnings
@@ -44,6 +48,11 @@ def parse(cls, sslmode):
return getattr(cls, sslmode.replace('-', '_'))
+class SSLNegotiation(compat.StrEnum):
+ postgres = "postgres"
+ direct = "direct"
+
+
_ConnectionParameters = collections.namedtuple(
'ConnectionParameters',
[
@@ -52,8 +61,11 @@ def parse(cls, sslmode):
'database',
'ssl',
'sslmode',
- 'connect_timeout',
+ 'ssl_negotiation',
'server_settings',
+ 'target_session_attrs',
+ 'krbsrvname',
+ 'gsslib',
])
@@ -76,6 +88,9 @@ def parse(cls, sslmode):
PGPASSFILE = '.pgpass'
+PG_SERVICEFILE = '.pg_service.conf'
+
+
def _read_password_file(passfile: pathlib.Path) \
-> typing.List[typing.Tuple[str, ...]]:
@@ -157,13 +172,15 @@ def _read_password_from_pgpass(
def _validate_port_spec(hosts, port):
- if isinstance(port, list):
+ if isinstance(port, list) and len(port) > 1:
# If there is a list of ports, its length must
# match that of the host list.
if len(port) != len(hosts):
- raise exceptions.InterfaceError(
+ raise exceptions.ClientConfigurationError(
'could not match {} port numbers to {} hosts'.format(
len(port), len(hosts)))
+ elif isinstance(port, list) and len(port) == 1:
+ port = [port[0] for _ in range(len(hosts))]
else:
port = [port for _ in range(len(hosts))]
@@ -196,11 +213,25 @@ def _parse_hostlist(hostlist, port, *, unquote=False):
port = _validate_port_spec(hostspecs, port)
for i, hostspec in enumerate(hostspecs):
- if not hostspec.startswith('/'):
- addr, _, hostspec_port = hostspec.partition(':')
- else:
+ if hostspec[0] == '/':
+ # Unix socket
addr = hostspec
hostspec_port = ''
+ elif hostspec[0] == '[':
+ # IPv6 address
+ m = re.match(r'(?:\[([^\]]+)\])(?::([0-9]+))?', hostspec)
+ if m:
+ addr = m.group(1)
+ hostspec_port = m.group(2)
+ else:
+ raise exceptions.ClientConfigurationError(
+ 'invalid IPv6 address in the connection URI: {!r}'.format(
+ hostspec
+ )
+ )
+ else:
+ # IPv4 address
+ addr, _, hostspec_port = hostspec.partition(':')
if unquote:
addr = urllib.parse.unquote(addr)
@@ -220,19 +251,71 @@ def _parse_hostlist(hostlist, port, *, unquote=False):
return hosts, port
+def _parse_tls_version(tls_version):
+ if tls_version.startswith('SSL'):
+ raise exceptions.ClientConfigurationError(
+ f"Unsupported TLS version: {tls_version}"
+ )
+ try:
+ return ssl_module.TLSVersion[tls_version.replace('.', '_')]
+ except KeyError:
+ raise exceptions.ClientConfigurationError(
+ f"No such TLS version: {tls_version}"
+ )
+
+
+def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:
+ try:
+ homedir = pathlib.Path.home()
+ except (RuntimeError, KeyError):
+ return None
+
+ return (homedir / '.postgresql' / filename).resolve()
+
+
def _parse_connect_dsn_and_args(*, dsn, host, port, user,
password, passfile, database, ssl,
- connect_timeout, server_settings):
+ service, servicefile,
+ direct_tls, server_settings,
+ target_session_attrs, krbsrvname, gsslib):
# `auth_hosts` is the version of host information for the purposes
# of reading the pgpass file.
auth_hosts = None
- sslcert = sslkey = sslrootcert = sslcrl = None
+ sslcert = sslkey = sslrootcert = sslcrl = sslpassword = None
+ ssl_min_protocol_version = ssl_max_protocol_version = None
+ sslnegotiation = None
if dsn:
parsed = urllib.parse.urlparse(dsn)
+ query = None
+ if parsed.query:
+ query = urllib.parse.parse_qs(parsed.query, strict_parsing=True)
+ for key, val in query.items():
+ if isinstance(val, list):
+ query[key] = val[-1]
+
+ if 'service' in query:
+ val = query.pop('service')
+ if not service and val:
+ service = val
+
+ connection_service_file = servicefile
+
+ if connection_service_file is None:
+ connection_service_file = os.getenv('PGSERVICEFILE')
+
+ if connection_service_file is None:
+ homedir = compat.get_pg_home_directory()
+ if homedir:
+ connection_service_file = homedir / PG_SERVICEFILE
+ else:
+ connection_service_file = None
+ else:
+ connection_service_file = pathlib.Path(connection_service_file)
+
if parsed.scheme not in {'postgresql', 'postgres'}:
- raise ValueError(
+ raise exceptions.ClientConfigurationError(
'invalid DSN: scheme is expected to be either '
'"postgresql" or "postgres", got {!r}'.format(parsed.scheme))
@@ -265,11 +348,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if password is None and dsn_password:
password = urllib.parse.unquote(dsn_password)
- if parsed.query:
- query = urllib.parse.parse_qs(parsed.query, strict_parsing=True)
- for key, val in query.items():
- if isinstance(val, list):
- query[key] = val[-1]
+ if query:
if 'port' in query:
val = query.pop('port')
@@ -312,24 +391,54 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
ssl = val
if 'sslcert' in query:
- val = query.pop('sslcert')
- if sslcert is None:
- sslcert = val
+ sslcert = query.pop('sslcert')
if 'sslkey' in query:
- val = query.pop('sslkey')
- if sslkey is None:
- sslkey = val
+ sslkey = query.pop('sslkey')
if 'sslrootcert' in query:
- val = query.pop('sslrootcert')
- if sslrootcert is None:
- sslrootcert = val
+ sslrootcert = query.pop('sslrootcert')
+
+ if 'sslnegotiation' in query:
+ sslnegotiation = query.pop('sslnegotiation')
if 'sslcrl' in query:
- val = query.pop('sslcrl')
- if sslcrl is None:
- sslcrl = val
+ sslcrl = query.pop('sslcrl')
+
+ if 'sslpassword' in query:
+ sslpassword = query.pop('sslpassword')
+
+ if 'ssl_min_protocol_version' in query:
+ ssl_min_protocol_version = query.pop(
+ 'ssl_min_protocol_version'
+ )
+
+ if 'ssl_max_protocol_version' in query:
+ ssl_max_protocol_version = query.pop(
+ 'ssl_max_protocol_version'
+ )
+
+ if 'target_session_attrs' in query:
+ dsn_target_session_attrs = query.pop(
+ 'target_session_attrs'
+ )
+ if target_session_attrs is None:
+ target_session_attrs = dsn_target_session_attrs
+
+ if 'krbsrvname' in query:
+ val = query.pop('krbsrvname')
+ if krbsrvname is None:
+ krbsrvname = val
+
+ if 'gsslib' in query:
+ val = query.pop('gsslib')
+ if gsslib is None:
+ gsslib = val
+
+ if 'service' in query:
+ val = query.pop('service')
+ if service is None:
+ service = val
if query:
if server_settings is None:
@@ -337,6 +446,113 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
else:
server_settings = {**query, **server_settings}
+ if connection_service_file is not None and service is not None:
+ pg_service = configparser.ConfigParser()
+ pg_service.read(connection_service_file)
+ if service in pg_service.sections():
+ service_params = pg_service[service]
+ if 'port' in service_params:
+ val = service_params.pop('port')
+ if not port and val:
+ port = [int(p) for p in val.split(',')]
+
+ if 'host' in service_params:
+ val = service_params.pop('host')
+ if not host and val:
+ host, port = _parse_hostlist(val, port)
+
+ if 'dbname' in service_params:
+ val = service_params.pop('dbname')
+ if database is None:
+ database = val
+
+ if 'database' in service_params:
+ val = service_params.pop('database')
+ if database is None:
+ database = val
+
+ if 'user' in service_params:
+ val = service_params.pop('user')
+ if user is None:
+ user = val
+
+ if 'password' in service_params:
+ val = service_params.pop('password')
+ if password is None:
+ password = val
+
+ if 'passfile' in service_params:
+ val = service_params.pop('passfile')
+ if passfile is None:
+ passfile = val
+
+ if 'sslmode' in service_params:
+ val = service_params.pop('sslmode')
+ if ssl is None:
+ ssl = val
+
+ if 'sslcert' in service_params:
+ val = service_params.pop('sslcert')
+ if sslcert is None:
+ sslcert = val
+
+ if 'sslkey' in service_params:
+ val = service_params.pop('sslkey')
+ if sslkey is None:
+ sslkey = val
+
+ if 'sslrootcert' in service_params:
+ val = service_params.pop('sslrootcert')
+ if sslrootcert is None:
+ sslrootcert = val
+
+ if 'sslnegotiation' in service_params:
+ val = service_params.pop('sslnegotiation')
+ if sslnegotiation is None:
+ sslnegotiation = val
+
+ if 'sslcrl' in service_params:
+ val = service_params.pop('sslcrl')
+ if sslcrl is None:
+ sslcrl = val
+
+ if 'sslpassword' in service_params:
+ val = service_params.pop('sslpassword')
+ if sslpassword is None:
+ sslpassword = val
+
+ if 'ssl_min_protocol_version' in service_params:
+ val = service_params.pop(
+ 'ssl_min_protocol_version'
+ )
+ if ssl_min_protocol_version is None:
+ ssl_min_protocol_version = val
+
+ if 'ssl_max_protocol_version' in service_params:
+ val = service_params.pop(
+ 'ssl_max_protocol_version'
+ )
+ if ssl_max_protocol_version is None:
+ ssl_max_protocol_version = val
+
+ if 'target_session_attrs' in service_params:
+ dsn_target_session_attrs = service_params.pop(
+ 'target_session_attrs'
+ )
+ if target_session_attrs is None:
+ target_session_attrs = dsn_target_session_attrs
+
+ if 'krbsrvname' in service_params:
+ val = service_params.pop('krbsrvname')
+ if krbsrvname is None:
+ krbsrvname = val
+
+ if 'gsslib' in service_params:
+ val = service_params.pop('gsslib')
+ if gsslib is None:
+ gsslib = val
+ if not service:
+ service = os.environ.get('PGSERVICE')
if not host:
hostspec = os.environ.get('PGHOST')
if hostspec:
@@ -351,7 +567,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
host = ['/run/postgresql', '/var/run/postgresql',
'/tmp', '/private/tmp', 'localhost']
- if not isinstance(host, list):
+ if not isinstance(host, (list, tuple)):
host = [host]
if auth_hosts is None:
@@ -390,11 +606,11 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
database = user
if user is None:
- raise exceptions.InterfaceError(
+ raise exceptions.ClientConfigurationError(
'could not determine user name to connect with')
if database is None:
- raise exceptions.InterfaceError(
+ raise exceptions.ClientConfigurationError(
'could not determine database name to connect to')
if password is None:
@@ -430,7 +646,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
have_tcp_addrs = True
if not addrs:
- raise ValueError(
+ raise exceptions.InternalClientError(
'could not determine the database address to connect to')
if ssl is None:
@@ -439,46 +655,158 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if ssl is None and have_tcp_addrs:
ssl = 'prefer'
+ if direct_tls is not None:
+ sslneg = (
+ SSLNegotiation.direct if direct_tls else SSLNegotiation.postgres
+ )
+ else:
+ if sslnegotiation is None:
+ sslnegotiation = os.environ.get("PGSSLNEGOTIATION")
+
+ if sslnegotiation is not None:
+ try:
+ sslneg = SSLNegotiation(sslnegotiation)
+ except ValueError:
+ modes = ', '.join(
+ m.name.replace('_', '-')
+ for m in SSLNegotiation
+ )
+ raise exceptions.ClientConfigurationError(
+ f'`sslnegotiation` parameter must be one of: {modes}'
+ ) from None
+ else:
+ sslneg = SSLNegotiation.postgres
+
if isinstance(ssl, (str, SSLMode)):
try:
sslmode = SSLMode.parse(ssl)
except AttributeError:
modes = ', '.join(m.name.replace('_', '-') for m in SSLMode)
- raise exceptions.InterfaceError(
- '`sslmode` parameter must be one of: {}'.format(modes))
+ raise exceptions.ClientConfigurationError(
+ '`sslmode` parameter must be one of: {}'.format(modes)
+ ) from None
# docs at https://www.postgresql.org/docs/10/static/libpq-connect.html
if sslmode < SSLMode.allow:
ssl = False
else:
- ssl = ssl_module.create_default_context(
- ssl_module.Purpose.SERVER_AUTH)
+ ssl = ssl_module.SSLContext(ssl_module.PROTOCOL_TLS_CLIENT)
ssl.check_hostname = sslmode >= SSLMode.verify_full
- ssl.verify_mode = ssl_module.CERT_REQUIRED
- if sslmode <= SSLMode.require:
+ if sslmode < SSLMode.require:
ssl.verify_mode = ssl_module.CERT_NONE
+ else:
+ if sslrootcert is None:
+ sslrootcert = os.getenv('PGSSLROOTCERT')
+ if sslrootcert:
+ ssl.load_verify_locations(cafile=sslrootcert)
+ ssl.verify_mode = ssl_module.CERT_REQUIRED
+ else:
+ try:
+ sslrootcert = _dot_postgresql_path('root.crt')
+ if sslrootcert is not None:
+ ssl.load_verify_locations(cafile=sslrootcert)
+ else:
+ raise exceptions.ClientConfigurationError(
+ 'cannot determine location of user '
+ 'PostgreSQL configuration directory'
+ )
+ except (
+ exceptions.ClientConfigurationError,
+ FileNotFoundError,
+ NotADirectoryError,
+ ):
+ if sslmode > SSLMode.require:
+ if sslrootcert is None:
+ sslrootcert = '~/.postgresql/root.crt'
+ detail = (
+ 'Could not determine location of user '
+ 'home directory (HOME is either unset, '
+ 'inaccessible, or does not point to a '
+ 'valid directory)'
+ )
+ else:
+ detail = None
+ raise exceptions.ClientConfigurationError(
+ f'root certificate file "{sslrootcert}" does '
+ f'not exist or cannot be accessed',
+ hint='Provide the certificate file directly '
+ f'or make sure "{sslrootcert}" '
+ 'exists and is readable.',
+ detail=detail,
+ )
+ elif sslmode == SSLMode.require:
+ ssl.verify_mode = ssl_module.CERT_NONE
+ else:
+ assert False, 'unreachable'
+ else:
+ ssl.verify_mode = ssl_module.CERT_REQUIRED
- if sslcert is None:
- sslcert = os.getenv('PGSSLCERT')
+ if sslcrl is None:
+ sslcrl = os.getenv('PGSSLCRL')
+ if sslcrl:
+ ssl.load_verify_locations(cafile=sslcrl)
+ ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN
+ else:
+ sslcrl = _dot_postgresql_path('root.crl')
+ if sslcrl is not None:
+ try:
+ ssl.load_verify_locations(cafile=sslcrl)
+ except (
+ FileNotFoundError,
+ NotADirectoryError,
+ ):
+ pass
+ else:
+ ssl.verify_flags |= \
+ ssl_module.VERIFY_CRL_CHECK_CHAIN
if sslkey is None:
sslkey = os.getenv('PGSSLKEY')
-
- if sslrootcert is None:
- sslrootcert = os.getenv('PGSSLROOTCERT')
-
- if sslcrl is None:
- sslcrl = os.getenv('PGSSLCRL')
-
+ if not sslkey:
+ sslkey = _dot_postgresql_path('postgresql.key')
+ if sslkey is not None and not sslkey.exists():
+ sslkey = None
+ if not sslpassword:
+ sslpassword = ''
+ if sslcert is None:
+ sslcert = os.getenv('PGSSLCERT')
if sslcert:
- ssl.load_cert_chain(sslcert, keyfile=sslkey)
-
- if sslrootcert:
- ssl.load_verify_locations(cafile=sslrootcert)
+ ssl.load_cert_chain(
+ sslcert, keyfile=sslkey, password=lambda: sslpassword
+ )
+ else:
+ sslcert = _dot_postgresql_path('postgresql.crt')
+ if sslcert is not None:
+ try:
+ ssl.load_cert_chain(
+ sslcert,
+ keyfile=sslkey,
+ password=lambda: sslpassword
+ )
+ except (FileNotFoundError, NotADirectoryError):
+ pass
+
+ # OpenSSL 1.1.1 keylog file, copied from create_default_context()
+ if hasattr(ssl, 'keylog_filename'):
+ keylogfile = os.environ.get('SSLKEYLOGFILE')
+ if keylogfile and not sys.flags.ignore_environment:
+ ssl.keylog_filename = keylogfile
+
+ if ssl_min_protocol_version is None:
+ ssl_min_protocol_version = os.getenv('PGSSLMINPROTOCOLVERSION')
+ if ssl_min_protocol_version:
+ ssl.minimum_version = _parse_tls_version(
+ ssl_min_protocol_version
+ )
+ else:
+ ssl.minimum_version = _parse_tls_version('TLSv1.2')
- if sslcrl:
- ssl.load_verify_locations(cafile=sslcrl)
- ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN
+ if ssl_max_protocol_version is None:
+ ssl_max_protocol_version = os.getenv('PGSSLMAXPROTOCOLVERSION')
+ if ssl_max_protocol_version:
+ ssl.maximum_version = _parse_tls_version(
+ ssl_max_protocol_version
+ )
elif ssl is True:
ssl = ssl_module.create_default_context()
@@ -490,25 +818,55 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
not isinstance(server_settings, dict) or
not all(isinstance(k, str) for k in server_settings) or
not all(isinstance(v, str) for v in server_settings.values())):
- raise ValueError(
+ raise exceptions.ClientConfigurationError(
'server_settings is expected to be None or '
'a Dict[str, str]')
+ if target_session_attrs is None:
+ target_session_attrs = os.getenv(
+ "PGTARGETSESSIONATTRS", SessionAttribute.any
+ )
+ try:
+ target_session_attrs = SessionAttribute(target_session_attrs)
+ except ValueError:
+ raise exceptions.ClientConfigurationError(
+ "target_session_attrs is expected to be one of "
+ "{!r}"
+ ", got {!r}".format(
+ SessionAttribute.__members__.values, target_session_attrs
+ )
+ ) from None
+
+ if krbsrvname is None:
+ krbsrvname = os.getenv('PGKRBSRVNAME')
+
+ if gsslib is None:
+ gsslib = os.getenv('PGGSSLIB')
+ if gsslib is None:
+ gsslib = 'sspi' if _system == 'Windows' else 'gssapi'
+ if gsslib not in {'gssapi', 'sspi'}:
+ raise exceptions.ClientConfigurationError(
+ "gsslib parameter must be either 'gssapi' or 'sspi'"
+ ", got {!r}".format(gsslib))
+
params = _ConnectionParameters(
user=user, password=password, database=database, ssl=ssl,
- sslmode=sslmode, connect_timeout=connect_timeout,
- server_settings=server_settings)
+ sslmode=sslmode, ssl_negotiation=sslneg,
+ server_settings=server_settings,
+ target_session_attrs=target_session_attrs,
+ krbsrvname=krbsrvname, gsslib=gsslib)
return addrs, params
def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
- database, timeout, command_timeout,
+ database, command_timeout,
statement_cache_size,
max_cached_statement_lifetime,
max_cacheable_statement_size,
- ssl, server_settings):
-
+ ssl, direct_tls, server_settings,
+ target_session_attrs, krbsrvname, gsslib,
+ service, servicefile):
local_vars = locals()
for var_name in {'max_cacheable_statement_size',
'max_cached_statement_lifetime',
@@ -535,8 +893,11 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
addrs, params = _parse_connect_dsn_and_args(
dsn=dsn, host=host, port=port, user=user,
password=password, passfile=passfile, ssl=ssl,
- database=database, connect_timeout=timeout,
- server_settings=server_settings)
+ direct_tls=direct_tls, database=database,
+ server_settings=server_settings,
+ target_session_attrs=target_session_attrs,
+ krbsrvname=krbsrvname, gsslib=gsslib,
+ service=service, servicefile=servicefile)
config = _ClientConfiguration(
command_timeout=command_timeout,
@@ -548,14 +909,21 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
class TLSUpgradeProto(asyncio.Protocol):
- def __init__(self, loop, host, port, ssl_context, ssl_is_advisory):
+ def __init__(
+ self,
+ loop: asyncio.AbstractEventLoop,
+ host: str,
+ port: int,
+ ssl_context: ssl_module.SSLContext,
+ ssl_is_advisory: bool,
+ ) -> None:
self.on_data = _create_future(loop)
self.host = host
self.port = port
self.ssl_context = ssl_context
self.ssl_is_advisory = ssl_is_advisory
- def data_received(self, data):
+ def data_received(self, data: bytes) -> None:
if data == b'S':
self.on_data.set_result(True)
elif (self.ssl_is_advisory and
@@ -573,15 +941,30 @@ def data_received(self, data):
'rejected SSL upgrade'.format(
host=self.host, port=self.port)))
- def connection_lost(self, exc):
+ def connection_lost(self, exc: typing.Optional[Exception]) -> None:
if not self.on_data.done():
if exc is None:
exc = ConnectionError('unexpected connection_lost() call')
self.on_data.set_exception(exc)
-async def _create_ssl_connection(protocol_factory, host, port, *,
- loop, ssl_context, ssl_is_advisory=False):
+_ProctolFactoryR = typing.TypeVar(
+ "_ProctolFactoryR", bound=asyncio.protocols.Protocol
+)
+
+
+async def _create_ssl_connection(
+ # TODO: The return type is a specific combination of subclasses of
+ # asyncio.protocols.Protocol that we can't express. For now, having the
+ # return type be dependent on signature of the factory is an improvement
+ protocol_factory: Callable[[], _ProctolFactoryR],
+ host: str,
+ port: int,
+ *,
+ loop: asyncio.AbstractEventLoop,
+ ssl_context: ssl_module.SSLContext,
+ ssl_is_advisory: bool = False,
+) -> typing.Tuple[asyncio.Transport, _ProctolFactoryR]:
tr, pr = await loop.create_connection(
lambda: TLSUpgradeProto(loop, host, port,
@@ -601,6 +984,7 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
try:
new_tr = await loop.start_tls(
tr, pr, ssl_context, server_hostname=host)
+ assert new_tr is not None
except (Exception, asyncio.CancelledError):
tr.close()
raise
@@ -639,7 +1023,6 @@ async def _connect_addr(
*,
addr,
loop,
- timeout,
params,
config,
connection_class,
@@ -647,15 +1030,11 @@ async def _connect_addr(
):
assert loop is not None
- if timeout <= 0:
- raise asyncio.TimeoutError
-
params_input = params
if callable(params.password):
- if inspect.iscoroutinefunction(params.password):
- password = await params.password()
- else:
- password = params.password()
+ password = params.password()
+ if inspect.isawaitable(password):
+ password = await password
params = params._replace(password=password)
args = (addr, loop, config, connection_class, record_class, params_input)
@@ -668,21 +1047,16 @@ async def _connect_addr(
params_retry = params._replace(ssl=None)
else:
# skip retry if we don't have to
- return await __connect_addr(params, timeout, False, *args)
+ return await __connect_addr(params, False, *args)
# first attempt
- before = time.monotonic()
try:
- return await __connect_addr(params, timeout, True, *args)
+ return await __connect_addr(params, True, *args)
except _RetryConnectSignal:
pass
# second attempt
- timeout -= time.monotonic() - before
- if timeout <= 0:
- raise asyncio.TimeoutError
- else:
- return await __connect_addr(params_retry, timeout, False, *args)
+ return await __connect_addr(params_retry, False, *args)
class _RetryConnectSignal(Exception):
@@ -691,7 +1065,6 @@ class _RetryConnectSignal(Exception):
async def __connect_addr(
params,
- timeout,
retry,
addr,
loop,
@@ -708,6 +1081,14 @@ async def __connect_addr(
if isinstance(addr, str):
# UNIX socket
connector = loop.create_unix_connection(proto_factory, addr)
+
+ elif params.ssl and params.ssl_negotiation is SSLNegotiation.direct:
+ # if ssl and ssl_negotiation is `direct`, skip STARTTLS and perform
+ # direct SSL connection
+ connector = loop.create_connection(
+ proto_factory, *addr, ssl=params.ssl
+ )
+
elif params.ssl:
connector = _create_ssl_connection(
proto_factory, *addr, loop=loop, ssl_context=params.ssl,
@@ -715,15 +1096,10 @@ async def __connect_addr(
else:
connector = loop.create_connection(proto_factory, *addr)
- connector = asyncio.ensure_future(connector)
- before = time.monotonic()
- tr, pr = await compat.wait_for(connector, timeout=timeout)
- timeout -= time.monotonic() - before
+ tr, pr = await connector
try:
- if timeout <= 0:
- raise asyncio.TimeoutError
- await compat.wait_for(connected, timeout=timeout)
+ await connected
except (
exceptions.InvalidAuthorizationSpecificationError,
exceptions.ConnectionDoesNotExistError, # seen on Windows
@@ -762,32 +1138,118 @@ async def __connect_addr(
return con
-async def _connect(*, loop, timeout, connection_class, record_class, **kwargs):
+class SessionAttribute(str, enum.Enum):
+ any = 'any'
+ primary = 'primary'
+ standby = 'standby'
+ prefer_standby = 'prefer-standby'
+ read_write = "read-write"
+ read_only = "read-only"
+
+
+def _accept_in_hot_standby(should_be_in_hot_standby: bool):
+ """
+ If the server didn't report "in_hot_standby" at startup, we must determine
+ the state by checking "SELECT pg_catalog.pg_is_in_recovery()".
+ If the server allows a connection and states it is in recovery it must
+ be a replica/standby server.
+ """
+ async def can_be_used(connection):
+ settings = connection.get_settings()
+ hot_standby_status = getattr(settings, 'in_hot_standby', None)
+ if hot_standby_status is not None:
+ is_in_hot_standby = hot_standby_status == 'on'
+ else:
+ is_in_hot_standby = await connection.fetchval(
+ "SELECT pg_catalog.pg_is_in_recovery()"
+ )
+ return is_in_hot_standby == should_be_in_hot_standby
+
+ return can_be_used
+
+
+def _accept_read_only(should_be_read_only: bool):
+ """
+ Verify the server has not set default_transaction_read_only=True
+ """
+ async def can_be_used(connection):
+ settings = connection.get_settings()
+ is_readonly = getattr(settings, 'default_transaction_read_only', 'off')
+
+ if is_readonly == "on":
+ return should_be_read_only
+
+ return await _accept_in_hot_standby(should_be_read_only)(connection)
+ return can_be_used
+
+
+async def _accept_any(_):
+ return True
+
+
+target_attrs_check = {
+ SessionAttribute.any: _accept_any,
+ SessionAttribute.primary: _accept_in_hot_standby(False),
+ SessionAttribute.standby: _accept_in_hot_standby(True),
+ SessionAttribute.prefer_standby: _accept_in_hot_standby(True),
+ SessionAttribute.read_write: _accept_read_only(False),
+ SessionAttribute.read_only: _accept_read_only(True),
+}
+
+
+async def _can_use_connection(connection, attr: SessionAttribute):
+ can_use = target_attrs_check[attr]
+ return await can_use(connection)
+
+
+async def _connect(*, loop, connection_class, record_class, **kwargs):
if loop is None:
loop = asyncio.get_event_loop()
- addrs, params, config = _parse_connect_arguments(timeout=timeout, **kwargs)
+ addrs, params, config = _parse_connect_arguments(**kwargs)
+ target_attr = params.target_session_attrs
+ candidates = []
+ chosen_connection = None
last_error = None
- addr = None
- for addr in addrs:
- before = time.monotonic()
- try:
- return await _connect_addr(
- addr=addr,
- loop=loop,
- timeout=timeout,
- params=params,
- config=config,
- connection_class=connection_class,
- record_class=record_class,
+ try:
+ for addr in addrs:
+ try:
+ conn = await _connect_addr(
+ addr=addr,
+ loop=loop,
+ params=params,
+ config=config,
+ connection_class=connection_class,
+ record_class=record_class,
+ )
+ candidates.append(conn)
+ if await _can_use_connection(conn, target_attr):
+ chosen_connection = conn
+ break
+ except OSError as ex:
+ last_error = ex
+ else:
+ if target_attr == SessionAttribute.prefer_standby and candidates:
+ chosen_connection = random.choice(candidates)
+ finally:
+
+ async def _close_candidates(conns, chosen):
+ await asyncio.gather(
+ *(c.close() for c in conns if c is not chosen),
+ return_exceptions=True
)
- except (OSError, asyncio.TimeoutError, ConnectionError) as ex:
- last_error = ex
- finally:
- timeout -= time.monotonic() - before
+ if candidates:
+ asyncio.create_task(
+ _close_candidates(candidates, chosen_connection))
+
+ if chosen_connection:
+ return chosen_connection
- raise last_error
+ raise last_error or exceptions.TargetServerAttributeNotMatched(
+ 'None of the hosts match the target attribute requirement '
+ '{!r}'.format(target_attr)
+ )
async def _cancel(*, loop, addr, params: _ConnectionParameters,
diff --git a/asyncpg/connection.py b/asyncpg/connection.py
index 26249679..71fb04f8 100644
--- a/asyncpg/connection.py
+++ b/asyncpg/connection.py
@@ -9,6 +9,7 @@
import asyncpg
import collections
import collections.abc
+import contextlib
import functools
import itertools
import inspect
@@ -48,11 +49,12 @@ class Connection(metaclass=ConnectionMeta):
__slots__ = ('_protocol', '_transport', '_loop',
'_top_xact', '_aborted',
'_pool_release_ctr', '_stmt_cache', '_stmts_to_close',
+ '_stmt_cache_enabled',
'_listeners', '_server_version', '_server_caps',
'_intro_query', '_reset_query', '_proxy',
'_stmt_exclusive_section', '_config', '_params', '_addr',
'_log_listeners', '_termination_listeners', '_cancellations',
- '_source_traceback', '__weakref__')
+ '_source_traceback', '_query_loggers', '__weakref__')
def __init__(self, protocol, transport, loop,
addr,
@@ -80,11 +82,13 @@ def __init__(self, protocol, transport, loop,
max_lifetime=config.max_cached_statement_lifetime)
self._stmts_to_close = set()
+ self._stmt_cache_enabled = config.statement_cache_size > 0
self._listeners = {}
self._log_listeners = set()
self._cancellations = set()
self._termination_listeners = set()
+ self._query_loggers = set()
settings = self._protocol.get_settings()
ver_string = settings.server_version
@@ -94,7 +98,10 @@ def __init__(self, protocol, transport, loop,
self._server_caps = _detect_server_capabilities(
self._server_version, settings)
- self._intro_query = introspection.INTRO_LOOKUP_TYPES
+ if self._server_version < (14, 0):
+ self._intro_query = introspection.INTRO_LOOKUP_TYPES_13
+ else:
+ self._intro_query = introspection.INTRO_LOOKUP_TYPES
self._reset_query = None
self._proxy = None
@@ -219,6 +226,29 @@ def remove_termination_listener(self, callback):
"""
self._termination_listeners.discard(_Callback.from_callable(callback))
+ def add_query_logger(self, callback):
+ """Add a logger that will be called when queries are executed.
+
+ :param callable callback:
+ A callable or a coroutine function receiving one argument:
+ **record**, a LoggedQuery containing `query`, `args`, `timeout`,
+ `elapsed`, `exception`, `conn_addr`, and `conn_params`.
+
+ .. versionadded:: 0.29.0
+ """
+ self._query_loggers.add(_Callback.from_callable(callback))
+
+ def remove_query_logger(self, callback):
+ """Remove a query logger callback.
+
+ :param callable callback:
+ The callable or coroutine function that was passed to
+ :meth:`Connection.add_query_logger`.
+
+ .. versionadded:: 0.29.0
+ """
+ self._query_loggers.discard(_Callback.from_callable(callback))
+
def get_server_pid(self):
"""Return the PID of the Postgres server the connection is bound to."""
return self._protocol.get_server_pid()
@@ -255,9 +285,9 @@ def transaction(self, *, isolation=None, readonly=False,
:param isolation: Transaction isolation mode, can be one of:
`'serializable'`, `'repeatable_read'`,
- `'read_committed'`. If not specified, the behavior
- is up to the server and session, which is usually
- ``read_committed``.
+ `'read_uncommitted'`, `'read_committed'`. If not
+ specified, the behavior is up to the server and
+ session, which is usually ``read_committed``.
:param readonly: Specifies whether or not this transaction is
read-only.
@@ -281,7 +311,12 @@ def is_in_transaction(self):
"""
return self._protocol.is_in_transaction()
- async def execute(self, query: str, *args, timeout: float=None) -> str:
+ async def execute(
+ self,
+ query: str,
+ *args,
+ timeout: typing.Optional[float]=None,
+ ) -> str:
"""Execute an SQL command (or commands).
This method can execute many SQL commands at once, when no arguments
@@ -312,7 +347,12 @@ async def execute(self, query: str, *args, timeout: float=None) -> str:
self._check_open()
if not args:
- return await self._protocol.query(query, timeout)
+ if self._query_loggers:
+ with self._time_and_log(query, args, timeout):
+ result = await self._protocol.query(query, timeout)
+ else:
+ result = await self._protocol.query(query, timeout)
+ return result
_, status, _ = await self._execute(
query,
@@ -323,7 +363,13 @@ async def execute(self, query: str, *args, timeout: float=None) -> str:
)
return status.decode()
- async def executemany(self, command: str, args, *, timeout: float=None):
+ async def executemany(
+ self,
+ command: str,
+ args,
+ *,
+ timeout: typing.Optional[float]=None,
+ ):
"""Execute an SQL *command* for each sequence of arguments in *args*.
Example:
@@ -359,8 +405,8 @@ async def _get_statement(
query,
timeout,
*,
- named: bool=False,
- use_cache: bool=True,
+ named: typing.Union[str, bool, None] = False,
+ use_cache=True,
ignore_custom_codec=False,
record_class=None
):
@@ -379,13 +425,17 @@ async def _get_statement(
# Only use the cache when:
# * `statement_cache_size` is greater than 0;
# * query size is less than `max_cacheable_statement_size`.
- use_cache = self._stmt_cache.get_max_size() > 0
- if (use_cache and
- self._config.max_cacheable_statement_size and
- len(query) > self._config.max_cacheable_statement_size):
- use_cache = False
+ use_cache = (
+ self._stmt_cache_enabled
+ and (
+ not self._config.max_cacheable_statement_size
+ or len(query) <= self._config.max_cacheable_statement_size
+ )
+ )
- if use_cache or named:
+ if isinstance(named, str):
+ stmt_name = named
+ elif use_cache or named:
stmt_name = self._get_unique_id('stmt')
else:
stmt_name = ''
@@ -430,14 +480,16 @@ async def _get_statement(
# for the statement.
statement._init_codecs()
- if need_reprepare:
- await self._protocol.prepare(
- stmt_name,
- query,
- timeout,
- state=statement,
- record_class=record_class,
- )
+ if (
+ need_reprepare
+ or (not statement.name and not self._stmt_cache_enabled)
+ ):
+ # Mark this anonymous prepared statement as "unprepared",
+ # causing it to get re-Parsed in next bind_execute.
+ # We always do this when stmt_cache_size is set to 0 assuming
+ # people are running PgBouncer which is mishandling implicit
+ # transactions.
+ statement.mark_unprepared()
if use_cache:
self._stmt_cache.put(
@@ -451,7 +503,26 @@ async def _get_statement(
return statement
async def _introspect_types(self, typeoids, timeout):
- return await self.__execute(
+ if self._server_caps.jit:
+ try:
+ cfgrow, _ = await self.__execute(
+ """
+ SELECT
+ current_setting('jit') AS cur,
+ set_config('jit', 'off', false) AS new
+ """,
+ (),
+ 0,
+ timeout,
+ ignore_custom_codec=True,
+ )
+ jit_state = cfgrow[0]['cur']
+ except exceptions.UndefinedObjectError:
+ jit_state = 'off'
+ else:
+ jit_state = 'off'
+
+ result = await self.__execute(
self._intro_query,
(list(typeoids),),
0,
@@ -459,28 +530,34 @@ async def _introspect_types(self, typeoids, timeout):
ignore_custom_codec=True,
)
- async def _introspect_type(self, typename, schema):
- if (
- schema == 'pg_catalog'
- and typename.lower() in protocol.BUILTIN_TYPE_NAME_MAP
- ):
- typeoid = protocol.BUILTIN_TYPE_NAME_MAP[typename.lower()]
- rows = await self._execute(
- introspection.TYPE_BY_OID,
- [typeoid],
- limit=0,
- timeout=None,
- ignore_custom_codec=True,
- )
- else:
- rows = await self._execute(
- introspection.TYPE_BY_NAME,
- [typename, schema],
- limit=1,
- timeout=None,
+ if jit_state != 'off':
+ await self.__execute(
+ """
+ SELECT
+ set_config('jit', $1, false)
+ """,
+ (jit_state,),
+ 0,
+ timeout,
ignore_custom_codec=True,
)
+ return result
+
+ async def _introspect_type(self, typename, schema):
+ if schema == 'pg_catalog' and not typename.endswith("[]"):
+ typeoid = protocol.BUILTIN_TYPE_NAME_MAP.get(typename.lower())
+ if typeoid is not None:
+ return introspection.TypeRecord((typeoid, None, b"b"))
+
+ rows = await self._execute(
+ introspection.TYPE_BY_NAME,
+ [typename, schema],
+ limit=1,
+ timeout=None,
+ ignore_custom_codec=True,
+ )
+
if not rows:
raise ValueError(
'unknown type: {}.{}'.format(schema, typename))
@@ -526,11 +603,21 @@ def cursor(
record_class,
)
- async def prepare(self, query, *, timeout=None, record_class=None):
+ async def prepare(
+ self,
+ query,
+ *,
+ name=None,
+ timeout=None,
+ record_class=None,
+ ):
"""Create a *prepared statement* for the specified query.
:param str query:
Text of the query to create a prepared statement for.
+ :param str name:
+ Optional name of the returned prepared statement. If not
+ specified, the name is auto-generated.
:param float timeout:
Optional timeout value in seconds.
:param type record_class:
@@ -544,11 +631,14 @@ async def prepare(self, query, *, timeout=None, record_class=None):
.. versionchanged:: 0.22.0
Added the *record_class* parameter.
+
+ .. versionchanged:: 0.25.0
+ Added the *name* parameter.
"""
return await self._prepare(
query,
+ name=name,
timeout=timeout,
- use_cache=False,
record_class=record_class,
)
@@ -556,15 +646,18 @@ async def _prepare(
self,
query,
*,
+ name: typing.Union[str, bool, None] = None,
timeout=None,
use_cache: bool=False,
record_class=None
):
self._check_open()
+ if name is None:
+ name = self._stmt_cache_enabled
stmt = await self._get_statement(
query,
timeout,
- named=True,
+ named=name,
use_cache=use_cache,
record_class=record_class,
)
@@ -667,6 +760,49 @@ async def fetchrow(
return None
return data[0]
+ async def fetchmany(
+ self,
+ query,
+ args,
+ *,
+ timeout: typing.Optional[float]=None,
+ record_class=None,
+ ):
+ """Run a query for each sequence of arguments in *args*
+ and return the results as a list of :class:`Record`.
+
+ :param query:
+ Query to execute.
+ :param args:
+ An iterable containing sequences of arguments for the query.
+ :param float timeout:
+ Optional timeout value in seconds.
+ :param type record_class:
+ If specified, the class to use for records returned by this method.
+ Must be a subclass of :class:`~asyncpg.Record`. If not specified,
+ a per-connection *record_class* is used.
+
+ :return list:
+ A list of :class:`~asyncpg.Record` instances. If specified, the
+ actual type of list elements would be *record_class*.
+
+ Example:
+
+ .. code-block:: pycon
+
+ >>> rows = await con.fetchmany('''
+ ... INSERT INTO mytab (a, b) VALUES ($1, $2) RETURNING a;
+ ... ''', [('x', 1), ('y', 2), ('z', 3)])
+ >>> rows
+ [, , ]
+
+ .. versionadded:: 0.30.0
+ """
+ self._check_open()
+ return await self._executemany(
+ query, args, timeout, return_rows=True, record_class=record_class
+ )
+
async def copy_from_table(self, table_name, *, output,
columns=None, schema_name=None, timeout=None,
format=None, oids=None, delimiter=None,
@@ -710,7 +846,7 @@ async def copy_from_table(self, table_name, *, output,
... output='file.csv', format='csv')
... print(result)
...
- >>> asyncio.get_event_loop().run_until_complete(run())
+ >>> asyncio.run(run())
'COPY 100'
.. _`COPY statement documentation`:
@@ -779,7 +915,7 @@ async def copy_from_query(self, query, *args, output,
... output='file.csv', format='csv')
... print(result)
...
- >>> asyncio.get_event_loop().run_until_complete(run())
+ >>> asyncio.run(run())
'COPY 10'
.. _`COPY statement documentation`:
@@ -807,7 +943,7 @@ async def copy_to_table(self, table_name, *, source,
delimiter=None, null=None, header=None,
quote=None, escape=None, force_quote=None,
force_not_null=None, force_null=None,
- encoding=None):
+ encoding=None, where=None):
"""Copy data to the specified table.
:param str table_name:
@@ -826,6 +962,15 @@ async def copy_to_table(self, table_name, *, source,
:param str schema_name:
An optional schema name to qualify the table.
+ :param str where:
+ An optional SQL expression used to filter rows when copying.
+
+ .. note::
+
+ Usage of this parameter requires support for the
+ ``COPY FROM ... WHERE`` syntax, introduced in
+ PostgreSQL version 12.
+
:param float timeout:
Optional timeout value in seconds.
@@ -846,13 +991,16 @@ async def copy_to_table(self, table_name, *, source,
... 'mytable', source='datafile.tbl')
... print(result)
...
- >>> asyncio.get_event_loop().run_until_complete(run())
+ >>> asyncio.run(run())
'COPY 140000'
.. _`COPY statement documentation`:
https://www.postgresql.org/docs/current/static/sql-copy.html
.. versionadded:: 0.11.0
+
+ .. versionadded:: 0.29.0
+ Added the *where* parameter.
"""
tabname = utils._quote_ident(table_name)
if schema_name:
@@ -864,6 +1012,7 @@ async def copy_to_table(self, table_name, *, source,
else:
cols = ''
+ cond = self._format_copy_where(where)
opts = self._format_copy_opts(
format=format, oids=oids, freeze=freeze, delimiter=delimiter,
null=null, header=header, quote=quote, escape=escape,
@@ -871,14 +1020,14 @@ async def copy_to_table(self, table_name, *, source,
encoding=encoding
)
- copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format(
- tab=tabname, cols=cols, opts=opts)
+ copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond}'.format(
+ tab=tabname, cols=cols, opts=opts, cond=cond)
return await self._copy_in(copy_stmt, source, timeout)
async def copy_records_to_table(self, table_name, *, records,
columns=None, schema_name=None,
- timeout=None):
+ timeout=None, where=None):
"""Copy a list of records to the specified table using binary COPY.
:param str table_name:
@@ -895,6 +1044,16 @@ async def copy_records_to_table(self, table_name, *, records,
:param str schema_name:
An optional schema name to qualify the table.
+ :param str where:
+ An optional SQL expression used to filter rows when copying.
+
+ .. note::
+
+ Usage of this parameter requires support for the
+ ``COPY FROM ... WHERE`` syntax, introduced in
+ PostgreSQL version 12.
+
+
:param float timeout:
Optional timeout value in seconds.
@@ -914,7 +1073,7 @@ async def copy_records_to_table(self, table_name, *, records,
... (2, 'ham', 'spam')])
... print(result)
...
- >>> asyncio.get_event_loop().run_until_complete(run())
+ >>> asyncio.run(run())
'COPY 2'
Asynchronous record iterables are also supported:
@@ -932,13 +1091,16 @@ async def copy_records_to_table(self, table_name, *, records,
... 'mytable', records=record_gen(100))
... print(result)
...
- >>> asyncio.get_event_loop().run_until_complete(run())
+ >>> asyncio.run(run())
'COPY 100'
.. versionadded:: 0.11.0
.. versionchanged:: 0.24.0
The ``records`` argument may be an asynchronous iterable.
+
+ .. versionadded:: 0.29.0
+ Added the *where* parameter.
"""
tabname = utils._quote_ident(table_name)
if schema_name:
@@ -954,16 +1116,29 @@ async def copy_records_to_table(self, table_name, *, records,
intro_query = 'SELECT {cols} FROM {tab} LIMIT 1'.format(
tab=tabname, cols=col_list)
- intro_ps = await self._prepare(intro_query, use_cache=True)
+ intro_ps = await self.prepare(intro_query)
+ cond = self._format_copy_where(where)
opts = '(FORMAT binary)'
- copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts}'.format(
- tab=tabname, cols=cols, opts=opts)
+ copy_stmt = 'COPY {tab}{cols} FROM STDIN {opts} {cond}'.format(
+ tab=tabname, cols=cols, opts=opts, cond=cond)
return await self._protocol.copy_in(
copy_stmt, None, None, records, intro_ps._state, timeout)
+ def _format_copy_where(self, where):
+ if where and not self._server_caps.sql_copy_from_where:
+ raise exceptions.UnsupportedServerFeatureError(
+ 'the `where` parameter requires PostgreSQL 12 or later')
+
+ if where:
+ where_clause = 'WHERE ' + where
+ else:
+ where_clause = ''
+
+ return where_clause
+
def _format_copy_opts(self, *, format=None, oids=None, freeze=None,
delimiter=None, null=None, header=None, quote=None,
escape=None, force_quote=None, force_not_null=None,
@@ -1135,6 +1310,9 @@ async def set_type_codec(self, typename, *,
| ``time with | (``microseconds``, |
| time zone`` | ``time zone offset in seconds``) |
+-----------------+---------------------------------------------+
+ | any composite | Composite value elements |
+ | type | |
+ +-----------------+---------------------------------------------+
:param encoder:
Callable accepting a Python object as a single argument and
@@ -1173,7 +1351,7 @@ async def set_type_codec(self, typename, *,
... print(result)
... print(datetime.datetime(2002, 1, 1) + result)
...
- >>> asyncio.get_event_loop().run_until_complete(run())
+ >>> asyncio.run(run())
relativedelta(years=+2, months=+3, days=+1)
2004-04-02 00:00:00
@@ -1189,6 +1367,10 @@ async def set_type_codec(self, typename, *,
The ``binary`` keyword argument was removed in favor of
``format``.
+ .. versionchanged:: 0.29.0
+ Custom codecs for composite types are now supported with
+ ``format='tuple'``.
+
.. note::
It is recommended to use the ``'binary'`` or ``'tuple'`` *format*
@@ -1199,11 +1381,28 @@ async def set_type_codec(self, typename, *,
codecs.
"""
self._check_open()
+ settings = self._protocol.get_settings()
typeinfo = await self._introspect_type(typename, schema)
- if not introspection.is_scalar_type(typeinfo):
+ full_typeinfos = []
+ if introspection.is_scalar_type(typeinfo):
+ kind = 'scalar'
+ elif introspection.is_composite_type(typeinfo):
+ if format != 'tuple':
+ raise exceptions.UnsupportedClientFeatureError(
+ 'only tuple-format codecs can be used on composite types',
+ hint="Use `set_type_codec(..., format='tuple')` and "
+ "pass/interpret data as a Python tuple. See an "
+ "example at https://magicstack.github.io/asyncpg/"
+ "current/usage.html#example-decoding-complex-types",
+ )
+ kind = 'composite'
+ full_typeinfos, _ = await self._introspect_types(
+ (typeinfo['oid'],), 10)
+ else:
raise exceptions.InterfaceError(
- 'cannot use custom codec on non-scalar type {}.{}'.format(
- schema, typename))
+ f'cannot use custom codec on type {schema}.{typename}: '
+ f'it is neither a scalar type nor a composite type'
+ )
if introspection.is_domain_type(typeinfo):
raise exceptions.UnsupportedClientFeatureError(
'custom codecs on domain types are not supported',
@@ -1215,8 +1414,8 @@ async def set_type_codec(self, typename, *,
)
oid = typeinfo['oid']
- self._protocol.get_settings().add_python_codec(
- oid, typename, schema, 'scalar',
+ settings.add_python_codec(
+ oid, typename, schema, full_typeinfos, kind,
encoder, decoder, format)
# Statement cache is no longer valid due to codec changes.
@@ -1325,11 +1524,10 @@ def terminate(self):
self._abort()
self._cleanup()
- async def reset(self, *, timeout=None):
+ async def _reset(self):
self._check_open()
self._listeners.clear()
self._log_listeners.clear()
- reset_query = self._get_reset_query()
if self._protocol.is_in_transaction() or self._top_xact is not None:
if self._top_xact is None or not self._top_xact._managed:
@@ -1341,10 +1539,36 @@ async def reset(self, *, timeout=None):
})
self._top_xact = None
- reset_query = 'ROLLBACK;\n' + reset_query
+ await self.execute("ROLLBACK")
- if reset_query:
- await self.execute(reset_query, timeout=timeout)
+ async def reset(self, *, timeout=None):
+ """Reset the connection state.
+
+ Calling this will reset the connection session state to a state
+ resembling that of a newly obtained connection. Namely, an open
+ transaction (if any) is rolled back, open cursors are closed,
+ all `LISTEN `_
+ registrations are removed, all session configuration
+ variables are reset to their default values, and all advisory locks
+ are released.
+
+ Note that the above describes the default query returned by
+ :meth:`Connection.get_reset_query`. If one overloads the method
+ by subclassing ``Connection``, then this method will do whatever
+ the overloaded method returns, except open transactions are always
+ terminated and any callbacks registered by
+ :meth:`Connection.add_listener` or :meth:`Connection.add_log_listener`
+ are removed.
+
+ :param float timeout:
+ A timeout for resetting the connection. If not specified, defaults
+ to no timeout.
+ """
+ async with compat.timeout(timeout):
+ await self._reset()
+ reset_query = self.get_reset_query()
+ if reset_query:
+ await self.execute(reset_query)
def _abort(self):
# Put the connection into the aborted state.
@@ -1365,6 +1589,7 @@ def _cleanup(self):
self._mark_stmts_as_closed()
self._listeners.clear()
self._log_listeners.clear()
+ self._query_loggers.clear()
self._clean_tasks()
def _clean_tasks(self):
@@ -1397,6 +1622,7 @@ def _mark_stmts_as_closed(self):
def _maybe_gc_stmt(self, stmt):
if (
stmt.refs == 0
+ and stmt.name
and not self._stmt_cache.has(
(stmt.query, stmt.record_class, stmt.ignore_custom_codec)
)
@@ -1448,7 +1674,7 @@ async def _cancel(self, waiter):
waiter.set_exception(ex)
finally:
self._cancellations.discard(
- compat.current_asyncio_task(self._loop))
+ asyncio.current_task(self._loop))
if not waiter.done():
waiter.set_result(None)
@@ -1503,7 +1729,15 @@ def _unwrap(self):
con_ref = self._proxy
return con_ref
- def _get_reset_query(self):
+ def get_reset_query(self):
+ """Return the query sent to server on connection release.
+
+ The query returned by this method is used by :meth:`Connection.reset`,
+ which is, in turn, used by :class:`~asyncpg.pool.Pool` before making
+ the connection available to another acquirer.
+
+ .. versionadded:: 0.30.0
+ """
if self._reset_query is not None:
return self._reset_query
@@ -1617,7 +1851,7 @@ async def reload_schema_state(self):
... await con.execute('LOCK TABLE tbl')
... await change_type(con)
...
- >>> asyncio.get_event_loop().run_until_complete(run())
+ >>> asyncio.run(run())
.. versionadded:: 0.14.0
"""
@@ -1647,6 +1881,62 @@ async def _execute(
)
return result
+ @contextlib.contextmanager
+ def query_logger(self, callback):
+ """Context manager that adds `callback` to the list of query loggers,
+ and removes it upon exit.
+
+ :param callable callback:
+ A callable or a coroutine function receiving one argument:
+ **record**, a LoggedQuery containing `query`, `args`, `timeout`,
+ `elapsed`, `exception`, `conn_addr`, and `conn_params`.
+
+ Example:
+
+ .. code-block:: pycon
+
+ >>> class QuerySaver:
+ def __init__(self):
+ self.queries = []
+ def __call__(self, record):
+ self.queries.append(record.query)
+ >>> with con.query_logger(QuerySaver()):
+ >>> await con.execute("SELECT 1")
+ >>> print(log.queries)
+ ['SELECT 1']
+
+ .. versionadded:: 0.29.0
+ """
+ self.add_query_logger(callback)
+ yield
+ self.remove_query_logger(callback)
+
+ @contextlib.contextmanager
+ def _time_and_log(self, query, args, timeout):
+ start = time.monotonic()
+ exception = None
+ try:
+ yield
+ except BaseException as ex:
+ exception = ex
+ raise
+ finally:
+ elapsed = time.monotonic() - start
+ record = LoggedQuery(
+ query=query,
+ args=args,
+ timeout=timeout,
+ elapsed=elapsed,
+ exception=exception,
+ conn_addr=self._addr,
+ conn_params=self._params,
+ )
+ for cb in self._query_loggers:
+ if cb.is_async:
+ self._loop.create_task(cb.cb(record))
+ else:
+ self._loop.call_soon(cb.cb, record)
+
async def __execute(
self,
query,
@@ -1659,22 +1949,54 @@ async def __execute(
record_class=None
):
executor = lambda stmt, timeout: self._protocol.bind_execute(
- stmt, args, '', limit, return_status, timeout)
- timeout = self._protocol._get_timeout(timeout)
- return await self._do_execute(
- query,
- executor,
- timeout,
- record_class=record_class,
- ignore_custom_codec=ignore_custom_codec,
+ state=stmt,
+ args=args,
+ portal_name='',
+ limit=limit,
+ return_extra=return_status,
+ timeout=timeout,
)
+ timeout = self._protocol._get_timeout(timeout)
+ if self._query_loggers:
+ with self._time_and_log(query, args, timeout):
+ result, stmt = await self._do_execute(
+ query,
+ executor,
+ timeout,
+ record_class=record_class,
+ ignore_custom_codec=ignore_custom_codec,
+ )
+ else:
+ result, stmt = await self._do_execute(
+ query,
+ executor,
+ timeout,
+ record_class=record_class,
+ ignore_custom_codec=ignore_custom_codec,
+ )
+ return result, stmt
- async def _executemany(self, query, args, timeout):
+ async def _executemany(
+ self,
+ query,
+ args,
+ timeout,
+ return_rows=False,
+ record_class=None,
+ ):
executor = lambda stmt, timeout: self._protocol.bind_execute_many(
- stmt, args, '', timeout)
+ state=stmt,
+ args=args,
+ portal_name='',
+ timeout=timeout,
+ return_rows=return_rows,
+ )
timeout = self._protocol._get_timeout(timeout)
with self._stmt_exclusive_section:
- result, _ = await self._do_execute(query, executor, timeout)
+ with self._time_and_log(query, args, timeout):
+ result, _ = await self._do_execute(
+ query, executor, timeout, record_class=record_class
+ )
return result
async def _do_execute(
@@ -1761,6 +2083,8 @@ async def _do_execute(
async def connect(dsn=None, *,
host=None, port=None,
user=None, password=None, passfile=None,
+ service=None,
+ servicefile=None,
database=None,
loop=None,
timeout=60,
@@ -1769,9 +2093,13 @@ async def connect(dsn=None, *,
max_cacheable_statement_size=1024 * 15,
command_timeout=None,
ssl=None,
+ direct_tls=None,
connection_class=Connection,
record_class=protocol.Record,
- server_settings=None):
+ server_settings=None,
+ target_session_attrs=None,
+ krbsrvname=None,
+ gsslib=None):
r"""A coroutine to establish a connection to a PostgreSQL server.
The connection parameters may be specified either as a connection
@@ -1796,7 +2124,13 @@ async def connect(dsn=None, *,
.. note::
The URI must be *valid*, which means that all components must
- be properly quoted with :py:func:`urllib.parse.quote`.
+ be properly quoted with :py:func:`urllib.parse.quote_plus`, and
+ any literal IPv6 addresses must be enclosed in square brackets.
+ For example:
+
+ .. code-block:: text
+
+ postgres://dbuser@[fe80::1ff:fe23:4567:890a%25eth0]/dbname
:param host:
Database host address as one of the following:
@@ -1841,7 +2175,7 @@ async def connect(dsn=None, *,
If not specified, the value parsed from the *dsn* argument is used,
or the value of the ``PGDATABASE`` environment variable, or the
- operating system name of the user running the application.
+ computed value of the *user* argument.
:param password:
Password to be used for authentication, if the server requires
@@ -1860,6 +2194,14 @@ async def connect(dsn=None, *,
(defaults to ``~/.pgpass``, or ``%APPDATA%\postgresql\pgpass.conf``
on Windows).
+ :param service:
+ The name of the postgres connection service stored in the postgres
+ connection service file.
+
+ :param servicefile:
+ The location of the connnection service file used to store
+ connection parameters.
+
:param loop:
An asyncio event loop instance. If ``None``, the default
event loop will be used.
@@ -1939,7 +2281,7 @@ async def connect(dsn=None, *,
... )
... con = await asyncpg.connect(user='postgres', ssl=sslctx)
... await con.close()
- >>> asyncio.run(run())
+ >>> asyncio.run(main())
Example of programmatic SSL context configuration that is equivalent
to ``sslmode=require`` (no server certificate or host verification):
@@ -1956,7 +2298,11 @@ async def connect(dsn=None, *,
... sslctx.verify_mode = ssl.CERT_NONE
... con = await asyncpg.connect(user='postgres', ssl=sslctx)
... await con.close()
- >>> asyncio.run(run())
+ >>> asyncio.run(main())
+
+ :param bool direct_tls:
+ Pass ``True`` to skip PostgreSQL STARTTLS mode and perform a direct
+ SSL connection. Must be used alongside ``ssl`` param.
:param dict server_settings:
An optional dict of server runtime parameters. Refer to
@@ -1972,6 +2318,31 @@ async def connect(dsn=None, *,
this connection object. Must be a subclass of
:class:`~asyncpg.Record`.
+ :param SessionAttribute target_session_attrs:
+ If specified, check that the host has the correct attribute.
+ Can be one of:
+
+ - ``"any"`` - the first successfully connected host
+ - ``"primary"`` - the host must NOT be in hot standby mode
+ - ``"standby"`` - the host must be in hot standby mode
+ - ``"read-write"`` - the host must allow writes
+ - ``"read-only"`` - the host most NOT allow writes
+ - ``"prefer-standby"`` - first try to find a standby host, but if
+ none of the listed hosts is a standby server,
+ return any of them.
+
+ If not specified, the value parsed from the *dsn* argument is used,
+ or the value of the ``PGTARGETSESSIONATTRS`` environment variable,
+ or ``"any"`` if neither is specified.
+
+ :param str krbsrvname:
+ Kerberos service name to use when authenticating with GSSAPI. This
+ must match the server configuration. Defaults to 'postgres'.
+
+ :param str gsslib:
+ GSS library to use for GSSAPI/SSPI authentication. Can be 'gssapi'
+ or 'sspi'. Defaults to 'sspi' on Windows and 'gssapi' otherwise.
+
:return: A :class:`~asyncpg.connection.Connection` instance.
Example:
@@ -1985,7 +2356,7 @@ async def connect(dsn=None, *,
... types = await con.fetch('SELECT * FROM pg_type')
... print(types)
...
- >>> asyncio.get_event_loop().run_until_complete(run())
+ >>> asyncio.run(run())
[= (11, 0)
+ sql_copy_from_where = server_version.major >= 12
return ServerCapabilities(
advisory_locks=advisory_locks,
notifications=notifications,
plpgsql=plpgsql,
sql_reset=sql_reset,
- sql_close_all=sql_close_all
+ sql_close_all=sql_close_all,
+ sql_copy_from_where=sql_copy_from_where,
+ jit=jit,
)
@@ -2331,8 +2751,8 @@ def _check_record_class(record_class):
and issubclass(record_class, protocol.Record)
):
if (
- record_class.__new__ is not object.__new__
- or record_class.__init__ is not object.__init__
+ record_class.__new__ is not protocol.Record.__new__
+ or record_class.__init__ is not protocol.Record.__init__
):
raise exceptions.InterfaceError(
'record_class must not redefine __new__ or __init__'
diff --git a/asyncpg/cursor.py b/asyncpg/cursor.py
index 7ec159ba..b4abeed1 100644
--- a/asyncpg/cursor.py
+++ b/asyncpg/cursor.py
@@ -158,6 +158,17 @@ async def _exec(self, n, timeout):
self._state, self._portal_name, n, True, timeout)
return buffer
+ async def _close_portal(self, timeout):
+ self._check_ready()
+
+ if not self._portal_name:
+ raise exceptions.InterfaceError(
+ 'cursor does not have an open portal')
+
+ protocol = self._connection._protocol
+ await protocol.close_portal(self._portal_name, timeout)
+ self._portal_name = None
+
def __repr__(self):
attrs = []
if self._exhausted:
@@ -219,7 +230,7 @@ async def __anext__(self):
)
self._state.attach()
- if not self._portal_name:
+ if not self._portal_name and not self._exhausted:
buffer = await self._bind_exec(self._prefetch, self._timeout)
self._buffer.extend(buffer)
@@ -227,6 +238,9 @@ async def __anext__(self):
buffer = await self._exec(self._prefetch, self._timeout)
self._buffer.extend(buffer)
+ if self._portal_name and self._exhausted:
+ await self._close_portal(self._timeout)
+
if self._buffer:
return self._buffer.popleft()
diff --git a/asyncpg/exceptions/__init__.py b/asyncpg/exceptions/__init__.py
index 446a71a8..752fd007 100644
--- a/asyncpg/exceptions/__init__.py
+++ b/asyncpg/exceptions/__init__.py
@@ -121,6 +121,10 @@ class StackedDiagnosticsAccessedWithoutActiveHandlerError(DiagnosticsError):
sqlstate = '0Z002'
+class InvalidArgumentForXqueryError(_base.PostgresError):
+ sqlstate = '10608'
+
+
class CaseNotFoundError(_base.PostgresError):
sqlstate = '20000'
@@ -337,6 +341,10 @@ class DuplicateJsonObjectKeyValueError(DataError):
sqlstate = '22030'
+class InvalidArgumentForSQLJsonDatetimeFunctionError(DataError):
+ sqlstate = '22031'
+
+
class InvalidJsonTextError(DataError):
sqlstate = '22032'
@@ -393,6 +401,10 @@ class SQLJsonScalarRequiredError(DataError):
sqlstate = '2203F'
+class SQLJsonItemCannotBeCastToTargetTypeError(DataError):
+ sqlstate = '2203G'
+
+
class IntegrityConstraintViolationError(_base.PostgresError):
sqlstate = '23000'
@@ -477,6 +489,10 @@ class IdleInTransactionSessionTimeoutError(InvalidTransactionStateError):
sqlstate = '25P03'
+class TransactionTimeoutError(InvalidTransactionStateError):
+ sqlstate = '25P04'
+
+
class InvalidSQLStatementNameError(_base.PostgresError):
sqlstate = '26000'
@@ -872,6 +888,10 @@ class DatabaseDroppedError(OperatorInterventionError):
sqlstate = '57P04'
+class IdleSessionTimeoutError(OperatorInterventionError):
+ sqlstate = '57P05'
+
+
class PostgresSystemError(_base.PostgresError):
sqlstate = '58000'
@@ -888,6 +908,10 @@ class DuplicateFileError(PostgresSystemError):
sqlstate = '58P02'
+class FileNameTooLongError(PostgresSystemError):
+ sqlstate = '58P03'
+
+
class SnapshotTooOldError(_base.PostgresError):
sqlstate = '72000'
@@ -1040,7 +1064,7 @@ class IndexCorruptedError(InternalServerError):
sqlstate = 'XX002'
-__all__ = _base.__all__ + (
+__all__ = (
'ActiveSQLTransactionError', 'AdminShutdownError',
'AmbiguousAliasError', 'AmbiguousColumnError',
'AmbiguousFunctionError', 'AmbiguousParameterError',
@@ -1083,11 +1107,11 @@ class IndexCorruptedError(InternalServerError):
'FDWTableNotFoundError', 'FDWTooManyHandlesError',
'FDWUnableToCreateExecutionError', 'FDWUnableToCreateReplyError',
'FDWUnableToEstablishConnectionError', 'FeatureNotSupportedError',
- 'ForeignKeyViolationError', 'FunctionExecutedNoReturnStatementError',
- 'GeneratedAlwaysError', 'GroupingError',
- 'HeldCursorRequiresSameIsolationLevelError',
- 'IdleInTransactionSessionTimeoutError', 'ImplicitZeroBitPadding',
- 'InFailedSQLTransactionError',
+ 'FileNameTooLongError', 'ForeignKeyViolationError',
+ 'FunctionExecutedNoReturnStatementError', 'GeneratedAlwaysError',
+ 'GroupingError', 'HeldCursorRequiresSameIsolationLevelError',
+ 'IdleInTransactionSessionTimeoutError', 'IdleSessionTimeoutError',
+ 'ImplicitZeroBitPadding', 'InFailedSQLTransactionError',
'InappropriateAccessModeForBranchTransactionError',
'InappropriateIsolationLevelForBranchTransactionError',
'IndeterminateCollationError', 'IndeterminateDatatypeError',
@@ -1098,7 +1122,9 @@ class IndexCorruptedError(InternalServerError):
'InvalidArgumentForNthValueFunctionError',
'InvalidArgumentForNtileFunctionError',
'InvalidArgumentForPowerFunctionError',
+ 'InvalidArgumentForSQLJsonDatetimeFunctionError',
'InvalidArgumentForWidthBucketFunctionError',
+ 'InvalidArgumentForXqueryError',
'InvalidAuthorizationSpecificationError',
'InvalidBinaryRepresentationError', 'InvalidCachedStatementError',
'InvalidCatalogNameError', 'InvalidCharacterValueForCastError',
@@ -1154,6 +1180,7 @@ class IndexCorruptedError(InternalServerError):
'ReadingExternalRoutineSQLDataNotPermittedError',
'ReadingSQLDataNotPermittedError', 'ReservedNameError',
'RestrictViolationError', 'SQLJsonArrayNotFoundError',
+ 'SQLJsonItemCannotBeCastToTargetTypeError',
'SQLJsonMemberNotFoundError', 'SQLJsonNumberNotFoundError',
'SQLJsonObjectNotFoundError', 'SQLJsonScalarRequiredError',
'SQLRoutineError', 'SQLStatementNotYetCompleteError',
@@ -1170,9 +1197,9 @@ class IndexCorruptedError(InternalServerError):
'TooManyJsonObjectMembersError', 'TooManyRowsError',
'TransactionIntegrityConstraintViolationError',
'TransactionResolutionUnknownError', 'TransactionRollbackError',
- 'TriggerProtocolViolatedError', 'TriggeredActionError',
- 'TriggeredDataChangeViolationError', 'TrimError',
- 'UndefinedColumnError', 'UndefinedFileError',
+ 'TransactionTimeoutError', 'TriggerProtocolViolatedError',
+ 'TriggeredActionError', 'TriggeredDataChangeViolationError',
+ 'TrimError', 'UndefinedColumnError', 'UndefinedFileError',
'UndefinedFunctionError', 'UndefinedObjectError',
'UndefinedParameterError', 'UndefinedTableError',
'UniqueViolationError', 'UnsafeNewEnumValueUsageError',
@@ -1180,3 +1207,5 @@ class IndexCorruptedError(InternalServerError):
'WindowingError', 'WithCheckOptionViolationError',
'WrongObjectTypeError', 'ZeroLengthCharacterStringError'
)
+
+__all__ += _base.__all__
diff --git a/asyncpg/exceptions/_base.py b/asyncpg/exceptions/_base.py
index 783b5eb5..00e9699a 100644
--- a/asyncpg/exceptions/_base.py
+++ b/asyncpg/exceptions/_base.py
@@ -12,8 +12,10 @@
__all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError',
'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage',
+ 'ClientConfigurationError',
'InternalClientError', 'OutdatedSchemaCacheError', 'ProtocolError',
- 'UnsupportedClientFeatureError')
+ 'UnsupportedClientFeatureError', 'TargetServerAttributeNotMatched',
+ 'UnsupportedServerFeatureError')
def _is_asyncpg_class(cls):
@@ -220,6 +222,10 @@ def with_msg(self, msg):
)
+class ClientConfigurationError(InterfaceError, ValueError):
+ """An error caused by improper client configuration."""
+
+
class DataError(InterfaceError, ValueError):
"""An error caused by invalid query input."""
@@ -228,6 +234,10 @@ class UnsupportedClientFeatureError(InterfaceError):
"""Requested feature is unsupported by asyncpg."""
+class UnsupportedServerFeatureError(InterfaceError):
+ """Requested feature is unsupported by PostgreSQL server."""
+
+
class InterfaceWarning(InterfaceMessage, UserWarning):
"""A warning caused by an improper use of asyncpg API."""
@@ -244,6 +254,10 @@ class ProtocolError(InternalClientError):
"""Unexpected condition in the handling of PostgreSQL protocol input."""
+class TargetServerAttributeNotMatched(InternalClientError):
+ """Could not find a host that satisfies the target attribute requirement"""
+
+
class OutdatedSchemaCacheError(InternalClientError):
"""A value decoding error caused by a schema change before row fetching."""
diff --git a/asyncpg/introspection.py b/asyncpg/introspection.py
index 64508692..c3b4e60c 100644
--- a/asyncpg/introspection.py
+++ b/asyncpg/introspection.py
@@ -4,8 +4,16 @@
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
+from __future__ import annotations
-_TYPEINFO = '''\
+import typing
+from .protocol.protocol import _create_record # type: ignore
+
+if typing.TYPE_CHECKING:
+ from . import protocol
+
+
+_TYPEINFO_13: typing.Final = '''\
(
SELECT
t.oid AS oid,
@@ -82,6 +90,130 @@
'''
+INTRO_LOOKUP_TYPES_13 = '''\
+WITH RECURSIVE typeinfo_tree(
+ oid, ns, name, kind, basetype, elemtype, elemdelim,
+ range_subtype, attrtypoids, attrnames, depth)
+AS (
+ SELECT
+ ti.oid, ti.ns, ti.name, ti.kind, ti.basetype,
+ ti.elemtype, ti.elemdelim, ti.range_subtype,
+ ti.attrtypoids, ti.attrnames, 0
+ FROM
+ {typeinfo} AS ti
+ WHERE
+ ti.oid = any($1::oid[])
+
+ UNION ALL
+
+ SELECT
+ ti.oid, ti.ns, ti.name, ti.kind, ti.basetype,
+ ti.elemtype, ti.elemdelim, ti.range_subtype,
+ ti.attrtypoids, ti.attrnames, tt.depth + 1
+ FROM
+ {typeinfo} ti,
+ typeinfo_tree tt
+ WHERE
+ (tt.elemtype IS NOT NULL AND ti.oid = tt.elemtype)
+ OR (tt.attrtypoids IS NOT NULL AND ti.oid = any(tt.attrtypoids))
+ OR (tt.range_subtype IS NOT NULL AND ti.oid = tt.range_subtype)
+ OR (tt.basetype IS NOT NULL AND ti.oid = tt.basetype)
+)
+
+SELECT DISTINCT
+ *,
+ basetype::regtype::text AS basetype_name,
+ elemtype::regtype::text AS elemtype_name,
+ range_subtype::regtype::text AS range_subtype_name
+FROM
+ typeinfo_tree
+ORDER BY
+ depth DESC
+'''.format(typeinfo=_TYPEINFO_13)
+
+
+_TYPEINFO: typing.Final = '''\
+ (
+ SELECT
+ t.oid AS oid,
+ ns.nspname AS ns,
+ t.typname AS name,
+ t.typtype AS kind,
+ (CASE WHEN t.typtype = 'd' THEN
+ (WITH RECURSIVE typebases(oid, depth) AS (
+ SELECT
+ t2.typbasetype AS oid,
+ 0 AS depth
+ FROM
+ pg_type t2
+ WHERE
+ t2.oid = t.oid
+
+ UNION ALL
+
+ SELECT
+ t2.typbasetype AS oid,
+ tb.depth + 1 AS depth
+ FROM
+ pg_type t2,
+ typebases tb
+ WHERE
+ tb.oid = t2.oid
+ AND t2.typbasetype != 0
+ ) SELECT oid FROM typebases ORDER BY depth DESC LIMIT 1)
+
+ ELSE NULL
+ END) AS basetype,
+ t.typelem AS elemtype,
+ elem_t.typdelim AS elemdelim,
+ COALESCE(
+ range_t.rngsubtype,
+ multirange_t.rngsubtype) AS range_subtype,
+ (CASE WHEN t.typtype = 'c' THEN
+ (SELECT
+ array_agg(ia.atttypid ORDER BY ia.attnum)
+ FROM
+ pg_attribute ia
+ INNER JOIN pg_class c
+ ON (ia.attrelid = c.oid)
+ WHERE
+ ia.attnum > 0 AND NOT ia.attisdropped
+ AND c.reltype = t.oid)
+
+ ELSE NULL
+ END) AS attrtypoids,
+ (CASE WHEN t.typtype = 'c' THEN
+ (SELECT
+ array_agg(ia.attname::text ORDER BY ia.attnum)
+ FROM
+ pg_attribute ia
+ INNER JOIN pg_class c
+ ON (ia.attrelid = c.oid)
+ WHERE
+ ia.attnum > 0 AND NOT ia.attisdropped
+ AND c.reltype = t.oid)
+
+ ELSE NULL
+ END) AS attrnames
+ FROM
+ pg_catalog.pg_type AS t
+ INNER JOIN pg_catalog.pg_namespace ns ON (
+ ns.oid = t.typnamespace)
+ LEFT JOIN pg_type elem_t ON (
+ t.typlen = -1 AND
+ t.typelem != 0 AND
+ t.typelem = elem_t.oid
+ )
+ LEFT JOIN pg_range range_t ON (
+ t.oid = range_t.rngtypid
+ )
+ LEFT JOIN pg_range multirange_t ON (
+ t.oid = multirange_t.rngmultitypid
+ )
+ )
+'''
+
+
INTRO_LOOKUP_TYPES = '''\
WITH RECURSIVE typeinfo_tree(
oid, ns, name, kind, basetype, elemtype, elemdelim,
@@ -109,6 +241,7 @@
(tt.elemtype IS NOT NULL AND ti.oid = tt.elemtype)
OR (tt.attrtypoids IS NOT NULL AND ti.oid = any(tt.attrtypoids))
OR (tt.range_subtype IS NOT NULL AND ti.oid = tt.range_subtype)
+ OR (tt.basetype IS NOT NULL AND ti.oid = tt.basetype)
)
SELECT DISTINCT
@@ -123,7 +256,7 @@
'''.format(typeinfo=_TYPEINFO)
-TYPE_BY_NAME = '''\
+TYPE_BY_NAME: typing.Final = '''\
SELECT
t.oid,
t.typelem AS elemtype,
@@ -136,28 +269,28 @@
'''
-TYPE_BY_OID = '''\
-SELECT
- t.oid,
- t.typelem AS elemtype,
- t.typtype AS kind
-FROM
- pg_catalog.pg_type AS t
-WHERE
- t.oid = $1
-'''
+def TypeRecord(
+ rec: typing.Tuple[int, typing.Optional[int], bytes],
+) -> protocol.Record:
+ assert len(rec) == 3
+ return _create_record( # type: ignore
+ {"oid": 0, "elemtype": 1, "kind": 2}, rec)
# 'b' for a base type, 'd' for a domain, 'e' for enum.
SCALAR_TYPE_KINDS = (b'b', b'd', b'e')
-def is_scalar_type(typeinfo) -> bool:
+def is_scalar_type(typeinfo: protocol.Record) -> bool:
return (
typeinfo['kind'] in SCALAR_TYPE_KINDS and
not typeinfo['elemtype']
)
-def is_domain_type(typeinfo) -> bool:
- return typeinfo['kind'] == b'd'
+def is_domain_type(typeinfo: protocol.Record) -> bool:
+ return typeinfo['kind'] == b'd' # type: ignore[no-any-return]
+
+
+def is_composite_type(typeinfo: protocol.Record) -> bool:
+ return typeinfo['kind'] == b'c' # type: ignore[no-any-return]
diff --git a/asyncpg/pgproto b/asyncpg/pgproto
index 1720f8af..a29a6f6a 160000
--- a/asyncpg/pgproto
+++ b/asyncpg/pgproto
@@ -1 +1 @@
-Subproject commit 1720f8af63725d79454884cfa787202a50eb5430
+Subproject commit a29a6f6aaa09013cb33ffadb8dd57e21d671ab55
diff --git a/asyncpg/pool.py b/asyncpg/pool.py
index c868097c..5c7ea9ca 100644
--- a/asyncpg/pool.py
+++ b/asyncpg/pool.py
@@ -4,17 +4,20 @@
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
+from __future__ import annotations
import asyncio
+from collections.abc import Awaitable, Callable
import functools
import inspect
import logging
import time
+from types import TracebackType
+from typing import Any, Optional, Type
import warnings
from . import compat
from . import connection
-from . import connect_utils
from . import exceptions
from . import protocol
@@ -24,7 +27,14 @@
class PoolConnectionProxyMeta(type):
- def __new__(mcls, name, bases, dct, *, wrap=False):
+ def __new__(
+ mcls,
+ name: str,
+ bases: tuple[Type[Any], ...],
+ dct: dict[str, Any],
+ *,
+ wrap: bool = False,
+ ) -> PoolConnectionProxyMeta:
if wrap:
for attrname in dir(connection.Connection):
if attrname.startswith('_') or attrname in dct:
@@ -34,7 +44,8 @@ def __new__(mcls, name, bases, dct, *, wrap=False):
if not inspect.isfunction(meth):
continue
- wrapper = mcls._wrap_connection_method(attrname)
+ iscoroutine = inspect.iscoroutinefunction(meth)
+ wrapper = mcls._wrap_connection_method(attrname, iscoroutine)
wrapper = functools.update_wrapper(wrapper, meth)
dct[attrname] = wrapper
@@ -43,13 +54,11 @@ def __new__(mcls, name, bases, dct, *, wrap=False):
return super().__new__(mcls, name, bases, dct)
- def __init__(cls, name, bases, dct, *, wrap=False):
- # Needed for Python 3.5 to handle `wrap` class keyword argument.
- super().__init__(name, bases, dct)
-
@staticmethod
- def _wrap_connection_method(meth_name):
- def call_con_method(self, *args, **kwargs):
+ def _wrap_connection_method(
+ meth_name: str, iscoroutine: bool
+ ) -> Callable[..., Any]:
+ def call_con_method(self: Any, *args: Any, **kwargs: Any) -> Any:
# This method will be owned by PoolConnectionProxy class.
if self._con is None:
raise exceptions.InterfaceError(
@@ -60,6 +69,9 @@ def call_con_method(self, *args, **kwargs):
meth = getattr(self._con.__class__, meth_name)
return meth(self._con, *args, **kwargs)
+ if iscoroutine:
+ compat.markcoroutinefunction(call_con_method)
+
return call_con_method
@@ -69,17 +81,18 @@ class PoolConnectionProxy(connection._ConnectionProxy,
__slots__ = ('_con', '_holder')
- def __init__(self, holder: 'PoolConnectionHolder',
- con: connection.Connection):
+ def __init__(
+ self, holder: PoolConnectionHolder, con: connection.Connection
+ ) -> None:
self._con = con
self._holder = holder
con._set_proxy(self)
- def __getattr__(self, attr):
+ def __getattr__(self, attr: str) -> Any:
# Proxy all unresolved attributes to the wrapped Connection object.
return getattr(self._con, attr)
- def _detach(self) -> connection.Connection:
+ def _detach(self) -> Optional[connection.Connection]:
if self._con is None:
return
@@ -87,7 +100,7 @@ def _detach(self) -> connection.Connection:
con._set_proxy(None)
return con
- def __repr__(self):
+ def __repr__(self) -> str:
if self._con is None:
return '<{classname} [released] {id:#x}>'.format(
classname=self.__class__.__name__, id=id(self))
@@ -104,21 +117,34 @@ class PoolConnectionHolder:
'_inactive_callback', '_timeout',
'_generation')
- def __init__(self, pool, *, max_queries, setup, max_inactive_time):
+ def __init__(
+ self,
+ pool: "Pool",
+ *,
+ max_queries: float,
+ setup: Optional[Callable[[PoolConnectionProxy], Awaitable[None]]],
+ max_inactive_time: float,
+ ) -> None:
self._pool = pool
- self._con = None
- self._proxy = None
+ self._con: Optional[connection.Connection] = None
+ self._proxy: Optional[PoolConnectionProxy] = None
self._max_queries = max_queries
self._max_inactive_time = max_inactive_time
self._setup = setup
- self._inactive_callback = None
- self._in_use = None # type: asyncio.Future
- self._timeout = None
- self._generation = None
+ self._inactive_callback: Optional[Callable] = None
+ self._in_use: Optional[asyncio.Future] = None
+ self._timeout: Optional[float] = None
+ self._generation: Optional[int] = None
+
+ def is_connected(self) -> bool:
+ return self._con is not None and not self._con.is_closed()
- async def connect(self):
+ def is_idle(self) -> bool:
+ return not self._in_use
+
+ async def connect(self) -> None:
if self._con is not None:
raise exceptions.InternalClientError(
'PoolConnectionHolder.connect() called while another '
@@ -166,7 +192,7 @@ async def acquire(self) -> PoolConnectionProxy:
return proxy
- async def release(self, timeout):
+ async def release(self, timeout: Optional[float]) -> None:
if self._in_use is None:
raise exceptions.InternalClientError(
'PoolConnectionHolder.release() called on '
@@ -205,7 +231,12 @@ async def release(self, timeout):
if budget is not None:
budget -= time.monotonic() - started
- await self._con.reset(timeout=budget)
+ if self._pool._reset is not None:
+ async with compat.timeout(budget):
+ await self._con._reset()
+ await self._pool._reset(self._con)
+ else:
+ await self._con.reset(timeout=budget)
except (Exception, asyncio.CancelledError) as ex:
# If the `reset` call failed, terminate the connection.
# A new one will be created when `acquire` is called
@@ -224,25 +255,25 @@ async def release(self, timeout):
# Rearm the connection inactivity timer.
self._setup_inactive_callback()
- async def wait_until_released(self):
+ async def wait_until_released(self) -> None:
if self._in_use is None:
return
else:
await self._in_use
- async def close(self):
+ async def close(self) -> None:
if self._con is not None:
# Connection.close() will call _release_on_close() to
# finish holder cleanup.
await self._con.close()
- def terminate(self):
+ def terminate(self) -> None:
if self._con is not None:
# Connection.terminate() will call _release_on_close() to
# finish holder cleanup.
self._con.terminate()
- def _setup_inactive_callback(self):
+ def _setup_inactive_callback(self) -> None:
if self._inactive_callback is not None:
raise exceptions.InternalClientError(
'pool connection inactivity timer already exists')
@@ -251,12 +282,12 @@ def _setup_inactive_callback(self):
self._inactive_callback = self._pool._loop.call_later(
self._max_inactive_time, self._deactivate_inactive_connection)
- def _maybe_cancel_inactive_callback(self):
+ def _maybe_cancel_inactive_callback(self) -> None:
if self._inactive_callback is not None:
self._inactive_callback.cancel()
self._inactive_callback = None
- def _deactivate_inactive_connection(self):
+ def _deactivate_inactive_connection(self) -> None:
if self._in_use is not None:
raise exceptions.InternalClientError(
'attempting to deactivate an acquired connection')
@@ -270,12 +301,12 @@ def _deactivate_inactive_connection(self):
# so terminate() above will not call the below.
self._release_on_close()
- def _release_on_close(self):
+ def _release_on_close(self) -> None:
self._maybe_cancel_inactive_callback()
self._release()
self._con = None
- def _release(self):
+ def _release(self) -> None:
"""Release this connection holder."""
if self._in_use is None:
# The holder is not checked out.
@@ -308,8 +339,7 @@ class Pool:
__slots__ = (
'_queue', '_loop', '_minsize', '_maxsize',
- '_init', '_connect_args', '_connect_kwargs',
- '_working_addr', '_working_config', '_working_params',
+ '_init', '_connect', '_reset', '_connect_args', '_connect_kwargs',
'_holders', '_initialized', '_initializing', '_closing',
'_closed', '_connection_class', '_record_class', '_generation',
'_setup', '_max_queries', '_max_inactive_connection_lifetime'
@@ -320,8 +350,10 @@ def __init__(self, *connect_args,
max_size,
max_queries,
max_inactive_connection_lifetime,
- setup,
- init,
+ connect=None,
+ setup=None,
+ init=None,
+ reset=None,
loop,
connection_class,
record_class,
@@ -375,28 +407,28 @@ def __init__(self, *connect_args,
self._initializing = False
self._queue = None
- self._working_addr = None
- self._working_config = None
- self._working_params = None
-
self._connection_class = connection_class
self._record_class = record_class
self._closing = False
self._closed = False
self._generation = 0
- self._init = init
+
+ self._connect = connect if connect is not None else connection.connect
self._connect_args = connect_args
self._connect_kwargs = connect_kwargs
self._setup = setup
+ self._init = init
+ self._reset = reset
+
self._max_queries = max_queries
self._max_inactive_connection_lifetime = \
max_inactive_connection_lifetime
async def _async__init__(self):
if self._initialized:
- return
+ return self
if self._initializing:
raise exceptions.InterfaceError(
'pool is being initialized in another task')
@@ -428,9 +460,8 @@ async def _initialize(self):
# first few connections in the queue, therefore we want to walk
# `self._holders` in reverse.
- # Connect the first connection holder in the queue so that it
- # can record `_working_addr` and `_working_opts`, which will
- # speed up successive connection attempts.
+ # Connect the first connection holder in the queue so that
+ # any connection issues are visible early.
first_ch = self._holders[-1] # type: PoolConnectionHolder
await first_ch.connect()
@@ -444,6 +475,41 @@ async def _initialize(self):
await asyncio.gather(*connect_tasks)
+ def is_closing(self):
+ """Return ``True`` if the pool is closing or is closed.
+
+ .. versionadded:: 0.28.0
+ """
+ return self._closed or self._closing
+
+ def get_size(self):
+ """Return the current number of connections in this pool.
+
+ .. versionadded:: 0.25.0
+ """
+ return sum(h.is_connected() for h in self._holders)
+
+ def get_min_size(self):
+ """Return the minimum number of connections in this pool.
+
+ .. versionadded:: 0.25.0
+ """
+ return self._minsize
+
+ def get_max_size(self):
+ """Return the maximum allowed number of connections in this pool.
+
+ .. versionadded:: 0.25.0
+ """
+ return self._maxsize
+
+ def get_idle_size(self):
+ """Return the current number of idle connections in this pool.
+
+ .. versionadded:: 0.25.0
+ """
+ return sum(h.is_connected() and h.is_idle() for h in self._holders)
+
def set_connect_args(self, dsn=None, **connect_kwargs):
r"""Set the new connection arguments for this pool.
@@ -467,35 +533,26 @@ def set_connect_args(self, dsn=None, **connect_kwargs):
self._connect_args = [dsn]
self._connect_kwargs = connect_kwargs
- self._working_addr = None
- self._working_config = None
- self._working_params = None
async def _get_new_connection(self):
- if self._working_addr is None:
- # First connection attempt on this pool.
- con = await connection.connect(
- *self._connect_args,
- loop=self._loop,
- connection_class=self._connection_class,
- record_class=self._record_class,
- **self._connect_kwargs)
-
- self._working_addr = con._addr
- self._working_config = con._config
- self._working_params = con._params
-
- else:
- # We've connected before and have a resolved address,
- # and parsed options and config.
- con = await connect_utils._connect_addr(
- loop=self._loop,
- addr=self._working_addr,
- timeout=self._working_params.connect_timeout,
- config=self._working_config,
- params=self._working_params,
- connection_class=self._connection_class,
- record_class=self._record_class,
+ con = await self._connect(
+ *self._connect_args,
+ loop=self._loop,
+ connection_class=self._connection_class,
+ record_class=self._record_class,
+ **self._connect_kwargs,
+ )
+ if not isinstance(con, self._connection_class):
+ good = self._connection_class
+ good_n = f'{good.__module__}.{good.__name__}'
+ bad = type(con)
+ if bad.__module__ == "builtins":
+ bad_n = bad.__name__
+ else:
+ bad_n = f'{bad.__module__}.{bad.__name__}'
+ raise exceptions.InterfaceError(
+ "expected pool connect callback to return an instance of "
+ f"'{good_n}', got " f"'{bad_n}'"
)
if self._init is not None:
@@ -517,7 +574,12 @@ async def _get_new_connection(self):
return con
- async def execute(self, query: str, *args, timeout: float=None) -> str:
+ async def execute(
+ self,
+ query: str,
+ *args,
+ timeout: Optional[float]=None,
+ ) -> str:
"""Execute an SQL command (or commands).
Pool performs this operation using one of its connections. Other than
@@ -529,7 +591,13 @@ async def execute(self, query: str, *args, timeout: float=None) -> str:
async with self.acquire() as con:
return await con.execute(query, *args, timeout=timeout)
- async def executemany(self, command: str, args, *, timeout: float=None):
+ async def executemany(
+ self,
+ command: str,
+ args,
+ *,
+ timeout: Optional[float]=None,
+ ):
"""Execute an SQL *command* for each sequence of arguments in *args*.
Pool performs this operation using one of its connections. Other than
@@ -542,7 +610,13 @@ async def executemany(self, command: str, args, *, timeout: float=None):
async with self.acquire() as con:
return await con.executemany(command, args, timeout=timeout)
- async def fetch(self, query, *args, timeout=None) -> list:
+ async def fetch(
+ self,
+ query,
+ *args,
+ timeout=None,
+ record_class=None
+ ) -> list:
"""Run a query and return the results as a list of :class:`Record`.
Pool performs this operation using one of its connections. Other than
@@ -552,7 +626,12 @@ async def fetch(self, query, *args, timeout=None) -> list:
.. versionadded:: 0.10.0
"""
async with self.acquire() as con:
- return await con.fetch(query, *args, timeout=timeout)
+ return await con.fetch(
+ query,
+ *args,
+ timeout=timeout,
+ record_class=record_class
+ )
async def fetchval(self, query, *args, column=0, timeout=None):
"""Run a query and return a value in the first row.
@@ -568,7 +647,7 @@ async def fetchval(self, query, *args, column=0, timeout=None):
return await con.fetchval(
query, *args, column=column, timeout=timeout)
- async def fetchrow(self, query, *args, timeout=None):
+ async def fetchrow(self, query, *args, timeout=None, record_class=None):
"""Run a query and return the first row.
Pool performs this operation using one of its connections. Other than
@@ -578,7 +657,28 @@ async def fetchrow(self, query, *args, timeout=None):
.. versionadded:: 0.10.0
"""
async with self.acquire() as con:
- return await con.fetchrow(query, *args, timeout=timeout)
+ return await con.fetchrow(
+ query,
+ *args,
+ timeout=timeout,
+ record_class=record_class
+ )
+
+ async def fetchmany(self, query, args, *, timeout=None, record_class=None):
+ """Run a query for each sequence of arguments in *args*
+ and return the results as a list of :class:`Record`.
+
+ Pool performs this operation using one of its connections. Other than
+ that, it behaves identically to
+ :meth:`Connection.fetchmany()
+ `.
+
+ .. versionadded:: 0.30.0
+ """
+ async with self.acquire() as con:
+ return await con.fetchmany(
+ query, args, timeout=timeout, record_class=record_class
+ )
async def copy_from_table(
self,
@@ -686,7 +786,8 @@ async def copy_to_table(
force_quote=None,
force_not_null=None,
force_null=None,
- encoding=None
+ encoding=None,
+ where=None
):
"""Copy data to the specified table.
@@ -715,7 +816,8 @@ async def copy_to_table(
force_quote=force_quote,
force_not_null=force_not_null,
force_null=force_null,
- encoding=encoding
+ encoding=encoding,
+ where=where
)
async def copy_records_to_table(
@@ -725,7 +827,8 @@ async def copy_records_to_table(
records,
columns=None,
schema_name=None,
- timeout=None
+ timeout=None,
+ where=None
):
"""Copy a list of records to the specified table using binary COPY.
@@ -742,7 +845,8 @@ async def copy_records_to_table(
records=records,
columns=columns,
schema_name=schema_name,
- timeout=timeout
+ timeout=timeout,
+ where=where
)
def acquire(self, *, timeout=None):
@@ -940,7 +1044,7 @@ class PoolAcquireContext:
__slots__ = ('timeout', 'connection', 'done', 'pool')
- def __init__(self, pool, timeout):
+ def __init__(self, pool: Pool, timeout: Optional[float]) -> None:
self.pool = pool
self.timeout = timeout
self.connection = None
@@ -952,7 +1056,12 @@ async def __aenter__(self):
self.connection = await self.pool._acquire(self.timeout)
return self.connection
- async def __aexit__(self, *exc):
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]] = None,
+ exc_val: Optional[BaseException] = None,
+ exc_tb: Optional[TracebackType] = None,
+ ) -> None:
self.done = True
con = self.connection
self.connection = None
@@ -968,8 +1077,10 @@ def create_pool(dsn=None, *,
max_size=10,
max_queries=50000,
max_inactive_connection_lifetime=300.0,
+ connect=None,
setup=None,
init=None,
+ reset=None,
loop=None,
connection_class=connection.Connection,
record_class=protocol.Record,
@@ -1050,9 +1161,16 @@ def create_pool(dsn=None, *,
Number of seconds after which inactive connections in the
pool will be closed. Pass ``0`` to disable this mechanism.
+ :param coroutine connect:
+ A coroutine that is called instead of
+ :func:`~asyncpg.connection.connect` whenever the pool needs to make a
+ new connection. Must return an instance of type specified by
+ *connection_class* or :class:`~asyncpg.connection.Connection` if
+ *connection_class* was not specified.
+
:param coroutine setup:
A coroutine to prepare a connection right before it is returned
- from :meth:`Pool.acquire() `. An example use
+ from :meth:`Pool.acquire()`. An example use
case would be to automatically set up notifications listeners for
all connections of a pool.
@@ -1064,6 +1182,25 @@ def create_pool(dsn=None, *,
or :meth:`Connection.set_type_codec() <\
asyncpg.connection.Connection.set_type_codec>`.
+ :param coroutine reset:
+ A coroutine to reset a connection before it is returned to the pool by
+ :meth:`Pool.release()`. The function is supposed
+ to reset any changes made to the database session so that the next
+ acquirer gets the connection in a well-defined state.
+
+ The default implementation calls :meth:`Connection.reset() <\
+ asyncpg.connection.Connection.reset>`, which runs the following::
+
+ SELECT pg_advisory_unlock_all();
+ CLOSE ALL;
+ UNLISTEN *;
+ RESET ALL;
+
+ The exact reset query is determined by detected server capabilities,
+ and a custom *reset* implementation can obtain the default query
+ by calling :meth:`Connection.get_reset_query() <\
+ asyncpg.connection.Connection.get_reset_query>`.
+
:param loop:
An asyncio event loop instance. If ``None``, the default
event loop will be used.
@@ -1090,12 +1227,22 @@ def create_pool(dsn=None, *,
.. versionchanged:: 0.22.0
Added the *record_class* parameter.
+
+ .. versionchanged:: 0.30.0
+ Added the *connect* and *reset* parameters.
"""
return Pool(
dsn,
connection_class=connection_class,
record_class=record_class,
- min_size=min_size, max_size=max_size,
- max_queries=max_queries, loop=loop, setup=setup, init=init,
+ min_size=min_size,
+ max_size=max_size,
+ max_queries=max_queries,
+ loop=loop,
+ connect=connect,
+ setup=setup,
+ init=init,
+ reset=reset,
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
- **connect_kwargs)
+ **connect_kwargs,
+ )
diff --git a/asyncpg/prepared_stmt.py b/asyncpg/prepared_stmt.py
index eeb45367..0c2d335e 100644
--- a/asyncpg/prepared_stmt.py
+++ b/asyncpg/prepared_stmt.py
@@ -6,6 +6,7 @@
import json
+import typing
from . import connresource
from . import cursor
@@ -24,6 +25,14 @@ def __init__(self, connection, query, state):
state.attach()
self._last_status = None
+ @connresource.guarded
+ def get_name(self) -> str:
+ """Return the name of this prepared statement.
+
+ .. versionadded:: 0.25.0
+ """
+ return self._state.name
+
@connresource.guarded
def get_query(self) -> str:
"""Return the text of the query for this prepared statement.
@@ -139,8 +148,8 @@ async def explain(self, *args, analyze=False):
# will discard any output that a SELECT would return, other
# side effects of the statement will happen as usual. If you
# wish to use EXPLAIN ANALYZE on an INSERT, UPDATE, DELETE,
- # CREATE TABLE AS, or EXECUTE statement without letting the
- # command affect your data, use this approach:
+ # MERGE, CREATE TABLE AS, or EXECUTE statement without letting
+ # the command affect your data, use this approach:
# BEGIN;
# EXPLAIN ANALYZE ...;
# ROLLBACK;
@@ -203,7 +212,28 @@ async def fetchrow(self, *args, timeout=None):
return data[0]
@connresource.guarded
- async def executemany(self, args, *, timeout: float=None):
+ async def fetchmany(self, args, *, timeout=None):
+ """Execute the statement and return a list of :class:`Record` objects.
+
+ :param args: Query arguments.
+ :param float timeout: Optional timeout value in seconds.
+
+ :return: A list of :class:`Record` instances.
+
+ .. versionadded:: 0.30.0
+ """
+ return await self.__do_execute(
+ lambda protocol: protocol.bind_execute_many(
+ self._state,
+ args,
+ portal_name='',
+ timeout=timeout,
+ return_rows=True,
+ )
+ )
+
+ @connresource.guarded
+ async def executemany(self, args, *, timeout: typing.Optional[float]=None):
"""Execute the statement for each sequence of arguments in *args*.
:param args: An iterable containing sequences of arguments.
@@ -214,7 +244,12 @@ async def executemany(self, args, *, timeout: float=None):
"""
return await self.__do_execute(
lambda protocol: protocol.bind_execute_many(
- self._state, args, '', timeout))
+ self._state,
+ args,
+ portal_name='',
+ timeout=timeout,
+ return_rows=False,
+ ))
async def __do_execute(self, executor):
protocol = self._connection._protocol
diff --git a/asyncpg/protocol/__init__.py b/asyncpg/protocol/__init__.py
index 8b3e06a0..043454db 100644
--- a/asyncpg/protocol/__init__.py
+++ b/asyncpg/protocol/__init__.py
@@ -6,4 +6,7 @@
# flake8: NOQA
-from .protocol import Protocol, Record, NO_TIMEOUT, BUILTIN_TYPE_NAME_MAP
+from __future__ import annotations
+
+from .protocol import Protocol, NO_TIMEOUT, BUILTIN_TYPE_NAME_MAP
+from .record import Record
diff --git a/asyncpg/protocol/codecs/array.pyx b/asyncpg/protocol/codecs/array.pyx
index 3c39e49c..f8f9b8dd 100644
--- a/asyncpg/protocol/codecs/array.pyx
+++ b/asyncpg/protocol/codecs/array.pyx
@@ -858,19 +858,7 @@ cdef arraytext_decode(ConnectionSettings settings, FRBuffer *buf):
return array_decode(settings, buf, &text_decode_ex, NULL)
-cdef anyarray_decode(ConnectionSettings settings, FRBuffer *buf):
- # Instances of anyarray (or any other polymorphic pseudotype) are
- # never supposed to be returned from actual queries.
- raise exceptions.ProtocolError(
- 'unexpected instance of \'anyarray\' type')
-
-
cdef init_array_codecs():
- register_core_codec(ANYARRAYOID,
- NULL,
- &anyarray_decode,
- PG_FORMAT_BINARY)
-
# oid[] and text[] are registered as core codecs
# to make type introspection query work
#
diff --git a/asyncpg/protocol/codecs/base.pxd b/asyncpg/protocol/codecs/base.pxd
index 79d7a695..f5492590 100644
--- a/asyncpg/protocol/codecs/base.pxd
+++ b/asyncpg/protocol/codecs/base.pxd
@@ -22,13 +22,26 @@ ctypedef object (*codec_decode_func)(Codec codec,
FRBuffer *buf)
+cdef class CodecMap:
+ cdef:
+ void** binary_codec_map
+ void** text_codec_map
+ dict extra_codecs
+
+ cdef inline void *get_binary_codec_ptr(self, uint32_t idx)
+ cdef inline void set_binary_codec_ptr(self, uint32_t idx, void *ptr)
+ cdef inline void *get_text_codec_ptr(self, uint32_t idx)
+ cdef inline void set_text_codec_ptr(self, uint32_t idx, void *ptr)
+
+
cdef enum CodecType:
- CODEC_UNDEFINED = 0
- CODEC_C = 1
- CODEC_PY = 2
- CODEC_ARRAY = 3
- CODEC_COMPOSITE = 4
- CODEC_RANGE = 5
+ CODEC_UNDEFINED = 0
+ CODEC_C = 1
+ CODEC_PY = 2
+ CODEC_ARRAY = 3
+ CODEC_COMPOSITE = 4
+ CODEC_RANGE = 5
+ CODEC_MULTIRANGE = 6
cdef enum ServerDataFormat:
@@ -56,6 +69,7 @@ cdef class Codec:
encode_func c_encoder
decode_func c_decoder
+ Codec base_codec
object py_encoder
object py_decoder
@@ -78,6 +92,7 @@ cdef class Codec:
CodecType type, ServerDataFormat format,
ClientExchangeFormat xformat,
encode_func c_encoder, decode_func c_decoder,
+ Codec base_codec,
object py_encoder, object py_decoder,
Codec element_codec, tuple element_type_oids,
object element_names, list element_codecs,
@@ -95,6 +110,9 @@ cdef class Codec:
cdef encode_range(self, ConnectionSettings settings, WriteBuffer buf,
object obj)
+ cdef encode_multirange(self, ConnectionSettings settings, WriteBuffer buf,
+ object obj)
+
cdef encode_composite(self, ConnectionSettings settings, WriteBuffer buf,
object obj)
@@ -109,6 +127,8 @@ cdef class Codec:
cdef decode_range(self, ConnectionSettings settings, FRBuffer *buf)
+ cdef decode_multirange(self, ConnectionSettings settings, FRBuffer *buf)
+
cdef decode_composite(self, ConnectionSettings settings, FRBuffer *buf)
cdef decode_in_python(self, ConnectionSettings settings, FRBuffer *buf)
@@ -139,6 +159,12 @@ cdef class Codec:
str schema,
Codec element_codec)
+ @staticmethod
+ cdef Codec new_multirange_codec(uint32_t oid,
+ str name,
+ str schema,
+ Codec element_codec)
+
@staticmethod
cdef Codec new_composite_codec(uint32_t oid,
str name,
@@ -157,6 +183,7 @@ cdef class Codec:
object decoder,
encode_func c_encoder,
decode_func c_decoder,
+ Codec base_codec,
ServerDataFormat format,
ClientExchangeFormat xformat)
diff --git a/asyncpg/protocol/codecs/base.pyx b/asyncpg/protocol/codecs/base.pyx
index e4a767a9..009598a8 100644
--- a/asyncpg/protocol/codecs/base.pyx
+++ b/asyncpg/protocol/codecs/base.pyx
@@ -11,9 +11,33 @@ import asyncpg
from asyncpg import exceptions
-cdef void* binary_codec_map[(MAXSUPPORTEDOID + 1) * 2]
-cdef void* text_codec_map[(MAXSUPPORTEDOID + 1) * 2]
-cdef dict EXTRA_CODECS = {}
+# The class indirection is needed because Cython
+# does not (as of 3.1.0) store global cdef variables
+# in module state.
+@cython.final
+cdef class CodecMap:
+
+ def __cinit__(self):
+ self.extra_codecs = {}
+ self.binary_codec_map = cpython.PyMem_Calloc(
+ (MAXSUPPORTEDOID + 1) * 2, sizeof(void *))
+ self.text_codec_map = cpython.PyMem_Calloc(
+ (MAXSUPPORTEDOID + 1) * 2, sizeof(void *))
+
+ cdef inline void *get_binary_codec_ptr(self, uint32_t idx):
+ return self.binary_codec_map[idx]
+
+ cdef inline void set_binary_codec_ptr(self, uint32_t idx, void *ptr):
+ self.binary_codec_map[idx] = ptr
+
+ cdef inline void *get_text_codec_ptr(self, uint32_t idx):
+ return self.text_codec_map[idx]
+
+ cdef inline void set_text_codec_ptr(self, uint32_t idx, void *ptr):
+ self.text_codec_map[idx] = ptr
+
+
+codec_map = CodecMap()
@cython.final
@@ -23,14 +47,25 @@ cdef class Codec:
self.oid = oid
self.type = CODEC_UNDEFINED
- cdef init(self, str name, str schema, str kind,
- CodecType type, ServerDataFormat format,
- ClientExchangeFormat xformat,
- encode_func c_encoder, decode_func c_decoder,
- object py_encoder, object py_decoder,
- Codec element_codec, tuple element_type_oids,
- object element_names, list element_codecs,
- Py_UCS4 element_delimiter):
+ cdef init(
+ self,
+ str name,
+ str schema,
+ str kind,
+ CodecType type,
+ ServerDataFormat format,
+ ClientExchangeFormat xformat,
+ encode_func c_encoder,
+ decode_func c_decoder,
+ Codec base_codec,
+ object py_encoder,
+ object py_decoder,
+ Codec element_codec,
+ tuple element_type_oids,
+ object element_names,
+ list element_codecs,
+ Py_UCS4 element_delimiter,
+ ):
self.name = name
self.schema = schema
@@ -40,6 +75,7 @@ cdef class Codec:
self.xformat = xformat
self.c_encoder = c_encoder
self.c_decoder = c_decoder
+ self.base_codec = base_codec
self.py_encoder = py_encoder
self.py_decoder = py_decoder
self.element_codec = element_codec
@@ -48,8 +84,14 @@ cdef class Codec:
self.element_delimiter = element_delimiter
self.element_names = element_names
+ if base_codec is not None:
+ if c_encoder != NULL or c_decoder != NULL:
+ raise exceptions.InternalClientError(
+ 'base_codec is mutually exclusive with c_encoder/c_decoder'
+ )
+
if element_names is not None:
- self.record_desc = record.ApgRecordDesc_New(
+ self.record_desc = RecordDescriptor(
element_names, tuple(element_names))
else:
self.record_desc = None
@@ -71,6 +113,13 @@ cdef class Codec:
'range types is not supported'.format(schema, name))
self.encoder = &self.encode_range
self.decoder = &self.decode_range
+ elif type == CODEC_MULTIRANGE:
+ if format != PG_FORMAT_BINARY:
+ raise exceptions.UnsupportedClientFeatureError(
+ 'cannot decode type "{}"."{}": text encoding of '
+ 'range types is not supported'.format(schema, name))
+ self.encoder = &self.encode_multirange
+ self.decoder = &self.decode_multirange
elif type == CODEC_COMPOSITE:
if format != PG_FORMAT_BINARY:
raise exceptions.UnsupportedClientFeatureError(
@@ -91,7 +140,7 @@ cdef class Codec:
codec = Codec(self.oid)
codec.init(self.name, self.schema, self.kind,
self.type, self.format, self.xformat,
- self.c_encoder, self.c_decoder,
+ self.c_encoder, self.c_decoder, self.base_codec,
self.py_encoder, self.py_decoder,
self.element_codec,
self.element_type_oids, self.element_names,
@@ -122,6 +171,12 @@ cdef class Codec:
codec_encode_func_ex,
(self.element_codec))
+ cdef encode_multirange(self, ConnectionSettings settings, WriteBuffer buf,
+ object obj):
+ multirange_encode(settings, buf, obj, self.element_codec.oid,
+ codec_encode_func_ex,
+ (self.element_codec))
+
cdef encode_composite(self, ConnectionSettings settings, WriteBuffer buf,
object obj):
cdef:
@@ -183,7 +238,10 @@ cdef class Codec:
raise exceptions.InternalClientError(
'unexpected data format: {}'.format(self.format))
elif self.xformat == PG_XFORMAT_TUPLE:
- self.c_encoder(settings, buf, data)
+ if self.base_codec is not None:
+ self.base_codec.encode(settings, buf, data)
+ else:
+ self.c_encoder(settings, buf, data)
else:
raise exceptions.InternalClientError(
'unexpected exchange format: {}'.format(self.xformat))
@@ -209,6 +267,10 @@ cdef class Codec:
return range_decode(settings, buf, codec_decode_func_ex,
(self.element_codec))
+ cdef decode_multirange(self, ConnectionSettings settings, FRBuffer *buf):
+ return multirange_decode(settings, buf, codec_decode_func_ex,
+ (self.element_codec))
+
cdef decode_composite(self, ConnectionSettings settings,
FRBuffer *buf):
cdef:
@@ -233,7 +295,7 @@ cdef class Codec:
schema=self.schema,
data_type=self.name,
)
- result = record.ApgRecord_New(asyncpg.Record, self.record_desc, elem_count)
+ result = self.record_desc.make_record(asyncpg.Record, elem_count)
for i in range(elem_count):
elem_typ = self.element_type_oids[i]
received_elem_typ = hton.unpack_int32(frb_read(buf, 4))
@@ -263,7 +325,7 @@ cdef class Codec:
settings, frb_slice_from(&elem_buf, buf, elem_len))
cpython.Py_INCREF(elem)
- record.ApgRecord_SET_ITEM(result, i, elem)
+ recordcapi.ApgRecord_SET_ITEM(result, i, elem)
return result
@@ -278,7 +340,10 @@ cdef class Codec:
raise exceptions.InternalClientError(
'unexpected data format: {}'.format(self.format))
elif self.xformat == PG_XFORMAT_TUPLE:
- data = self.c_decoder(settings, buf)
+ if self.base_codec is not None:
+ data = self.base_codec.decode(settings, buf)
+ else:
+ data = self.c_decoder(settings, buf)
else:
raise exceptions.InternalClientError(
'unexpected exchange format: {}'.format(self.xformat))
@@ -294,7 +359,11 @@ cdef class Codec:
if self.c_encoder is not NULL or self.py_encoder is not None:
return True
- elif self.type == CODEC_ARRAY or self.type == CODEC_RANGE:
+ elif (
+ self.type == CODEC_ARRAY
+ or self.type == CODEC_RANGE
+ or self.type == CODEC_MULTIRANGE
+ ):
return self.element_codec.has_encoder()
elif self.type == CODEC_COMPOSITE:
@@ -312,7 +381,11 @@ cdef class Codec:
if self.c_decoder is not NULL or self.py_decoder is not None:
return True
- elif self.type == CODEC_ARRAY or self.type == CODEC_RANGE:
+ elif (
+ self.type == CODEC_ARRAY
+ or self.type == CODEC_RANGE
+ or self.type == CODEC_MULTIRANGE
+ ):
return self.element_codec.has_decoder()
elif self.type == CODEC_COMPOSITE:
@@ -342,8 +415,8 @@ cdef class Codec:
cdef Codec codec
codec = Codec(oid)
codec.init(name, schema, 'array', CODEC_ARRAY, element_codec.format,
- PG_XFORMAT_OBJECT, NULL, NULL, None, None, element_codec,
- None, None, None, element_delimiter)
+ PG_XFORMAT_OBJECT, NULL, NULL, None, None, None,
+ element_codec, None, None, None, element_delimiter)
return codec
@staticmethod
@@ -354,8 +427,20 @@ cdef class Codec:
cdef Codec codec
codec = Codec(oid)
codec.init(name, schema, 'range', CODEC_RANGE, element_codec.format,
- PG_XFORMAT_OBJECT, NULL, NULL, None, None, element_codec,
- None, None, None, 0)
+ PG_XFORMAT_OBJECT, NULL, NULL, None, None, None,
+ element_codec, None, None, None, 0)
+ return codec
+
+ @staticmethod
+ cdef Codec new_multirange_codec(uint32_t oid,
+ str name,
+ str schema,
+ Codec element_codec):
+ cdef Codec codec
+ codec = Codec(oid)
+ codec.init(name, schema, 'multirange', CODEC_MULTIRANGE,
+ element_codec.format, PG_XFORMAT_OBJECT, NULL, NULL, None,
+ None, None, element_codec, None, None, None, 0)
return codec
@staticmethod
@@ -370,7 +455,7 @@ cdef class Codec:
codec = Codec(oid)
codec.init(name, schema, 'composite', CODEC_COMPOSITE,
format, PG_XFORMAT_OBJECT, NULL, NULL, None, None, None,
- element_type_oids, element_names, element_codecs, 0)
+ None, element_type_oids, element_names, element_codecs, 0)
return codec
@staticmethod
@@ -382,12 +467,13 @@ cdef class Codec:
object decoder,
encode_func c_encoder,
decode_func c_decoder,
+ Codec base_codec,
ServerDataFormat format,
ClientExchangeFormat xformat):
cdef Codec codec
codec = Codec(oid)
codec.init(name, schema, kind, CODEC_PY, format, xformat,
- c_encoder, c_decoder, encoder, decoder,
+ c_encoder, c_decoder, base_codec, encoder, decoder,
None, None, None, None, 0)
return codec
@@ -421,7 +507,7 @@ cdef uint32_t pylong_as_oid(val) except? 0xFFFFFFFFl:
cdef class DataCodecConfig:
- def __init__(self, cache_key):
+ def __init__(self):
# Codec instance cache for derived types:
# composites, arrays, ranges, domains and their combinations.
self._derived_type_codecs = {}
@@ -536,6 +622,21 @@ cdef class DataCodecConfig:
self._derived_type_codecs[oid, elem_codec.format] = \
Codec.new_range_codec(oid, name, schema, elem_codec)
+ elif ti['kind'] == b'm':
+ # Multirange type
+
+ if not range_subtype_oid:
+ raise exceptions.InternalClientError(
+ f'type record missing base type for multirange {oid}')
+
+ elem_codec = self.get_codec(range_subtype_oid, PG_FORMAT_ANY)
+ if elem_codec is None:
+ elem_codec = self.declare_fallback_codec(
+ range_subtype_oid, ti['range_subtype_name'], schema)
+
+ self._derived_type_codecs[oid, elem_codec.format] = \
+ Codec.new_multirange_codec(oid, name, schema, elem_codec)
+
elif ti['kind'] == b'e':
# Enum types are essentially text
self._set_builtin_type_codec(oid, name, schema, 'scalar',
@@ -544,17 +645,21 @@ cdef class DataCodecConfig:
self.declare_fallback_codec(oid, name, schema)
def add_python_codec(self, typeoid, typename, typeschema, typekind,
- encoder, decoder, format, xformat):
+ typeinfos, encoder, decoder, format, xformat):
cdef:
- Codec core_codec
+ Codec core_codec = None
encode_func c_encoder = NULL
decode_func c_decoder = NULL
+ Codec base_codec = None
uint32_t oid = pylong_as_oid(typeoid)
bint codec_set = False
# Clear all previous overrides (this also clears type cache).
self.remove_python_codec(typeoid, typename, typeschema)
+ if typeinfos:
+ self.add_types(typeinfos)
+
if format == PG_FORMAT_ANY:
formats = (PG_FORMAT_TEXT, PG_FORMAT_BINARY)
else:
@@ -562,16 +667,21 @@ cdef class DataCodecConfig:
for fmt in formats:
if xformat == PG_XFORMAT_TUPLE:
- core_codec = get_core_codec(oid, fmt, xformat)
- if core_codec is None:
- continue
- c_encoder = core_codec.c_encoder
- c_decoder = core_codec.c_decoder
+ if typekind == "scalar":
+ core_codec = get_core_codec(oid, fmt, xformat)
+ if core_codec is None:
+ continue
+ c_encoder = core_codec.c_encoder
+ c_decoder = core_codec.c_decoder
+ elif typekind == "composite":
+ base_codec = self.get_codec(oid, fmt)
+ if base_codec is None:
+ continue
self._custom_type_codecs[typeoid, fmt] = \
Codec.new_python_codec(oid, typename, typeschema, typekind,
encoder, decoder, c_encoder, c_decoder,
- fmt, xformat)
+ base_codec, fmt, xformat)
codec_set = True
if not codec_set:
@@ -725,9 +835,9 @@ cdef inline Codec get_core_codec(
if oid > MAXSUPPORTEDOID:
return None
if format == PG_FORMAT_BINARY:
- ptr = binary_codec_map[oid * xformat]
+ ptr = (codec_map).get_binary_codec_ptr(oid * xformat)
elif format == PG_FORMAT_TEXT:
- ptr = text_codec_map[oid * xformat]
+ ptr = (codec_map).get_text_codec_ptr(oid * xformat)
if ptr is NULL:
return None
@@ -753,7 +863,10 @@ cdef inline Codec get_any_core_codec(
cdef inline int has_core_codec(uint32_t oid):
- return binary_codec_map[oid] != NULL or text_codec_map[oid] != NULL
+ return (
+ (codec_map).get_binary_codec_ptr(oid) != NULL
+ or (codec_map).get_text_codec_ptr(oid) != NULL
+ )
cdef register_core_codec(uint32_t oid,
@@ -777,13 +890,13 @@ cdef register_core_codec(uint32_t oid,
codec = Codec(oid)
codec.init(name, 'pg_catalog', kind, CODEC_C, format, xformat,
- encode, decode, None, None, None, None, None, None, 0)
+ encode, decode, None, None, None, None, None, None, None, 0)
cpython.Py_INCREF(codec) # immortalize
if format == PG_FORMAT_BINARY:
- binary_codec_map[oid * xformat] = codec
+ (codec_map).set_binary_codec_ptr(oid * xformat, codec)
elif format == PG_FORMAT_TEXT:
- text_codec_map[oid * xformat] = codec
+ (codec_map).set_text_codec_ptr(oid * xformat, codec)
else:
raise exceptions.InternalClientError(
'invalid data format: {}'.format(format))
@@ -801,9 +914,9 @@ cdef register_extra_codec(str name,
codec = Codec(INVALIDOID)
codec.init(name, None, kind, CODEC_C, format, PG_XFORMAT_OBJECT,
- encode, decode, None, None, None, None, None, None, 0)
- EXTRA_CODECS[name, format] = codec
+ encode, decode, None, None, None, None, None, None, None, 0)
+ (codec_map).extra_codecs[name, format] = codec
cdef inline Codec get_extra_codec(str name, ServerDataFormat format):
- return EXTRA_CODECS.get((name, format))
+ return (codec_map).extra_codecs.get((name, format))
diff --git a/asyncpg/protocol/codecs/pgproto.pyx b/asyncpg/protocol/codecs/pgproto.pyx
index 11417d45..51d650d0 100644
--- a/asyncpg/protocol/codecs/pgproto.pyx
+++ b/asyncpg/protocol/codecs/pgproto.pyx
@@ -273,8 +273,9 @@ cdef init_pseudo_codecs():
FDW_HANDLEROID, TSM_HANDLEROID, INTERNALOID, OPAQUEOID,
ANYELEMENTOID, ANYNONARRAYOID, ANYCOMPATIBLEOID,
ANYCOMPATIBLEARRAYOID, ANYCOMPATIBLENONARRAYOID,
- ANYCOMPATIBLERANGEOID, PG_DDL_COMMANDOID, INDEX_AM_HANDLEROID,
- TABLE_AM_HANDLEROID,
+ ANYCOMPATIBLERANGEOID, ANYCOMPATIBLEMULTIRANGEOID,
+ ANYRANGEOID, ANYMULTIRANGEOID, ANYARRAYOID,
+ PG_DDL_COMMANDOID, INDEX_AM_HANDLEROID, TABLE_AM_HANDLEROID,
]
register_core_codec(ANYENUMOID,
@@ -330,6 +331,19 @@ cdef init_pseudo_codecs():
pgproto.bytea_decode,
PG_FORMAT_BINARY)
+ # These two are internal to BRIN index support and are unlikely
+ # to be sent, but since I/O functions for these exist, add decoders
+ # nonetheless.
+ register_core_codec(PG_BRIN_BLOOM_SUMMARYOID,
+ NULL,
+ pgproto.bytea_decode,
+ PG_FORMAT_BINARY)
+
+ register_core_codec(PG_BRIN_MINMAX_MULTI_SUMMARYOID,
+ NULL,
+ pgproto.bytea_decode,
+ PG_FORMAT_BINARY)
+
cdef init_text_codecs():
textoids = [
diff --git a/asyncpg/protocol/codecs/range.pyx b/asyncpg/protocol/codecs/range.pyx
index 2f598c1b..1038c18d 100644
--- a/asyncpg/protocol/codecs/range.pyx
+++ b/asyncpg/protocol/codecs/range.pyx
@@ -7,6 +7,8 @@
from asyncpg import types as apg_types
+from collections.abc import Sequence as SequenceABC
+
# defined in postgresql/src/include/utils/rangetypes.h
DEF RANGE_EMPTY = 0x01 # range is empty
DEF RANGE_LB_INC = 0x02 # lower bound is inclusive
@@ -139,11 +141,67 @@ cdef range_decode(ConnectionSettings settings, FRBuffer *buf,
empty=(flags & RANGE_EMPTY) != 0)
-cdef init_range_codecs():
- register_core_codec(ANYRANGEOID,
- NULL,
- pgproto.text_decode,
- PG_FORMAT_TEXT)
+cdef multirange_encode(ConnectionSettings settings, WriteBuffer buf,
+ object obj, uint32_t elem_oid,
+ encode_func_ex encoder, const void *encoder_arg):
+ cdef:
+ WriteBuffer elem_data
+ ssize_t elem_data_len
+ ssize_t elem_count
+
+ if not isinstance(obj, SequenceABC):
+ raise TypeError(
+ 'expected a sequence (got type {!r})'.format(type(obj).__name__)
+ )
+
+ elem_data = WriteBuffer.new()
+
+ for elem in obj:
+ range_encode(settings, elem_data, elem, elem_oid, encoder, encoder_arg)
+ elem_count = len(obj)
+ if elem_count > INT32_MAX:
+ raise OverflowError(f'too many elements in multirange value')
+
+ elem_data_len = elem_data.len()
+ if elem_data_len > INT32_MAX - 4:
+ raise OverflowError(
+ f'size of encoded multirange datum exceeds the maximum allowed'
+ f' {INT32_MAX - 4} bytes')
+
+ # Datum length
+ buf.write_int32(4 + elem_data_len)
+ # Number of elements in multirange
+ buf.write_int32(elem_count)
+ buf.write_buffer(elem_data)
+
+
+cdef multirange_decode(ConnectionSettings settings, FRBuffer *buf,
+ decode_func_ex decoder, const void *decoder_arg):
+ cdef:
+ int32_t nelems = hton.unpack_int32(frb_read(buf, 4))
+ FRBuffer elem_buf
+ int32_t elem_len
+ int i
+ list result
+
+ if nelems == 0:
+ return []
+
+ if nelems < 0:
+ raise exceptions.ProtocolError(
+ 'unexpected multirange size value: {}'.format(nelems))
+
+ result = cpython.PyList_New(nelems)
+ for i in range(nelems):
+ elem_len = hton.unpack_int32(frb_read(buf, 4))
+ if elem_len == -1:
+ raise exceptions.ProtocolError(
+ 'unexpected NULL element in multirange value')
+ else:
+ frb_slice_from(&elem_buf, buf, elem_len)
+ elem = range_decode(settings, &elem_buf, decoder, decoder_arg)
+ cpython.Py_INCREF(elem)
+ cpython.PyList_SET_ITEM(result, i, elem)
-init_range_codecs()
+ return result
diff --git a/asyncpg/protocol/coreproto.pxd b/asyncpg/protocol/coreproto.pxd
index f21559b4..34c7c712 100644
--- a/asyncpg/protocol/coreproto.pxd
+++ b/asyncpg/protocol/coreproto.pxd
@@ -51,16 +51,6 @@ cdef enum AuthenticationMessage:
AUTH_SASL_FINAL = 12
-AUTH_METHOD_NAME = {
- AUTH_REQUIRED_KERBEROS: 'kerberosv5',
- AUTH_REQUIRED_PASSWORD: 'password',
- AUTH_REQUIRED_PASSWORDMD5: 'md5',
- AUTH_REQUIRED_GSS: 'gss',
- AUTH_REQUIRED_SASL: 'scram-sha-256',
- AUTH_REQUIRED_SSPI: 'sspi',
-}
-
-
cdef enum ResultType:
RESULT_OK = 1
RESULT_FAILED = 2
@@ -96,10 +86,13 @@ cdef class CoreProtocol:
object transport
+ object address
# Instance of _ConnectionParameters
object con_params
# Instance of SCRAMAuthentication
SCRAMAuthentication scram
+ # Instance of gssapi.SecurityContext or sspilib.SecurityContext
+ object gss_ctx
readonly int32_t backend_pid
readonly int32_t backend_secret
@@ -145,6 +138,10 @@ cdef class CoreProtocol:
cdef _auth_password_message_md5(self, bytes salt)
cdef _auth_password_message_sasl_initial(self, list sasl_auth_methods)
cdef _auth_password_message_sasl_continue(self, bytes server_response)
+ cdef _auth_gss_init_gssapi(self)
+ cdef _auth_gss_init_sspi(self, bint negotiate)
+ cdef _auth_gss_get_service(self)
+ cdef _auth_gss_step(self, bytes server_response)
cdef _write(self, buf)
cdef _writelines(self, list buffers)
@@ -167,13 +164,14 @@ cdef class CoreProtocol:
cdef _connect(self)
- cdef _prepare(self, str stmt_name, str query)
+ cdef _prepare_and_describe(self, str stmt_name, str query)
+ cdef _send_parse_message(self, str stmt_name, str query)
cdef _send_bind_message(self, str portal_name, str stmt_name,
WriteBuffer bind_data, int32_t limit)
cdef _bind_execute(self, str portal_name, str stmt_name,
WriteBuffer bind_data, int32_t limit)
cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
- object bind_data)
+ object bind_data, bint return_rows)
cdef bint _bind_execute_many_more(self, bint first=*)
cdef _bind_execute_many_fail(self, object error, bint first=*)
cdef _bind(self, str portal_name, str stmt_name,
diff --git a/asyncpg/protocol/coreproto.pyx b/asyncpg/protocol/coreproto.pyx
index e7d7c2bc..da96c412 100644
--- a/asyncpg/protocol/coreproto.pyx
+++ b/asyncpg/protocol/coreproto.pyx
@@ -5,15 +5,26 @@
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
-from hashlib import md5 as hashlib_md5 # for MD5 authentication
+import hashlib
include "scram.pyx"
+AUTH_METHOD_NAME = {
+ AUTH_REQUIRED_KERBEROS: 'kerberosv5',
+ AUTH_REQUIRED_PASSWORD: 'password',
+ AUTH_REQUIRED_PASSWORDMD5: 'md5',
+ AUTH_REQUIRED_GSS: 'gss',
+ AUTH_REQUIRED_SASL: 'scram-sha-256',
+ AUTH_REQUIRED_SSPI: 'sspi',
+}
+
+
cdef class CoreProtocol:
- def __init__(self, con_params):
+ def __init__(self, addr, con_params):
+ self.address = addr
# type of `con_params` is `_ConnectionParameters`
self.buffer = ReadBuffer()
self.user = con_params.user
@@ -26,6 +37,9 @@ cdef class CoreProtocol:
self.encoding = 'utf-8'
# type of `scram` is `SCRAMAuthentcation`
self.scram = None
+ # type of `gss_ctx` is `gssapi.SecurityContext` or
+ # `sspilib.SecurityContext`
+ self.gss_ctx = None
self._reset_result()
@@ -150,15 +164,28 @@ cdef class CoreProtocol:
cdef _process__auth(self, char mtype):
if mtype == b'R':
# Authentication...
- self._parse_msg_authentication()
- if self.result_type != RESULT_OK:
+ try:
+ self._parse_msg_authentication()
+ except Exception as ex:
+ # Exception in authentication parsing code
+ # is usually either malformed authentication data
+ # or missing support for cryptographic primitives
+ # in the hashlib module.
+ self.result_type = RESULT_FAILED
+ self.result = apg_exc.InternalClientError(
+ f"unexpected error while performing authentication: {ex}")
+ self.result.__cause__ = ex
self.con_status = CONNECTION_BAD
self._push_result()
+ else:
+ if self.result_type != RESULT_OK:
+ self.con_status = CONNECTION_BAD
+ self._push_result()
- elif self.auth_msg is not None:
- # Server wants us to send auth data, so do that.
- self._write(self.auth_msg)
- self.auth_msg = None
+ elif self.auth_msg is not None:
+ # Server wants us to send auth data, so do that.
+ self._write(self.auth_msg)
+ self.auth_msg = None
elif mtype == b'K':
# BackendKeyData
@@ -224,6 +251,10 @@ cdef class CoreProtocol:
# ErrorResponse
self._parse_msg_error_response(True)
+ elif mtype == b'1':
+ # ParseComplete, in case `_bind_execute()` is reparsing
+ self.buffer.discard_message()
+
elif mtype == b'2':
# BindComplete
self.buffer.discard_message()
@@ -256,6 +287,10 @@ cdef class CoreProtocol:
# ErrorResponse
self._parse_msg_error_response(True)
+ elif mtype == b'1':
+ # ParseComplete, in case `_bind_execute_many()` is reparsing
+ self.buffer.discard_message()
+
elif mtype == b'2':
# BindComplete
self.buffer.discard_message()
@@ -598,22 +633,35 @@ cdef class CoreProtocol:
'could not verify server signature for '
'SCRAM authentciation: scram-sha-256',
)
+ self.scram = None
- elif status in (AUTH_REQUIRED_KERBEROS, AUTH_REQUIRED_SCMCRED,
- AUTH_REQUIRED_GSS, AUTH_REQUIRED_GSS_CONTINUE,
- AUTH_REQUIRED_SSPI):
- self.result_type = RESULT_FAILED
- self.result = apg_exc.InterfaceError(
- 'unsupported authentication method requested by the '
- 'server: {!r}'.format(AUTH_METHOD_NAME[status]))
+ elif status in (AUTH_REQUIRED_GSS, AUTH_REQUIRED_SSPI):
+ # AUTH_REQUIRED_SSPI is the same as AUTH_REQUIRED_GSS, except that
+ # it uses protocol negotiation with SSPI clients. Both methods use
+ # AUTH_REQUIRED_GSS_CONTINUE for subsequent authentication steps.
+ if self.gss_ctx is not None:
+ self.result_type = RESULT_FAILED
+ self.result = apg_exc.InterfaceError(
+ 'duplicate GSSAPI/SSPI authentication request')
+ else:
+ if self.con_params.gsslib == 'gssapi':
+ self._auth_gss_init_gssapi()
+ else:
+ self._auth_gss_init_sspi(status == AUTH_REQUIRED_SSPI)
+ self.auth_msg = self._auth_gss_step(None)
+
+ elif status == AUTH_REQUIRED_GSS_CONTINUE:
+ server_response = self.buffer.consume_message()
+ self.auth_msg = self._auth_gss_step(server_response)
else:
self.result_type = RESULT_FAILED
self.result = apg_exc.InterfaceError(
'unsupported authentication method requested by the '
- 'server: {}'.format(status))
+ 'server: {!r}'.format(AUTH_METHOD_NAME.get(status, status)))
- if status not in [AUTH_SASL_CONTINUE, AUTH_SASL_FINAL]:
+ if status not in (AUTH_SASL_CONTINUE, AUTH_SASL_FINAL,
+ AUTH_REQUIRED_GSS_CONTINUE):
self.buffer.discard_message()
cdef _auth_password_message_cleartext(self):
@@ -621,7 +669,7 @@ cdef class CoreProtocol:
WriteBuffer msg
msg = WriteBuffer.new_message(b'p')
- msg.write_bytestring(self.password.encode('ascii'))
+ msg.write_bytestring(self.password.encode(self.encoding))
msg.end_message()
return msg
@@ -633,11 +681,11 @@ cdef class CoreProtocol:
msg = WriteBuffer.new_message(b'p')
# 'md5' + md5(md5(password + username) + salt))
- userpass = ((self.password or '') + (self.user or '')).encode('ascii')
- hash = hashlib_md5(hashlib_md5(userpass).hexdigest().\
- encode('ascii') + salt).hexdigest().encode('ascii')
+ userpass = (self.password or '') + (self.user or '')
+ md5_1 = hashlib.md5(userpass.encode(self.encoding)).hexdigest()
+ md5_2 = hashlib.md5(md5_1.encode('ascii') + salt).hexdigest()
- msg.write_bytestring(b'md5' + hash)
+ msg.write_bytestring(b'md5' + md5_2.encode('ascii'))
msg.end_message()
return msg
@@ -670,6 +718,59 @@ cdef class CoreProtocol:
return msg
+ cdef _auth_gss_init_gssapi(self):
+ try:
+ import gssapi
+ except ModuleNotFoundError:
+ raise apg_exc.InterfaceError(
+ 'gssapi module not found; please install asyncpg[gssauth] to '
+ 'use asyncpg with Kerberos/GSSAPI/SSPI authentication'
+ ) from None
+
+ service_name, host = self._auth_gss_get_service()
+ self.gss_ctx = gssapi.SecurityContext(
+ name=gssapi.Name(
+ f'{service_name}@{host}', gssapi.NameType.hostbased_service),
+ usage='initiate')
+
+ cdef _auth_gss_init_sspi(self, bint negotiate):
+ try:
+ import sspilib
+ except ModuleNotFoundError:
+ raise apg_exc.InterfaceError(
+ 'sspilib module not found; please install asyncpg[gssauth] to '
+ 'use asyncpg with Kerberos/GSSAPI/SSPI authentication'
+ ) from None
+
+ service_name, host = self._auth_gss_get_service()
+ self.gss_ctx = sspilib.ClientSecurityContext(
+ target_name=f'{service_name}/{host}',
+ credential=sspilib.UserCredential(
+ protocol='Negotiate' if negotiate else 'Kerberos'))
+
+ cdef _auth_gss_get_service(self):
+ service_name = self.con_params.krbsrvname or 'postgres'
+ if isinstance(self.address, str):
+ raise apg_exc.InternalClientError(
+ 'GSSAPI/SSPI authentication is only supported for TCP/IP '
+ 'connections')
+
+ return service_name, self.address[0]
+
+ cdef _auth_gss_step(self, bytes server_response):
+ cdef:
+ WriteBuffer msg
+
+ token = self.gss_ctx.step(server_response)
+ if not token:
+ self.gss_ctx = None
+ return None
+ msg = WriteBuffer.new_message(b'p')
+ msg.write_bytes(token)
+ msg.end_message()
+
+ return msg
+
cdef _parse_msg_ready_for_query(self):
cdef char status = self.buffer.read_byte()
@@ -861,7 +962,15 @@ cdef class CoreProtocol:
outbuf.write_buffer(buf)
self._write(outbuf)
- cdef _prepare(self, str stmt_name, str query):
+ cdef _send_parse_message(self, str stmt_name, str query):
+ cdef:
+ WriteBuffer msg
+
+ self._ensure_connected()
+ msg = self._build_parse_message(stmt_name, query)
+ self._write(msg)
+
+ cdef _prepare_and_describe(self, str stmt_name, str query):
cdef:
WriteBuffer packet
WriteBuffer buf
@@ -911,12 +1020,12 @@ cdef class CoreProtocol:
self._send_bind_message(portal_name, stmt_name, bind_data, limit)
cdef bint _bind_execute_many(self, str portal_name, str stmt_name,
- object bind_data):
+ object bind_data, bint return_rows):
self._ensure_connected()
self._set_state(PROTOCOL_BIND_EXECUTE_MANY)
- self.result = None
- self._discard_data = True
+ self.result = [] if return_rows else None
+ self._discard_data = not return_rows
self._execute_iter = bind_data
self._execute_portal_name = portal_name
self._execute_stmt_name = stmt_name
@@ -1120,5 +1229,5 @@ cdef class CoreProtocol:
pass
-cdef bytes SYNC_MESSAGE = bytes(WriteBuffer.new_message(b'S').end_message())
-cdef bytes FLUSH_MESSAGE = bytes(WriteBuffer.new_message(b'H').end_message())
+SYNC_MESSAGE = bytes(WriteBuffer.new_message(b'S').end_message())
+FLUSH_MESSAGE = bytes(WriteBuffer.new_message(b'H').end_message())
diff --git a/asyncpg/protocol/encodings.pyx b/asyncpg/protocol/encodings.pyx
index dcd692b7..1463dbe4 100644
--- a/asyncpg/protocol/encodings.pyx
+++ b/asyncpg/protocol/encodings.pyx
@@ -10,7 +10,7 @@
https://www.postgresql.org/docs/current/static/multibyte.html#CHARSET-TABLE
'''
-cdef dict ENCODINGS_MAP = {
+ENCODINGS_MAP = {
'abc': 'cp1258',
'alt': 'cp866',
'euc_cn': 'euccn',
diff --git a/asyncpg/protocol/pgtypes.pxi b/asyncpg/protocol/pgtypes.pxi
index d0cc22a6..86f8e663 100644
--- a/asyncpg/protocol/pgtypes.pxi
+++ b/asyncpg/protocol/pgtypes.pxi
@@ -101,6 +101,10 @@ DEF JSONPATHOID = 4072
DEF REGNAMESPACEOID = 4089
DEF REGROLEOID = 4096
DEF REGCOLLATIONOID = 4191
+DEF ANYMULTIRANGEOID = 4537
+DEF ANYCOMPATIBLEMULTIRANGEOID = 4538
+DEF PG_BRIN_BLOOM_SUMMARYOID = 4600
+DEF PG_BRIN_MINMAX_MULTI_SUMMARYOID = 4601
DEF PG_MCV_LISTOID = 5017
DEF PG_SNAPSHOTOID = 5038
DEF XID8OID = 5069
@@ -109,18 +113,20 @@ DEF ANYCOMPATIBLEARRAYOID = 5078
DEF ANYCOMPATIBLENONARRAYOID = 5079
DEF ANYCOMPATIBLERANGEOID = 5080
-cdef ARRAY_TYPES = (_TEXTOID, _OIDOID,)
+ARRAY_TYPES = {_TEXTOID, _OIDOID}
BUILTIN_TYPE_OID_MAP = {
ABSTIMEOID: 'abstime',
ACLITEMOID: 'aclitem',
ANYARRAYOID: 'anyarray',
ANYCOMPATIBLEARRAYOID: 'anycompatiblearray',
+ ANYCOMPATIBLEMULTIRANGEOID: 'anycompatiblemultirange',
ANYCOMPATIBLENONARRAYOID: 'anycompatiblenonarray',
ANYCOMPATIBLEOID: 'anycompatible',
ANYCOMPATIBLERANGEOID: 'anycompatiblerange',
ANYELEMENTOID: 'anyelement',
ANYENUMOID: 'anyenum',
+ ANYMULTIRANGEOID: 'anymultirange',
ANYNONARRAYOID: 'anynonarray',
ANYOID: 'any',
ANYRANGEOID: 'anyrange',
@@ -161,6 +167,8 @@ BUILTIN_TYPE_OID_MAP = {
OIDOID: 'oid',
OPAQUEOID: 'opaque',
PATHOID: 'path',
+ PG_BRIN_BLOOM_SUMMARYOID: 'pg_brin_bloom_summary',
+ PG_BRIN_MINMAX_MULTI_SUMMARYOID: 'pg_brin_minmax_multi_summary',
PG_DDL_COMMANDOID: 'pg_ddl_command',
PG_DEPENDENCIESOID: 'pg_dependencies',
PG_LSNOID: 'pg_lsn',
diff --git a/asyncpg/protocol/prepared_stmt.pxd b/asyncpg/protocol/prepared_stmt.pxd
index 4427bfdc..369db733 100644
--- a/asyncpg/protocol/prepared_stmt.pxd
+++ b/asyncpg/protocol/prepared_stmt.pxd
@@ -10,6 +10,7 @@ cdef class PreparedStatementState:
readonly str name
readonly str query
readonly bint closed
+ readonly bint prepared
readonly int refs
readonly type record_class
readonly bint ignore_custom_codec
@@ -29,7 +30,7 @@ cdef class PreparedStatementState:
bint have_text_cols
tuple rows_codecs
- cdef _encode_bind_msg(self, args)
+ cdef _encode_bind_msg(self, args, int seqno = ?)
cpdef _init_codecs(self)
cdef _ensure_rows_decoder(self)
cdef _ensure_args_encoder(self)
diff --git a/asyncpg/protocol/prepared_stmt.pyx b/asyncpg/protocol/prepared_stmt.pyx
index 5f1820de..4145c664 100644
--- a/asyncpg/protocol/prepared_stmt.pyx
+++ b/asyncpg/protocol/prepared_stmt.pyx
@@ -27,6 +27,7 @@ cdef class PreparedStatementState:
self.args_num = self.cols_num = 0
self.cols_desc = None
self.closed = False
+ self.prepared = True
self.refs = 0
self.record_class = record_class
self.ignore_custom_codec = ignore_custom_codec
@@ -101,12 +102,31 @@ cdef class PreparedStatementState:
def mark_closed(self):
self.closed = True
- cdef _encode_bind_msg(self, args):
+ def mark_unprepared(self):
+ if self.name:
+ raise exceptions.InternalClientError(
+ "named prepared statements cannot be marked unprepared")
+ self.prepared = False
+
+ cdef _encode_bind_msg(self, args, int seqno = -1):
cdef:
int idx
WriteBuffer writer
Codec codec
+ if not cpython.PySequence_Check(args):
+ if seqno >= 0:
+ raise exceptions.DataError(
+ f'invalid input in executemany() argument sequence '
+ f'element #{seqno}: expected a sequence, got '
+ f'{type(args).__name__}'
+ )
+ else:
+ # Non executemany() callers do not pass user input directly,
+ # so bad input is a bug.
+ raise exceptions.InternalClientError(
+ f'Bind: expected a sequence, got {type(args).__name__}')
+
if len(args) > 32767:
raise exceptions.InterfaceError(
'the number of query arguments cannot exceed 32767')
@@ -122,7 +142,7 @@ cdef class PreparedStatementState:
# that the user tried to parametrize a statement that does
# not support parameters.
hint += (r' Note that parameters are supported only in'
- r' SELECT, INSERT, UPDATE, DELETE, and VALUES'
+ r' SELECT, INSERT, UPDATE, DELETE, MERGE and VALUES'
r' statements, and will *not* work in statements '
r' like CREATE VIEW or DECLARE CURSOR.')
@@ -138,7 +158,7 @@ cdef class PreparedStatementState:
writer.write_int16(self.args_num)
for idx in range(self.args_num):
codec = (self.args_codecs[idx])
- writer.write_int16(codec.format)
+ writer.write_int16(codec.format)
else:
# All arguments are in binary format
writer.write_int32(0x00010001)
@@ -159,25 +179,38 @@ cdef class PreparedStatementState:
except exceptions.InterfaceError as e:
# This is already a descriptive error, but annotate
# with argument name for clarity.
+ pos = f'${idx + 1}'
+ if seqno >= 0:
+ pos = (
+ f'{pos} in element #{seqno} of'
+ f' executemany() sequence'
+ )
raise e.with_msg(
- f'query argument ${idx + 1}: {e.args[0]}') from None
+ f'query argument {pos}: {e.args[0]}'
+ ) from None
except Exception as e:
# Everything else is assumed to be an encoding error
# due to invalid input.
+ pos = f'${idx + 1}'
+ if seqno >= 0:
+ pos = (
+ f'{pos} in element #{seqno} of'
+ f' executemany() sequence'
+ )
value_repr = repr(arg)
if len(value_repr) > 40:
value_repr = value_repr[:40] + '...'
raise exceptions.DataError(
- 'invalid input for query argument'
- ' ${n}: {v} ({msg})'.format(
- n=idx + 1, v=value_repr, msg=e)) from e
+ f'invalid input for query argument'
+ f' {pos}: {value_repr} ({e})'
+ ) from e
if self.have_text_cols:
writer.write_int16(self.cols_num)
for idx in range(self.cols_num):
codec = (self.rows_codecs[idx])
- writer.write_int16(codec.format)
+ writer.write_int16(codec.format)
else:
# All columns are in binary format
writer.write_int32(0x00010001)
@@ -197,7 +230,7 @@ cdef class PreparedStatementState:
return
if self.cols_num == 0:
- self.cols_desc = record.ApgRecordDesc_New({}, ())
+ self.cols_desc = RecordDescriptor({}, ())
return
cols_mapping = collections.OrderedDict()
@@ -219,7 +252,7 @@ cdef class PreparedStatementState:
codecs.append(codec)
- self.cols_desc = record.ApgRecordDesc_New(
+ self.cols_desc = RecordDescriptor(
cols_mapping, tuple(cols_names))
self.rows_codecs = tuple(codecs)
@@ -277,7 +310,7 @@ cdef class PreparedStatementState:
'different from what was described ({})'.format(
fnum, self.cols_num))
- dec_row = record.ApgRecord_New(self.record_class, self.cols_desc, fnum)
+ dec_row = self.cols_desc.make_record(self.record_class, fnum)
for i in range(fnum):
flen = hton.unpack_int32(frb_read(&rbuf, 4))
@@ -300,7 +333,7 @@ cdef class PreparedStatementState:
frb_set_len(&rbuf, bl - flen)
cpython.Py_INCREF(val)
- record.ApgRecord_SET_ITEM(dec_row, i, val)
+ recordcapi.ApgRecord_SET_ITEM(dec_row, i, val)
if frb_get_len(&rbuf) != 0:
raise BufferError('unexpected trailing {} bytes in buffer'.format(
diff --git a/asyncpg/protocol/protocol.pxd b/asyncpg/protocol/protocol.pxd
index 5f144e55..cd221fbb 100644
--- a/asyncpg/protocol/protocol.pxd
+++ b/asyncpg/protocol/protocol.pxd
@@ -31,7 +31,6 @@ cdef class BaseProtocol(CoreProtocol):
cdef:
object loop
- object address
ConnectionSettings settings
object cancel_sent_waiter
object cancel_waiter
@@ -39,8 +38,6 @@ cdef class BaseProtocol(CoreProtocol):
bint return_extra
object create_future
object timeout_handle
- object timeout_callback
- object completed_callback
object conref
type record_class
bint is_reading
diff --git a/asyncpg/protocol/protocol.pyi b/asyncpg/protocol/protocol.pyi
new file mode 100644
index 00000000..34db6440
--- /dev/null
+++ b/asyncpg/protocol/protocol.pyi
@@ -0,0 +1,282 @@
+import asyncio
+import asyncio.protocols
+import hmac
+from codecs import CodecInfo
+from collections.abc import Callable, Iterable, Sequence
+from hashlib import md5, sha256
+from typing import (
+ Any,
+ ClassVar,
+ Final,
+ Generic,
+ Literal,
+ NewType,
+ TypeVar,
+ final,
+ overload,
+)
+from typing_extensions import TypeAlias
+
+import asyncpg.pgproto.pgproto
+
+from ..connect_utils import _ConnectionParameters
+from ..pgproto.pgproto import WriteBuffer
+from ..types import Attribute, Type
+from .record import Record
+
+_Record = TypeVar('_Record', bound=Record)
+_OtherRecord = TypeVar('_OtherRecord', bound=Record)
+_PreparedStatementState = TypeVar(
+ '_PreparedStatementState', bound=PreparedStatementState[Any]
+)
+
+_NoTimeoutType = NewType('_NoTimeoutType', object)
+_TimeoutType: TypeAlias = float | None | _NoTimeoutType
+
+BUILTIN_TYPE_NAME_MAP: Final[dict[str, int]]
+BUILTIN_TYPE_OID_MAP: Final[dict[int, str]]
+NO_TIMEOUT: Final[_NoTimeoutType]
+
+hashlib_md5 = md5
+
+@final
+class ConnectionSettings(asyncpg.pgproto.pgproto.CodecContext):
+ __pyx_vtable__: Any
+ def __init__(self, conn_key: object) -> None: ...
+ def add_python_codec(
+ self,
+ typeoid: int,
+ typename: str,
+ typeschema: str,
+ typeinfos: Iterable[object],
+ typekind: str,
+ encoder: Callable[[Any], Any],
+ decoder: Callable[[Any], Any],
+ format: object,
+ ) -> Any: ...
+ def clear_type_cache(self) -> None: ...
+ def get_data_codec(
+ self, oid: int, format: object = ..., ignore_custom_codec: bool = ...
+ ) -> Any: ...
+ def get_text_codec(self) -> CodecInfo: ...
+ def register_data_types(self, types: Iterable[object]) -> None: ...
+ def remove_python_codec(
+ self, typeoid: int, typename: str, typeschema: str
+ ) -> None: ...
+ def set_builtin_type_codec(
+ self,
+ typeoid: int,
+ typename: str,
+ typeschema: str,
+ typekind: str,
+ alias_to: str,
+ format: object = ...,
+ ) -> Any: ...
+ def __getattr__(self, name: str) -> Any: ...
+ def __reduce__(self) -> Any: ...
+
+@final
+class PreparedStatementState(Generic[_Record]):
+ closed: bool
+ prepared: bool
+ name: str
+ query: str
+ refs: int
+ record_class: type[_Record]
+ ignore_custom_codec: bool
+ __pyx_vtable__: Any
+ def __init__(
+ self,
+ name: str,
+ query: str,
+ protocol: BaseProtocol[Any],
+ record_class: type[_Record],
+ ignore_custom_codec: bool,
+ ) -> None: ...
+ def _get_parameters(self) -> tuple[Type, ...]: ...
+ def _get_attributes(self) -> tuple[Attribute, ...]: ...
+ def _init_types(self) -> set[int]: ...
+ def _init_codecs(self) -> None: ...
+ def attach(self) -> None: ...
+ def detach(self) -> None: ...
+ def mark_closed(self) -> None: ...
+ def mark_unprepared(self) -> None: ...
+ def __reduce__(self) -> Any: ...
+
+class CoreProtocol:
+ backend_pid: Any
+ backend_secret: Any
+ __pyx_vtable__: Any
+ def __init__(self, addr: object, con_params: _ConnectionParameters) -> None: ...
+ def is_in_transaction(self) -> bool: ...
+ def __reduce__(self) -> Any: ...
+
+class BaseProtocol(CoreProtocol, Generic[_Record]):
+ queries_count: Any
+ is_ssl: bool
+ __pyx_vtable__: Any
+ def __init__(
+ self,
+ addr: object,
+ connected_fut: object,
+ con_params: _ConnectionParameters,
+ record_class: type[_Record],
+ loop: object,
+ ) -> None: ...
+ def set_connection(self, connection: object) -> None: ...
+ def get_server_pid(self, *args: object, **kwargs: object) -> int: ...
+ def get_settings(self, *args: object, **kwargs: object) -> ConnectionSettings: ...
+ def get_record_class(self) -> type[_Record]: ...
+ def abort(self) -> None: ...
+ async def bind(
+ self,
+ state: PreparedStatementState[_OtherRecord],
+ args: Sequence[object],
+ portal_name: str,
+ timeout: _TimeoutType,
+ ) -> Any: ...
+ @overload
+ async def bind_execute(
+ self,
+ state: PreparedStatementState[_OtherRecord],
+ args: Sequence[object],
+ portal_name: str,
+ limit: int,
+ return_extra: Literal[False],
+ timeout: _TimeoutType,
+ ) -> list[_OtherRecord]: ...
+ @overload
+ async def bind_execute(
+ self,
+ state: PreparedStatementState[_OtherRecord],
+ args: Sequence[object],
+ portal_name: str,
+ limit: int,
+ return_extra: Literal[True],
+ timeout: _TimeoutType,
+ ) -> tuple[list[_OtherRecord], bytes, bool]: ...
+ @overload
+ async def bind_execute(
+ self,
+ state: PreparedStatementState[_OtherRecord],
+ args: Sequence[object],
+ portal_name: str,
+ limit: int,
+ return_extra: bool,
+ timeout: _TimeoutType,
+ ) -> list[_OtherRecord] | tuple[list[_OtherRecord], bytes, bool]: ...
+ async def bind_execute_many(
+ self,
+ state: PreparedStatementState[_OtherRecord],
+ args: Iterable[Sequence[object]],
+ portal_name: str,
+ timeout: _TimeoutType,
+ ) -> None: ...
+ async def close(self, timeout: _TimeoutType) -> None: ...
+ def _get_timeout(self, timeout: _TimeoutType) -> float | None: ...
+ def _is_cancelling(self) -> bool: ...
+ async def _wait_for_cancellation(self) -> None: ...
+ async def close_statement(
+ self, state: PreparedStatementState[_OtherRecord], timeout: _TimeoutType
+ ) -> Any: ...
+ async def copy_in(self, *args: object, **kwargs: object) -> str: ...
+ async def copy_out(self, *args: object, **kwargs: object) -> str: ...
+ async def execute(self, *args: object, **kwargs: object) -> Any: ...
+ def is_closed(self, *args: object, **kwargs: object) -> Any: ...
+ def is_connected(self, *args: object, **kwargs: object) -> Any: ...
+ def data_received(self, data: object) -> None: ...
+ def connection_made(self, transport: object) -> None: ...
+ def connection_lost(self, exc: Exception | None) -> None: ...
+ def pause_writing(self, *args: object, **kwargs: object) -> Any: ...
+ @overload
+ async def prepare(
+ self,
+ stmt_name: str,
+ query: str,
+ timeout: float | None = ...,
+ *,
+ state: _PreparedStatementState,
+ ignore_custom_codec: bool = ...,
+ record_class: None,
+ ) -> _PreparedStatementState: ...
+ @overload
+ async def prepare(
+ self,
+ stmt_name: str,
+ query: str,
+ timeout: float | None = ...,
+ *,
+ state: None = ...,
+ ignore_custom_codec: bool = ...,
+ record_class: type[_OtherRecord],
+ ) -> PreparedStatementState[_OtherRecord]: ...
+ async def close_portal(self, portal_name: str, timeout: _TimeoutType) -> None: ...
+ async def query(self, *args: object, **kwargs: object) -> str: ...
+ def resume_writing(self, *args: object, **kwargs: object) -> Any: ...
+ def __reduce__(self) -> Any: ...
+
+@final
+class Codec:
+ __pyx_vtable__: Any
+ def __reduce__(self) -> Any: ...
+
+class DataCodecConfig:
+ __pyx_vtable__: Any
+ def __init__(self) -> None: ...
+ def add_python_codec(
+ self,
+ typeoid: int,
+ typename: str,
+ typeschema: str,
+ typekind: str,
+ typeinfos: Iterable[object],
+ encoder: Callable[[ConnectionSettings, WriteBuffer, object], object],
+ decoder: Callable[..., object],
+ format: object,
+ xformat: object,
+ ) -> Any: ...
+ def add_types(self, types: Iterable[object]) -> Any: ...
+ def clear_type_cache(self) -> None: ...
+ def declare_fallback_codec(self, oid: int, name: str, schema: str) -> Codec: ...
+ def remove_python_codec(
+ self, typeoid: int, typename: str, typeschema: str
+ ) -> Any: ...
+ def set_builtin_type_codec(
+ self,
+ typeoid: int,
+ typename: str,
+ typeschema: str,
+ typekind: str,
+ alias_to: str,
+ format: object = ...,
+ ) -> Any: ...
+ def __reduce__(self) -> Any: ...
+
+class Protocol(BaseProtocol[_Record], asyncio.protocols.Protocol): ...
+
+class Timer:
+ def __init__(self, budget: float | None) -> None: ...
+ def __enter__(self) -> None: ...
+ def __exit__(self, et: object, e: object, tb: object) -> None: ...
+ def get_remaining_budget(self) -> float: ...
+ def has_budget_greater_than(self, amount: float) -> bool: ...
+
+@final
+class SCRAMAuthentication:
+ AUTHENTICATION_METHODS: ClassVar[list[str]]
+ DEFAULT_CLIENT_NONCE_BYTES: ClassVar[int]
+ DIGEST = sha256
+ REQUIREMENTS_CLIENT_FINAL_MESSAGE: ClassVar[list[str]]
+ REQUIREMENTS_CLIENT_PROOF: ClassVar[list[str]]
+ SASLPREP_PROHIBITED: ClassVar[tuple[Callable[[str], bool], ...]]
+ authentication_method: bytes
+ authorization_message: bytes | None
+ client_channel_binding: bytes
+ client_first_message_bare: bytes | None
+ client_nonce: bytes | None
+ client_proof: bytes | None
+ password_salt: bytes | None
+ password_iterations: int
+ server_first_message: bytes | None
+ server_key: hmac.HMAC | None
+ server_nonce: bytes | None
diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx
index dbe52e9e..acce4e9f 100644
--- a/asyncpg/protocol/protocol.pyx
+++ b/asyncpg/protocol/protocol.pyx
@@ -34,11 +34,11 @@ from asyncpg.pgproto.pgproto cimport (
from asyncpg.pgproto cimport pgproto
from asyncpg.protocol cimport cpythonx
-from asyncpg.protocol cimport record
+from asyncpg.protocol cimport recordcapi
from libc.stdint cimport int8_t, uint8_t, int16_t, uint16_t, \
int32_t, uint32_t, int64_t, uint64_t, \
- UINT32_MAX
+ INT32_MAX, UINT32_MAX
from asyncpg.exceptions import _base as apg_exc_base
from asyncpg import compat
@@ -46,6 +46,7 @@ from asyncpg import types as apg_types
from asyncpg import exceptions as apg_exc
from asyncpg.pgproto cimport hton
+from asyncpg.protocol.record import Record, RecordDescriptor
include "consts.pxi"
@@ -75,7 +76,7 @@ NO_TIMEOUT = object()
cdef class BaseProtocol(CoreProtocol):
def __init__(self, addr, connected_fut, con_params, record_class: type, loop):
# type of `con_params` is `_ConnectionParameters`
- CoreProtocol.__init__(self, con_params)
+ CoreProtocol.__init__(self, addr, con_params)
self.loop = loop
self.transport = None
@@ -83,8 +84,7 @@ cdef class BaseProtocol(CoreProtocol):
self.cancel_waiter = None
self.cancel_sent_waiter = None
- self.address = addr
- self.settings = ConnectionSettings((self.address, con_params.database))
+ self.settings = ConnectionSettings((addr, con_params.database))
self.record_class = record_class
self.statement = None
@@ -98,8 +98,6 @@ cdef class BaseProtocol(CoreProtocol):
self.writing_allowed.set()
self.timeout_handle = None
- self.timeout_callback = self._on_timeout
- self.completed_callback = self._on_waiter_completed
self.queries_count = 0
@@ -138,7 +136,6 @@ cdef class BaseProtocol(CoreProtocol):
self.is_reading = False
self.transport.pause_reading()
- @cython.iterable_coroutine
async def prepare(self, stmt_name, query, timeout,
*,
PreparedStatementState state=None,
@@ -155,7 +152,7 @@ cdef class BaseProtocol(CoreProtocol):
waiter = self._new_waiter(timeout)
try:
- self._prepare(stmt_name, query) # network op
+ self._prepare_and_describe(stmt_name, query) # network op
self.last_query = query
if state is None:
state = PreparedStatementState(
@@ -167,11 +164,15 @@ cdef class BaseProtocol(CoreProtocol):
finally:
return await waiter
- @cython.iterable_coroutine
- async def bind_execute(self, PreparedStatementState state, args,
- str portal_name, int limit, return_extra,
- timeout):
-
+ async def bind_execute(
+ self,
+ state: PreparedStatementState,
+ args,
+ portal_name: str,
+ limit: int,
+ return_extra: bool,
+ timeout,
+ ):
if self.cancel_waiter is not None:
await self.cancel_waiter
if self.cancel_sent_waiter is not None:
@@ -184,6 +185,9 @@ cdef class BaseProtocol(CoreProtocol):
waiter = self._new_waiter(timeout)
try:
+ if not state.prepared:
+ self._send_parse_message(state.name, state.query)
+
self._bind_execute(
portal_name,
state.name,
@@ -200,10 +204,14 @@ cdef class BaseProtocol(CoreProtocol):
finally:
return await waiter
- @cython.iterable_coroutine
- async def bind_execute_many(self, PreparedStatementState state, args,
- str portal_name, timeout):
-
+ async def bind_execute_many(
+ self,
+ state: PreparedStatementState,
+ args,
+ portal_name: str,
+ timeout,
+ return_rows: bool,
+ ):
if self.cancel_waiter is not None:
await self.cancel_waiter
if self.cancel_sent_waiter is not None:
@@ -217,15 +225,19 @@ cdef class BaseProtocol(CoreProtocol):
# Make sure the argument sequence is encoded lazily with
# this generator expression to keep the memory pressure under
# control.
- data_gen = (state._encode_bind_msg(b) for b in args)
+ data_gen = (state._encode_bind_msg(b, i) for i, b in enumerate(args))
arg_bufs = iter(data_gen)
waiter = self._new_waiter(timeout)
try:
+ if not state.prepared:
+ self._send_parse_message(state.name, state.query)
+
more = self._bind_execute_many(
portal_name,
state.name,
- arg_bufs) # network op
+ arg_bufs,
+ return_rows) # network op
self.last_query = state.query
self.statement = state
@@ -234,7 +246,7 @@ cdef class BaseProtocol(CoreProtocol):
while more:
with timer:
- await asyncio.wait_for(
+ await compat.wait_for(
self.writing_allowed.wait(),
timeout=timer.get_remaining_budget())
# On Windows the above event somehow won't allow context
@@ -253,7 +265,6 @@ cdef class BaseProtocol(CoreProtocol):
finally:
return await waiter
- @cython.iterable_coroutine
async def bind(self, PreparedStatementState state, args,
str portal_name, timeout):
@@ -282,7 +293,6 @@ cdef class BaseProtocol(CoreProtocol):
finally:
return await waiter
- @cython.iterable_coroutine
async def execute(self, PreparedStatementState state,
str portal_name, int limit, return_extra,
timeout):
@@ -312,7 +322,28 @@ cdef class BaseProtocol(CoreProtocol):
finally:
return await waiter
- @cython.iterable_coroutine
+ async def close_portal(self, str portal_name, timeout):
+
+ if self.cancel_waiter is not None:
+ await self.cancel_waiter
+ if self.cancel_sent_waiter is not None:
+ await self.cancel_sent_waiter
+ self.cancel_sent_waiter = None
+
+ self._check_state()
+ timeout = self._get_timeout_impl(timeout)
+
+ waiter = self._new_waiter(timeout)
+ try:
+ self._close(
+ portal_name,
+ True) # network op
+ except Exception as ex:
+ waiter.set_exception(ex)
+ self._coreproto_error()
+ finally:
+ return await waiter
+
async def query(self, query, timeout):
if self.cancel_waiter is not None:
await self.cancel_waiter
@@ -337,7 +368,6 @@ cdef class BaseProtocol(CoreProtocol):
finally:
return await waiter
- @cython.iterable_coroutine
async def copy_out(self, copy_stmt, sink, timeout):
if self.cancel_waiter is not None:
await self.cancel_waiter
@@ -368,7 +398,7 @@ cdef class BaseProtocol(CoreProtocol):
if buffer:
try:
with timer:
- await asyncio.wait_for(
+ await compat.wait_for(
sink(buffer),
timeout=timer.get_remaining_budget())
except (Exception, asyncio.CancelledError) as ex:
@@ -391,7 +421,6 @@ cdef class BaseProtocol(CoreProtocol):
return status_msg
- @cython.iterable_coroutine
async def copy_in(self, copy_stmt, reader, data,
records, PreparedStatementState record_stmt, timeout):
cdef:
@@ -496,7 +525,7 @@ cdef class BaseProtocol(CoreProtocol):
with timer:
await self.writing_allowed.wait()
with timer:
- chunk = await asyncio.wait_for(
+ chunk = await compat.wait_for(
iterator.__anext__(),
timeout=timer.get_remaining_budget())
self._write_copy_data_msg(chunk)
@@ -530,7 +559,6 @@ cdef class BaseProtocol(CoreProtocol):
return status_msg
- @cython.iterable_coroutine
async def close_statement(self, PreparedStatementState state, timeout):
if self.cancel_waiter is not None:
await self.cancel_waiter
@@ -569,8 +597,8 @@ cdef class BaseProtocol(CoreProtocol):
self._handle_waiter_on_connection_lost(None)
self._terminate()
self.transport.abort()
+ self.transport = None
- @cython.iterable_coroutine
async def close(self, timeout):
if self.closing:
return
@@ -609,7 +637,7 @@ cdef class BaseProtocol(CoreProtocol):
pass
finally:
self.waiter = None
- self.transport.abort()
+ self.transport.abort()
def _request_cancel(self):
self.cancel_waiter = self.create_future()
@@ -647,12 +675,12 @@ cdef class BaseProtocol(CoreProtocol):
self.waiter.set_exception(asyncio.TimeoutError())
def _on_waiter_completed(self, fut):
+ if self.timeout_handle:
+ self.timeout_handle.cancel()
+ self.timeout_handle = None
if fut is not self.waiter or self.cancel_waiter is not None:
return
if fut.cancelled():
- if self.timeout_handle:
- self.timeout_handle.cancel()
- self.timeout_handle = None
self._request_cancel()
def _create_future_fallback(self):
@@ -713,7 +741,6 @@ cdef class BaseProtocol(CoreProtocol):
self.cancel_sent_waiter is not None
)
- @cython.iterable_coroutine
async def _wait_for_cancellation(self):
if self.cancel_sent_waiter is not None:
await self.cancel_sent_waiter
@@ -739,8 +766,8 @@ cdef class BaseProtocol(CoreProtocol):
self.waiter = self.create_future()
if timeout is not None:
self.timeout_handle = self.loop.call_later(
- timeout, self.timeout_callback, self.waiter)
- self.waiter.add_done_callback(self.completed_callback)
+ timeout, self._on_timeout, self.waiter)
+ self.waiter.add_done_callback(self._on_waiter_completed)
return self.waiter
cdef _on_result__connect(self, object waiter):
@@ -1011,17 +1038,14 @@ def _create_record(object mapping, tuple elems):
int32_t i
if mapping is None:
- desc = record.ApgRecordDesc_New({}, ())
+ desc = RecordDescriptor({}, ())
else:
- desc = record.ApgRecordDesc_New(
+ desc = RecordDescriptor(
mapping, tuple(mapping) if mapping else ())
- rec = record.ApgRecord_New(Record, desc, len(elems))
+ rec = desc.make_record(Record, len(elems))
for i in range(len(elems)):
elem = elems[i]
cpython.Py_INCREF(elem)
- record.ApgRecord_SET_ITEM(rec, i, elem)
+ recordcapi.ApgRecord_SET_ITEM(rec, i, elem)
return rec
-
-
-Record =