diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 42c04d4..09288c5 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -2,31 +2,39 @@ name: build on: [push, pull_request] jobs: build: - runs-on: ubuntu-24.04 + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-26] + runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - uses: ankane/setup-postgres@v1 with: database: pgvector_cpp_test dev-files: true - run: | cd /tmp - git clone --branch v0.8.0 https://github.com/pgvector/pgvector.git + git clone --branch v0.8.2 https://github.com/pgvector/pgvector.git cd pgvector make sudo make install - - run: cmake -S . -B build -DBUILD_TESTING=ON -DCMAKE_CXX_STANDARD=17 + - run: cmake -S . -B build -DBUILD_TESTING=ON -DCMAKE_CXX_STANDARD=20 - run: cmake --build build - run: build/test - - run: cmake -S . -B build -DBUILD_TESTING=ON -DCMAKE_CXX_STANDARD=20 + - run: cmake -S . -B build -DBUILD_TESTING=ON -DCMAKE_CXX_STANDARD=23 - run: cmake --build build - run: build/test - - run: | + - if: ${{ runner.os == 'Linux' }} + run: | sudo apt-get install valgrind - valgrind --leak-check=yes build/test + valgrind --leak-check=yes --error-exitcode=1 build/test + + - if: ${{ runner.os == 'macOS' }} + run: /opt/homebrew/opt/llvm@20/bin/scan-build --status-bugs cmake --build build --clean-first # test install - run: rm -r build diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c57cbc..7eaa3a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,20 @@ +## 0.3.0 (2026-03-08) + +- Added support for libpqxx 8 +- Changed `HalfVector` to use `std::float16_t` or `_Float16` when available +- Replaced `std::vector` conversion with `values` function for `Vector` and `HalfVector` +- Removed default constructors (no longer needed for streaming) +- Dropped support for libpqxx 7 +- Dropped support for C++17 + +## 0.2.4 (2025-09-12) + +- Added `from_string` support for `SparseVector` + +## 0.2.3 (2025-07-13) + +- Fixed `duplicate symbol` errors + ## 0.2.2 (2025-02-23) - Added map constructor to `SparseVector` diff --git a/CMakeLists.txt b/CMakeLists.txt index 5f5e98a..e7f5a95 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,13 +1,13 @@ cmake_minimum_required(VERSION 3.18) -project(pgvector VERSION 0.2.2 LANGUAGES CXX) +project(pgvector VERSION 0.3.0 LANGUAGES CXX) include(GNUInstallDirs) add_library(pgvector INTERFACE) add_library(pgvector::pgvector ALIAS pgvector) -target_compile_features(pgvector INTERFACE cxx_std_17) +target_compile_features(pgvector INTERFACE cxx_std_20) target_include_directories( pgvector @@ -26,13 +26,13 @@ if(CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR) if(BUILD_TESTING) include(FetchContent) - FetchContent_Declare(libpqxx GIT_REPOSITORY https://github.com/jtv/libpqxx.git GIT_TAG 7.10.1) + FetchContent_Declare(libpqxx GIT_REPOSITORY https://github.com/jtv/libpqxx.git GIT_TAG 8.0.0) FetchContent_MakeAvailable(libpqxx) add_executable(test test/halfvec_test.cpp test/main.cpp test/pqxx_test.cpp test/sparsevec_test.cpp test/vector_test.cpp) target_link_libraries(test PRIVATE libpqxx::pqxx pgvector::pgvector) if(NOT MSVC) - target_compile_options(test PRIVATE -Wall -Wextra -Wpedantic -Werror) + target_compile_options(test PRIVATE -Wall -Wextra -Wpedantic -Wconversion -Werror) endif() endif() endif() diff --git a/LICENSE.txt b/LICENSE.txt index b612d6d..17e5210 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,6 +1,6 @@ The MIT License (MIT) -Copyright (c) 2021-2025 Andrew Kane +Copyright (c) 2021-2026 Andrew Kane Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index 9879343..7e4aecd 100644 --- a/README.md +++ b/README.md @@ -8,14 +8,14 @@ Supports [libpqxx](https://github.com/jtv/libpqxx) ## Installation -Add [the headers](https://github.com/pgvector/pgvector-cpp/tree/v0.2.2/include) to your project (supports C++17 and greater). +Add [the headers](https://github.com/pgvector/pgvector-cpp/tree/v0.3.0/include) to your project (supports C++20 and greater). There is also support for CMake and FetchContent: ```cmake include(FetchContent) -FetchContent_Declare(pgvector GIT_REPOSITORY https://github.com/pgvector/pgvector-cpp.git GIT_TAG v0.2.2) +FetchContent_Declare(pgvector GIT_REPOSITORY https://github.com/pgvector/pgvector-cpp.git GIT_TAG v0.3.0) FetchContent_MakeAvailable(pgvector) target_link_libraries(app PRIVATE pgvector::pgvector) @@ -46,6 +46,8 @@ Include the header #include ``` +The latest version works libpqxx 8. For libpqxx 7, use version 0.2.4 and [this readme](https://github.com/pgvector/pgvector-cpp/blob/v0.2.4/README.md#libpqxx). + Enable the extension ```cpp @@ -61,7 +63,7 @@ tx.exec("CREATE TABLE items (id bigserial PRIMARY KEY, embedding vector(3))"); Insert a vector ```cpp -auto embedding = pgvector::Vector({1, 2, 3}); +pgvector::Vector embedding{{1, 2, 3}}; tx.exec("INSERT INTO items (embedding) VALUES ($1)", {embedding}); ``` @@ -74,7 +76,7 @@ pqxx::result r = tx.exec("SELECT * FROM items ORDER BY embedding <-> $1 LIMIT 5" Retrieve a vector ```cpp -auto row = tx.exec("SELECT embedding FROM items LIMIT 1").one_row(); +pqxx::row row = tx.exec("SELECT embedding FROM items LIMIT 1").one_row(); auto embedding = row[0].as(); ``` @@ -84,47 +86,69 @@ Use `std::optional` if the value could be `NULL` ### Vectors -Create a vector from a `std::vector` +Create a vector from a `std::vector` + +```cpp +pgvector::Vector vec{std::vector{1, 2, 3}}; +``` + +Or a span ```cpp -auto vec = pgvector::Vector({1, 2, 3}); +pgvector::Vector vec{std::span{{1, 2, 3}}}; ``` -Convert to a `std::vector` +Get a `std::vector` ```cpp -auto float_vec = static_cast>(vec); +const std::vector& values = vec.values(); ``` ### Half Vectors -Create a half vector from a `std::vector` +Create a half vector from a `std::vector` ```cpp -auto vec = pgvector::HalfVector({1, 2, 3}); +pgvector::HalfVector vec{std::vector{1, 2, 3}}; ``` -Convert to a `std::vector` +Note: `pgvector::Half` is `std::float16_t` or `_Float16` when available, or `float` otherwise + +Or a span ```cpp -auto float_vec = static_cast>(vec); +pgvector::HalfVector vec{std::span{{1, 2, 3}}}; +``` + +Get a `std::vector` + +```cpp +const std::vector& values = vec.values(); ``` ### Sparse Vectors -Create a sparse vector from a `std::vector` +Create a sparse vector from a `std::vector` ```cpp -auto vec = pgvector::SparseVector({1, 0, 2, 0, 3, 0}); +pgvector::SparseVector vec{std::vector{1, 0, 2, 0, 3, 0}}; +``` + +Or a span + +```cpp +pgvector::SparseVector vec{std::span{{1, 0, 2, 0, 3, 0}}}; ``` Or a map of non-zero elements ```cpp -std::unordered_map map = {{0, 1}, {2, 2}, {4, 3}}; -auto vec = pgvector::SparseVector(map, 6); +std::unordered_map map{{0, 1}, {2, 2}, {4, 3}}; +pgvector::SparseVector vec{map, 6}; ``` +Note: Indices start at 0 + Get the number of dimensions ```cpp @@ -134,13 +158,13 @@ int dim = vec.dimensions(); Get the indices of non-zero elements ```cpp -auto indices = vec.indices(); +const std::vector& indices = vec.indices(); ``` Get the values of non-zero elements ```cpp -auto values = vec.values(); +const std::vector& values = vec.values(); ``` ## History diff --git a/examples/citus/CMakeLists.txt b/examples/citus/CMakeLists.txt index cdaa70f..f05b59c 100644 --- a/examples/citus/CMakeLists.txt +++ b/examples/citus/CMakeLists.txt @@ -2,11 +2,11 @@ cmake_minimum_required(VERSION 3.18) project(example) -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) include(FetchContent) -FetchContent_Declare(libpqxx GIT_REPOSITORY https://github.com/jtv/libpqxx.git GIT_TAG 7.10.1) +FetchContent_Declare(libpqxx GIT_REPOSITORY https://github.com/jtv/libpqxx.git GIT_TAG 8.0.0) FetchContent_MakeAvailable(libpqxx) add_subdirectory("${PROJECT_SOURCE_DIR}/../.." pgvector) diff --git a/examples/citus/example.cpp b/examples/citus/example.cpp index 7ac53c5..1222cb8 100644 --- a/examples/citus/example.cpp +++ b/examples/citus/example.cpp @@ -8,8 +8,8 @@ std::vector> random_embeddings(int rows, int dimensions) { std::random_device rd; - std::mt19937_64 prng(rd()); - std::uniform_real_distribution dist(0, 1); + std::mt19937_64 prng{rd()}; + std::uniform_real_distribution dist{0, 1}; std::vector> embeddings; embeddings.reserve(rows); @@ -26,8 +26,8 @@ std::vector> random_embeddings(int rows, int dimensions) { std::vector random_categories(int rows) { std::random_device rd; - std::mt19937_64 prng(rd()); - std::uniform_int_distribution dist(1, 100); + std::mt19937_64 prng{rd()}; + std::uniform_int_distribution dist{1, 100}; std::vector categories; categories.reserve(rows); @@ -41,13 +41,13 @@ int main() { // generate random data int rows = 100000; int dimensions = 128; - auto embeddings = random_embeddings(rows, dimensions); - auto categories = random_categories(rows); - auto queries = random_embeddings(10, dimensions); + std::vector> embeddings = random_embeddings(rows, dimensions); + std::vector categories = random_categories(rows); + std::vector> queries = random_embeddings(10, dimensions); // enable extensions - pqxx::connection conn("dbname=pgvector_citus"); - pqxx::nontransaction tx(conn); + pqxx::connection conn{"dbname=pgvector_citus"}; + pqxx::nontransaction tx{conn}; tx.exec("CREATE EXTENSION IF NOT EXISTS citus"); tx.exec("CREATE EXTENSION IF NOT EXISTS vector"); @@ -61,8 +61,8 @@ int main() { conn.close(); // reconnect for updated GUC variables to take effect - pqxx::connection conn2("dbname=pgvector_citus"); - pqxx::nontransaction tx2(conn2); + pqxx::connection conn2{"dbname=pgvector_citus"}; + pqxx::nontransaction tx2{conn2}; std::cout << "Creating distributed table" << std::endl; tx2.exec("DROP TABLE IF EXISTS items"); @@ -72,9 +72,9 @@ int main() { // libpqxx does not support binary COPY std::cout << "Loading data in parallel" << std::endl; - auto stream = pqxx::stream_to::table(tx2, {"items"}, {"embedding", "category_id"}); + pqxx::stream_to stream = pqxx::stream_to::table(tx2, {"items"}, {"embedding", "category_id"}); for (size_t i = 0; i < embeddings.size(); i++) { - stream << std::make_tuple(pgvector::Vector(embeddings[i]), categories[i]); + stream.write_values(pgvector::Vector{embeddings[i]}, categories[i]); } stream.complete(); @@ -82,10 +82,10 @@ int main() { tx2.exec("CREATE INDEX ON items USING hnsw (embedding vector_l2_ops)"); std::cout << "Running distributed queries" << std::endl; - for (auto& query : queries) { + for (const auto& query : queries) { pqxx::result result = tx2.exec( "SELECT id FROM items ORDER BY embedding <-> $1 LIMIT 10", - pqxx::params{pgvector::Vector(query)} + pqxx::params{pgvector::Vector{query}} ); for (const auto& row : result) { std::cout << row[0].as() << " "; diff --git a/examples/cohere/CMakeLists.txt b/examples/cohere/CMakeLists.txt index c2acad2..5c40779 100644 --- a/examples/cohere/CMakeLists.txt +++ b/examples/cohere/CMakeLists.txt @@ -2,13 +2,14 @@ cmake_minimum_required(VERSION 3.18) project(example) -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) +set(CPR_USE_SYSTEM_CURL ON) include(FetchContent) -FetchContent_Declare(cpr GIT_REPOSITORY https://github.com/libcpr/cpr.git GIT_TAG 1.11.1) -FetchContent_Declare(json GIT_REPOSITORY https://github.com/nlohmann/json.git GIT_TAG v3.11.3) -FetchContent_Declare(libpqxx GIT_REPOSITORY https://github.com/jtv/libpqxx.git GIT_TAG 7.10.1) +FetchContent_Declare(cpr GIT_REPOSITORY https://github.com/libcpr/cpr.git GIT_TAG 1.14.2) +FetchContent_Declare(json URL https://github.com/nlohmann/json/releases/download/v3.12.0/json.tar.xz DOWNLOAD_EXTRACT_TIMESTAMP TRUE) +FetchContent_Declare(libpqxx GIT_REPOSITORY https://github.com/jtv/libpqxx.git GIT_TAG 8.0.0) FetchContent_MakeAvailable(cpr json libpqxx) add_subdirectory("${PROJECT_SOURCE_DIR}/../.." pgvector) diff --git a/examples/cohere/example.cpp b/examples/cohere/example.cpp index 6660f45..e076476 100644 --- a/examples/cohere/example.cpp +++ b/examples/cohere/example.cpp @@ -12,9 +12,13 @@ using json = nlohmann::json; // https://docs.cohere.com/reference/embed -std::vector embed(const std::vector& texts, const std::string& input_type, char *api_key) { - std::string url = "https://api.cohere.com/v2/embed"; - json data = { +std::vector embed( + const std::vector& texts, + const std::string& input_type, + char* api_key +) { + std::string url{"https://api.cohere.com/v2/embed"}; + json data{ {"texts", texts}, {"model", "embed-v4.0"}, {"input_type", input_type}, @@ -28,12 +32,12 @@ std::vector embed(const std::vector& texts, const std: cpr::Header{{"Content-Type", "application/json"}} ); if (r.status_code != 200) { - throw std::runtime_error("Bad status: " + std::to_string(r.status_code)); + throw std::runtime_error{"Bad status: " + std::to_string(r.status_code)}; } json response = json::parse(r.text); std::vector embeddings; - for (auto& v : response["embeddings"]["ubinary"]) { + for (const auto& v : response["embeddings"]["ubinary"]) { std::stringstream buf; for (uint8_t c : v) { std::bitset<8> b{c}; @@ -45,31 +49,31 @@ std::vector embed(const std::vector& texts, const std: } int main() { - char *api_key = std::getenv("CO_API_KEY"); + char* api_key = std::getenv("CO_API_KEY"); if (!api_key) { std::cout << "Set CO_API_KEY" << std::endl; return 1; } - pqxx::connection conn("dbname=pgvector_example"); + pqxx::connection conn{"dbname=pgvector_example"}; - pqxx::nontransaction tx(conn); + pqxx::nontransaction tx{conn}; tx.exec("CREATE EXTENSION IF NOT EXISTS vector"); tx.exec("DROP TABLE IF EXISTS documents"); tx.exec("CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding bit(1536))"); - std::vector input = { + std::vector input{ "The dog is barking", "The cat is purring", "The bear is growling" }; - auto embeddings = embed(input, "search_document", api_key); + std::vector embeddings = embed(input, "search_document", api_key); for (size_t i = 0; i < input.size(); i++) { tx.exec("INSERT INTO documents (content, embedding) VALUES ($1, $2)", pqxx::params{input[i], embeddings[i]}); } - std::string query = "forest"; - auto query_embedding = embed({query}, "search_query", api_key)[0]; + std::string query{"forest"}; + std::string query_embedding = embed({query}, "search_query", api_key)[0]; pqxx::result result = tx.exec("SELECT content FROM documents ORDER BY embedding <~> $1 LIMIT 5", pqxx::params{query_embedding}); for (const auto& row : result) { std::cout << row[0].as() << std::endl; diff --git a/examples/disco/CMakeLists.txt b/examples/disco/CMakeLists.txt index c734b42..55017e1 100644 --- a/examples/disco/CMakeLists.txt +++ b/examples/disco/CMakeLists.txt @@ -7,7 +7,7 @@ set(CMAKE_CXX_STANDARD 20) include(FetchContent) FetchContent_Declare(disco GIT_REPOSITORY https://github.com/ankane/disco-cpp.git GIT_TAG v0.1.3) -FetchContent_Declare(libpqxx GIT_REPOSITORY https://github.com/jtv/libpqxx.git GIT_TAG 7.10.1) +FetchContent_Declare(libpqxx GIT_REPOSITORY https://github.com/jtv/libpqxx.git GIT_TAG 8.0.0) FetchContent_MakeAvailable(disco libpqxx) add_subdirectory("${PROJECT_SOURCE_DIR}/../.." pgvector) diff --git a/examples/disco/example.cpp b/examples/disco/example.cpp index 0e1985c..b87f631 100644 --- a/examples/disco/example.cpp +++ b/examples/disco/example.cpp @@ -16,7 +16,7 @@ using disco::Recommender; std::string convert_to_utf8(const std::string& str) { std::stringstream buf; - for (auto &v : str) { + for (const auto& v : str) { if (v >= 0) { buf << v; } else { @@ -40,7 +40,7 @@ Dataset load_movielens(const std::string& path) { } // read ratings and create dataset - auto data = Dataset(); + Dataset data; std::ifstream ratings_file(path + "/u.data"); assert(ratings_file.is_open()); while (std::getline(ratings_file, line)) { @@ -59,35 +59,35 @@ Dataset load_movielens(const std::string& path) { int main() { // https://grouplens.org/datasets/movielens/100k/ - char *movielens_path = std::getenv("MOVIELENS_100K_PATH"); + char* movielens_path = std::getenv("MOVIELENS_100K_PATH"); if (!movielens_path) { std::cout << "Set MOVIELENS_100K_PATH" << std::endl; return 1; } - pqxx::connection conn("dbname=pgvector_example"); + pqxx::connection conn{"dbname=pgvector_example"}; - pqxx::nontransaction tx(conn); + pqxx::nontransaction tx{conn}; tx.exec("CREATE EXTENSION IF NOT EXISTS vector"); tx.exec("DROP TABLE IF EXISTS users"); tx.exec("DROP TABLE IF EXISTS movies"); tx.exec("CREATE TABLE users (id integer PRIMARY KEY, factors vector(20))"); tx.exec("CREATE TABLE movies (name text PRIMARY KEY, factors vector(20))"); - auto data = load_movielens(movielens_path); + Dataset data = load_movielens(movielens_path); auto recommender = Recommender::fit_explicit(data, { .factors = 20 }); - for (auto& user_id : recommender.user_ids()) { - auto factors = pgvector::Vector(*recommender.user_factors(user_id)); + for (const auto& user_id : recommender.user_ids()) { + pgvector::Vector factors{*recommender.user_factors(user_id)}; tx.exec("INSERT INTO users (id, factors) VALUES ($1, $2)", pqxx::params{user_id, factors}); } - for (auto& item_id : recommender.item_ids()) { - auto factors = pgvector::Vector(*recommender.item_factors(item_id)); + for (const auto& item_id : recommender.item_ids()) { + pgvector::Vector factors{*recommender.item_factors(item_id)}; tx.exec("INSERT INTO movies (name, factors) VALUES ($1, $2)", pqxx::params{item_id, factors}); } - std::string movie = "Star Wars (1977)"; + std::string movie{"Star Wars (1977)"}; std::cout << "Item-based recommendations for " << movie << std::endl; pqxx::result result = tx.exec("SELECT name FROM movies WHERE name != $1 ORDER BY factors <=> (SELECT factors FROM movies WHERE name = $1) LIMIT 5", pqxx::params{movie}); for (const auto& row : result) { diff --git a/examples/hybrid/CMakeLists.txt b/examples/hybrid/CMakeLists.txt index c2acad2..5c40779 100644 --- a/examples/hybrid/CMakeLists.txt +++ b/examples/hybrid/CMakeLists.txt @@ -2,13 +2,14 @@ cmake_minimum_required(VERSION 3.18) project(example) -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) +set(CPR_USE_SYSTEM_CURL ON) include(FetchContent) -FetchContent_Declare(cpr GIT_REPOSITORY https://github.com/libcpr/cpr.git GIT_TAG 1.11.1) -FetchContent_Declare(json GIT_REPOSITORY https://github.com/nlohmann/json.git GIT_TAG v3.11.3) -FetchContent_Declare(libpqxx GIT_REPOSITORY https://github.com/jtv/libpqxx.git GIT_TAG 7.10.1) +FetchContent_Declare(cpr GIT_REPOSITORY https://github.com/libcpr/cpr.git GIT_TAG 1.14.2) +FetchContent_Declare(json URL https://github.com/nlohmann/json/releases/download/v3.12.0/json.tar.xz DOWNLOAD_EXTRACT_TIMESTAMP TRUE) +FetchContent_Declare(libpqxx GIT_REPOSITORY https://github.com/jtv/libpqxx.git GIT_TAG 8.0.0) FetchContent_MakeAvailable(cpr json libpqxx) add_subdirectory("${PROJECT_SOURCE_DIR}/../.." pgvector) diff --git a/examples/hybrid/example.cpp b/examples/hybrid/example.cpp index c9bb3a6..7cd81f6 100644 --- a/examples/hybrid/example.cpp +++ b/examples/hybrid/example.cpp @@ -13,17 +13,20 @@ using json = nlohmann::json; -std::vector> embed(const std::vector& texts, const std::string& taskType) { +std::vector> embed( + const std::vector& texts, + const std::string& taskType +) { // nomic-embed-text-v1.5 uses a task prefix // https://huggingface.co/nomic-ai/nomic-embed-text-v1.5 std::vector input; input.reserve(texts.size()); - for (auto& v : texts) { + for (const auto& v : texts) { input.push_back(taskType + ": " + v); } - std::string url = "http://localhost:8080/v1/embeddings"; - json data = { + std::string url{"http://localhost:8080/v1/embeddings"}; + json data{ {"input", input} }; @@ -33,38 +36,38 @@ std::vector> embed(const std::vector& texts, con cpr::Header{{"Content-Type", "application/json"}} ); if (r.status_code != 200) { - throw std::runtime_error("Bad status: " + std::to_string(r.status_code)); + throw std::runtime_error{"Bad status: " + std::to_string(r.status_code)}; } json response = json::parse(r.text); std::vector> embeddings; - for (auto& v : response["data"]) { + for (const auto& v : response["data"]) { embeddings.emplace_back(v["embedding"]); } return embeddings; } int main() { - pqxx::connection conn("dbname=pgvector_example"); + pqxx::connection conn{"dbname=pgvector_example"}; - pqxx::nontransaction tx(conn); + pqxx::nontransaction tx{conn}; tx.exec("CREATE EXTENSION IF NOT EXISTS vector"); tx.exec("DROP TABLE IF EXISTS documents"); tx.exec("CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding vector(768))"); tx.exec("CREATE INDEX ON documents USING GIN (to_tsvector('english', content))"); - std::vector input = { + std::vector input{ "The dog is barking", "The cat is purring", "The bear is growling" }; - auto embeddings = embed(input, "search_document"); + std::vector> embeddings = embed(input, "search_document"); for (size_t i = 0; i < input.size(); i++) { - tx.exec("INSERT INTO documents (content, embedding) VALUES ($1, $2)", pqxx::params{input[i], pgvector::Vector(embeddings[i])}); + tx.exec("INSERT INTO documents (content, embedding) VALUES ($1, $2)", pqxx::params{input[i], pgvector::Vector{embeddings[i]}}); } - std::string sql = R"( + std::string sql{R"( WITH semantic_search AS ( SELECT id, RANK () OVER (ORDER BY embedding <=> $2) AS rank FROM documents @@ -86,11 +89,11 @@ int main() { FULL OUTER JOIN keyword_search ON semantic_search.id = keyword_search.id ORDER BY score DESC LIMIT 5 - )"; - std::string query = "growling bear"; - auto query_embedding = embed({query}, "search_query")[0]; + )"}; + std::string query{"growling bear"}; + std::vector query_embedding = embed({query}, "search_query")[0]; double k = 60; - pqxx::result result = tx.exec(sql, pqxx::params{query, pgvector::Vector(query_embedding), k}); + pqxx::result result = tx.exec(sql, pqxx::params{query, pgvector::Vector{query_embedding}, k}); for (const auto& row : result) { std::cout << "document: " << row[0].as() << ", RRF score: " << row[1].as() << std::endl; } diff --git a/examples/loading/CMakeLists.txt b/examples/loading/CMakeLists.txt index cdaa70f..f05b59c 100644 --- a/examples/loading/CMakeLists.txt +++ b/examples/loading/CMakeLists.txt @@ -2,11 +2,11 @@ cmake_minimum_required(VERSION 3.18) project(example) -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) include(FetchContent) -FetchContent_Declare(libpqxx GIT_REPOSITORY https://github.com/jtv/libpqxx.git GIT_TAG 7.10.1) +FetchContent_Declare(libpqxx GIT_REPOSITORY https://github.com/jtv/libpqxx.git GIT_TAG 8.0.0) FetchContent_MakeAvailable(libpqxx) add_subdirectory("${PROJECT_SOURCE_DIR}/../.." pgvector) diff --git a/examples/loading/example.cpp b/examples/loading/example.cpp index 86e15b1..54918fc 100644 --- a/examples/loading/example.cpp +++ b/examples/loading/example.cpp @@ -12,7 +12,7 @@ int main() { std::vector> embeddings; embeddings.reserve(rows); std::mt19937_64 prng; - std::uniform_real_distribution dist(0, 1); + std::uniform_real_distribution dist{0, 1}; for (int i = 0; i < rows; i++) { std::vector embedding; embedding.reserve(dimensions); @@ -23,8 +23,8 @@ int main() { } // enable extension - pqxx::connection conn("dbname=pgvector_example"); - pqxx::nontransaction tx(conn); + pqxx::connection conn{"dbname=pgvector_example"}; + pqxx::nontransaction tx{conn}; tx.exec("CREATE EXTENSION IF NOT EXISTS vector"); // create table @@ -34,14 +34,14 @@ int main() { // load data // libpqxx does not support binary COPY std::cout << "Loading " << rows << " rows" << std::endl; - auto stream = pqxx::stream_to::table(tx, {"items"}, {"embedding"}); + pqxx::stream_to stream = pqxx::stream_to::table(tx, {"items"}, {"embedding"}); for (size_t i = 0; i < embeddings.size(); i++) { // show progress if (i % 10000 == 0) { std::cout << '.' << std::flush; } - stream << pgvector::Vector(embeddings[i]); + stream.write_values(pgvector::Vector{embeddings[i]}); } stream.complete(); std::cout << std::endl << "Success!" << std::endl; diff --git a/examples/openai/CMakeLists.txt b/examples/openai/CMakeLists.txt index c2acad2..5c40779 100644 --- a/examples/openai/CMakeLists.txt +++ b/examples/openai/CMakeLists.txt @@ -2,13 +2,14 @@ cmake_minimum_required(VERSION 3.18) project(example) -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) +set(CPR_USE_SYSTEM_CURL ON) include(FetchContent) -FetchContent_Declare(cpr GIT_REPOSITORY https://github.com/libcpr/cpr.git GIT_TAG 1.11.1) -FetchContent_Declare(json GIT_REPOSITORY https://github.com/nlohmann/json.git GIT_TAG v3.11.3) -FetchContent_Declare(libpqxx GIT_REPOSITORY https://github.com/jtv/libpqxx.git GIT_TAG 7.10.1) +FetchContent_Declare(cpr GIT_REPOSITORY https://github.com/libcpr/cpr.git GIT_TAG 1.14.2) +FetchContent_Declare(json URL https://github.com/nlohmann/json/releases/download/v3.12.0/json.tar.xz DOWNLOAD_EXTRACT_TIMESTAMP TRUE) +FetchContent_Declare(libpqxx GIT_REPOSITORY https://github.com/jtv/libpqxx.git GIT_TAG 8.0.0) FetchContent_MakeAvailable(cpr json libpqxx) add_subdirectory("${PROJECT_SOURCE_DIR}/../.." pgvector) diff --git a/examples/openai/example.cpp b/examples/openai/example.cpp index ef0feea..9098192 100644 --- a/examples/openai/example.cpp +++ b/examples/openai/example.cpp @@ -11,9 +11,9 @@ using json = nlohmann::json; // https://platform.openai.com/docs/guides/embeddings/how-to-get-embeddings // input can be an array with 2048 elements -std::vector> embed(const std::vector& input, char *api_key) { - std::string url = "https://api.openai.com/v1/embeddings"; - json data = { +std::vector> embed(const std::vector& input, char* api_key) { + std::string url{"https://api.openai.com/v1/embeddings"}; + json data{ {"input", input}, {"model", "text-embedding-3-small"} }; @@ -25,44 +25,44 @@ std::vector> embed(const std::vector& input, cha cpr::Header{{"Content-Type", "application/json"}} ); if (r.status_code != 200) { - throw std::runtime_error("Bad status: " + std::to_string(r.status_code)); + throw std::runtime_error{"Bad status: " + std::to_string(r.status_code)}; } json response = json::parse(r.text); std::vector> embeddings; - for (auto& v : response["data"]) { + for (const auto& v : response["data"]) { embeddings.emplace_back(v["embedding"]); } return embeddings; } int main() { - char *api_key = std::getenv("OPENAI_API_KEY"); + char* api_key = std::getenv("OPENAI_API_KEY"); if (!api_key) { std::cout << "Set OPENAI_API_KEY" << std::endl; return 1; } - pqxx::connection conn("dbname=pgvector_example"); + pqxx::connection conn{"dbname=pgvector_example"}; - pqxx::nontransaction tx(conn); + pqxx::nontransaction tx{conn}; tx.exec("CREATE EXTENSION IF NOT EXISTS vector"); tx.exec("DROP TABLE IF EXISTS documents"); tx.exec("CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding vector(1536))"); - std::vector input = { + std::vector input{ "The dog is barking", "The cat is purring", "The bear is growling" }; - auto embeddings = embed(input, api_key); + std::vector> embeddings = embed(input, api_key); for (size_t i = 0; i < input.size(); i++) { - tx.exec("INSERT INTO documents (content, embedding) VALUES ($1, $2)", pqxx::params{input[i], pgvector::Vector(embeddings[i])}); + tx.exec("INSERT INTO documents (content, embedding) VALUES ($1, $2)", pqxx::params{input[i], pgvector::Vector{embeddings[i]}}); } - std::string query = "forest"; - auto query_embedding = embed({query}, api_key)[0]; - pqxx::result result = tx.exec("SELECT content FROM documents ORDER BY embedding <=> $1 LIMIT 5", pqxx::params{pgvector::Vector(query_embedding)}); + std::string query{"forest"}; + std::vector query_embedding = embed({query}, api_key)[0]; + pqxx::result result = tx.exec("SELECT content FROM documents ORDER BY embedding <=> $1 LIMIT 5", pqxx::params{pgvector::Vector{query_embedding}}); for (const auto& row : result) { std::cout << row[0].as() << std::endl; } diff --git a/examples/rdkit/CMakeLists.txt b/examples/rdkit/CMakeLists.txt index 1d607ee..504d619 100644 --- a/examples/rdkit/CMakeLists.txt +++ b/examples/rdkit/CMakeLists.txt @@ -2,14 +2,14 @@ cmake_minimum_required(VERSION 3.18) project(example) -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) find_package(RDKit REQUIRED) -find_package(Boost COMPONENTS iostreams serialization system REQUIRED) +find_package(Boost COMPONENTS iostreams serialization REQUIRED) include(FetchContent) -FetchContent_Declare(libpqxx GIT_REPOSITORY https://github.com/jtv/libpqxx.git GIT_TAG 7.10.1) +FetchContent_Declare(libpqxx GIT_REPOSITORY https://github.com/jtv/libpqxx.git GIT_TAG 8.0.0) FetchContent_MakeAvailable(libpqxx) add_subdirectory("${PROJECT_SOURCE_DIR}/../.." pgvector) diff --git a/examples/rdkit/example.cpp b/examples/rdkit/example.cpp index b09f882..8caf453 100644 --- a/examples/rdkit/example.cpp +++ b/examples/rdkit/example.cpp @@ -4,14 +4,14 @@ #include #include -#include #include +#include #include #include std::string generate_fingerprint(const std::string& molecule) { - std::unique_ptr mol(RDKit::SmilesToMol(molecule)); - std::unique_ptr fp(RDKit::MorganFingerprints::getFingerprintAsBitVect(*mol, 3, 2048)); + std::unique_ptr mol{RDKit::SmilesToMol(molecule)}; + std::unique_ptr fp{RDKit::MorganFingerprints::getFingerprintAsBitVect(*mol, 3, 2048)}; std::stringstream buf; for (size_t i = 0; i < fp->getNumBits(); i++) { buf << (fp->getBit(i) ? '1' : '0'); @@ -20,21 +20,21 @@ std::string generate_fingerprint(const std::string& molecule) { } int main() { - pqxx::connection conn("dbname=pgvector_example"); + pqxx::connection conn{"dbname=pgvector_example"}; - pqxx::nontransaction tx(conn); + pqxx::nontransaction tx{conn}; tx.exec("CREATE EXTENSION IF NOT EXISTS vector"); tx.exec("DROP TABLE IF EXISTS molecules"); tx.exec("CREATE TABLE molecules (id text PRIMARY KEY, fingerprint bit(2048))"); - std::vector molecules = {"Cc1ccccc1", "Cc1ncccc1", "c1ccccn1"}; - for (auto& molecule : molecules) { - auto fingerprint = generate_fingerprint(molecule); + std::vector molecules{"Cc1ccccc1", "Cc1ncccc1", "c1ccccn1"}; + for (const auto& molecule : molecules) { + std::string fingerprint = generate_fingerprint(molecule); tx.exec("INSERT INTO molecules (id, fingerprint) VALUES ($1, $2)", pqxx::params{molecule, fingerprint}); } - std::string query_molecule = "c1ccco1"; - auto query_fingerprint = generate_fingerprint(query_molecule); + std::string query_molecule{"c1ccco1"}; + std::string query_fingerprint = generate_fingerprint(query_molecule); pqxx::result result = tx.exec("SELECT id, fingerprint <%> $1 AS distance FROM molecules ORDER BY distance LIMIT 5", pqxx::params{query_fingerprint}); for (const auto& row : result) { std::cout << row[0].as() << ": " << row[1].as() << std::endl; diff --git a/examples/sparse/CMakeLists.txt b/examples/sparse/CMakeLists.txt index c2acad2..5c40779 100644 --- a/examples/sparse/CMakeLists.txt +++ b/examples/sparse/CMakeLists.txt @@ -2,13 +2,14 @@ cmake_minimum_required(VERSION 3.18) project(example) -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) +set(CPR_USE_SYSTEM_CURL ON) include(FetchContent) -FetchContent_Declare(cpr GIT_REPOSITORY https://github.com/libcpr/cpr.git GIT_TAG 1.11.1) -FetchContent_Declare(json GIT_REPOSITORY https://github.com/nlohmann/json.git GIT_TAG v3.11.3) -FetchContent_Declare(libpqxx GIT_REPOSITORY https://github.com/jtv/libpqxx.git GIT_TAG 7.10.1) +FetchContent_Declare(cpr GIT_REPOSITORY https://github.com/libcpr/cpr.git GIT_TAG 1.14.2) +FetchContent_Declare(json URL https://github.com/nlohmann/json/releases/download/v3.12.0/json.tar.xz DOWNLOAD_EXTRACT_TIMESTAMP TRUE) +FetchContent_Declare(libpqxx GIT_REPOSITORY https://github.com/jtv/libpqxx.git GIT_TAG 8.0.0) FetchContent_MakeAvailable(cpr json libpqxx) add_subdirectory("${PROJECT_SOURCE_DIR}/../.." pgvector) diff --git a/examples/sparse/example.cpp b/examples/sparse/example.cpp index efe461a..463ab53 100644 --- a/examples/sparse/example.cpp +++ b/examples/sparse/example.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -18,8 +19,8 @@ using json = nlohmann::json; std::vector embed(const std::vector& inputs) { - std::string url = "http://localhost:3000/embed_sparse"; - json data = { + std::string url{"http://localhost:3000/embed_sparse"}; + json data{ {"inputs", inputs} }; @@ -29,43 +30,41 @@ std::vector embed(const std::vector& inputs cpr::Header{{"Content-Type", "application/json"}} ); if (r.status_code != 200) { - throw std::runtime_error("Bad status: " + std::to_string(r.status_code)); + throw std::runtime_error{"Bad status: " + std::to_string(r.status_code)}; } json response = json::parse(r.text); std::vector embeddings; - for (auto& item : response) { - std::vector indices; - std::vector values; - for (auto& e : item) { - indices.emplace_back(e["index"]); - values.emplace_back(e["value"]); + for (const auto& item : response) { + std::unordered_map map; + for (const auto& e : item) { + map.insert({e["index"], e["value"]}); } - embeddings.emplace_back(pgvector::SparseVector(30522, indices, values)); + embeddings.emplace_back(pgvector::SparseVector{map, 30522}); } return embeddings; } int main() { - pqxx::connection conn("dbname=pgvector_example"); + pqxx::connection conn{"dbname=pgvector_example"}; - pqxx::nontransaction tx(conn); + pqxx::nontransaction tx{conn}; tx.exec("CREATE EXTENSION IF NOT EXISTS vector"); tx.exec("DROP TABLE IF EXISTS documents"); tx.exec("CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding sparsevec(30522))"); - std::vector input = { + std::vector input{ "The dog is barking", "The cat is purring", "The bear is growling" }; - auto embeddings = embed(input); + std::vector embeddings = embed(input); for (size_t i = 0; i < input.size(); i++) { tx.exec("INSERT INTO documents (content, embedding) VALUES ($1, $2)", pqxx::params{input[i], embeddings[i]}); } - std::string query = "forest"; - auto query_embedding = embed({query})[0]; + std::string query{"forest"}; + pgvector::SparseVector query_embedding = embed({query})[0]; pqxx::result result = tx.exec("SELECT content FROM documents ORDER BY embedding <#> $1 LIMIT 5", pqxx::params{query_embedding}); for (const auto& row : result) { std::cout << row[0].as() << std::endl; diff --git a/include/pgvector/halfvec.hpp b/include/pgvector/halfvec.hpp index d0cae75..5513936 100644 --- a/include/pgvector/halfvec.hpp +++ b/include/pgvector/halfvec.hpp @@ -1,5 +1,5 @@ /* - * pgvector-cpp v0.2.2 + * pgvector-cpp v0.3.0 * https://github.com/pgvector/pgvector-cpp * MIT License */ @@ -8,53 +8,46 @@ #include #include +#include #include #include -#if __cplusplus >= 202002L -#include +#if __STDCPP_FLOAT16_T__ +#include +#else +#define __STDC_WANT_IEC_60559_TYPES_EXT__ +#include #endif namespace pgvector { +/// A half vector element. +#if __STDCPP_FLOAT16_T__ +using Half = std::float16_t; +#elif defined(__FLT16_MAX__) +using Half = _Float16; +#else +using Half = float; +#endif + /// A half vector. class HalfVector { public: - /// @private - // TODO remove in 0.3.0 - HalfVector() = default; - - /// Creates a half vector from a `std::vector`. - // TODO add explicit in 0.3.0 - HalfVector(const std::vector& value) { - value_ = value; - } + /// Creates a half vector from a `std::vector`. + explicit HalfVector(const std::vector& value) : value_{value} {} - /// Creates a half vector from a `std::vector`. - // TODO add explicit in 0.3.0 - HalfVector(std::vector&& value) { - value_ = std::move(value); - } - - /// Creates a half vector from an array. - HalfVector(const float* value, size_t n) { - value_ = std::vector{value, value + n}; - } + /// Creates a half vector from a `std::vector`. + explicit HalfVector(std::vector&& value) : value_{std::move(value)} {} -#if __cplusplus >= 202002L /// Creates a half vector from a span. - // TODO add explicit in 0.3.0 - HalfVector(std::span value) { - value_ = std::vector(value.begin(), value.end()); - } -#endif + explicit HalfVector(std::span value) : value_{std::vector(value.begin(), value.end())} {} /// Returns the number of dimensions. size_t dimensions() const { return value_.size(); } - /// Returns the half vector as a `std::vector`. - operator const std::vector() const { + /// Returns the values. + const std::vector& values() const { return value_; } @@ -64,18 +57,24 @@ class HalfVector { friend std::ostream& operator<<(std::ostream& os, const HalfVector& value) { os << "["; - for (size_t i = 0; i < value.value_.size(); i++) { + // TODO use std::views::enumerate for C++23 + size_t i = 0; + for (auto v : value.value_) { if (i > 0) { os << ","; } - os << value.value_[i]; +#if __STDCPP_FLOAT16_T__ + os << v; +#else + os << static_cast(v); +#endif + i++; } os << "]"; return os; } private: - // TODO use std::float16_t for C++23 - std::vector value_; + std::vector value_; }; } // namespace pgvector diff --git a/include/pgvector/pqxx.hpp b/include/pgvector/pqxx.hpp index 8c0812e..4fd5893 100644 --- a/include/pgvector/pqxx.hpp +++ b/include/pgvector/pqxx.hpp @@ -1,5 +1,5 @@ /* - * pgvector-cpp v0.2.2 + * pgvector-cpp v0.3.0 * https://github.com/pgvector/pgvector-cpp * MIT License */ @@ -7,9 +7,13 @@ #pragma once #include -#include -#include +#include +#include +#include +#include #include +#include +#include #include #include @@ -21,115 +25,217 @@ /// @cond namespace pqxx { -template <> std::string const type_name{"vector"}; +template<> +inline constexpr std::string_view name_type() noexcept { + return "vector"; +}; -template <> struct nullness : pqxx::no_null {}; +template<> +struct nullness : no_null {}; -template <> struct string_traits { - static constexpr bool converts_to_string{true}; +template<> +struct string_traits { + static pgvector::Vector from_string(std::string_view text, ctx c = {}) { + if (text.size() < 2 || text.front() != '[' || text.back() != ']') { + throw conversion_error{"Malformed vector literal"}; + } - static constexpr bool converts_from_string{true}; + std::vector values; + if (text.size() > 2) { + std::string_view inner = text.substr(1, text.size() - 2); + for (const auto& v : std::views::split(inner, ',')) { + std::string_view sv{v.begin(), v.end()}; + values.push_back(pqxx::from_string(sv, c)); + } + } + return pgvector::Vector{std::move(values)}; + } - static pgvector::Vector from_string(std::string_view text) { - if (text.front() != '[' || text.back() != ']') { - throw conversion_error("Malformed vector literal"); + static std::string_view to_buf(std::span buf, const pgvector::Vector& value, ctx c = {}) { + // confirm caller provided estimated buffer space + if (buf.size() < size_buffer(value)) { + throw conversion_overrun{"Not enough space in buffer for vector"}; } - // TODO don't copy string - std::vector result; - std::stringstream ss(std::string(text.substr(1, text.size() - 2))); - while (ss.good()) { - std::string substr; - getline(ss, substr, ','); - result.push_back(std::stof(substr)); + const std::vector& values = value.values(); + + // important! size_buffer cannot throw an exception on overflow + // so perform this check before writing any data + if (values.size() > 16000) { + throw conversion_overrun{"vector cannot have more than 16000 dimensions"}; } - return pgvector::Vector(result); - } - static zview to_buf(char* begin, char* end, const pgvector::Vector& value) { - char *const next = into_buf(begin, end, value); - return zview{begin, next - begin - 1}; - } + size_t here = 0; + here += pqxx::into_buf(buf.subspan(here), "[", c); + + // TODO use std::views::enumerate for C++23 + size_t i = 0; + for (auto v : values) { + if (i != 0) { + here += pqxx::into_buf(buf.subspan(here), ",", c); + } + here += pqxx::into_buf(buf.subspan(here), v, c); + i++; + } + + here += pqxx::into_buf(buf.subspan(here), "]", c); - static char* into_buf(char* begin, char* end, const pgvector::Vector& value) { - auto ret = string_traits>::into_buf( - begin, end, static_cast>(value)); - // replace array brackets - *begin = '['; - *(ret - 2) = ']'; - return ret; + return {std::data(buf), here}; } static size_t size_buffer(const pgvector::Vector& value) noexcept { - return string_traits>::size_buffer( - static_cast>(value)); + const std::vector& values = value.values(); + + // cannot throw an exception here on overflow + // so throw in into_buf + + size_t size = 0; + size += pqxx::size_buffer("["); + for (const auto v : values) { + size += pqxx::size_buffer(","); + size += pqxx::size_buffer(v); + } + size += pqxx::size_buffer("]"); + return size; } }; -template <> std::string const type_name{"halfvec"}; +template<> +inline constexpr std::string_view name_type() noexcept { + return "halfvec"; +}; -template <> struct nullness : pqxx::no_null {}; +template<> +struct nullness : no_null {}; -template <> struct string_traits { - static constexpr bool converts_to_string{true}; +template<> +struct string_traits { + static pgvector::HalfVector from_string(std::string_view text, ctx c = {}) { + if (text.size() < 2 || text.front() != '[' || text.back() != ']') { + throw conversion_error{"Malformed halfvec literal"}; + } - static constexpr bool converts_from_string{true}; + std::vector values; + if (text.size() > 2) { + std::string_view inner = text.substr(1, text.size() - 2); + for (const auto& v : std::views::split(inner, ',')) { + std::string_view sv{v.begin(), v.end()}; + values.push_back(static_cast(pqxx::from_string(sv, c))); + } + } + return pgvector::HalfVector{std::move(values)}; + } - static pgvector::HalfVector from_string(std::string_view text) { - if (text.front() != '[' || text.back() != ']') { - throw conversion_error("Malformed halfvec literal"); + static std::string_view to_buf(std::span buf, const pgvector::HalfVector& value, ctx c = {}) { + // confirm caller provided estimated buffer space + if (buf.size() < size_buffer(value)) { + throw conversion_overrun{"Not enough space in buffer for halfvec"}; } - // TODO don't copy string - std::vector result; - std::stringstream ss(std::string(text.substr(1, text.size() - 2))); - while (ss.good()) { - std::string substr; - getline(ss, substr, ','); - result.push_back(std::stof(substr)); + const std::vector& values = value.values(); + + // important! size_buffer cannot throw an exception on overflow + // so perform this check before writing any data + if (values.size() > 16000) { + throw conversion_overrun{"halfvec cannot have more than 16000 dimensions"}; } - return pgvector::HalfVector(result); - } - static zview to_buf(char* begin, char* end, const pgvector::HalfVector& value) { - char *const next = into_buf(begin, end, value); - return zview{begin, next - begin - 1}; - } + size_t here = 0; + here += pqxx::into_buf(buf.subspan(here), "[", c); + + // TODO use std::views::enumerate for C++23 + size_t i = 0; + for (auto v : values) { + if (i != 0) { + here += pqxx::into_buf(buf.subspan(here), ",", c); + } + here += pqxx::into_buf(buf.subspan(here), static_cast(v), c); + i++; + } + + here += pqxx::into_buf(buf.subspan(here), "]", c); - static char* into_buf(char* begin, char* end, const pgvector::HalfVector& value) { - auto ret = string_traits>::into_buf( - begin, end, static_cast>(value)); - // replace array brackets - *begin = '['; - *(ret - 2) = ']'; - return ret; + return {std::data(buf), here}; } static size_t size_buffer(const pgvector::HalfVector& value) noexcept { - return string_traits>::size_buffer( - static_cast>(value)); + const std::vector& values = value.values(); + + // cannot throw an exception here on overflow + // so throw in into_buf + + size_t size = 0; + size += pqxx::size_buffer("["); + for (const auto v : values) { + size += pqxx::size_buffer(","); + size += pqxx::size_buffer(static_cast(v)); + } + size += pqxx::size_buffer("]"); + return size; } }; -template <> std::string const type_name{"sparsevec"}; +template<> +inline constexpr std::string_view name_type() noexcept { + return "sparsevec"; +}; + +template<> +struct nullness : no_null {}; + +template<> +struct string_traits { + static pgvector::SparseVector from_string(std::string_view text, ctx c = {}) { + if (text.size() < 4 || text.front() != '{') { + throw conversion_error{"Malformed sparsevec literal"}; + } + + size_t n = text.find("}/", 1); + if (n == std::string_view::npos) { + throw conversion_error{"Malformed sparsevec literal"}; + } + + int dimensions = pqxx::from_string(text.substr(n + 2), c); + + std::unordered_map map; + if (n > 1) { + std::string_view inner = text.substr(1, n - 1); + for (const auto& v : std::views::split(inner, ',')) { + std::string_view sv{v.begin(), v.end()}; -template <> struct nullness : pqxx::no_null {}; + size_t ne = sv.find(':'); + if (ne == std::string_view::npos) { + throw conversion_error{"Malformed sparsevec literal"}; + } -template <> struct string_traits { - static constexpr bool converts_to_string{true}; + int index = pqxx::from_string(sv.substr(0, ne), c); + float value = pqxx::from_string(sv.substr(ne + 1), c); - // TODO add from_string - static constexpr bool converts_from_string{false}; + // check to avoid undefined behavior + if (index > std::numeric_limits::min()) { + index -= 1; + } - static zview to_buf(char* begin, char* end, const pgvector::SparseVector& value) { - char *const next = into_buf(begin, end, value); - return zview{begin, next - begin - 1}; + map.insert({index, value}); + } + } + + try { + return pgvector::SparseVector{map, dimensions}; + } catch (const std::invalid_argument& e) { + throw conversion_error{e.what()}; + } } - static char* into_buf(char* begin, char* end, const pgvector::SparseVector& value) { + static std::string_view to_buf(std::span buf, const pgvector::SparseVector& value, ctx c = {}) { + // confirm caller provided estimated buffer space + if (buf.size() < size_buffer(value)) { + throw conversion_overrun{"Not enough space in buffer for sparsevec"}; + } + int dimensions = value.dimensions(); - auto indices = value.indices(); - auto values = value.values(); + const std::vector& indices = value.indices(); + const std::vector& values = value.values(); size_t nnz = indices.size(); // important! size_buffer cannot throw an exception on overflow @@ -138,43 +244,46 @@ template <> struct string_traits { throw conversion_overrun{"sparsevec cannot have more than 16000 dimensions"}; } - char *here = begin; - *here++ = '{'; + size_t here = 0; + here += pqxx::into_buf(buf.subspan(here), "{", c); + // TODO use std::views::zip for C++23 for (size_t i = 0; i < nnz; i++) { if (i != 0) { - *here++ = ','; + here += pqxx::into_buf(buf.subspan(here), ",", c); } - - here = string_traits::into_buf(here, end, indices[i] + 1) - 1; - *here++ = ':'; - here = string_traits::into_buf(here, end, values[i]) - 1; + // cast to avoid undefined behavior and require less buffer space + here += pqxx::into_buf(buf.subspan(here), static_cast(indices.at(i)) + 1, c); + here += pqxx::into_buf(buf.subspan(here), ":", c); + here += pqxx::into_buf(buf.subspan(here), values.at(i), c); } - *here++ = '}'; - *here++ = '/'; - here = string_traits::into_buf(here, end, dimensions) - 1; - *here++ = '\0'; + here += pqxx::into_buf(buf.subspan(here), "}/", c); + here += pqxx::into_buf(buf.subspan(here), dimensions, c); - return here; + return {std::data(buf), here}; } static size_t size_buffer(const pgvector::SparseVector& value) noexcept { int dimensions = value.dimensions(); - auto indices = value.indices(); - auto values = value.values(); + const std::vector& indices = value.indices(); + const std::vector& values = value.values(); size_t nnz = indices.size(); // cannot throw an exception here on overflow // so throw in into_buf - size_t size = 4; // {, }, /, and \0 - size += string_traits::size_buffer(dimensions); + size_t size = 0; + size += pqxx::size_buffer("{"); + // TODO use std::views::zip for C++23 for (size_t i = 0; i < nnz; i++) { - size += 2; // : and , - size += string_traits::size_buffer(indices[i]); - size += string_traits::size_buffer(values[i]); + size += pqxx::size_buffer(","); + size += pqxx::size_buffer(static_cast(indices.at(i)) + 1); + size += pqxx::size_buffer(":"); + size += pqxx::size_buffer(values.at(i)); } + size += pqxx::size_buffer("}/"); + size += pqxx::size_buffer(dimensions); return size; } }; diff --git a/include/pgvector/sparsevec.hpp b/include/pgvector/sparsevec.hpp index 6686178..995f7e2 100644 --- a/include/pgvector/sparsevec.hpp +++ b/include/pgvector/sparsevec.hpp @@ -1,5 +1,5 @@ /* - * pgvector-cpp v0.2.2 + * pgvector-cpp v0.3.0 * https://github.com/pgvector/pgvector-cpp * MIT License */ @@ -8,81 +8,60 @@ #include #include +#include #include +#include #include #include #include -#if __cplusplus >= 202002L -#include -#endif - namespace pgvector { /// A sparse vector. class SparseVector { public: - /// @private - // TODO remove in 0.3.0 - SparseVector() = default; - - /// @private - SparseVector(int dimensions, const std::vector& indices, const std::vector& values) { - if (values.size() != indices.size()) { - throw std::invalid_argument("indices and values must be the same length"); - } - dimensions_ = dimensions; - indices_ = indices; - values_ = values; - } - /// Creates a sparse vector from a dense vector. - // TODO add explicit in 0.3.0 - SparseVector(const std::vector& value) { - dimensions_ = value.size(); - for (size_t i = 0; i < value.size(); i++) { - float v = value[i]; - if (v != 0) { - indices_.push_back(i); - values_.push_back(v); - } - } - } + explicit SparseVector(const std::vector& value) : SparseVector(std::span{value}) {} -#if __cplusplus >= 202002L /// Creates a sparse vector from a span. - // TODO add explicit in 0.3.0 - SparseVector(std::span value) { - dimensions_ = value.size(); - for (size_t i = 0; i < value.size(); i++) { - float v = value[i]; + explicit SparseVector(std::span value) { + if (value.size() > std::numeric_limits::max()) { + throw std::invalid_argument{"sparsevec cannot have more than max int dimensions"}; + } + dimensions_ = static_cast(value.size()); + + // do not reserve capacity for indices/values since likely many zeros + // TODO use std::views::enumerate for C++23 + size_t i = 0; + for (auto v : value) { if (v != 0) { - indices_.push_back(i); + indices_.push_back(static_cast(i)); values_.push_back(v); } + i++; } } -#endif /// Creates a sparse vector from a map of non-zero elements. SparseVector(const std::unordered_map& map, int dimensions) { - if (dimensions < 1) { - throw std::invalid_argument("sparsevec must have at least 1 dimension"); + if (dimensions < 0) { + throw std::invalid_argument{"sparsevec cannot have negative dimensions"}; } dimensions_ = dimensions; - for (auto [i, v] : map) { + // could probably reserve capacity for indices since not expecting zeros + for (const auto& [i, v] : map) { if (i < 0 || i >= dimensions) { - throw std::invalid_argument("sparsevec index out of bounds"); + throw std::invalid_argument{"sparsevec index out of bounds"}; } if (v != 0) { indices_.push_back(i); } } - std::sort(indices_.begin(), indices_.end()); + std::ranges::sort(indices_); values_.reserve(indices_.size()); - for (auto i : indices_) { + for (const auto i : indices_) { values_.push_back(map.at(i)); } } @@ -108,13 +87,14 @@ class SparseVector { friend std::ostream& operator<<(std::ostream& os, const SparseVector& value) { os << "{"; + // TODO use std::views::zip for C++23 for (size_t i = 0; i < value.indices_.size(); i++) { if (i > 0) { os << ","; } - os << value.indices_[i] + 1; + os << value.indices_.at(i) + 1; os << ":"; - os << value.values_[i]; + os << value.values_.at(i); } os << "}/"; os << value.dimensions_; diff --git a/include/pgvector/vector.hpp b/include/pgvector/vector.hpp index fcc385a..43804ed 100644 --- a/include/pgvector/vector.hpp +++ b/include/pgvector/vector.hpp @@ -1,5 +1,5 @@ /* - * pgvector-cpp v0.2.2 + * pgvector-cpp v0.3.0 * https://github.com/pgvector/pgvector-cpp * MIT License */ @@ -8,53 +8,30 @@ #include #include +#include #include #include -#if __cplusplus >= 202002L -#include -#endif - namespace pgvector { /// A vector. class Vector { public: - /// @private - // TODO remove in 0.3.0 - Vector() = default; + /// Creates a vector from a `std::vector`. + explicit Vector(const std::vector& value) : value_{value} {} - /// Creates a vector from a `std::vector`. - // TODO add explicit in 0.3.0 - Vector(const std::vector& value) { - value_ = value; - } + /// Creates a vector from a `std::vector`. + explicit Vector(std::vector&& value) : value_{std::move(value)} {} - /// Creates a vector from a `std::vector`. - // TODO add explicit in 0.3.0 - Vector(std::vector&& value) { - value_ = std::move(value); - } - - /// Creates a vector from an array. - Vector(const float* value, size_t n) { - value_ = std::vector{value, value + n}; - } - -#if __cplusplus >= 202002L /// Creates a vector from a span. - // TODO add explicit in 0.3.0 - Vector(std::span value) { - value_ = std::vector(value.begin(), value.end()); - } -#endif + explicit Vector(std::span value) : value_{std::vector(value.begin(), value.end())} {} /// Returns the number of dimensions. size_t dimensions() const { return value_.size(); } - /// Returns the vector as a `std::vector`. - operator const std::vector() const { + /// Returns the values. + const std::vector& values() const { return value_; } @@ -64,11 +41,14 @@ class Vector { friend std::ostream& operator<<(std::ostream& os, const Vector& value) { os << "["; - for (size_t i = 0; i < value.value_.size(); i++) { + // TODO use std::views::enumerate for C++23 + size_t i = 0; + for (auto v : value.value_) { if (i > 0) { os << ","; } - os << value.value_[i]; + os << v; + i++; } os << "]"; return os; diff --git a/test/halfvec_test.cpp b/test/halfvec_test.cpp index cfbc48e..94c68f6 100644 --- a/test/halfvec_test.cpp +++ b/test/halfvec_test.cpp @@ -1,28 +1,50 @@ -#include +#include +#include +#include -#include "../include/pgvector/halfvec.hpp" +#include -#if __cplusplus >= 202002L -#include -#endif +#include "helper.hpp" using pgvector::HalfVector; static void test_constructor_vector() { - auto vec = HalfVector({1, 2, 3}); - assert(vec.dimensions() == 3); + HalfVector vec{std::vector{1, 2, 3}}; + assert_equal(vec.dimensions(), 3u); } -#if __cplusplus >= 202002L static void test_constructor_span() { - auto vec = HalfVector(std::span({1, 2, 3})); - assert(vec.dimensions() == 3); + HalfVector vec{std::span{{1, 2, 3}}}; + assert_equal(vec.dimensions(), 3u); +} + +static void test_constructor_empty() { + HalfVector vec{std::vector{}}; + assert_equal(vec.dimensions(), 0u); +} + +static void test_dimensions() { + HalfVector vec{{1, 2, 3}}; + assert_equal(vec.dimensions(), 3u); +} + +static void test_values() { + HalfVector vec{{1, 2, 3}}; + assert_equal(vec.values() == std::vector{1, 2, 3}, true); +} + +static void test_string() { + HalfVector vec{{1, 2, 3}}; + std::ostringstream oss; + oss << vec; + assert_equal(oss.str(), "[1,2,3]"); } -#endif void test_halfvec() { test_constructor_vector(); -#if __cplusplus >= 202002L test_constructor_span(); -#endif + test_constructor_empty(); + test_dimensions(); + test_values(); + test_string(); } diff --git a/test/helper.hpp b/test/helper.hpp new file mode 100644 index 0000000..08d3e00 --- /dev/null +++ b/test/helper.hpp @@ -0,0 +1,38 @@ +#pragma once + +#include +#include +#include +#include +#include + +template +void assert_equal( + const T& left, + const U& right, + const std::source_location& loc = std::source_location::current() +) { + if (left != right) { + std::ostringstream message; + message << left << " != " << right; + message << " in " << loc.function_name() << " " << loc.file_name() << ":" << loc.line(); + throw std::runtime_error{message.str()}; + } +} + +template +void assert_exception( + const std::function& code, + std::optional message = std::nullopt +) { + std::optional exception; + try { + code(); + } catch (const T& e) { + exception = e; + } + assert_equal(exception.has_value(), true); + if (message) { + assert_equal(std::string_view{exception.value().what()}, message.value()); + } +} diff --git a/test/main.cpp b/test/main.cpp index 28d0239..c2bcb27 100644 --- a/test/main.cpp +++ b/test/main.cpp @@ -1,3 +1,6 @@ +// Test ODR +#include + void test_vector(); void test_halfvec(); void test_sparsevec(); diff --git a/test/pqxx_test.cpp b/test/pqxx_test.cpp index a60f0d1..f87e215 100644 --- a/test/pqxx_test.cpp +++ b/test/pqxx_test.cpp @@ -1,142 +1,387 @@ -#include +#include +#include #include +#include #include +#include +#include #include +#include #include -#include "../include/pgvector/pqxx.hpp" +#include "helper.hpp" -void setup(pqxx::connection &conn) { - pqxx::nontransaction tx(conn); +void setup(pqxx::connection& conn) { + pqxx::nontransaction tx{conn}; tx.exec("CREATE EXTENSION IF NOT EXISTS vector"); tx.exec("DROP TABLE IF EXISTS items"); tx.exec("CREATE TABLE items (id serial PRIMARY KEY, embedding vector(3), half_embedding halfvec(3), binary_embedding bit(3), sparse_embedding sparsevec(3))"); } -void before_each(pqxx::connection &conn) { - pqxx::nontransaction tx(conn); +void before_each(pqxx::connection& conn) { + pqxx::nontransaction tx{conn}; tx.exec("TRUNCATE items"); } -void test_vector(pqxx::connection &conn) { +std::optional float_error([[maybe_unused]] std::string_view message) { +#ifdef __linux__ + return message; +#else + return std::nullopt; +#endif +} + +void test_vector(pqxx::connection& conn) { before_each(conn); - pqxx::nontransaction tx(conn); - auto embedding = pgvector::Vector({1, 2, 3}); - assert(embedding.dimensions() == 3); - float arr[] = {4, 5, 6}; - auto embedding2 = pgvector::Vector(arr, 3); + pqxx::nontransaction tx{conn}; + pgvector::Vector embedding{{1, 2, 3}}; + pgvector::Vector embedding2{{4, 5, 6}}; tx.exec("INSERT INTO items (embedding) VALUES ($1), ($2), ($3)", {embedding, embedding2, std::nullopt}); pqxx::result res = tx.exec("SELECT embedding FROM items ORDER BY embedding <-> $1", {embedding2}); - assert(res.size() == 3); - assert(res[0][0].as() == embedding2); - assert(res[1][0].as() == embedding); - assert(!res[2][0].as>().has_value()); + assert_equal(res.size(), 3); + assert_equal(res.at(0).at(0).as(), embedding2); + assert_equal(res.at(1).at(0).as(), embedding); + assert_equal(res.at(2).at(0).as>().has_value(), false); } -void test_halfvec(pqxx::connection &conn) { +void test_halfvec(pqxx::connection& conn) { before_each(conn); - pqxx::nontransaction tx(conn); - auto embedding = pgvector::HalfVector({1, 2, 3}); - assert(embedding.dimensions() == 3); - float arr[] = {4, 5, 6}; - auto embedding2 = pgvector::HalfVector(arr, 3); + pqxx::nontransaction tx{conn}; + pgvector::HalfVector embedding{{1, 2, 3}}; + pgvector::HalfVector embedding2{{4, 5, 6}}; tx.exec("INSERT INTO items (half_embedding) VALUES ($1), ($2), ($3)", {embedding, embedding2, std::nullopt}); pqxx::result res = tx.exec("SELECT half_embedding FROM items ORDER BY half_embedding <-> $1", {embedding2}); - assert(res.size() == 3); - assert(res[0][0].as() == embedding2); - assert(res[1][0].as() == embedding); - assert(!res[2][0].as>().has_value()); + assert_equal(res.size(), 3); + assert_equal(res.at(0).at(0).as(), embedding2); + assert_equal(res.at(1).at(0).as(), embedding); + assert_equal(res.at(2).at(0).as>().has_value(), false); } -void test_bit(pqxx::connection &conn) { +void test_bit(pqxx::connection& conn) { before_each(conn); - pqxx::nontransaction tx(conn); - auto embedding = "101"; - auto embedding2 = "111"; + pqxx::nontransaction tx{conn}; + std::string embedding{"101"}; + std::string embedding2{"111"}; tx.exec("INSERT INTO items (binary_embedding) VALUES ($1), ($2), ($3)", {embedding, embedding2, std::nullopt}); pqxx::result res = tx.exec("SELECT binary_embedding FROM items ORDER BY binary_embedding <~> $1", pqxx::params{embedding2}); - assert(res.size() == 3); - assert(res[0][0].as() == embedding2); - assert(res[1][0].as() == embedding); - assert(!res[2][0].as>().has_value()); + assert_equal(res.size(), 3); + assert_equal(res.at(0).at(0).as(), embedding2); + assert_equal(res.at(1).at(0).as(), embedding); + assert_equal(res.at(2).at(0).as>().has_value(), false); } -void test_sparsevec(pqxx::connection &conn) { +void test_sparsevec(pqxx::connection& conn) { before_each(conn); - pqxx::nontransaction tx(conn); - auto embedding = pgvector::SparseVector({1, 2, 3}); - auto embedding2 = pgvector::SparseVector({4, 5, 6}); + pqxx::nontransaction tx{conn}; + pgvector::SparseVector embedding{{1, 2, 3}}; + pgvector::SparseVector embedding2{{4, 5, 6}}; tx.exec("INSERT INTO items (sparse_embedding) VALUES ($1), ($2), ($3)", {embedding, embedding2, std::nullopt}); pqxx::result res = tx.exec("SELECT sparse_embedding FROM items ORDER BY sparse_embedding <-> $1", {embedding2}); - assert(res.size() == 3); - assert(res[0][0].as() == "{1:4,2:5,3:6}/3"); - assert(res[1][0].as() == "{1:1,2:2,3:3}/3"); - assert(!res[2][0].as>().has_value()); + assert_equal(res.size(), 3); + assert_equal(res.at(0).at(0).as(), embedding2); + assert_equal(res.at(1).at(0).as(), embedding); + assert_equal(res.at(2).at(0).as>().has_value(), false); } -void test_sparsevec_nnz(pqxx::connection &conn) { +void test_sparsevec_nnz(pqxx::connection& conn) { before_each(conn); - pqxx::nontransaction tx(conn); + pqxx::nontransaction tx{conn}; std::vector vec(16001, 1); - auto embedding = pgvector::SparseVector(vec); - try { + pgvector::SparseVector embedding{vec}; + assert_exception([&] { tx.exec("INSERT INTO items (sparse_embedding) VALUES ($1)", {embedding}); - assert(false); - } catch (const pqxx::conversion_overrun& e) { - assert(std::strcmp(e.what(), "sparsevec cannot have more than 16000 dimensions") == 0); - } + }, "sparsevec cannot have more than 16000 dimensions"); } -void test_stream(pqxx::connection &conn) { +void test_stream(pqxx::connection& conn) { before_each(conn); - pqxx::nontransaction tx(conn); - auto embedding = pgvector::Vector({1, 2, 3}); + pqxx::nontransaction tx{conn}; + pgvector::Vector embedding({1, 2, 3}); tx.exec("INSERT INTO items (embedding) VALUES ($1)", {embedding}); int count = 0; - for (auto [id, embedding] : tx.stream("SELECT id, embedding FROM items WHERE embedding IS NOT NULL")) { - assert(embedding.dimensions() == 3); + auto stream = tx.stream("SELECT id, embedding FROM items WHERE embedding IS NOT NULL"); + for (const auto& [id, embedding2] : stream) { + assert_equal(embedding2, embedding); count++; } - assert(count == 1); + assert_equal(count, 1); } -void test_stream_to(pqxx::connection &conn) { +void test_stream_to(pqxx::connection& conn) { before_each(conn); - pqxx::nontransaction tx(conn); - auto stream = pqxx::stream_to::table(tx, {"items"}, {"embedding"}); - stream << pgvector::Vector({1, 2, 3}); - stream << pgvector::Vector({4, 5, 6}); + pqxx::nontransaction tx{conn}; + pqxx::stream_to stream = pqxx::stream_to::table(tx, {"items"}, {"embedding"}); + stream.write_values(pgvector::Vector{{1, 2, 3}}); + stream.write_values(pgvector::Vector{{4, 5, 6}}); stream.complete(); pqxx::result res = tx.exec("SELECT embedding FROM items ORDER BY id"); - assert(res[0][0].as() == "[1,2,3]"); - assert(res[1][0].as() == "[4,5,6]"); + assert_equal(res.at(0).at(0).as(), "[1,2,3]"); + assert_equal(res.at(1).at(0).as(), "[4,5,6]"); } -void test_precision(pqxx::connection &conn) { +void test_precision(pqxx::connection& conn) { before_each(conn); - pqxx::nontransaction tx(conn); - auto embedding = pgvector::Vector({1.23456789, 0, 0}); + pqxx::nontransaction tx{conn}; + pgvector::Vector embedding{{1.23456789f, 0, 0}}; tx.exec("INSERT INTO items (embedding) VALUES ($1)", {embedding}); tx.exec("SET extra_float_digits = 3"); pqxx::result res = tx.exec("SELECT embedding FROM items ORDER BY id DESC LIMIT 1"); - assert(res[0][0].as() == embedding); + assert_equal(res.at(0).at(0).as(), embedding); +} + +void test_vector_to_string() { + assert_equal(pqxx::to_string(pgvector::Vector{{1, 2, 3}}), "[1,2,3]"); + assert_equal(pqxx::to_string(pgvector::Vector{{-1.234567890123f}}), "[-1.2345679]"); + + assert_exception([] { + pqxx::to_string(pgvector::Vector{std::vector(16001)}); + }, "vector cannot have more than 16000 dimensions"); +} + +void test_vector_from_string() { + assert_equal(pqxx::from_string("[1,2,3]"), pgvector::Vector{{1, 2, 3}}); + + // not valid, but test current behavior + assert_equal(pqxx::from_string("[]"), pgvector::Vector{std::vector{}}); + + assert_exception([] { + auto _ = pqxx::from_string(""); + }, "Malformed vector literal"); + + assert_exception([] { + auto _ = pqxx::from_string("["); + }, "Malformed vector literal"); + + assert_exception([] { + auto _ = pqxx::from_string("[hello]"); + }, float_error("Could not convert 'hello' to float: Invalid argument.")); + + assert_exception([] { + auto _ = pqxx::from_string("[4e38]"); + }, float_error("Could not convert '4e38' to float: Value out of range.")); + + assert_exception([] { + auto _ = pqxx::from_string("[,]"); + }, float_error("Could not convert '' to float: Invalid argument.")); + + assert_exception([] { + auto _ = pqxx::from_string("[1,]"); + }, float_error("Could not convert '' to float: Invalid argument.")); +} + +void test_halfvec_to_string() { + assert_equal(pqxx::to_string(pgvector::HalfVector{{1, 2, 3}}), "[1,2,3]"); +#if __STDCPP_FLOAT16_T__ || defined(__FLT16_MAX__) + assert_equal(pqxx::to_string(pgvector::HalfVector{{static_cast(-1.234567890123f)}}), "[-1.234375]"); +#else + assert_equal(pqxx::to_string(pgvector::HalfVector{{-1.234567890123f}}), "[-1.2345679]"); +#endif + + assert_exception([] { + pqxx::to_string(pgvector::HalfVector{std::vector(16001)}); + }, "halfvec cannot have more than 16000 dimensions"); +} + +void test_halfvec_from_string() { + assert_equal(pqxx::from_string("[1,2,3]"), pgvector::HalfVector{{1, 2, 3}}); + + // not valid, but test current behavior + assert_equal(pqxx::from_string("[]"), pgvector::HalfVector{std::vector{}}); + + assert_exception([] { + auto _ = pqxx::from_string(""); + }, "Malformed halfvec literal"); + + assert_exception([] { + auto _ = pqxx::from_string("["); + }, "Malformed halfvec literal"); + + assert_exception([] { + auto _ = pqxx::from_string("[hello]"); + }, float_error("Could not convert 'hello' to float: Invalid argument.")); + + assert_exception([] { + auto _ = pqxx::from_string("[4e38]"); + }, float_error("Could not convert '4e38' to float: Value out of range.")); + + assert_exception([] { + auto _ = pqxx::from_string("[,]"); + }, float_error("Could not convert '' to float: Invalid argument.")); + + assert_exception([] { + auto _ = pqxx::from_string("[1,]"); + }, float_error("Could not convert '' to float: Invalid argument.")); +} + +void test_sparsevec_to_string() { + assert_equal(pqxx::to_string(pgvector::SparseVector{{1, 0, 2, 0, 3, 0}}), "{1:1,3:2,5:3}/6"); + std::unordered_map map{{999999999, -1.234567890123f}}; + assert_equal(pqxx::to_string(pgvector::SparseVector{map, 1000000000}), "{1000000000:-1.2345679}/1000000000"); + + assert_exception([] { + pqxx::to_string(pgvector::SparseVector{std::vector(16001, 1)}); + }, "sparsevec cannot have more than 16000 dimensions"); +} + +void test_sparsevec_from_string() { + assert_equal(pqxx::from_string("{1:1,3:2,5:3}/6"), pgvector::SparseVector{{1, 0, 2, 0, 3, 0}}); + assert_equal(pqxx::from_string("{}/6"), pgvector::SparseVector{{0, 0, 0, 0, 0, 0}}); + + // not valid, but test current behavior + assert_equal(pqxx::from_string("{}/0"), pgvector::SparseVector{std::vector{}}); + assert_equal(pqxx::from_string("{1:2,1:3}/1"), pgvector::SparseVector{{2}}); + + auto vec = pqxx::from_string("{2:4,1:3,3:0}/3"); + assert_equal(vec.indices() == std::vector{0, 1}, true); + assert_equal(vec.values() == std::vector{3, 4}, true); + + assert_exception([] { + auto _ = pqxx::from_string(""); + }, "Malformed sparsevec literal"); + + assert_exception([] { + auto _ = pqxx::from_string("{"); + }, "Malformed sparsevec literal"); + + assert_exception([] { + auto _ = pqxx::from_string("{ }/"); + }, "Could not convert '' to int: Invalid argument."); + + assert_exception([] { + auto _ = pqxx::from_string("{}/-1"); + }, "sparsevec cannot have negative dimensions"); + + assert_exception([] { + auto _ = pqxx::from_string("{:}/1"); + }, "Could not convert '' to int: Invalid argument."); + + assert_exception([] { + auto _ = pqxx::from_string("{,}/1"); + }, "Malformed sparsevec literal"); + + assert_exception([] { + auto _ = pqxx::from_string("{0:1}/1"); + }, "sparsevec index out of bounds"); + + assert_exception([] { + auto _ = pqxx::from_string("{-2147483648:1}/1"); + }, "sparsevec index out of bounds"); + + assert_exception([] { + auto _ = pqxx::from_string("{2:1}/1"); + }, "sparsevec index out of bounds"); + + assert_exception([] { + auto _ = pqxx::from_string("{1:1}/0"); + }, "sparsevec index out of bounds"); + + assert_exception([] { + auto _ = pqxx::from_string("{1:4e38}/1"); + }, float_error("Could not convert '4e38' to float: Value out of range.")); + + assert_exception([] { + auto _ = pqxx::from_string("{a:1}/1"); + }, "Could not convert 'a' to int: Invalid argument."); + + assert_exception([] { + auto _ = pqxx::from_string("{1:a}/1"); + }, float_error("Could not convert 'a' to float: Invalid argument.")); + + assert_exception([] { + auto _ = pqxx::from_string("{}/a"); + }, "Could not convert 'a' to int: Invalid argument."); +} + +void test_vector_to_buf() { + std::array buf{}; + assert_equal(pqxx::to_buf(std::span{buf}, pgvector::Vector{{1, 2, 3}}), "[1,2,3]"); + + assert_exception([] { + return pqxx::to_buf(std::span{}, pgvector::Vector{{1, 2, 3}}); + }, "Not enough space in buffer for vector"); +} + +void test_vector_into_buf() { + std::array buf{}; + size_t size = pqxx::into_buf(std::span{buf}, pgvector::Vector{{1, 2, 3}}); + assert_equal(size, 7u); + assert_equal(std::string_view{buf.data(), size}, "[1,2,3]"); + + assert_exception([] { + return pqxx::into_buf(std::span{}, pgvector::Vector{{1, 2, 3}}); + }, "Not enough space in buffer for vector"); +} + +void test_halfvec_to_buf() { + std::array buf{}; + assert_equal(pqxx::to_buf(std::span{buf}, pgvector::HalfVector{{1, 2, 3}}), "[1,2,3]"); + + assert_exception([] { + return pqxx::to_buf(std::span{}, pgvector::HalfVector{{1, 2, 3}}); + }, "Not enough space in buffer for halfvec"); +} + +void test_halfvec_into_buf() { + std::array buf{}; + size_t size = pqxx::into_buf(std::span{buf}, pgvector::HalfVector{{1, 2, 3}}); + assert_equal(size, 7u); + assert_equal(std::string_view{buf.data(), size}, "[1,2,3]"); + + assert_exception([] { + return pqxx::into_buf(std::span{}, pgvector::HalfVector{{1, 2, 3}}); + }, "Not enough space in buffer for halfvec"); +} + +void test_sparsevec_to_buf() { + std::array buf{}; + assert_equal(pqxx::to_buf(std::span{buf}, pgvector::SparseVector{{1, 2, 3}}), "{1:1,2:2,3:3}/3"); + + int max = std::numeric_limits::max(); + assert_equal(pqxx::to_buf(std::span{buf}, pgvector::SparseVector{{{max - 1, 1}}, max}), "{2147483647:1}/2147483647"); + + assert_exception([] { + return pqxx::to_buf(std::span{}, pgvector::SparseVector{{1, 2, 3}}); + }, "Not enough space in buffer for sparsevec"); +} + +void test_sparsevec_into_buf() { + std::array buf{}; + size_t size = pqxx::into_buf(std::span{buf}, pgvector::SparseVector{{1, 2, 3}}); + assert_equal(size, 15u); + assert_equal(std::string_view{buf.data(), size}, "{1:1,2:2,3:3}/3"); + + assert_exception([] { + return pqxx::into_buf(std::span{}, pgvector::SparseVector{{1, 2, 3}}); + }, "Not enough space in buffer for sparsevec"); +} + +void test_vector_size_buffer() { + assert_equal(pqxx::size_buffer(pgvector::Vector{{1, 2, 3}}), 55u); +} + +void test_halfvec_size_buffer() { + assert_equal(pqxx::size_buffer(pgvector::HalfVector{{1, 2, 3}}), 55u); +} + +void test_sparsevec_size_buffer() { + assert_equal(pqxx::size_buffer(pgvector::SparseVector{{1, 2, 3}}), 103u); } void test_pqxx() { - pqxx::connection conn("dbname=pgvector_cpp_test"); + pqxx::connection conn{"dbname=pgvector_cpp_test"}; setup(conn); test_vector(conn); @@ -147,4 +392,22 @@ void test_pqxx() { test_stream(conn); test_stream_to(conn); test_precision(conn); + + test_vector_to_string(); + test_vector_from_string(); + test_halfvec_to_string(); + test_halfvec_from_string(); + test_sparsevec_to_string(); + test_sparsevec_from_string(); + + test_vector_to_buf(); + test_vector_into_buf(); + test_halfvec_to_buf(); + test_halfvec_into_buf(); + test_sparsevec_to_buf(); + test_sparsevec_into_buf(); + + test_vector_size_buffer(); + test_halfvec_size_buffer(); + test_sparsevec_size_buffer(); } diff --git a/test/sparsevec_test.cpp b/test/sparsevec_test.cpp index 9b9dff2..b2dd65e 100644 --- a/test/sparsevec_test.cpp +++ b/test/sparsevec_test.cpp @@ -1,40 +1,84 @@ -#include +#include +#include +#include #include +#include -#include "../include/pgvector/sparsevec.hpp" +#include -#if __cplusplus >= 202002L -#include -#endif +#include "helper.hpp" using pgvector::SparseVector; static void test_constructor_vector() { - auto vec = SparseVector({1, 0, 2, 0, 3, 0}); - assert(vec.dimensions() == 6); - assert(vec.indices() == (std::vector{0, 2, 4})); - assert(vec.values() == (std::vector{1, 2, 3})); + SparseVector vec{std::vector{1, 0, 2, 0, 3, 0}}; + assert_equal(vec.dimensions(), 6); + assert_equal(vec.indices() == std::vector{0, 2, 4}, true); + assert_equal(vec.values() == std::vector{1, 2, 3}, true); } -#if __cplusplus >= 202002L static void test_constructor_span() { - auto vec = SparseVector(std::span({1, 0, 2, 0, 3, 0})); - assert(vec.dimensions() == 6); + SparseVector vec{std::span{{1, 0, 2, 0, 3, 0}}}; + assert_equal(vec.dimensions(), 6); } -#endif static void test_constructor_map() { - std::unordered_map map = {{2, 2}, {4, 3}, {3, 0}, {0, 1}}; - auto vec = SparseVector(map, 6); - assert(vec.dimensions() == 6); - assert(vec.indices() == (std::vector{0, 2, 4})); - assert(vec.values() == (std::vector{1, 2, 3})); + std::unordered_map map{{2, 2}, {4, 3}, {3, 0}, {0, 1}}; + SparseVector vec{map, 6}; + assert_equal(vec.dimensions(), 6); + assert_equal(vec.indices() == std::vector{0, 2, 4}, true); + assert_equal(vec.values() == std::vector{1, 2, 3}, true); + + assert_exception([&]{ + SparseVector{map, -1}; + }, "sparsevec cannot have negative dimensions"); + + assert_exception([&]{ + SparseVector{map, 4}; + }, "sparsevec index out of bounds"); + + assert_exception([]{ + SparseVector{{{0, 1}}, 0}; + }, "sparsevec index out of bounds"); +} + +static void test_constructor_empty() { + SparseVector vec{std::vector{}}; + assert_equal(vec.dimensions(), 0); + + SparseVector vec2{{}, 0}; + assert_equal(vec2.dimensions(), 0); +} + +static void test_dimensions() { + SparseVector vec{std::vector{1, 0, 2, 0, 3, 0}}; + assert_equal(vec.dimensions(), 6); +} + +static void test_indices() { + SparseVector vec{std::vector{1, 0, 2, 0, 3, 0}}; + assert_equal(vec.indices() == std::vector{0, 2, 4}, true); +} + +static void test_values() { + SparseVector vec{std::vector{1, 0, 2, 0, 3, 0}}; + assert_equal(vec.values() == std::vector{1, 2, 3}, true); +} + +static void test_string() { + SparseVector vec{std::vector{1, 0, 2, 0, 3, 0}}; + std::ostringstream oss; + oss << vec; + assert_equal(oss.str(), "{1:1,3:2,5:3}/6"); } void test_sparsevec() { test_constructor_vector(); -#if __cplusplus >= 202002L test_constructor_span(); -#endif + test_constructor_empty(); test_constructor_map(); + test_dimensions(); + test_indices(); + test_values(); + test_string(); } diff --git a/test/vector_test.cpp b/test/vector_test.cpp index 3580a8d..44b1d7a 100644 --- a/test/vector_test.cpp +++ b/test/vector_test.cpp @@ -1,28 +1,50 @@ -#include +#include +#include +#include -#include "../include/pgvector/vector.hpp" +#include -#if __cplusplus >= 202002L -#include -#endif +#include "helper.hpp" using pgvector::Vector; static void test_constructor_vector() { - auto vec = Vector({1, 2, 3}); - assert(vec.dimensions() == 3); + Vector vec{std::vector{1, 2, 3}}; + assert_equal(vec.dimensions(), 3u); } -#if __cplusplus >= 202002L static void test_constructor_span() { - auto vec = Vector(std::span({1, 2, 3})); - assert(vec.dimensions() == 3); + Vector vec{std::span{{1, 2, 3}}}; + assert_equal(vec.dimensions(), 3u); +} + +static void test_constructor_empty() { + Vector vec{std::vector{}}; + assert_equal(vec.dimensions(), 0u); +} + +static void test_dimensions() { + Vector vec{{1, 2, 3}}; + assert_equal(vec.dimensions(), 3u); +} + +static void test_values() { + Vector vec{{1, 2, 3}}; + assert_equal(vec.values() == std::vector{1, 2, 3}, true); +} + +static void test_string() { + Vector vec{{1, 2, 3}}; + std::ostringstream oss; + oss << vec; + assert_equal(oss.str(), "[1,2,3]"); } -#endif void test_vector() { test_constructor_vector(); -#if __cplusplus >= 202002L test_constructor_span(); -#endif + test_constructor_empty(); + test_dimensions(); + test_values(); + test_string(); }