diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index d4f656a..f4798bf 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -78,6 +78,7 @@ jobs: nmake /NOLOGO /F Makefile.win uninstall shell: cmd i386: + if: ${{ !startsWith(github.ref_name, 'mac') && !startsWith(github.ref_name, 'windows') }} runs-on: ubuntu-latest container: image: debian:11 diff --git a/CHANGELOG.md b/CHANGELOG.md index 7944166..07040d0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,19 @@ +## 0.5.1 (2023-10-10) + +- Improved performance of HNSW index builds +- Added check for MVCC-compliant snapshot for index scans + +## 0.5.0 (2023-08-28) + +- Added HNSW index type +- Added support for parallel index builds for IVFFlat +- Added `l1_distance` function +- Added element-wise multiplication for vectors +- Added `sum` aggregate +- Improved performance of distance functions +- Fixed out of range results for cosine distance +- Fixed results for NULL and NaN distances for IVFFlat + ## 0.4.4 (2023-06-12) - Improved error message for malformed vector literal diff --git a/Dockerfile b/Dockerfile index 6fe3099..f3ded45 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,6 +5,7 @@ ARG PG_MAJOR COPY . /tmp/pgvector RUN apt-get update && \ + apt-mark hold locales && \ apt-get install -y --no-install-recommends build-essential postgresql-server-dev-$PG_MAJOR && \ cd /tmp/pgvector && \ make clean && \ @@ -15,4 +16,5 @@ RUN apt-get update && \ rm -r /tmp/pgvector && \ apt-get remove -y build-essential postgresql-server-dev-$PG_MAJOR && \ apt-get autoremove -y && \ + apt-mark unhold locales && \ rm -rf /var/lib/apt/lists/* diff --git a/META.json b/META.json index a71d810..38d3919 100644 --- a/META.json +++ b/META.json @@ -2,7 +2,7 @@ "name": "vector", "abstract": "Open-source vector similarity search for Postgres", "description": "Supports L2 distance, inner product, and cosine distance", - "version": "0.4.4", + "version": "0.5.1", "maintainer": [ "Andrew Kane " ], @@ -20,7 +20,7 @@ "vector": { "file": "sql/vector.sql", "docfile": "README.md", - "version": "0.4.4", + "version": "0.5.1", "abstract": "Open-source vector similarity search for Postgres" } }, diff --git a/Makefile b/Makefile index ff26f56..f6c1f20 100644 --- a/Makefile +++ b/Makefile @@ -1,13 +1,14 @@ EXTENSION = vector -EXTVERSION = 0.4.4 +EXTVERSION = 0.5.1 MODULE_big = vector DATA = $(wildcard sql/*--*.sql) -OBJS = src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/ivfvacuum.o src/vector.o +OBJS = src/hnsw.o src/hnswbuild.o src/hnswinsert.o src/hnswscan.o src/hnswutils.o src/hnswvacuum.o src/ivfbuild.o src/ivfflat.o src/ivfinsert.o src/ivfkmeans.o src/ivfscan.o src/ivfutils.o src/ivfvacuum.o src/vector.o +HEADERS = src/vector.h TESTS = $(wildcard test/sql/*.sql) REGRESS = $(patsubst test/sql/%.sql,%,$(TESTS)) -REGRESS_OPTS = --inputdir=test --load-extension=vector +REGRESS_OPTS = --inputdir=test --load-extension=$(EXTENSION) OPTFLAGS = -march=native diff --git a/Makefile.win b/Makefile.win index 8ceb572..fbe5768 100644 --- a/Makefile.win +++ b/Makefile.win @@ -1,10 +1,11 @@ EXTENSION = vector -EXTVERSION = 0.4.4 +EXTVERSION = 0.5.1 -OBJS = src\ivfbuild.obj src\ivfflat.obj src\ivfinsert.obj src\ivfkmeans.obj src\ivfscan.obj src\ivfutils.obj src\ivfvacuum.obj src\vector.obj +OBJS = src\hnsw.obj src\hnswbuild.obj src\hnswinsert.obj src\hnswscan.obj src\hnswutils.obj src\hnswvacuum.obj src\ivfbuild.obj src\ivfflat.obj src\ivfinsert.obj src\ivfkmeans.obj src\ivfscan.obj src\ivfutils.obj src\ivfvacuum.obj src\vector.obj +HEADERS = src\vector.h REGRESS = btree cast copy functions input ivfflat_cosine ivfflat_ip ivfflat_l2 ivfflat_options ivfflat_unlogged -REGRESS_OPTS = --inputdir=test --load-extension=vector +REGRESS_OPTS = --inputdir=test --load-extension=$(EXTENSION) # For /arch flags # https://learn.microsoft.com/en-us/cpp/build/reference/arch-minimum-cpu-architecture @@ -54,6 +55,8 @@ install: copy $(SHLIB) "$(PKGLIBDIR)" copy $(EXTENSION).control "$(SHAREDIR)\extension" copy sql\$(EXTENSION)--*.sql "$(SHAREDIR)\extension" + mkdir "$(INCLUDEDIR_SERVER)\extension\$(EXTENSION)" + copy $(HEADERS) "$(INCLUDEDIR_SERVER)\extension\$(EXTENSION)" installcheck: "$(BINDIR)\pg_regress" --bindir="$(BINDIR)" $(REGRESS_OPTS) $(REGRESS) @@ -61,7 +64,9 @@ installcheck: uninstall: del /f "$(PKGLIBDIR)\$(SHLIB)" del /f "$(SHAREDIR)\extension\$(EXTENSION).control" - del /f "$(SHAREDIR)\extension\vector--*.sql" + del /f "$(SHAREDIR)\extension\$(EXTENSION)--*.sql" + del /f "$(INCLUDEDIR_SERVER)\extension\$(EXTENSION)\*.h" + rmdir "$(INCLUDEDIR_SERVER)\extension\$(EXTENSION)" clean: del /f $(SHLIB) $(EXTENSION).lib $(EXTENSION).exp diff --git a/README.md b/README.md index 68bcfd6..33c9945 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ Open-source vector similarity search for Postgres -Store all of your application data in one place. Supports: +Store your vectors with the rest of your data. Supports: - exact and approximate nearest neighbor search - L2 distance, inner product, and cosine distance @@ -18,7 +18,7 @@ Compile and install the extension (supports Postgres 11+) ```sh cd /tmp -git clone --branch v0.4.4 https://github.com/pgvector/pgvector.git +git clone --branch v0.5.1 https://github.com/pgvector/pgvector.git cd pgvector make make install # may need sudo @@ -157,7 +157,16 @@ SELECT category_id, AVG(embedding) FROM items GROUP BY category_id; By default, pgvector performs exact nearest neighbor search, which provides perfect recall. -You can add an index to use approximate nearest neighbor search, which trades some recall for performance. Unlike typical indexes, you will see different results for queries after adding an approximate index. +You can add an index to use approximate nearest neighbor search, which trades some recall for speed. Unlike typical indexes, you will see different results for queries after adding an approximate index. + +Supported index types are: + +- [IVFFlat](#ivfflat) +- [HNSW](#hnsw) - added in 0.5.0 + +## IVFFlat + +An IVFFlat index divides vectors into lists, and then searches a subset of those lists that are closest to the query vector. It has faster build times and uses less memory than HNSW, but has lower query performance (in terms of speed-recall tradeoff). Three keys to achieving good recall are: @@ -206,7 +215,63 @@ SELECT ... COMMIT; ``` -### Indexing Progress +## HNSW + +An HNSW index creates a multilayer graph. It has slower build times and uses more memory than IVFFlat, but has better query performance (in terms of speed-recall tradeoff). There’s no training step like IVFFlat, so the index can be created without any data in the table. + +Add an index for each distance function you want to use. + +L2 distance + +```sql +CREATE INDEX ON items USING hnsw (embedding vector_l2_ops); +``` + +Inner product + +```sql +CREATE INDEX ON items USING hnsw (embedding vector_ip_ops); +``` + +Cosine distance + +```sql +CREATE INDEX ON items USING hnsw (embedding vector_cosine_ops); +``` + +Vectors with up to 2,000 dimensions can be indexed. + +### Index Options + +Specify HNSW parameters + +- `m` - the max number of connections per layer (16 by default) +- `ef_construction` - the size of the dynamic candidate list for constructing the graph (64 by default) + +```sql +CREATE INDEX ON items USING hnsw (embedding vector_l2_ops) WITH (m = 16, ef_construction = 64); +``` + +### Query Options + +Specify the size of the dynamic candidate list for search (40 by default) + +```sql +SET hnsw.ef_search = 100; +``` + +A higher value provides better recall at the cost of speed. + +Use `SET LOCAL` inside a transaction to set it for a single query + +```sql +BEGIN; +SET LOCAL hnsw.ef_search = 100; +SELECT ... +COMMIT; +``` + +## Indexing Progress Check [indexing progress](https://www.postgresql.org/docs/current/progress-reporting.html#CREATE-INDEX-PROGRESS-REPORTING) with Postgres 12+ @@ -217,13 +282,13 @@ SELECT phase, tuples_done, tuples_total FROM pg_stat_progress_create_index; The phases are: 1. `initializing` -2. `performing k-means` -3. `sorting tuples` +2. `performing k-means` - IVFFlat only +3. `assigning tuples` - IVFFlat only 4. `loading tuples` Note: `tuples_done` and `tuples_total` are only populated during the `loading tuples` phase -### Filtering +## Filtering There are a few ways to index nearest neighbor queries with a `WHERE` clause @@ -283,7 +348,7 @@ SELECT * FROM items ORDER BY embedding <#> '[3,1,2]' LIMIT 5; ### Approximate Search -To speed up queries with an index, increase the number of inverted lists (at the expense of recall). +To speed up queries with an IVFFlat index, increase the number of inverted lists (at the expense of recall). ```sql CREATE INDEX ON items USING ivfflat (embedding vector_l2_ops) WITH (lists = 1000); @@ -298,6 +363,7 @@ Language | Libraries / Examples C++ | [pgvector-cpp](https://github.com/pgvector/pgvector-cpp) C# | [pgvector-dotnet](https://github.com/pgvector/pgvector-dotnet) Crystal | [pgvector-crystal](https://github.com/pgvector/pgvector-crystal) +Dart | [pgvector-dart](https://github.com/pgvector/pgvector-dart) Elixir | [pgvector-elixir](https://github.com/pgvector/pgvector-elixir) Go | [pgvector-go](https://github.com/pgvector/pgvector-go) Haskell | [pgvector-haskell](https://github.com/pgvector/pgvector-haskell) @@ -327,10 +393,45 @@ Yes, pgvector uses the write-ahead log (WAL), which allows for replication and p You’ll need to use [dimensionality reduction](https://en.wikipedia.org/wiki/Dimensionality_reduction) at the moment. -#### Why am I seeing less results after adding an index? +## Troubleshooting + +#### Why isn’t a query using an index? + +The cost estimation in pgvector < 0.4.3 does not always work well with the planner. You can encourage the planner to use an index for a query with: + +```sql +BEGIN; +SET LOCAL enable_seqscan = off; +SELECT ... +COMMIT; +``` + +#### Why isn’t a query using a parallel table scan? + +The planner doesn’t consider [out-of-line storage](https://www.postgresql.org/docs/current/storage-toast.html) in cost estimates, which can make a serial scan look cheaper. You can reduce the cost of a parallel scan for a query with: + +```sql +BEGIN; +SET LOCAL min_parallel_table_scan_size = 1; +SET LOCAL parallel_setup_cost = 1; +SELECT ... +COMMIT; +``` + +or choose to store vectors inline: + +```sql +ALTER TABLE items ALTER COLUMN embedding SET STORAGE PLAIN; +``` + +#### Why are there less results for a query after adding an IVFFlat index? The index was likely created with too little data for the number of lists. Drop the index until the table has more data. +```sql +DROP INDEX index_name; +``` + ## Reference ### Vector Type @@ -339,29 +440,32 @@ Each vector takes `4 * dimensions + 8` bytes of storage. Each element is a singl ### Vector Operators -Operator | Description ---- | --- -\+ | element-wise addition -\- | element-wise subtraction -<-> | Euclidean distance -<#> | negative inner product -<=> | cosine distance +Operator | Description | Added +--- | --- | --- +\+ | element-wise addition | +\- | element-wise subtraction | +\* | element-wise multiplication | 0.5.0 +<-> | Euclidean distance | +<#> | negative inner product | +<=> | cosine distance | ### Vector Functions -Function | Description ---- | --- -cosine_distance(vector, vector) → double precision | cosine distance -inner_product(vector, vector) → double precision | inner product -l2_distance(vector, vector) → double precision | Euclidean distance -vector_dims(vector) → integer | number of dimensions -vector_norm(vector) → double precision | Euclidean norm +Function | Description | Added +--- | --- | --- +cosine_distance(vector, vector) → double precision | cosine distance | +inner_product(vector, vector) → double precision | inner product | +l2_distance(vector, vector) → double precision | Euclidean distance | +l1_distance(vector, vector) → double precision | taxicab distance | 0.5.0 +vector_dims(vector) → integer | number of dimensions | +vector_norm(vector) → double precision | Euclidean norm | ### Aggregate Functions -Function | Description ---- | --- -avg(vector) → vector | arithmetic mean +Function | Description | Added +--- | --- | --- +avg(vector) → vector | average | +sum(vector) → vector | sum | 0.5.0 ## Installation Notes @@ -393,11 +497,19 @@ Note: Replace `15` with your Postgres server version ### Windows -Support for Windows is currently experimental. Use `nmake` to build: +Support for Windows is currently experimental. Ensure [C++ support in Visual Studio](https://learn.microsoft.com/en-us/cpp/build/building-on-the-command-line?view=msvc-170#download-and-install-the-tools) is installed, and run: + +```cmd +call "C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvars64.bat" +``` + +Note: The exact path will vary depending on your Visual Studio version and edition + +Then use `nmake` to build: ```cmd set "PGROOT=C:\Program Files\PostgreSQL\15" -git clone --branch v0.4.4 https://github.com/pgvector/pgvector.git +git clone --branch v0.5.1 https://github.com/pgvector/pgvector.git cd pgvector nmake /F Makefile.win nmake /F Makefile.win install @@ -418,7 +530,7 @@ This adds pgvector to the [Postgres image](https://hub.docker.com/_/postgres) (r You can also build the image manually: ```sh -git clone --branch v0.4.4 https://github.com/pgvector/pgvector.git +git clone --branch v0.5.1 https://github.com/pgvector/pgvector.git cd pgvector docker build --build-arg PG_MAJOR=15 -t myuser/pgvector . ``` @@ -481,19 +593,20 @@ Download the [latest release](https://postgresapp.com/downloads.html) with Postg pgvector is available on [these providers](https://github.com/pgvector/pgvector/issues/54). -To request a new extension on other providers: - -- DigitalOcean Managed Databases - vote or comment on [this page](https://ideas.digitalocean.com/managed-database/p/pgvector-extension-for-postgresql) -- Heroku Postgres - vote or comment on [this page](https://github.com/heroku/roadmap/issues/156) - ## Upgrading -Install the latest version and run: +Install the latest version. Then in each database you want to upgrade, run: ```sql ALTER EXTENSION vector UPDATE; ``` +You can check the version in the current database with: + +```sql +SELECT extversion FROM pg_extension WHERE extname = 'vector'; +``` + ## Upgrade Notes ### 0.4.0 @@ -526,9 +639,10 @@ Thanks to: - [PASE: PostgreSQL Ultra-High-Dimensional Approximate Nearest Neighbor Search Extension](https://dl.acm.org/doi/pdf/10.1145/3318464.3386131) - [Faiss: A Library for Efficient Similarity Search and Clustering of Dense Vectors](https://github.com/facebookresearch/faiss) -- [Using the Triangle Inequality to Accelerate k-means](https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf) +- [Using the Triangle Inequality to Accelerate k-means](https://cdn.aaai.org/ICML/2003/ICML03-022.pdf) - [k-means++: The Advantage of Careful Seeding](https://theory.stanford.edu/~sergei/papers/kMeansPP-soda.pdf) - [Concept Decompositions for Large Sparse Text Data using Clustering](https://www.cs.utexas.edu/users/inderjit/public_papers/concept_mlj.pdf) +- [Efficient and Robust Approximate Nearest Neighbor Search using Hierarchical Navigable Small World Graphs](https://arxiv.org/ftp/arxiv/papers/1603/1603.09320.pdf) ## History @@ -576,4 +690,4 @@ Resources for contributors - [Extension Building Infrastructure](https://www.postgresql.org/docs/current/extend-pgxs.html) - [Index Access Method Interface Definition](https://www.postgresql.org/docs/current/indexam.html) -- [Generic WAL Records](https://www.postgresql.org/docs/13/generic-wal.html) +- [Generic WAL Records](https://www.postgresql.org/docs/current/generic-wal.html) diff --git a/sql/vector--0.4.4--0.5.0.sql b/sql/vector--0.4.4--0.5.0.sql new file mode 100644 index 0000000..48572bf --- /dev/null +++ b/sql/vector--0.4.4--0.5.0.sql @@ -0,0 +1,43 @@ +-- complain if script is sourced in psql, rather than via CREATE EXTENSION +\echo Use "ALTER EXTENSION vector UPDATE TO '0.5.0'" to load this file. \quit + +CREATE FUNCTION l1_distance(vector, vector) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE FUNCTION vector_mul(vector, vector) RETURNS vector + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + +CREATE OPERATOR * ( + LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_mul, + COMMUTATOR = * +); + +CREATE AGGREGATE sum(vector) ( + SFUNC = vector_add, + STYPE = vector, + COMBINEFUNC = vector_add, + PARALLEL = SAFE +); + +CREATE FUNCTION hnswhandler(internal) RETURNS index_am_handler + AS 'MODULE_PATHNAME' LANGUAGE C; + +CREATE ACCESS METHOD hnsw TYPE INDEX HANDLER hnswhandler; + +COMMENT ON ACCESS METHOD hnsw IS 'hnsw index access method'; + +CREATE OPERATOR CLASS vector_l2_ops + FOR TYPE vector USING hnsw AS + OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 1 vector_l2_squared_distance(vector, vector); + +CREATE OPERATOR CLASS vector_ip_ops + FOR TYPE vector USING hnsw AS + OPERATOR 1 <#> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 1 vector_negative_inner_product(vector, vector); + +CREATE OPERATOR CLASS vector_cosine_ops + FOR TYPE vector USING hnsw AS + OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 1 vector_negative_inner_product(vector, vector), + FUNCTION 2 vector_norm(vector); diff --git a/sql/vector--0.5.0--0.5.1.sql b/sql/vector--0.5.0--0.5.1.sql new file mode 100644 index 0000000..54e09c5 --- /dev/null +++ b/sql/vector--0.5.0--0.5.1.sql @@ -0,0 +1,2 @@ +-- complain if script is sourced in psql, rather than via CREATE EXTENSION +\echo Use "ALTER EXTENSION vector UPDATE TO '0.5.1'" to load this file. \quit diff --git a/sql/vector.sql b/sql/vector.sql index 6188e2e..137931f 100644 --- a/sql/vector.sql +++ b/sql/vector.sql @@ -40,6 +40,9 @@ CREATE FUNCTION inner_product(vector, vector) RETURNS float8 CREATE FUNCTION cosine_distance(vector, vector) RETURNS float8 AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION l1_distance(vector, vector) RETURNS float8 + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + CREATE FUNCTION vector_dims(vector) RETURNS integer AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; @@ -52,6 +55,9 @@ CREATE FUNCTION vector_add(vector, vector) RETURNS vector CREATE FUNCTION vector_sub(vector, vector) RETURNS vector AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; +CREATE FUNCTION vector_mul(vector, vector) RETURNS vector + AS 'MODULE_PATHNAME' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE; + -- private functions CREATE FUNCTION vector_lt(vector, vector) RETURNS bool @@ -104,6 +110,13 @@ CREATE AGGREGATE avg(vector) ( PARALLEL = SAFE ); +CREATE AGGREGATE sum(vector) ( + SFUNC = vector_add, + STYPE = vector, + COMBINEFUNC = vector_add, + PARALLEL = SAFE +); + -- cast functions CREATE FUNCTION vector(vector, integer, boolean) RETURNS vector @@ -171,6 +184,11 @@ CREATE OPERATOR - ( COMMUTATOR = - ); +CREATE OPERATOR * ( + LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_mul, + COMMUTATOR = * +); + CREATE OPERATOR < ( LEFTARG = vector, RIGHTARG = vector, PROCEDURE = vector_lt, COMMUTATOR = > , NEGATOR = >= , @@ -209,7 +227,7 @@ CREATE OPERATOR > ( RESTRICT = scalargtsel, JOIN = scalargtjoinsel ); --- access method +-- access methods CREATE FUNCTION ivfflathandler(internal) RETURNS index_am_handler AS 'MODULE_PATHNAME' LANGUAGE C; @@ -218,6 +236,13 @@ CREATE ACCESS METHOD ivfflat TYPE INDEX HANDLER ivfflathandler; COMMENT ON ACCESS METHOD ivfflat IS 'ivfflat index access method'; +CREATE FUNCTION hnswhandler(internal) RETURNS index_am_handler + AS 'MODULE_PATHNAME' LANGUAGE C; + +CREATE ACCESS METHOD hnsw TYPE INDEX HANDLER hnswhandler; + +COMMENT ON ACCESS METHOD hnsw IS 'hnsw index access method'; + -- opclasses CREATE OPERATOR CLASS vector_ops @@ -249,3 +274,19 @@ CREATE OPERATOR CLASS vector_cosine_ops FUNCTION 2 vector_norm(vector), FUNCTION 3 vector_spherical_distance(vector, vector), FUNCTION 4 vector_norm(vector); + +CREATE OPERATOR CLASS vector_l2_ops + FOR TYPE vector USING hnsw AS + OPERATOR 1 <-> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 1 vector_l2_squared_distance(vector, vector); + +CREATE OPERATOR CLASS vector_ip_ops + FOR TYPE vector USING hnsw AS + OPERATOR 1 <#> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 1 vector_negative_inner_product(vector, vector); + +CREATE OPERATOR CLASS vector_cosine_ops + FOR TYPE vector USING hnsw AS + OPERATOR 1 <=> (vector, vector) FOR ORDER BY float_ops, + FUNCTION 1 vector_negative_inner_product(vector, vector), + FUNCTION 2 vector_norm(vector); diff --git a/src/hnsw.c b/src/hnsw.c new file mode 100644 index 0000000..758e418 --- /dev/null +++ b/src/hnsw.c @@ -0,0 +1,224 @@ +#include "postgres.h" + +#include +#include + +#include "access/amapi.h" +#include "commands/vacuum.h" +#include "hnsw.h" +#include "utils/guc.h" +#include "utils/selfuncs.h" + +#if PG_VERSION_NUM >= 120000 +#include "commands/progress.h" +#endif + +int hnsw_ef_search; +static relopt_kind hnsw_relopt_kind; + +/* + * Initialize index options and variables + */ +void +HnswInit(void) +{ + hnsw_relopt_kind = add_reloption_kind(); + add_int_reloption(hnsw_relopt_kind, "m", "Max number of connections", + HNSW_DEFAULT_M, HNSW_MIN_M, HNSW_MAX_M +#if PG_VERSION_NUM >= 130000 + ,AccessExclusiveLock +#endif + ); + add_int_reloption(hnsw_relopt_kind, "ef_construction", "Size of the dynamic candidate list for construction", + HNSW_DEFAULT_EF_CONSTRUCTION, HNSW_MIN_EF_CONSTRUCTION, HNSW_MAX_EF_CONSTRUCTION +#if PG_VERSION_NUM >= 130000 + ,AccessExclusiveLock +#endif + ); + + DefineCustomIntVariable("hnsw.ef_search", "Sets the size of the dynamic candidate list for search", + "Valid range is 1..1000.", &hnsw_ef_search, + HNSW_DEFAULT_EF_SEARCH, HNSW_MIN_EF_SEARCH, HNSW_MAX_EF_SEARCH, PGC_USERSET, 0, NULL, NULL, NULL); +} + +/* + * Get the name of index build phase + */ +#if PG_VERSION_NUM >= 120000 +static char * +hnswbuildphasename(int64 phasenum) +{ + switch (phasenum) + { + case PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE: + return "initializing"; + case PROGRESS_HNSW_PHASE_LOAD: + return "loading tuples"; + default: + return NULL; + } +} +#endif + +/* + * Estimate the cost of an index scan + */ +static void +hnswcostestimate(PlannerInfo *root, IndexPath *path, double loop_count, + Cost *indexStartupCost, Cost *indexTotalCost, + Selectivity *indexSelectivity, double *indexCorrelation, + double *indexPages) +{ + GenericCosts costs; + int m; + int entryLevel; + Relation index; +#if PG_VERSION_NUM < 120000 + List *qinfos; +#endif + + /* Never use index without order */ + if (path->indexorderbys == NULL) + { + *indexStartupCost = DBL_MAX; + *indexTotalCost = DBL_MAX; + *indexSelectivity = 0; + *indexCorrelation = 0; + *indexPages = 0; + return; + } + + MemSet(&costs, 0, sizeof(costs)); + + index = index_open(path->indexinfo->indexoid, NoLock); + HnswGetMetaPageInfo(index, &m, NULL); + index_close(index, NoLock); + + /* Approximate entry level */ + entryLevel = (int) -log(1.0 / path->indexinfo->tuples) * HnswGetMl(m); + + /* TODO Improve estimate of visited tuples (currently underestimates) */ + /* Account for number of tuples (or entry level), m, and ef_search */ + costs.numIndexTuples = (entryLevel + 2) * m; + +#if PG_VERSION_NUM >= 120000 + genericcostestimate(root, path, loop_count, &costs); +#else + qinfos = deconstruct_indexquals(path); + genericcostestimate(root, path, loop_count, qinfos, &costs); +#endif + + /* Use total cost since most work happens before first tuple is returned */ + *indexStartupCost = costs.indexTotalCost; + *indexTotalCost = costs.indexTotalCost; + *indexSelectivity = costs.indexSelectivity; + *indexCorrelation = costs.indexCorrelation; + *indexPages = costs.numIndexPages; +} + +/* + * Parse and validate the reloptions + */ +static bytea * +hnswoptions(Datum reloptions, bool validate) +{ + static const relopt_parse_elt tab[] = { + {"m", RELOPT_TYPE_INT, offsetof(HnswOptions, m)}, + {"ef_construction", RELOPT_TYPE_INT, offsetof(HnswOptions, efConstruction)}, + }; + +#if PG_VERSION_NUM >= 130000 + return (bytea *) build_reloptions(reloptions, validate, + hnsw_relopt_kind, + sizeof(HnswOptions), + tab, lengthof(tab)); +#else + relopt_value *options; + int numoptions; + HnswOptions *rdopts; + + options = parseRelOptions(reloptions, validate, hnsw_relopt_kind, &numoptions); + rdopts = allocateReloptStruct(sizeof(HnswOptions), options, numoptions); + fillRelOptions((void *) rdopts, sizeof(HnswOptions), options, numoptions, + validate, tab, lengthof(tab)); + + return (bytea *) rdopts; +#endif +} + +/* + * Validate catalog entries for the specified operator class + */ +static bool +hnswvalidate(Oid opclassoid) +{ + return true; +} + +/* + * Define index handler + * + * See https://www.postgresql.org/docs/current/index-api.html + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(hnswhandler); +Datum +hnswhandler(PG_FUNCTION_ARGS) +{ + IndexAmRoutine *amroutine = makeNode(IndexAmRoutine); + + amroutine->amstrategies = 0; + amroutine->amsupport = 2; +#if PG_VERSION_NUM >= 130000 + amroutine->amoptsprocnum = 0; +#endif + amroutine->amcanorder = false; + amroutine->amcanorderbyop = true; + amroutine->amcanbackward = false; /* can change direction mid-scan */ + amroutine->amcanunique = false; + amroutine->amcanmulticol = false; + amroutine->amoptionalkey = true; + amroutine->amsearcharray = false; + amroutine->amsearchnulls = false; + amroutine->amstorage = false; + amroutine->amclusterable = false; + amroutine->ampredlocks = false; + amroutine->amcanparallel = false; + amroutine->amcaninclude = false; +#if PG_VERSION_NUM >= 130000 + amroutine->amusemaintenanceworkmem = false; /* not used during VACUUM */ + amroutine->amparallelvacuumoptions = VACUUM_OPTION_PARALLEL_BULKDEL; +#endif + amroutine->amkeytype = InvalidOid; + + /* Interface functions */ + amroutine->ambuild = hnswbuild; + amroutine->ambuildempty = hnswbuildempty; + amroutine->aminsert = hnswinsert; + amroutine->ambulkdelete = hnswbulkdelete; + amroutine->amvacuumcleanup = hnswvacuumcleanup; + amroutine->amcanreturn = NULL; + amroutine->amcostestimate = hnswcostestimate; + amroutine->amoptions = hnswoptions; + amroutine->amproperty = NULL; /* TODO AMPROP_DISTANCE_ORDERABLE */ +#if PG_VERSION_NUM >= 120000 + amroutine->ambuildphasename = hnswbuildphasename; +#endif + amroutine->amvalidate = hnswvalidate; +#if PG_VERSION_NUM >= 140000 + amroutine->amadjustmembers = NULL; +#endif + amroutine->ambeginscan = hnswbeginscan; + amroutine->amrescan = hnswrescan; + amroutine->amgettuple = hnswgettuple; + amroutine->amgetbitmap = NULL; + amroutine->amendscan = hnswendscan; + amroutine->ammarkpos = NULL; + amroutine->amrestrpos = NULL; + + /* Interface functions to support parallel index scans */ + amroutine->amestimateparallelscan = NULL; + amroutine->aminitparallelscan = NULL; + amroutine->amparallelrescan = NULL; + + PG_RETURN_POINTER(amroutine); +} diff --git a/src/hnsw.h b/src/hnsw.h new file mode 100644 index 0000000..eb2aa9f --- /dev/null +++ b/src/hnsw.h @@ -0,0 +1,309 @@ +#ifndef HNSW_H +#define HNSW_H + +#include "postgres.h" + +#include "access/generic_xlog.h" +#include "access/reloptions.h" +#include "nodes/execnodes.h" +#include "port.h" /* for random() */ +#include "utils/sampling.h" +#include "vector.h" + +#if PG_VERSION_NUM < 110000 +#error "Requires PostgreSQL 11+" +#endif + +#define HNSW_MAX_DIM 2000 + +/* Support functions */ +#define HNSW_DISTANCE_PROC 1 +#define HNSW_NORM_PROC 2 + +#define HNSW_VERSION 1 +#define HNSW_MAGIC_NUMBER 0xA953A953 +#define HNSW_PAGE_ID 0xFF90 + +/* Preserved page numbers */ +#define HNSW_METAPAGE_BLKNO 0 +#define HNSW_HEAD_BLKNO 1 /* first element page */ + +/* Must correspond to page numbers since page lock is used */ +#define HNSW_UPDATE_LOCK 0 +#define HNSW_SCAN_LOCK 1 + +/* HNSW parameters */ +#define HNSW_DEFAULT_M 16 +#define HNSW_MIN_M 2 +#define HNSW_MAX_M 100 +#define HNSW_DEFAULT_EF_CONSTRUCTION 64 +#define HNSW_MIN_EF_CONSTRUCTION 4 +#define HNSW_MAX_EF_CONSTRUCTION 1000 +#define HNSW_DEFAULT_EF_SEARCH 40 +#define HNSW_MIN_EF_SEARCH 1 +#define HNSW_MAX_EF_SEARCH 1000 + +/* Tuple types */ +#define HNSW_ELEMENT_TUPLE_TYPE 1 +#define HNSW_NEIGHBOR_TUPLE_TYPE 2 + +/* Make graph robust against non-HOT updates */ +#define HNSW_HEAPTIDS 10 + +#define HNSW_UPDATE_ENTRY_GREATER 1 +#define HNSW_UPDATE_ENTRY_ALWAYS 2 + +/* Build phases */ +/* PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE is 1 */ +#define PROGRESS_HNSW_PHASE_LOAD 2 + +#define HNSW_MAX_SIZE (BLCKSZ - MAXALIGN(SizeOfPageHeaderData) - MAXALIGN(sizeof(HnswPageOpaqueData)) - sizeof(ItemIdData)) + +#define HNSW_ELEMENT_TUPLE_SIZE(_dim) MAXALIGN(offsetof(HnswElementTupleData, vec) + VECTOR_SIZE(_dim)) +#define HNSW_NEIGHBOR_TUPLE_SIZE(level, m) MAXALIGN(offsetof(HnswNeighborTupleData, indextids) + ((level) + 2) * (m) * sizeof(ItemPointerData)) + +#define HnswPageGetOpaque(page) ((HnswPageOpaque) PageGetSpecialPointer(page)) +#define HnswPageGetMeta(page) ((HnswMetaPageData *) PageGetContents(page)) + +#if PG_VERSION_NUM >= 150000 +#define RandomDouble() pg_prng_double(&pg_global_prng_state) +#else +#define RandomDouble() (((double) random()) / MAX_RANDOM_VALUE) +#endif + +#if PG_VERSION_NUM < 130000 +#define list_delete_last(list) list_truncate(list, list_length(list) - 1) +#define list_sort(list, cmp) list_qsort(list, cmp) +#endif + +#define HnswIsElementTuple(tup) ((tup)->type == HNSW_ELEMENT_TUPLE_TYPE) +#define HnswIsNeighborTuple(tup) ((tup)->type == HNSW_NEIGHBOR_TUPLE_TYPE) + +/* 2 * M connections for ground layer */ +#define HnswGetLayerM(m, layer) (layer == 0 ? (m) * 2 : (m)) + +/* Optimal ML from paper */ +#define HnswGetMl(m) (1 / log(m)) + +/* Ensure fits on page and in uint8 */ +#define HnswGetMaxLevel(m) Min(((BLCKSZ - MAXALIGN(SizeOfPageHeaderData) - MAXALIGN(sizeof(HnswPageOpaqueData)) - offsetof(HnswNeighborTupleData, indextids) - sizeof(ItemIdData)) / (sizeof(ItemPointerData)) / m) - 2, 255) + +/* Variables */ +extern int hnsw_ef_search; + +typedef struct HnswNeighborArray HnswNeighborArray; + +typedef struct HnswElementData +{ + List *heaptids; + uint8 level; + uint8 deleted; + HnswNeighborArray *neighbors; + BlockNumber blkno; + OffsetNumber offno; + OffsetNumber neighborOffno; + BlockNumber neighborPage; + Vector *vec; +} HnswElementData; + +typedef HnswElementData * HnswElement; + +typedef struct HnswCandidate +{ + HnswElement element; + float distance; + bool closer; +} HnswCandidate; + +typedef struct HnswNeighborArray +{ + int length; + bool closerSet; + HnswCandidate *items; +} HnswNeighborArray; + +typedef struct HnswPairingHeapNode +{ + pairingheap_node ph_node; + HnswCandidate *inner; +} HnswPairingHeapNode; + +/* HNSW index options */ +typedef struct HnswOptions +{ + int32 vl_len_; /* varlena header (do not touch directly!) */ + int m; /* number of connections */ + int efConstruction; /* size of dynamic candidate list */ +} HnswOptions; + +typedef struct HnswBuildState +{ + /* Info */ + Relation heap; + Relation index; + IndexInfo *indexInfo; + ForkNumber forkNum; + + /* Settings */ + int dimensions; + int m; + int efConstruction; + + /* Statistics */ + double indtuples; + double reltuples; + + /* Support functions */ + FmgrInfo *procinfo; + FmgrInfo *normprocinfo; + Oid collation; + + /* Variables */ + List *elements; + HnswElement entryPoint; + double ml; + int maxLevel; + double maxInMemoryElements; + bool flushed; + Vector *normvec; + + /* Memory */ + MemoryContext tmpCtx; +} HnswBuildState; + +typedef struct HnswMetaPageData +{ + uint32 magicNumber; + uint32 version; + uint32 dimensions; + uint16 m; + uint16 efConstruction; + BlockNumber entryBlkno; + OffsetNumber entryOffno; + int16 entryLevel; + BlockNumber insertPage; +} HnswMetaPageData; + +typedef HnswMetaPageData * HnswMetaPage; + +typedef struct HnswPageOpaqueData +{ + BlockNumber nextblkno; + uint16 unused; + uint16 page_id; /* for identification of HNSW indexes */ +} HnswPageOpaqueData; + +typedef HnswPageOpaqueData * HnswPageOpaque; + +typedef struct HnswElementTupleData +{ + uint8 type; + uint8 level; + uint8 deleted; + uint8 unused; + ItemPointerData heaptids[HNSW_HEAPTIDS]; + ItemPointerData neighbortid; + uint16 unused2; + Vector vec; +} HnswElementTupleData; + +typedef HnswElementTupleData * HnswElementTuple; + +typedef struct HnswNeighborTupleData +{ + uint8 type; + uint8 unused; + uint16 count; + ItemPointerData indextids[FLEXIBLE_ARRAY_MEMBER]; +} HnswNeighborTupleData; + +typedef HnswNeighborTupleData * HnswNeighborTuple; + +typedef struct HnswScanOpaqueData +{ + bool first; + List *w; + MemoryContext tmpCtx; + + /* Support functions */ + FmgrInfo *procinfo; + FmgrInfo *normprocinfo; + Oid collation; +} HnswScanOpaqueData; + +typedef HnswScanOpaqueData * HnswScanOpaque; + +typedef struct HnswVacuumState +{ + /* Info */ + Relation index; + IndexBulkDeleteResult *stats; + IndexBulkDeleteCallback callback; + void *callback_state; + + /* Settings */ + int m; + int efConstruction; + + /* Support functions */ + FmgrInfo *procinfo; + Oid collation; + + /* Variables */ + HTAB *deleted; + BufferAccessStrategy bas; + HnswNeighborTuple ntup; + HnswElementData highestPoint; + + /* Memory */ + MemoryContext tmpCtx; +} HnswVacuumState; + +/* Methods */ +int HnswGetM(Relation index); +int HnswGetEfConstruction(Relation index); +FmgrInfo *HnswOptionalProcInfo(Relation index, uint16 procnum); +bool HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result); +void HnswCommitBuffer(Buffer buf, GenericXLogState *state); +Buffer HnswNewBuffer(Relation index, ForkNumber forkNum); +void HnswInitPage(Buffer buf, Page page); +void HnswInitRegisterPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state); +void HnswInit(void); +List *HnswSearchLayer(Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement); +HnswElement HnswGetEntryPoint(Relation index); +void HnswGetMetaPageInfo(Relation index, int *m, HnswElement * entryPoint); +HnswElement HnswInitElement(ItemPointer tid, int m, double ml, int maxLevel); +void HnswFreeElement(HnswElement element); +HnswElement HnswInitElementFromBlock(BlockNumber blkno, OffsetNumber offno); +void HnswInsertElement(HnswElement element, HnswElement entryPoint, Relation index, FmgrInfo *procinfo, Oid collation, int m, int efConstruction, bool existing); +HnswElement HnswFindDuplicate(HnswElement e); +HnswCandidate *HnswEntryCandidate(HnswElement em, Datum q, Relation rel, FmgrInfo *procinfo, Oid collation, bool loadVec); +void HnswUpdateMetaPage(Relation index, int updateEntry, HnswElement entryPoint, BlockNumber insertPage, ForkNumber forkNum); +void HnswSetNeighborTuple(HnswNeighborTuple ntup, HnswElement e, int m); +void HnswAddHeapTid(HnswElement element, ItemPointer heaptid); +void HnswInitNeighbors(HnswElement element, int m); +bool HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heapRel); +void HnswUpdateNeighborPages(Relation index, FmgrInfo *procinfo, Oid collation, HnswElement e, int m, bool checkExisting); +void HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHeaptids, bool loadVec); +void HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec); +void HnswSetElementTuple(HnswElementTuple etup, HnswElement element); +void HnswUpdateConnection(HnswElement element, HnswCandidate * hc, int m, int lc, int *updateIdx, Relation index, FmgrInfo *procinfo, Oid collation); +void HnswLoadNeighbors(HnswElement element, Relation index, int m); + +/* Index access methods */ +IndexBuildResult *hnswbuild(Relation heap, Relation index, IndexInfo *indexInfo); +void hnswbuildempty(Relation index); +bool hnswinsert(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heap, IndexUniqueCheck checkUnique +#if PG_VERSION_NUM >= 140000 + ,bool indexUnchanged +#endif + ,IndexInfo *indexInfo +); +IndexBulkDeleteResult *hnswbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats, IndexBulkDeleteCallback callback, void *callback_state); +IndexBulkDeleteResult *hnswvacuumcleanup(IndexVacuumInfo *info, IndexBulkDeleteResult *stats); +IndexScanDesc hnswbeginscan(Relation index, int nkeys, int norderbys); +void hnswrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int norderbys); +bool hnswgettuple(IndexScanDesc scan, ScanDirection dir); +void hnswendscan(IndexScanDesc scan); + +#endif diff --git a/src/hnswbuild.c b/src/hnswbuild.c new file mode 100644 index 0000000..18959d5 --- /dev/null +++ b/src/hnswbuild.c @@ -0,0 +1,523 @@ +#include "postgres.h" + +#include + +#include "catalog/index.h" +#include "hnsw.h" +#include "miscadmin.h" +#include "lib/pairingheap.h" +#include "nodes/pg_list.h" +#include "storage/bufmgr.h" +#include "utils/memutils.h" + +#if PG_VERSION_NUM >= 140000 +#include "utils/backend_progress.h" +#elif PG_VERSION_NUM >= 120000 +#include "pgstat.h" +#endif + +#if PG_VERSION_NUM >= 120000 +#include "access/tableam.h" +#include "commands/progress.h" +#else +#define PROGRESS_CREATEIDX_TUPLES_DONE 0 +#endif + +#if PG_VERSION_NUM >= 130000 +#define CALLBACK_ITEM_POINTER ItemPointer tid +#else +#define CALLBACK_ITEM_POINTER HeapTuple hup +#endif + +#if PG_VERSION_NUM >= 120000 +#define UpdateProgress(index, val) pgstat_progress_update_param(index, val) +#else +#define UpdateProgress(index, val) ((void)val) +#endif + +/* + * Create the metapage + */ +static void +CreateMetaPage(HnswBuildState * buildstate) +{ + Relation index = buildstate->index; + ForkNumber forkNum = buildstate->forkNum; + Buffer buf; + Page page; + GenericXLogState *state; + HnswMetaPage metap; + + buf = HnswNewBuffer(index, forkNum); + HnswInitRegisterPage(index, &buf, &page, &state); + + /* Set metapage data */ + metap = HnswPageGetMeta(page); + metap->magicNumber = HNSW_MAGIC_NUMBER; + metap->version = HNSW_VERSION; + metap->dimensions = buildstate->dimensions; + metap->m = buildstate->m; + metap->efConstruction = buildstate->efConstruction; + metap->entryBlkno = InvalidBlockNumber; + metap->entryOffno = InvalidOffsetNumber; + metap->entryLevel = -1; + metap->insertPage = InvalidBlockNumber; + ((PageHeader) page)->pd_lower = + ((char *) metap + sizeof(HnswMetaPageData)) - (char *) page; + + HnswCommitBuffer(buf, state); +} + +/* + * Add a new page + */ +static void +HnswBuildAppendPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state, ForkNumber forkNum) +{ + /* Add a new page */ + Buffer newbuf = HnswNewBuffer(index, forkNum); + + /* Update previous page */ + HnswPageGetOpaque(*page)->nextblkno = BufferGetBlockNumber(newbuf); + + /* Commit */ + GenericXLogFinish(*state); + UnlockReleaseBuffer(*buf); + + /* Can take a while, so ensure we can interrupt */ + /* Needs to be called when no buffer locks are held */ + LockBuffer(newbuf, BUFFER_LOCK_UNLOCK); + CHECK_FOR_INTERRUPTS(); + LockBuffer(newbuf, BUFFER_LOCK_EXCLUSIVE); + + /* Prepare new page */ + *buf = newbuf; + *state = GenericXLogStart(index); + *page = GenericXLogRegisterBuffer(*state, *buf, GENERIC_XLOG_FULL_IMAGE); + HnswInitPage(*buf, *page); +} + +/* + * Create element pages + */ +static void +CreateElementPages(HnswBuildState * buildstate) +{ + Relation index = buildstate->index; + ForkNumber forkNum = buildstate->forkNum; + int dimensions = buildstate->dimensions; + Size etupSize; + Size maxSize; + HnswElementTuple etup; + HnswNeighborTuple ntup; + BlockNumber insertPage; + Buffer buf; + Page page; + GenericXLogState *state; + ListCell *lc; + + /* Calculate sizes */ + maxSize = HNSW_MAX_SIZE; + etupSize = HNSW_ELEMENT_TUPLE_SIZE(dimensions); + + /* Allocate once */ + etup = palloc0(etupSize); + ntup = palloc0(BLCKSZ); + + /* Prepare first page */ + buf = HnswNewBuffer(index, forkNum); + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, GENERIC_XLOG_FULL_IMAGE); + HnswInitPage(buf, page); + + foreach(lc, buildstate->elements) + { + HnswElement element = lfirst(lc); + Size ntupSize; + Size combinedSize; + + HnswSetElementTuple(etup, element); + + /* Calculate sizes */ + ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(element->level, buildstate->m); + combinedSize = etupSize + ntupSize + sizeof(ItemIdData); + + /* Keep element and neighbors on the same page if possible */ + if (PageGetFreeSpace(page) < etupSize || (combinedSize <= maxSize && PageGetFreeSpace(page) < combinedSize)) + HnswBuildAppendPage(index, &buf, &page, &state, forkNum); + + /* Calculate offsets */ + element->blkno = BufferGetBlockNumber(buf); + element->offno = OffsetNumberNext(PageGetMaxOffsetNumber(page)); + if (combinedSize <= maxSize) + { + element->neighborPage = element->blkno; + element->neighborOffno = OffsetNumberNext(element->offno); + } + else + { + element->neighborPage = element->blkno + 1; + element->neighborOffno = FirstOffsetNumber; + } + + ItemPointerSet(&etup->neighbortid, element->neighborPage, element->neighborOffno); + + /* Add element */ + if (PageAddItem(page, (Item) etup, etupSize, InvalidOffsetNumber, false, false) != element->offno) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + + /* Add new page if needed */ + if (PageGetFreeSpace(page) < ntupSize) + HnswBuildAppendPage(index, &buf, &page, &state, forkNum); + + /* Add placeholder for neighbors */ + if (PageAddItem(page, (Item) ntup, ntupSize, InvalidOffsetNumber, false, false) != element->neighborOffno) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + } + + insertPage = BufferGetBlockNumber(buf); + + /* Commit */ + GenericXLogFinish(state); + UnlockReleaseBuffer(buf); + + HnswUpdateMetaPage(index, HNSW_UPDATE_ENTRY_ALWAYS, buildstate->entryPoint, insertPage, forkNum); + + pfree(etup); + pfree(ntup); +} + +/* + * Create neighbor pages + */ +static void +CreateNeighborPages(HnswBuildState * buildstate) +{ + Relation index = buildstate->index; + ForkNumber forkNum = buildstate->forkNum; + int m = buildstate->m; + ListCell *lc; + HnswNeighborTuple ntup; + + /* Allocate once */ + ntup = palloc0(BLCKSZ); + + foreach(lc, buildstate->elements) + { + HnswElement e = lfirst(lc); + Buffer buf; + Page page; + GenericXLogState *state; + Size ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(e->level, m); + + /* Can take a while, so ensure we can interrupt */ + /* Needs to be called when no buffer locks are held */ + CHECK_FOR_INTERRUPTS(); + + buf = ReadBufferExtended(index, forkNum, e->neighborPage, RBM_NORMAL, NULL); + LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, 0); + + HnswSetNeighborTuple(ntup, e, m); + + if (!PageIndexTupleOverwrite(page, e->neighborOffno, (Item) ntup, ntupSize)) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + + /* Commit */ + GenericXLogFinish(state); + UnlockReleaseBuffer(buf); + } + + pfree(ntup); +} + +/* + * Free elements + */ +static void +FreeElements(HnswBuildState * buildstate) +{ + ListCell *lc; + + foreach(lc, buildstate->elements) + HnswFreeElement(lfirst(lc)); + + list_free(buildstate->elements); +} + +/* + * Flush pages + */ +static void +FlushPages(HnswBuildState * buildstate) +{ + CreateMetaPage(buildstate); + CreateElementPages(buildstate); + CreateNeighborPages(buildstate); + + buildstate->flushed = true; + FreeElements(buildstate); +} + +/* + * Insert tuple + */ +static bool +InsertTuple(Relation index, Datum *values, HnswElement element, HnswBuildState * buildstate, HnswElement * dup) +{ + FmgrInfo *procinfo = buildstate->procinfo; + Oid collation = buildstate->collation; + HnswElement entryPoint = buildstate->entryPoint; + int efConstruction = buildstate->efConstruction; + int m = buildstate->m; + + /* Detoast once for all calls */ + Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); + + /* Normalize if needed */ + if (buildstate->normprocinfo != NULL) + { + if (!HnswNormValue(buildstate->normprocinfo, collation, &value, buildstate->normvec)) + return false; + } + + /* Copy value to element so accessible outside of memory context */ + memcpy(element->vec, DatumGetVector(value), VECTOR_SIZE(buildstate->dimensions)); + + /* Insert element in graph */ + HnswInsertElement(element, entryPoint, NULL, procinfo, collation, m, efConstruction, false); + + /* Look for duplicate */ + *dup = HnswFindDuplicate(element); + + /* Update neighbors if needed */ + if (*dup == NULL) + { + for (int lc = element->level; lc >= 0; lc--) + { + int lm = HnswGetLayerM(m, lc); + HnswNeighborArray *neighbors = &element->neighbors[lc]; + + for (int i = 0; i < neighbors->length; i++) + HnswUpdateConnection(element, &neighbors->items[i], lm, lc, NULL, NULL, procinfo, collation); + } + } + + /* Update entry point if needed */ + if (*dup == NULL && (entryPoint == NULL || element->level > entryPoint->level)) + buildstate->entryPoint = element; + + UpdateProgress(PROGRESS_CREATEIDX_TUPLES_DONE, ++buildstate->indtuples); + + return *dup == NULL; +} + +/* + * Callback for table_index_build_scan + */ +static void +BuildCallback(Relation index, CALLBACK_ITEM_POINTER, Datum *values, + bool *isnull, bool tupleIsAlive, void *state) +{ + HnswBuildState *buildstate = (HnswBuildState *) state; + MemoryContext oldCtx; + HnswElement element; + HnswElement dup = NULL; + bool inserted; + +#if PG_VERSION_NUM < 130000 + ItemPointer tid = &hup->t_self; +#endif + + /* Skip nulls */ + if (isnull[0]) + return; + + if (buildstate->indtuples >= buildstate->maxInMemoryElements) + { + if (!buildstate->flushed) + { + ereport(NOTICE, + (errmsg("hnsw graph no longer fits into maintenance_work_mem after " INT64_FORMAT " tuples", (int64) buildstate->indtuples), + errdetail("Building will take significantly more time."), + errhint("Increase maintenance_work_mem to speed up builds."))); + + FlushPages(buildstate); + } + + oldCtx = MemoryContextSwitchTo(buildstate->tmpCtx); + + if (HnswInsertTuple(buildstate->index, values, isnull, tid, buildstate->heap)) + UpdateProgress(PROGRESS_CREATEIDX_TUPLES_DONE, ++buildstate->indtuples); + + /* Reset memory context */ + MemoryContextSwitchTo(oldCtx); + MemoryContextReset(buildstate->tmpCtx); + + return; + } + + /* Allocate necessary memory outside of memory context */ + element = HnswInitElement(tid, buildstate->m, buildstate->ml, buildstate->maxLevel); + element->vec = palloc(VECTOR_SIZE(buildstate->dimensions)); + + /* Use memory context since detoast can allocate */ + oldCtx = MemoryContextSwitchTo(buildstate->tmpCtx); + + /* Insert tuple */ + inserted = InsertTuple(index, values, element, buildstate, &dup); + + /* Reset memory context */ + MemoryContextSwitchTo(oldCtx); + MemoryContextReset(buildstate->tmpCtx); + + /* Add outside memory context */ + if (dup != NULL) + HnswAddHeapTid(dup, tid); + + /* Add to buildstate or free */ + if (inserted) + buildstate->elements = lappend(buildstate->elements, element); + else + HnswFreeElement(element); +} + +/* + * Get the max number of elements that fit into maintenance_work_mem + */ +static double +HnswGetMaxInMemoryElements(int m, double ml, int dimensions) +{ + Size elementSize = sizeof(HnswElementData); + double avgLevel = -log(0.5) * ml; + + elementSize += sizeof(HnswNeighborArray) * (avgLevel + 1); + elementSize += sizeof(HnswCandidate) * (m * (avgLevel + 2)); + elementSize += sizeof(ItemPointerData); + elementSize += VECTOR_SIZE(dimensions); + return (maintenance_work_mem * 1024L) / elementSize; +} + +/* + * Initialize the build state + */ +static void +InitBuildState(HnswBuildState * buildstate, Relation heap, Relation index, IndexInfo *indexInfo, ForkNumber forkNum) +{ + buildstate->heap = heap; + buildstate->index = index; + buildstate->indexInfo = indexInfo; + buildstate->forkNum = forkNum; + + buildstate->m = HnswGetM(index); + buildstate->efConstruction = HnswGetEfConstruction(index); + buildstate->dimensions = TupleDescAttr(index->rd_att, 0)->atttypmod; + + /* Require column to have dimensions to be indexed */ + if (buildstate->dimensions < 0) + elog(ERROR, "column does not have dimensions"); + + if (buildstate->dimensions > HNSW_MAX_DIM) + elog(ERROR, "column cannot have more than %d dimensions for hnsw index", HNSW_MAX_DIM); + + if (buildstate->efConstruction < 2 * buildstate->m) + elog(ERROR, "ef_construction must be greater than or equal to 2 * m"); + + buildstate->reltuples = 0; + buildstate->indtuples = 0; + + /* Get support functions */ + buildstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); + buildstate->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); + buildstate->collation = index->rd_indcollation[0]; + + buildstate->elements = NIL; + buildstate->entryPoint = NULL; + buildstate->ml = HnswGetMl(buildstate->m); + buildstate->maxLevel = HnswGetMaxLevel(buildstate->m); + buildstate->maxInMemoryElements = HnswGetMaxInMemoryElements(buildstate->m, buildstate->ml, buildstate->dimensions); + buildstate->flushed = false; + + /* Reuse for each tuple */ + buildstate->normvec = InitVector(buildstate->dimensions); + + buildstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext, + "Hnsw build temporary context", + ALLOCSET_DEFAULT_SIZES); +} + +/* + * Free resources + */ +static void +FreeBuildState(HnswBuildState * buildstate) +{ + pfree(buildstate->normvec); + MemoryContextDelete(buildstate->tmpCtx); +} + +/* + * Build graph + */ +static void +BuildGraph(HnswBuildState * buildstate, ForkNumber forkNum) +{ + UpdateProgress(PROGRESS_CREATEIDX_SUBPHASE, PROGRESS_HNSW_PHASE_LOAD); + +#if PG_VERSION_NUM >= 120000 + buildstate->reltuples = table_index_build_scan(buildstate->heap, buildstate->index, buildstate->indexInfo, + true, true, BuildCallback, (void *) buildstate, NULL); +#else + buildstate->reltuples = IndexBuildHeapScan(buildstate->heap, buildstate->index, buildstate->indexInfo, + true, BuildCallback, (void *) buildstate, NULL); +#endif +} + +/* + * Build the index + */ +static void +BuildIndex(Relation heap, Relation index, IndexInfo *indexInfo, + HnswBuildState * buildstate, ForkNumber forkNum) +{ + InitBuildState(buildstate, heap, index, indexInfo, forkNum); + + if (buildstate->heap != NULL) + BuildGraph(buildstate, forkNum); + + if (!buildstate->flushed) + FlushPages(buildstate); + + FreeBuildState(buildstate); +} + +/* + * Build the index for a logged table + */ +IndexBuildResult * +hnswbuild(Relation heap, Relation index, IndexInfo *indexInfo) +{ + IndexBuildResult *result; + HnswBuildState buildstate; + + BuildIndex(heap, index, indexInfo, &buildstate, MAIN_FORKNUM); + + result = (IndexBuildResult *) palloc(sizeof(IndexBuildResult)); + result->heap_tuples = buildstate.reltuples; + result->index_tuples = buildstate.indtuples; + + return result; +} + +/* + * Build the index for an unlogged table + */ +void +hnswbuildempty(Relation index) +{ + IndexInfo *indexInfo = BuildIndexInfo(index); + HnswBuildState buildstate; + + BuildIndex(NULL, index, indexInfo, &buildstate, INIT_FORKNUM); +} diff --git a/src/hnswinsert.c b/src/hnswinsert.c new file mode 100644 index 0000000..f7cd51f --- /dev/null +++ b/src/hnswinsert.c @@ -0,0 +1,582 @@ +#include "postgres.h" + +#include + +#include "hnsw.h" +#include "storage/bufmgr.h" +#include "storage/lmgr.h" +#include "utils/memutils.h" + +/* + * Get the insert page + */ +static BlockNumber +GetInsertPage(Relation index) +{ + Buffer buf; + Page page; + HnswMetaPage metap; + BlockNumber insertPage; + + buf = ReadBuffer(index, HNSW_METAPAGE_BLKNO); + LockBuffer(buf, BUFFER_LOCK_SHARE); + page = BufferGetPage(buf); + metap = HnswPageGetMeta(page); + + insertPage = metap->insertPage; + + UnlockReleaseBuffer(buf); + + return insertPage; +} + +/* + * Check for a free offset + */ +static bool +HnswFreeOffset(Relation index, Buffer buf, Page page, HnswElement element, Size ntupSize, Buffer *nbuf, Page *npage, OffsetNumber *freeOffno, OffsetNumber *freeNeighborOffno, BlockNumber *newInsertPage) +{ + OffsetNumber offno; + OffsetNumber maxoffno = PageGetMaxOffsetNumber(page); + + for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) + { + HnswElementTuple etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, offno)); + + /* Skip neighbor tuples */ + if (!HnswIsElementTuple(etup)) + continue; + + if (etup->deleted) + { + BlockNumber elementPage = BufferGetBlockNumber(buf); + BlockNumber neighborPage = ItemPointerGetBlockNumber(&etup->neighbortid); + OffsetNumber neighborOffno = ItemPointerGetOffsetNumber(&etup->neighbortid); + ItemId itemid; + + if (!BlockNumberIsValid(*newInsertPage)) + *newInsertPage = elementPage; + + if (neighborPage == elementPage) + { + *nbuf = buf; + *npage = page; + } + else + { + *nbuf = ReadBuffer(index, neighborPage); + LockBuffer(*nbuf, BUFFER_LOCK_EXCLUSIVE); + + /* Skip WAL for now */ + *npage = BufferGetPage(*nbuf); + } + + itemid = PageGetItemId(*npage, neighborOffno); + + /* Check for space on neighbor tuple page */ + if (PageGetFreeSpace(*npage) + ItemIdGetLength(itemid) - sizeof(ItemIdData) >= ntupSize) + { + *freeOffno = offno; + *freeNeighborOffno = neighborOffno; + return true; + } + else if (*nbuf != buf) + UnlockReleaseBuffer(*nbuf); + } + } + + return false; +} + +/* + * Add a new page + */ +static void +HnswInsertAppendPage(Relation index, Buffer *nbuf, Page *npage, GenericXLogState *state, Page page) +{ + /* Add a new page */ + LockRelationForExtension(index, ExclusiveLock); + *nbuf = HnswNewBuffer(index, MAIN_FORKNUM); + UnlockRelationForExtension(index, ExclusiveLock); + + /* Init new page */ + *npage = GenericXLogRegisterBuffer(state, *nbuf, GENERIC_XLOG_FULL_IMAGE); + HnswInitPage(*nbuf, *npage); + + /* Update previous buffer */ + HnswPageGetOpaque(page)->nextblkno = BufferGetBlockNumber(*nbuf); +} + +/* + * Add to element and neighbor pages + */ +static void +WriteNewElementPages(Relation index, HnswElement e, int m, BlockNumber insertPage, BlockNumber *updatedInsertPage) +{ + Buffer buf; + Page page; + GenericXLogState *state; + Size etupSize; + Size ntupSize; + Size combinedSize; + Size maxSize; + Size minCombinedSize; + HnswElementTuple etup; + BlockNumber currentPage = insertPage; + int dimensions = e->vec->dim; + HnswNeighborTuple ntup; + Buffer nbuf; + Page npage; + OffsetNumber freeOffno = InvalidOffsetNumber; + OffsetNumber freeNeighborOffno = InvalidOffsetNumber; + BlockNumber newInsertPage = InvalidBlockNumber; + + /* Calculate sizes */ + etupSize = HNSW_ELEMENT_TUPLE_SIZE(dimensions); + ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(e->level, m); + combinedSize = etupSize + ntupSize + sizeof(ItemIdData); + maxSize = HNSW_MAX_SIZE; + minCombinedSize = etupSize + HNSW_NEIGHBOR_TUPLE_SIZE(0, m) + sizeof(ItemIdData); + + /* Prepare element tuple */ + etup = palloc0(etupSize); + HnswSetElementTuple(etup, e); + + /* Prepare neighbor tuple */ + ntup = palloc0(ntupSize); + HnswSetNeighborTuple(ntup, e, m); + + /* Find a page (or two if needed) to insert the tuples */ + for (;;) + { + buf = ReadBuffer(index, currentPage); + LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); + + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, 0); + + /* Keep track of first page where element at level 0 can fit */ + if (!BlockNumberIsValid(newInsertPage) && PageGetFreeSpace(page) >= minCombinedSize) + newInsertPage = currentPage; + + /* First, try the fastest path */ + /* Space for both tuples on the current page */ + /* This can split existing tuples in rare cases */ + if (PageGetFreeSpace(page) >= combinedSize) + { + nbuf = buf; + npage = page; + break; + } + + /* Next, try space from a deleted element */ + if (HnswFreeOffset(index, buf, page, e, ntupSize, &nbuf, &npage, &freeOffno, &freeNeighborOffno, &newInsertPage)) + { + if (nbuf != buf) + npage = GenericXLogRegisterBuffer(state, nbuf, 0); + + break; + } + + /* Finally, try space for element only if last page */ + /* Skip if both tuples can fit on the same page */ + if (combinedSize > maxSize && PageGetFreeSpace(page) >= etupSize && !BlockNumberIsValid(HnswPageGetOpaque(page)->nextblkno)) + { + HnswInsertAppendPage(index, &nbuf, &npage, state, page); + break; + } + + currentPage = HnswPageGetOpaque(page)->nextblkno; + + if (BlockNumberIsValid(currentPage)) + { + /* Move to next page */ + GenericXLogAbort(state); + UnlockReleaseBuffer(buf); + } + else + { + Buffer newbuf; + Page newpage; + + HnswInsertAppendPage(index, &newbuf, &newpage, state, page); + + /* Commit */ + GenericXLogFinish(state); + + /* Unlock previous buffer */ + UnlockReleaseBuffer(buf); + + /* Prepare new buffer */ + state = GenericXLogStart(index); + buf = newbuf; + page = GenericXLogRegisterBuffer(state, buf, 0); + + /* Create new page for neighbors if needed */ + if (PageGetFreeSpace(page) < combinedSize) + HnswInsertAppendPage(index, &nbuf, &npage, state, page); + else + { + nbuf = buf; + npage = page; + } + + break; + } + } + + e->blkno = BufferGetBlockNumber(buf); + e->neighborPage = BufferGetBlockNumber(nbuf); + + /* Added tuple to new page if newInsertPage is not set */ + /* So can set to neighbor page instead of element page */ + if (!BlockNumberIsValid(newInsertPage)) + newInsertPage = e->neighborPage; + + if (OffsetNumberIsValid(freeOffno)) + { + e->offno = freeOffno; + e->neighborOffno = freeNeighborOffno; + } + else + { + e->offno = OffsetNumberNext(PageGetMaxOffsetNumber(page)); + if (nbuf == buf) + e->neighborOffno = OffsetNumberNext(e->offno); + else + e->neighborOffno = FirstOffsetNumber; + } + + ItemPointerSet(&etup->neighbortid, e->neighborPage, e->neighborOffno); + + /* Add element and neighbors */ + if (OffsetNumberIsValid(freeOffno)) + { + if (!PageIndexTupleOverwrite(page, e->offno, (Item) etup, etupSize)) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + + if (!PageIndexTupleOverwrite(npage, e->neighborOffno, (Item) ntup, ntupSize)) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + } + else + { + if (PageAddItem(page, (Item) etup, etupSize, InvalidOffsetNumber, false, false) != e->offno) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + + if (PageAddItem(npage, (Item) ntup, ntupSize, InvalidOffsetNumber, false, false) != e->neighborOffno) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + } + + /* Commit */ + GenericXLogFinish(state); + UnlockReleaseBuffer(buf); + if (nbuf != buf) + UnlockReleaseBuffer(nbuf); + + /* Update the insert page */ + if (BlockNumberIsValid(newInsertPage) && newInsertPage != insertPage) + *updatedInsertPage = newInsertPage; +} + +/* + * Check if connection already exists + */ +static bool +ConnectionExists(HnswElement e, HnswNeighborTuple ntup, int startIdx, int lm) +{ + for (int i = 0; i < lm; i++) + { + ItemPointer indextid = &ntup->indextids[startIdx + i]; + + if (!ItemPointerIsValid(indextid)) + break; + + if (ItemPointerGetBlockNumber(indextid) == e->blkno && ItemPointerGetOffsetNumber(indextid) == e->offno) + return true; + } + + return false; +} + +/* + * Update neighbors + */ +void +HnswUpdateNeighborPages(Relation index, FmgrInfo *procinfo, Oid collation, HnswElement e, int m, bool checkExisting) +{ + for (int lc = e->level; lc >= 0; lc--) + { + int lm = HnswGetLayerM(m, lc); + HnswNeighborArray *neighbors = &e->neighbors[lc]; + + for (int i = 0; i < neighbors->length; i++) + { + HnswCandidate *hc = &neighbors->items[i]; + Buffer buf; + Page page; + GenericXLogState *state; + ItemId itemid; + HnswNeighborTuple ntup; + Size ntupSize; + int idx = -1; + int startIdx; + OffsetNumber offno = hc->element->neighborOffno; + + /* Get latest neighbors since they may have changed */ + /* Do not lock yet since selecting neighbors can take time */ + HnswLoadNeighbors(hc->element, index, m); + + /* + * Could improve performance for vacuuming by checking neighbors + * against list of elements being deleted to find index. It's + * important to exclude already deleted elements for this since + * they can be replaced at any time. + */ + + /* Select neighbors */ + HnswUpdateConnection(e, hc, lm, lc, &idx, index, procinfo, collation); + + /* New element was not selected as a neighbor */ + if (idx == -1) + continue; + + /* Register page */ + buf = ReadBuffer(index, hc->element->neighborPage); + LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, 0); + + /* Get tuple */ + itemid = PageGetItemId(page, offno); + ntup = (HnswNeighborTuple) PageGetItem(page, itemid); + ntupSize = ItemIdGetLength(itemid); + + /* Calculate index for update */ + startIdx = (hc->element->level - lc) * m; + + /* Check for existing connection */ + if (checkExisting && ConnectionExists(e, ntup, startIdx, lm)) + idx = -1; + else if (idx == -2) + { + /* Find free offset if still exists */ + /* TODO Retry updating connections if not */ + for (int j = 0; j < lm; j++) + { + if (!ItemPointerIsValid(&ntup->indextids[startIdx + j])) + { + idx = startIdx + j; + break; + } + } + } + else + idx += startIdx; + + /* Make robust to issues */ + if (idx >= 0 && idx < ntup->count) + { + ItemPointer indextid = &ntup->indextids[idx]; + + /* Update neighbor */ + ItemPointerSet(indextid, e->blkno, e->offno); + + /* Overwrite tuple */ + if (!PageIndexTupleOverwrite(page, offno, (Item) ntup, ntupSize)) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + + /* Commit */ + GenericXLogFinish(state); + } + else + GenericXLogAbort(state); + + UnlockReleaseBuffer(buf); + } + } +} + +/* + * Add a heap TID to an existing element + */ +static bool +HnswAddDuplicate(Relation index, HnswElement element, HnswElement dup) +{ + Buffer buf; + Page page; + GenericXLogState *state; + Size etupSize = HNSW_ELEMENT_TUPLE_SIZE(dup->vec->dim); + HnswElementTuple etup; + int i; + + /* Read page */ + buf = ReadBuffer(index, dup->blkno); + LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, 0); + + /* Find space */ + etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, dup->offno)); + for (i = 0; i < HNSW_HEAPTIDS; i++) + { + if (!ItemPointerIsValid(&etup->heaptids[i])) + break; + } + + /* Either being deleted or we lost our chance to another backend */ + if (i == 0 || i == HNSW_HEAPTIDS) + { + GenericXLogAbort(state); + UnlockReleaseBuffer(buf); + return false; + } + + /* Add heap TID */ + etup->heaptids[i] = *((ItemPointer) linitial(element->heaptids)); + + /* Overwrite tuple */ + if (!PageIndexTupleOverwrite(page, dup->offno, (Item) etup, etupSize)) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + + /* Commit */ + GenericXLogFinish(state); + UnlockReleaseBuffer(buf); + + return true; +} + +/* + * Write changes to disk + */ +static void +WriteElement(Relation index, FmgrInfo *procinfo, Oid collation, HnswElement element, int m, int efConstruction, HnswElement dup, HnswElement entryPoint) +{ + BlockNumber newInsertPage = InvalidBlockNumber; + + /* Try to add to existing page */ + if (dup != NULL) + { + if (HnswAddDuplicate(index, element, dup)) + return; + } + + /* Write element and neighbor tuples */ + WriteNewElementPages(index, element, m, GetInsertPage(index), &newInsertPage); + + /* Update insert page if needed */ + if (BlockNumberIsValid(newInsertPage)) + HnswUpdateMetaPage(index, 0, NULL, newInsertPage, MAIN_FORKNUM); + + /* Update neighbors */ + HnswUpdateNeighborPages(index, procinfo, collation, element, m, false); + + /* Update metapage if needed */ + if (entryPoint == NULL || element->level > entryPoint->level) + HnswUpdateMetaPage(index, HNSW_UPDATE_ENTRY_GREATER, element, InvalidBlockNumber, MAIN_FORKNUM); +} + +/* + * Insert a tuple into the index + */ +bool +HnswInsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heapRel) +{ + Datum value; + FmgrInfo *normprocinfo; + HnswElement entryPoint; + HnswElement element; + int m; + int efConstruction = HnswGetEfConstruction(index); + FmgrInfo *procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); + Oid collation = index->rd_indcollation[0]; + HnswElement dup; + LOCKMODE lockmode = ShareLock; + + /* Detoast once for all calls */ + value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); + + /* Normalize if needed */ + normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); + if (normprocinfo != NULL) + { + if (!HnswNormValue(normprocinfo, collation, &value, NULL)) + return false; + } + + /* + * Get a shared lock. This allows vacuum to ensure no in-flight inserts + * before repairing graph. Use a page lock so it does not interfere with + * buffer lock (or reads when vacuuming). + */ + LockPage(index, HNSW_UPDATE_LOCK, lockmode); + + /* Get m and entry point */ + HnswGetMetaPageInfo(index, &m, &entryPoint); + + /* Create an element */ + element = HnswInitElement(heap_tid, m, HnswGetMl(m), HnswGetMaxLevel(m)); + element->vec = DatumGetVector(value); + + /* Prevent concurrent inserts when likely updating entry point */ + if (entryPoint == NULL || element->level > entryPoint->level) + { + /* Release shared lock */ + UnlockPage(index, HNSW_UPDATE_LOCK, lockmode); + + /* Get exclusive lock */ + lockmode = ExclusiveLock; + LockPage(index, HNSW_UPDATE_LOCK, lockmode); + + /* Get latest entry point after lock is acquired */ + entryPoint = HnswGetEntryPoint(index); + } + + /* Insert element in graph */ + HnswInsertElement(element, entryPoint, index, procinfo, collation, m, efConstruction, false); + + /* Look for duplicate */ + dup = HnswFindDuplicate(element); + + /* Write to disk */ + WriteElement(index, procinfo, collation, element, m, efConstruction, dup, entryPoint); + + /* Release lock */ + UnlockPage(index, HNSW_UPDATE_LOCK, lockmode); + + return true; +} + +/* + * Insert a tuple into the index + */ +bool +hnswinsert(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, + Relation heap, IndexUniqueCheck checkUnique +#if PG_VERSION_NUM >= 140000 + ,bool indexUnchanged +#endif + ,IndexInfo *indexInfo +) +{ + MemoryContext oldCtx; + MemoryContext insertCtx; + + /* Skip nulls */ + if (isnull[0]) + return false; + + /* Create memory context */ + insertCtx = AllocSetContextCreate(CurrentMemoryContext, + "Hnsw insert temporary context", + ALLOCSET_DEFAULT_SIZES); + oldCtx = MemoryContextSwitchTo(insertCtx); + + /* Insert tuple */ + HnswInsertTuple(index, values, isnull, heap_tid, heap); + + /* Delete memory context */ + MemoryContextSwitchTo(oldCtx); + MemoryContextDelete(insertCtx); + + return false; +} diff --git a/src/hnswscan.c b/src/hnswscan.c new file mode 100644 index 0000000..7cf2bf0 --- /dev/null +++ b/src/hnswscan.c @@ -0,0 +1,229 @@ +#include "postgres.h" + +#include "access/relscan.h" +#include "hnsw.h" +#include "pgstat.h" +#include "storage/bufmgr.h" +#include "storage/lmgr.h" +#include "utils/memutils.h" + +/* + * Algorithm 5 from paper + */ +static List * +GetScanItems(IndexScanDesc scan, Datum q) +{ + HnswScanOpaque so = (HnswScanOpaque) scan->opaque; + Relation index = scan->indexRelation; + FmgrInfo *procinfo = so->procinfo; + Oid collation = so->collation; + List *ep; + List *w; + int m; + HnswElement entryPoint; + + /* Get m and entry point */ + HnswGetMetaPageInfo(index, &m, &entryPoint); + + if (entryPoint == NULL) + return NIL; + + ep = list_make1(HnswEntryCandidate(entryPoint, q, index, procinfo, collation, false)); + + for (int lc = entryPoint->level; lc >= 1; lc--) + { + w = HnswSearchLayer(q, ep, 1, lc, index, procinfo, collation, m, false, NULL); + ep = w; + } + + return HnswSearchLayer(q, ep, hnsw_ef_search, 0, index, procinfo, collation, m, false, NULL); +} + +/* + * Get dimensions from metapage + */ +static int +GetDimensions(Relation index) +{ + Buffer buf; + Page page; + HnswMetaPage metap; + int dimensions; + + buf = ReadBuffer(index, HNSW_METAPAGE_BLKNO); + LockBuffer(buf, BUFFER_LOCK_SHARE); + page = BufferGetPage(buf); + metap = HnswPageGetMeta(page); + + dimensions = metap->dimensions; + + UnlockReleaseBuffer(buf); + + return dimensions; +} + +/* + * Get scan value + */ +static Datum +GetScanValue(IndexScanDesc scan) +{ + HnswScanOpaque so = (HnswScanOpaque) scan->opaque; + Datum value; + + if (scan->orderByData->sk_flags & SK_ISNULL) + value = PointerGetDatum(InitVector(GetDimensions(scan->indexRelation))); + else + { + value = scan->orderByData->sk_argument; + + /* Value should not be compressed or toasted */ + Assert(!VARATT_IS_COMPRESSED(DatumGetPointer(value))); + Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value))); + + /* Fine if normalization fails */ + if (so->normprocinfo != NULL) + HnswNormValue(so->normprocinfo, so->collation, &value, NULL); + } + + return value; +} + +/* + * Prepare for an index scan + */ +IndexScanDesc +hnswbeginscan(Relation index, int nkeys, int norderbys) +{ + IndexScanDesc scan; + HnswScanOpaque so; + + scan = RelationGetIndexScan(index, nkeys, norderbys); + + so = (HnswScanOpaque) palloc(sizeof(HnswScanOpaqueData)); + so->first = true; + so->tmpCtx = AllocSetContextCreate(CurrentMemoryContext, + "Hnsw scan temporary context", + ALLOCSET_DEFAULT_SIZES); + + /* Set support functions */ + so->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); + so->normprocinfo = HnswOptionalProcInfo(index, HNSW_NORM_PROC); + so->collation = index->rd_indcollation[0]; + + scan->opaque = so; + + return scan; +} + +/* + * Start or restart an index scan + */ +void +hnswrescan(IndexScanDesc scan, ScanKey keys, int nkeys, ScanKey orderbys, int norderbys) +{ + HnswScanOpaque so = (HnswScanOpaque) scan->opaque; + + so->first = true; + MemoryContextReset(so->tmpCtx); + + if (keys && scan->numberOfKeys > 0) + memmove(scan->keyData, keys, scan->numberOfKeys * sizeof(ScanKeyData)); + + if (orderbys && scan->numberOfOrderBys > 0) + memmove(scan->orderByData, orderbys, scan->numberOfOrderBys * sizeof(ScanKeyData)); +} + +/* + * Fetch the next tuple in the given scan + */ +bool +hnswgettuple(IndexScanDesc scan, ScanDirection dir) +{ + HnswScanOpaque so = (HnswScanOpaque) scan->opaque; + MemoryContext oldCtx = MemoryContextSwitchTo(so->tmpCtx); + + /* + * Index can be used to scan backward, but Postgres doesn't support + * backward scan on operators + */ + Assert(ScanDirectionIsForward(dir)); + + if (so->first) + { + Datum value; + + /* Count index scan for stats */ + pgstat_count_index_scan(scan->indexRelation); + + /* Safety check */ + if (scan->orderByData == NULL) + elog(ERROR, "cannot scan hnsw index without order"); + + /* Requires MVCC-compliant snapshot as not able to maintain a pin */ + /* https://www.postgresql.org/docs/current/index-locking.html */ + if (!IsMVCCSnapshot(scan->xs_snapshot)) + elog(ERROR, "non-MVCC snapshots are not supported with hnsw"); + + /* Get scan value */ + value = GetScanValue(scan); + + /* + * Get a shared lock. This allows vacuum to ensure no in-flight scans + * before marking tuples as deleted. + */ + LockPage(scan->indexRelation, HNSW_SCAN_LOCK, ShareLock); + + so->w = GetScanItems(scan, value); + + /* Release shared lock */ + UnlockPage(scan->indexRelation, HNSW_SCAN_LOCK, ShareLock); + + so->first = false; + } + + while (list_length(so->w) > 0) + { + HnswCandidate *hc = llast(so->w); + ItemPointer heaptid; + + /* Move to next element if no valid heap TIDs */ + if (list_length(hc->element->heaptids) == 0) + { + so->w = list_delete_last(so->w); + continue; + } + + heaptid = llast(hc->element->heaptids); + + hc->element->heaptids = list_delete_last(hc->element->heaptids); + + MemoryContextSwitchTo(oldCtx); + +#if PG_VERSION_NUM >= 120000 + scan->xs_heaptid = *heaptid; +#else + scan->xs_ctup.t_self = *heaptid; +#endif + + scan->xs_recheckorderby = false; + return true; + } + + MemoryContextSwitchTo(oldCtx); + return false; +} + +/* + * End a scan and release resources + */ +void +hnswendscan(IndexScanDesc scan) +{ + HnswScanOpaque so = (HnswScanOpaque) scan->opaque; + + MemoryContextDelete(so->tmpCtx); + + pfree(so); + scan->opaque = NULL; +} diff --git a/src/hnswutils.c b/src/hnswutils.c new file mode 100644 index 0000000..e7d1705 --- /dev/null +++ b/src/hnswutils.c @@ -0,0 +1,1072 @@ +#include "postgres.h" + +#include + +#include "hnsw.h" +#include "storage/bufmgr.h" +#include "vector.h" + +/* + * Get the max number of connections in an upper layer for each element in the index + */ +int +HnswGetM(Relation index) +{ + HnswOptions *opts = (HnswOptions *) index->rd_options; + + if (opts) + return opts->m; + + return HNSW_DEFAULT_M; +} + +/* + * Get the size of the dynamic candidate list in the index + */ +int +HnswGetEfConstruction(Relation index) +{ + HnswOptions *opts = (HnswOptions *) index->rd_options; + + if (opts) + return opts->efConstruction; + + return HNSW_DEFAULT_EF_CONSTRUCTION; +} + +/* + * Get proc + */ +FmgrInfo * +HnswOptionalProcInfo(Relation index, uint16 procnum) +{ + if (!OidIsValid(index_getprocid(index, 1, procnum))) + return NULL; + + return index_getprocinfo(index, 1, procnum); +} + +/* + * Divide by the norm + * + * Returns false if value should not be indexed + * + * The caller needs to free the pointer stored in value + * if it's different than the original value + */ +bool +HnswNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result) +{ + double norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value)); + + if (norm > 0) + { + Vector *v = DatumGetVector(*value); + + if (result == NULL) + result = InitVector(v->dim); + + for (int i = 0; i < v->dim; i++) + result->x[i] = v->x[i] / norm; + + *value = PointerGetDatum(result); + + return true; + } + + return false; +} + +/* + * New buffer + */ +Buffer +HnswNewBuffer(Relation index, ForkNumber forkNum) +{ + Buffer buf = ReadBufferExtended(index, forkNum, P_NEW, RBM_NORMAL, NULL); + + LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); + return buf; +} + +/* + * Init page + */ +void +HnswInitPage(Buffer buf, Page page) +{ + PageInit(page, BufferGetPageSize(buf), sizeof(HnswPageOpaqueData)); + HnswPageGetOpaque(page)->nextblkno = InvalidBlockNumber; + HnswPageGetOpaque(page)->page_id = HNSW_PAGE_ID; +} + +/* + * Init and register page + */ +void +HnswInitRegisterPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state) +{ + *state = GenericXLogStart(index); + *page = GenericXLogRegisterBuffer(*state, *buf, GENERIC_XLOG_FULL_IMAGE); + HnswInitPage(*buf, *page); +} + +/* + * Commit buffer + */ +void +HnswCommitBuffer(Buffer buf, GenericXLogState *state) +{ + GenericXLogFinish(state); + UnlockReleaseBuffer(buf); +} + +/* + * Allocate neighbors + */ +void +HnswInitNeighbors(HnswElement element, int m) +{ + int level = element->level; + + element->neighbors = palloc(sizeof(HnswNeighborArray) * (level + 1)); + + for (int lc = 0; lc <= level; lc++) + { + HnswNeighborArray *a; + int lm = HnswGetLayerM(m, lc); + + a = &element->neighbors[lc]; + a->length = 0; + a->items = palloc(sizeof(HnswCandidate) * lm); + a->closerSet = false; + } +} + +/* + * Free neighbors + */ +static void +HnswFreeNeighbors(HnswElement element) +{ + for (int lc = 0; lc <= element->level; lc++) + pfree(element->neighbors[lc].items); + pfree(element->neighbors); +} + +/* + * Allocate an element + */ +HnswElement +HnswInitElement(ItemPointer heaptid, int m, double ml, int maxLevel) +{ + HnswElement element = palloc(sizeof(HnswElementData)); + + int level = (int) (-log(RandomDouble()) * ml); + + /* Cap level */ + if (level > maxLevel) + level = maxLevel; + + element->heaptids = NIL; + HnswAddHeapTid(element, heaptid); + + element->level = level; + element->deleted = 0; + + HnswInitNeighbors(element, m); + + return element; +} + +/* + * Free an element + */ +void +HnswFreeElement(HnswElement element) +{ + HnswFreeNeighbors(element); + list_free_deep(element->heaptids); + pfree(element->vec); + pfree(element); +} + +/* + * Add a heap TID to an element + */ +void +HnswAddHeapTid(HnswElement element, ItemPointer heaptid) +{ + ItemPointer copy = palloc(sizeof(ItemPointerData)); + + ItemPointerCopy(heaptid, copy); + element->heaptids = lappend(element->heaptids, copy); +} + +/* + * Allocate an element from block and offset numbers + */ +HnswElement +HnswInitElementFromBlock(BlockNumber blkno, OffsetNumber offno) +{ + HnswElement element = palloc(sizeof(HnswElementData)); + + element->blkno = blkno; + element->offno = offno; + element->neighbors = NULL; + element->vec = NULL; + return element; +} + +/* + * Get the metapage info + */ +void +HnswGetMetaPageInfo(Relation index, int *m, HnswElement * entryPoint) +{ + Buffer buf; + Page page; + HnswMetaPage metap; + + buf = ReadBuffer(index, HNSW_METAPAGE_BLKNO); + LockBuffer(buf, BUFFER_LOCK_SHARE); + page = BufferGetPage(buf); + metap = HnswPageGetMeta(page); + + if (m != NULL) + *m = metap->m; + + if (entryPoint != NULL) + { + if (BlockNumberIsValid(metap->entryBlkno)) + *entryPoint = HnswInitElementFromBlock(metap->entryBlkno, metap->entryOffno); + else + *entryPoint = NULL; + } + + UnlockReleaseBuffer(buf); +} + +/* + * Get the entry point + */ +HnswElement +HnswGetEntryPoint(Relation index) +{ + HnswElement entryPoint; + + HnswGetMetaPageInfo(index, NULL, &entryPoint); + + return entryPoint; +} + +/* + * Update the metapage info + */ +static void +HnswUpdateMetaPageInfo(Page page, int updateEntry, HnswElement entryPoint, BlockNumber insertPage) +{ + HnswMetaPage metap = HnswPageGetMeta(page); + + if (updateEntry) + { + if (entryPoint == NULL) + { + metap->entryBlkno = InvalidBlockNumber; + metap->entryOffno = InvalidOffsetNumber; + metap->entryLevel = -1; + } + else if (entryPoint->level > metap->entryLevel || updateEntry == HNSW_UPDATE_ENTRY_ALWAYS) + { + metap->entryBlkno = entryPoint->blkno; + metap->entryOffno = entryPoint->offno; + metap->entryLevel = entryPoint->level; + } + } + + if (BlockNumberIsValid(insertPage)) + metap->insertPage = insertPage; +} + +/* + * Update the metapage + */ +void +HnswUpdateMetaPage(Relation index, int updateEntry, HnswElement entryPoint, BlockNumber insertPage, ForkNumber forkNum) +{ + Buffer buf; + Page page; + GenericXLogState *state; + + buf = ReadBufferExtended(index, forkNum, HNSW_METAPAGE_BLKNO, RBM_NORMAL, NULL); + LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, 0); + + HnswUpdateMetaPageInfo(page, updateEntry, entryPoint, insertPage); + + HnswCommitBuffer(buf, state); +} + +/* + * Set element tuple, except for neighbor info + */ +void +HnswSetElementTuple(HnswElementTuple etup, HnswElement element) +{ + etup->type = HNSW_ELEMENT_TUPLE_TYPE; + etup->level = element->level; + etup->deleted = 0; + for (int i = 0; i < HNSW_HEAPTIDS; i++) + { + if (i < list_length(element->heaptids)) + etup->heaptids[i] = *((ItemPointer) list_nth(element->heaptids, i)); + else + ItemPointerSetInvalid(&etup->heaptids[i]); + } + memcpy(&etup->vec, element->vec, VECTOR_SIZE(element->vec->dim)); +} + +/* + * Set neighbor tuple + */ +void +HnswSetNeighborTuple(HnswNeighborTuple ntup, HnswElement e, int m) +{ + int idx = 0; + + ntup->type = HNSW_NEIGHBOR_TUPLE_TYPE; + + for (int lc = e->level; lc >= 0; lc--) + { + HnswNeighborArray *neighbors = &e->neighbors[lc]; + int lm = HnswGetLayerM(m, lc); + + for (int i = 0; i < lm; i++) + { + ItemPointer indextid = &ntup->indextids[idx++]; + + if (i < neighbors->length) + { + HnswCandidate *hc = &neighbors->items[i]; + + ItemPointerSet(indextid, hc->element->blkno, hc->element->offno); + } + else + ItemPointerSetInvalid(indextid); + } + } + + ntup->count = idx; +} + +/* + * Load neighbors from page + */ +static void +LoadNeighborsFromPage(HnswElement element, Relation index, Page page, int m) +{ + HnswNeighborTuple ntup = (HnswNeighborTuple) PageGetItem(page, PageGetItemId(page, element->neighborOffno)); + int neighborCount = (element->level + 2) * m; + + Assert(HnswIsNeighborTuple(ntup)); + + HnswInitNeighbors(element, m); + + /* Ensure expected neighbors */ + if (ntup->count != neighborCount) + return; + + for (int i = 0; i < neighborCount; i++) + { + HnswElement e; + int level; + HnswCandidate *hc; + ItemPointer indextid; + HnswNeighborArray *neighbors; + + indextid = &ntup->indextids[i]; + + if (!ItemPointerIsValid(indextid)) + continue; + + e = HnswInitElementFromBlock(ItemPointerGetBlockNumber(indextid), ItemPointerGetOffsetNumber(indextid)); + + /* Calculate level based on offset */ + level = element->level - i / m; + if (level < 0) + level = 0; + + neighbors = &element->neighbors[level]; + hc = &neighbors->items[neighbors->length++]; + hc->element = e; + } +} + +/* + * Load neighbors + */ +void +HnswLoadNeighbors(HnswElement element, Relation index, int m) +{ + Buffer buf; + Page page; + + buf = ReadBuffer(index, element->neighborPage); + LockBuffer(buf, BUFFER_LOCK_SHARE); + page = BufferGetPage(buf); + + LoadNeighborsFromPage(element, index, page, m); + + UnlockReleaseBuffer(buf); +} + +/* + * Load an element from a tuple + */ +void +HnswLoadElementFromTuple(HnswElement element, HnswElementTuple etup, bool loadHeaptids, bool loadVec) +{ + element->level = etup->level; + element->deleted = etup->deleted; + element->neighborPage = ItemPointerGetBlockNumber(&etup->neighbortid); + element->neighborOffno = ItemPointerGetOffsetNumber(&etup->neighbortid); + element->heaptids = NIL; + + if (loadHeaptids) + { + for (int i = 0; i < HNSW_HEAPTIDS; i++) + { + /* Can stop at first invalid */ + if (!ItemPointerIsValid(&etup->heaptids[i])) + break; + + HnswAddHeapTid(element, &etup->heaptids[i]); + } + } + + if (loadVec) + { + element->vec = palloc(VECTOR_SIZE(etup->vec.dim)); + memcpy(element->vec, &etup->vec, VECTOR_SIZE(etup->vec.dim)); + } +} + +/* + * Load an element and optionally get its distance from q + */ +void +HnswLoadElement(HnswElement element, float *distance, Datum *q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec) +{ + Buffer buf; + Page page; + HnswElementTuple etup; + + /* Read vector */ + buf = ReadBuffer(index, element->blkno); + LockBuffer(buf, BUFFER_LOCK_SHARE); + page = BufferGetPage(buf); + + etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, element->offno)); + + Assert(HnswIsElementTuple(etup)); + + /* Load element */ + HnswLoadElementFromTuple(element, etup, true, loadVec); + + /* Calculate distance */ + if (distance != NULL) + *distance = (float) DatumGetFloat8(FunctionCall2Coll(procinfo, collation, *q, PointerGetDatum(&etup->vec))); + + UnlockReleaseBuffer(buf); +} + +/* + * Get the distance for a candidate + */ +static float +GetCandidateDistance(HnswCandidate * hc, Datum q, FmgrInfo *procinfo, Oid collation) +{ + return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, q, PointerGetDatum(hc->element->vec))); +} + +/* + * Create a candidate for the entry point + */ +HnswCandidate * +HnswEntryCandidate(HnswElement entryPoint, Datum q, Relation index, FmgrInfo *procinfo, Oid collation, bool loadVec) +{ + HnswCandidate *hc = palloc(sizeof(HnswCandidate)); + + hc->element = entryPoint; + if (index == NULL) + hc->distance = GetCandidateDistance(hc, q, procinfo, collation); + else + HnswLoadElement(hc->element, &hc->distance, &q, index, procinfo, collation, loadVec); + return hc; +} + +/* + * Compare candidate distances + */ +static int +CompareNearestCandidates(const pairingheap_node *a, const pairingheap_node *b, void *arg) +{ + if (((const HnswPairingHeapNode *) a)->inner->distance < ((const HnswPairingHeapNode *) b)->inner->distance) + return 1; + + if (((const HnswPairingHeapNode *) a)->inner->distance > ((const HnswPairingHeapNode *) b)->inner->distance) + return -1; + + return 0; +} + +/* + * Compare candidate distances + */ +static int +CompareFurthestCandidates(const pairingheap_node *a, const pairingheap_node *b, void *arg) +{ + if (((const HnswPairingHeapNode *) a)->inner->distance < ((const HnswPairingHeapNode *) b)->inner->distance) + return -1; + + if (((const HnswPairingHeapNode *) a)->inner->distance > ((const HnswPairingHeapNode *) b)->inner->distance) + return 1; + + return 0; +} + +/* + * Create a pairing heap node for a candidate + */ +static HnswPairingHeapNode * +CreatePairingHeapNode(HnswCandidate * c) +{ + HnswPairingHeapNode *node = palloc(sizeof(HnswPairingHeapNode)); + + node->inner = c; + return node; +} + +/* + * Add to visited + */ +static inline void +AddToVisited(HTAB *v, HnswCandidate * hc, Relation index, bool *found) +{ + if (index == NULL) + hash_search(v, &hc->element, HASH_ENTER, found); + else + { + ItemPointerData indextid; + + ItemPointerSet(&indextid, hc->element->blkno, hc->element->offno); + hash_search(v, &indextid, HASH_ENTER, found); + } +} + +/* + * Algorithm 2 from paper + */ +List * +HnswSearchLayer(Datum q, List *ep, int ef, int lc, Relation index, FmgrInfo *procinfo, Oid collation, int m, bool inserting, HnswElement skipElement) +{ + ListCell *lc2; + + List *w = NIL; + pairingheap *C = pairingheap_allocate(CompareNearestCandidates, NULL); + pairingheap *W = pairingheap_allocate(CompareFurthestCandidates, NULL); + int wlen = 0; + HASHCTL hash_ctl; + HTAB *v; + + /* Create hash table */ + if (index == NULL) + { + hash_ctl.keysize = sizeof(HnswElement *); + hash_ctl.entrysize = sizeof(HnswElement *); + } + else + { + hash_ctl.keysize = sizeof(ItemPointerData); + hash_ctl.entrysize = sizeof(ItemPointerData); + } + + hash_ctl.hcxt = CurrentMemoryContext; + v = hash_create("hnsw visited", 256, &hash_ctl, HASH_ELEM | HASH_BLOBS | HASH_CONTEXT); + + /* Add entry points to v, C, and W */ + foreach(lc2, ep) + { + HnswCandidate *hc = (HnswCandidate *) lfirst(lc2); + + AddToVisited(v, hc, index, NULL); + + pairingheap_add(C, &(CreatePairingHeapNode(hc)->ph_node)); + pairingheap_add(W, &(CreatePairingHeapNode(hc)->ph_node)); + + /* + * Do not count elements being deleted towards ef when vacuuming. It + * would be ideal to do this for inserts as well, but this could + * affect insert performance. + */ + if (skipElement == NULL || list_length(hc->element->heaptids) != 0) + wlen++; + } + + while (!pairingheap_is_empty(C)) + { + HnswNeighborArray *neighborhood; + HnswCandidate *c = ((HnswPairingHeapNode *) pairingheap_remove_first(C))->inner; + HnswCandidate *f = ((HnswPairingHeapNode *) pairingheap_first(W))->inner; + + if (c->distance > f->distance) + break; + + if (c->element->neighbors == NULL) + HnswLoadNeighbors(c->element, index, m); + + /* Get the neighborhood at layer lc */ + neighborhood = &c->element->neighbors[lc]; + + for (int i = 0; i < neighborhood->length; i++) + { + HnswCandidate *e = &neighborhood->items[i]; + bool visited; + + AddToVisited(v, e, index, &visited); + + if (!visited) + { + float eDistance; + + f = ((HnswPairingHeapNode *) pairingheap_first(W))->inner; + + if (index == NULL) + eDistance = GetCandidateDistance(e, q, procinfo, collation); + else + HnswLoadElement(e->element, &eDistance, &q, index, procinfo, collation, inserting); + + Assert(!e->element->deleted); + + /* Make robust to issues */ + if (e->element->level < lc) + continue; + + if (eDistance < f->distance || wlen < ef) + { + /* Copy e */ + HnswCandidate *ec = palloc(sizeof(HnswCandidate)); + + ec->element = e->element; + ec->distance = eDistance; + + pairingheap_add(C, &(CreatePairingHeapNode(ec)->ph_node)); + pairingheap_add(W, &(CreatePairingHeapNode(ec)->ph_node)); + + /* + * Do not count elements being deleted towards ef when + * vacuuming. It would be ideal to do this for inserts as + * well, but this could affect insert performance. + */ + if (skipElement == NULL || list_length(e->element->heaptids) != 0) + { + wlen++; + + /* No need to decrement wlen */ + if (wlen > ef) + pairingheap_remove_first(W); + } + } + } + } + } + + /* Add each element of W to w */ + while (!pairingheap_is_empty(W)) + { + HnswCandidate *hc = ((HnswPairingHeapNode *) pairingheap_remove_first(W))->inner; + + w = lappend(w, hc); + } + + return w; +} + +/* + * Compare candidate distances + */ +static int +#if PG_VERSION_NUM >= 130000 +CompareCandidateDistances(const ListCell *a, const ListCell *b) +#else +CompareCandidateDistances(const void *a, const void *b) +#endif +{ + HnswCandidate *hca = lfirst((ListCell *) a); + HnswCandidate *hcb = lfirst((ListCell *) b); + + if (hca->distance < hcb->distance) + return 1; + + if (hca->distance > hcb->distance) + return -1; + + if (hca->element < hcb->element) + return 1; + + if (hca->element > hcb->element) + return -1; + + return 0; +} + +/* + * Calculate the distance between elements + */ +static float +HnswGetDistance(HnswElement a, HnswElement b, int lc, FmgrInfo *procinfo, Oid collation) +{ + /* Look for cached distance */ + if (a->neighbors != NULL) + { + Assert(a->level >= lc); + + for (int i = 0; i < a->neighbors[lc].length; i++) + { + if (a->neighbors[lc].items[i].element == b) + return a->neighbors[lc].items[i].distance; + } + } + + if (b->neighbors != NULL) + { + Assert(b->level >= lc); + + for (int i = 0; i < b->neighbors[lc].length; i++) + { + if (b->neighbors[lc].items[i].element == a) + return b->neighbors[lc].items[i].distance; + } + } + + return DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(a->vec), PointerGetDatum(b->vec))); +} + +/* + * Check if an element is closer to q than any element from R + */ +static bool +CheckElementCloser(HnswCandidate * e, List *r, int lc, FmgrInfo *procinfo, Oid collation) +{ + ListCell *lc2; + + foreach(lc2, r) + { + HnswCandidate *ri = lfirst(lc2); + float distance = HnswGetDistance(e->element, ri->element, lc, procinfo, collation); + + if (distance <= e->distance) + return false; + } + + return true; +} + +/* + * Algorithm 4 from paper + */ +static List * +SelectNeighbors(List *c, int m, int lc, FmgrInfo *procinfo, Oid collation, HnswElement e2, HnswCandidate * newCandidate, HnswCandidate * *pruned, bool sortCandidates) +{ + List *r = NIL; + List *w = list_copy(c); + pairingheap *wd; + bool mustCalculate = !e2->neighbors[lc].closerSet; + List *added = NIL; + bool removedAny = false; + + if (list_length(w) <= m) + return w; + + wd = pairingheap_allocate(CompareNearestCandidates, NULL); + + /* Ensure order of candidates is deterministic for closer caching */ + if (sortCandidates) + list_sort(w, CompareCandidateDistances); + + while (list_length(w) > 0 && list_length(r) < m) + { + /* Assumes w is already ordered desc */ + HnswCandidate *e = llast(w); + + w = list_delete_last(w); + + /* Use previous state of r and wd to skip work when possible */ + if (mustCalculate) + e->closer = CheckElementCloser(e, r, lc, procinfo, collation); + else if (list_length(added) > 0) + { + /* + * If the current candidate was closer, we only need to compare it + * with the other candidates that we have added. + */ + if (e->closer) + { + e->closer = CheckElementCloser(e, added, lc, procinfo, collation); + + if (!e->closer) + removedAny = true; + } + else + { + /* + * If we have removed any candidates from closer, a candidate + * that was not closer earlier might now be. + */ + if (removedAny) + { + e->closer = CheckElementCloser(e, r, lc, procinfo, collation); + if (e->closer) + added = lappend(added, e); + } + } + } + else if (e == newCandidate) + { + e->closer = CheckElementCloser(e, r, lc, procinfo, collation); + if (e->closer) + added = lappend(added, e); + } + + if (e->closer) + r = lappend(r, e); + else + pairingheap_add(wd, &(CreatePairingHeapNode(e)->ph_node)); + } + + /* Cached value can only be used in future if sorted deterministically */ + e2->neighbors[lc].closerSet = sortCandidates; + + /* Keep pruned connections */ + while (!pairingheap_is_empty(wd) && list_length(r) < m) + r = lappend(r, ((HnswPairingHeapNode *) pairingheap_remove_first(wd))->inner); + + /* Return pruned for update connections */ + if (pruned != NULL) + { + if (!pairingheap_is_empty(wd)) + *pruned = ((HnswPairingHeapNode *) pairingheap_first(wd))->inner; + else + *pruned = linitial(w); + } + + return r; +} + +/* + * Find duplicate element + */ +HnswElement +HnswFindDuplicate(HnswElement e) +{ + HnswNeighborArray *neighbors = &e->neighbors[0]; + + for (int i = 0; i < neighbors->length; i++) + { + HnswCandidate *neighbor = &neighbors->items[i]; + + /* Exit early since ordered by distance */ + if (vector_cmp_internal(e->vec, neighbor->element->vec) != 0) + break; + + /* Check for space */ + if (list_length(neighbor->element->heaptids) < HNSW_HEAPTIDS) + return neighbor->element; + } + + return NULL; +} + +/* + * Add connections + */ +static void +AddConnections(HnswElement element, List *neighbors, int m, int lc) +{ + ListCell *lc2; + HnswNeighborArray *a = &element->neighbors[lc]; + + foreach(lc2, neighbors) + a->items[a->length++] = *((HnswCandidate *) lfirst(lc2)); +} + +/* + * Update connections + */ +void +HnswUpdateConnection(HnswElement element, HnswCandidate * hc, int m, int lc, int *updateIdx, Relation index, FmgrInfo *procinfo, Oid collation) +{ + HnswNeighborArray *currentNeighbors = &hc->element->neighbors[lc]; + + HnswCandidate hc2; + + hc2.element = element; + hc2.distance = hc->distance; + + if (currentNeighbors->length < m) + { + currentNeighbors->items[currentNeighbors->length++] = hc2; + + /* Track update */ + if (updateIdx != NULL) + *updateIdx = -2; + } + else + { + /* Shrink connections */ + HnswCandidate *pruned = NULL; + + /* Load elements on insert */ + if (index != NULL) + { + Datum q = PointerGetDatum(hc->element->vec); + + for (int i = 0; i < currentNeighbors->length; i++) + { + HnswCandidate *hc3 = ¤tNeighbors->items[i]; + + if (hc3->element->vec == NULL) + HnswLoadElement(hc3->element, &hc3->distance, &q, index, procinfo, collation, true); + else + hc3->distance = GetCandidateDistance(hc3, q, procinfo, collation); + + /* Prune element if being deleted */ + if (list_length(hc3->element->heaptids) == 0) + { + pruned = ¤tNeighbors->items[i]; + break; + } + } + } + + if (pruned == NULL) + { + List *c = NIL; + + /* Add candidates */ + for (int i = 0; i < currentNeighbors->length; i++) + c = lappend(c, ¤tNeighbors->items[i]); + c = lappend(c, &hc2); + + SelectNeighbors(c, m, lc, procinfo, collation, hc->element, &hc2, &pruned, true); + + /* Should not happen */ + if (pruned == NULL) + return; + } + + /* Find and replace the pruned element */ + for (int i = 0; i < currentNeighbors->length; i++) + { + if (currentNeighbors->items[i].element == pruned->element) + { + currentNeighbors->items[i] = hc2; + + /* Track update */ + if (updateIdx != NULL) + *updateIdx = i; + + break; + } + } + } +} + +/* + * Remove elements being deleted or skipped + */ +static List * +RemoveElements(List *w, HnswElement skipElement) +{ + ListCell *lc2; + List *w2 = NIL; + + foreach(lc2, w) + { + HnswCandidate *hc = (HnswCandidate *) lfirst(lc2); + + /* Skip self for vacuuming update */ + if (skipElement != NULL && hc->element->blkno == skipElement->blkno && hc->element->offno == skipElement->offno) + continue; + + if (list_length(hc->element->heaptids) != 0) + w2 = lappend(w2, hc); + } + + return w2; +} + +/* + * Algorithm 1 from paper + */ +void +HnswInsertElement(HnswElement element, HnswElement entryPoint, Relation index, FmgrInfo *procinfo, Oid collation, int m, int efConstruction, bool existing) +{ + List *ep; + List *w; + int level = element->level; + int entryLevel; + Datum q = PointerGetDatum(element->vec); + HnswElement skipElement = existing ? element : NULL; + + /* No neighbors if no entry point */ + if (entryPoint == NULL) + return; + + /* Get entry point and level */ + ep = list_make1(HnswEntryCandidate(entryPoint, q, index, procinfo, collation, true)); + entryLevel = entryPoint->level; + + /* 1st phase: greedy search to insert level */ + for (int lc = entryLevel; lc >= level + 1; lc--) + { + w = HnswSearchLayer(q, ep, 1, lc, index, procinfo, collation, m, true, skipElement); + ep = w; + } + + if (level > entryLevel) + level = entryLevel; + + /* Add one for existing element */ + if (existing) + efConstruction++; + + /* 2nd phase */ + for (int lc = level; lc >= 0; lc--) + { + int lm = HnswGetLayerM(m, lc); + List *neighbors; + List *lw; + + w = HnswSearchLayer(q, ep, efConstruction, lc, index, procinfo, collation, m, true, skipElement); + + /* Elements being deleted or skipped can help with search */ + /* but should be removed before selecting neighbors */ + if (index != NULL) + lw = RemoveElements(w, skipElement); + else + lw = w; + + /* + * Candidates are sorted, but not deterministically. Could set + * sortCandidates to true for in-memory builds to enable closer + * caching, but there does not seem to be a difference in performance. + */ + neighbors = SelectNeighbors(lw, lm, lc, procinfo, collation, element, NULL, NULL, false); + + AddConnections(element, neighbors, lm, lc); + + ep = w; + } +} diff --git a/src/hnswvacuum.c b/src/hnswvacuum.c new file mode 100644 index 0000000..29b675f --- /dev/null +++ b/src/hnswvacuum.c @@ -0,0 +1,660 @@ +#include "postgres.h" + +#include + +#include "commands/vacuum.h" +#include "hnsw.h" +#include "storage/bufmgr.h" +#include "storage/lmgr.h" +#include "utils/memutils.h" + +/* + * Check if deleted list contains an index TID + */ +static bool +DeletedContains(HTAB *deleted, ItemPointer indextid) +{ + bool found; + + hash_search(deleted, indextid, HASH_FIND, &found); + return found; +} + +/* + * Remove deleted heap TIDs + * + * OK to remove for entry point, since always considered for searches and inserts + */ +static void +RemoveHeapTids(HnswVacuumState * vacuumstate) +{ + BlockNumber blkno = HNSW_HEAD_BLKNO; + HnswElement highestPoint = &vacuumstate->highestPoint; + Relation index = vacuumstate->index; + BufferAccessStrategy bas = vacuumstate->bas; + HnswElement entryPoint = HnswGetEntryPoint(vacuumstate->index); + IndexBulkDeleteResult *stats = vacuumstate->stats; + + /* Store separately since highestPoint.level is uint8 */ + int highestLevel = -1; + + /* Initialize highest point */ + highestPoint->blkno = InvalidBlockNumber; + highestPoint->offno = InvalidOffsetNumber; + + while (BlockNumberIsValid(blkno)) + { + Buffer buf; + Page page; + GenericXLogState *state; + OffsetNumber offno; + OffsetNumber maxoffno; + bool updated = false; + + vacuum_delay_point(); + + buf = ReadBufferExtended(index, MAIN_FORKNUM, blkno, RBM_NORMAL, bas); + LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, 0); + maxoffno = PageGetMaxOffsetNumber(page); + + /* Iterate over nodes */ + for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) + { + HnswElementTuple etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, offno)); + int idx = 0; + bool itemUpdated = false; + + /* Skip neighbor tuples */ + if (!HnswIsElementTuple(etup)) + continue; + + if (ItemPointerIsValid(&etup->heaptids[0])) + { + for (int i = 0; i < HNSW_HEAPTIDS; i++) + { + /* Stop at first unused */ + if (!ItemPointerIsValid(&etup->heaptids[i])) + break; + + if (vacuumstate->callback(&etup->heaptids[i], vacuumstate->callback_state)) + { + itemUpdated = true; + stats->tuples_removed++; + } + else + { + /* Move to front of list */ + etup->heaptids[idx++] = etup->heaptids[i]; + stats->num_index_tuples++; + } + } + + if (itemUpdated) + { + Size etupSize = HNSW_ELEMENT_TUPLE_SIZE(etup->vec.dim); + + /* Mark rest as invalid */ + for (int i = idx; i < HNSW_HEAPTIDS; i++) + ItemPointerSetInvalid(&etup->heaptids[i]); + + if (!PageIndexTupleOverwrite(page, offno, (Item) etup, etupSize)) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + + updated = true; + } + } + + if (!ItemPointerIsValid(&etup->heaptids[0])) + { + ItemPointerData ip; + + /* Add to deleted list */ + ItemPointerSet(&ip, blkno, offno); + + (void) hash_search(vacuumstate->deleted, &ip, HASH_ENTER, NULL); + } + else if (etup->level > highestLevel && !(entryPoint != NULL && blkno == entryPoint->blkno && offno == entryPoint->offno)) + { + /* Keep track of highest non-entry point */ + highestPoint->blkno = blkno; + highestPoint->offno = offno; + highestPoint->level = etup->level; + highestLevel = etup->level; + } + } + + blkno = HnswPageGetOpaque(page)->nextblkno; + + if (updated) + GenericXLogFinish(state); + else + GenericXLogAbort(state); + + UnlockReleaseBuffer(buf); + } +} + +/* + * Check for deleted neighbors + */ +static bool +NeedsUpdated(HnswVacuumState * vacuumstate, HnswElement element) +{ + Relation index = vacuumstate->index; + BufferAccessStrategy bas = vacuumstate->bas; + Buffer buf; + Page page; + HnswNeighborTuple ntup; + bool needsUpdated = false; + + buf = ReadBufferExtended(index, MAIN_FORKNUM, element->neighborPage, RBM_NORMAL, bas); + LockBuffer(buf, BUFFER_LOCK_SHARE); + page = BufferGetPage(buf); + ntup = (HnswNeighborTuple) PageGetItem(page, PageGetItemId(page, element->neighborOffno)); + + Assert(HnswIsNeighborTuple(ntup)); + + /* Check neighbors */ + for (int i = 0; i < ntup->count; i++) + { + ItemPointer indextid = &ntup->indextids[i]; + + if (!ItemPointerIsValid(indextid)) + continue; + + /* Check if in deleted list */ + if (DeletedContains(vacuumstate->deleted, indextid)) + { + needsUpdated = true; + break; + } + } + + /* Also update if layer 0 is not full */ + /* This could indicate too many candidates being deleted during insert */ + if (!needsUpdated) + needsUpdated = !ItemPointerIsValid(&ntup->indextids[ntup->count - 1]); + + UnlockReleaseBuffer(buf); + + return needsUpdated; +} + +/* + * Repair graph for a single element + */ +static void +RepairGraphElement(HnswVacuumState * vacuumstate, HnswElement element, HnswElement entryPoint) +{ + Relation index = vacuumstate->index; + Buffer buf; + Page page; + GenericXLogState *state; + int m = vacuumstate->m; + int efConstruction = vacuumstate->efConstruction; + FmgrInfo *procinfo = vacuumstate->procinfo; + Oid collation = vacuumstate->collation; + BufferAccessStrategy bas = vacuumstate->bas; + HnswNeighborTuple ntup = vacuumstate->ntup; + Size ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(element->level, m); + + /* Skip if element is entry point */ + if (entryPoint != NULL && element->blkno == entryPoint->blkno && element->offno == entryPoint->offno) + return; + + /* Init fields */ + HnswInitNeighbors(element, m); + element->heaptids = NIL; + + /* Add element to graph, skipping itself */ + HnswInsertElement(element, entryPoint, index, procinfo, collation, m, efConstruction, true); + + /* Update neighbor tuple */ + /* Do this before getting page to minimize locking */ + HnswSetNeighborTuple(ntup, element, m); + + /* Get neighbor page */ + buf = ReadBufferExtended(index, MAIN_FORKNUM, element->neighborPage, RBM_NORMAL, bas); + LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, 0); + + /* Overwrite tuple */ + if (!PageIndexTupleOverwrite(page, element->neighborOffno, (Item) ntup, ntupSize)) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + + /* Commit */ + GenericXLogFinish(state); + UnlockReleaseBuffer(buf); + + /* Update neighbors */ + HnswUpdateNeighborPages(index, procinfo, collation, element, m, true); +} + +/* + * Repair graph entry point + */ +static void +RepairGraphEntryPoint(HnswVacuumState * vacuumstate) +{ + Relation index = vacuumstate->index; + HnswElement highestPoint = &vacuumstate->highestPoint; + HnswElement entryPoint; + MemoryContext oldCtx = MemoryContextSwitchTo(vacuumstate->tmpCtx); + + if (!BlockNumberIsValid(highestPoint->blkno)) + highestPoint = NULL; + + /* + * Repair graph for highest non-entry point. Highest point may be outdated + * due to inserts that happen during and after RemoveHeapTids. + */ + if (highestPoint != NULL) + { + /* Get a shared lock */ + LockPage(index, HNSW_UPDATE_LOCK, ShareLock); + + /* Load element */ + HnswLoadElement(highestPoint, NULL, NULL, index, vacuumstate->procinfo, vacuumstate->collation, true); + + /* Repair if needed */ + if (NeedsUpdated(vacuumstate, highestPoint)) + RepairGraphElement(vacuumstate, highestPoint, HnswGetEntryPoint(index)); + + /* Release lock */ + UnlockPage(index, HNSW_UPDATE_LOCK, ShareLock); + } + + /* Prevent concurrent inserts when possibly updating entry point */ + LockPage(index, HNSW_UPDATE_LOCK, ExclusiveLock); + + /* Get latest entry point */ + entryPoint = HnswGetEntryPoint(index); + + if (entryPoint != NULL) + { + ItemPointerData epData; + + ItemPointerSet(&epData, entryPoint->blkno, entryPoint->offno); + + if (DeletedContains(vacuumstate->deleted, &epData)) + { + /* + * Replace the entry point with the highest point. If highest + * point is outdated and empty, the entry point will be empty + * until an element is repaired. + */ + HnswUpdateMetaPage(index, HNSW_UPDATE_ENTRY_ALWAYS, highestPoint, InvalidBlockNumber, MAIN_FORKNUM); + } + else + { + /* + * Repair the entry point with the highest point. If highest point + * is outdated, this can remove connections at higher levels in + * the graph until they are repaired, but this should be fine. + */ + HnswLoadElement(entryPoint, NULL, NULL, index, vacuumstate->procinfo, vacuumstate->collation, true); + + if (NeedsUpdated(vacuumstate, entryPoint)) + { + /* Reset neighbors from previous update */ + if (highestPoint != NULL) + highestPoint->neighbors = NULL; + + RepairGraphElement(vacuumstate, entryPoint, highestPoint); + } + } + } + + /* Release lock */ + UnlockPage(index, HNSW_UPDATE_LOCK, ExclusiveLock); + + /* Reset memory context */ + MemoryContextSwitchTo(oldCtx); + MemoryContextReset(vacuumstate->tmpCtx); +} + +/* + * Repair graph for all elements + */ +static void +RepairGraph(HnswVacuumState * vacuumstate) +{ + Relation index = vacuumstate->index; + BufferAccessStrategy bas = vacuumstate->bas; + BlockNumber blkno = HNSW_HEAD_BLKNO; + + /* + * Wait for inserts to complete. Inserts before this point may have + * neighbors about to be deleted. Inserts after this point will not. + */ + LockPage(index, HNSW_UPDATE_LOCK, ExclusiveLock); + UnlockPage(index, HNSW_UPDATE_LOCK, ExclusiveLock); + + /* Repair entry point first */ + RepairGraphEntryPoint(vacuumstate); + + while (BlockNumberIsValid(blkno)) + { + Buffer buf; + Page page; + OffsetNumber offno; + OffsetNumber maxoffno; + List *elements = NIL; + ListCell *lc2; + MemoryContext oldCtx; + + vacuum_delay_point(); + + oldCtx = MemoryContextSwitchTo(vacuumstate->tmpCtx); + + buf = ReadBufferExtended(index, MAIN_FORKNUM, blkno, RBM_NORMAL, bas); + LockBuffer(buf, BUFFER_LOCK_SHARE); + page = BufferGetPage(buf); + maxoffno = PageGetMaxOffsetNumber(page); + + /* Load items into memory to minimize locking */ + for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) + { + HnswElementTuple etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, offno)); + HnswElement element; + + /* Skip neighbor tuples */ + if (!HnswIsElementTuple(etup)) + continue; + + /* Skip updating neighbors if being deleted */ + if (!ItemPointerIsValid(&etup->heaptids[0])) + continue; + + /* Create an element */ + element = HnswInitElementFromBlock(blkno, offno); + HnswLoadElementFromTuple(element, etup, false, true); + + elements = lappend(elements, element); + } + + blkno = HnswPageGetOpaque(page)->nextblkno; + + UnlockReleaseBuffer(buf); + + /* Update neighbor pages */ + foreach(lc2, elements) + { + HnswElement element = (HnswElement) lfirst(lc2); + HnswElement entryPoint; + LOCKMODE lockmode = ShareLock; + + /* Check if any neighbors point to deleted values */ + if (!NeedsUpdated(vacuumstate, element)) + continue; + + /* Get a shared lock */ + LockPage(index, HNSW_UPDATE_LOCK, lockmode); + + /* Refresh entry point for each element */ + entryPoint = HnswGetEntryPoint(index); + + /* Prevent concurrent inserts when likely updating entry point */ + if (entryPoint == NULL || element->level > entryPoint->level) + { + /* Release shared lock */ + UnlockPage(index, HNSW_UPDATE_LOCK, lockmode); + + /* Get exclusive lock */ + lockmode = ExclusiveLock; + LockPage(index, HNSW_UPDATE_LOCK, lockmode); + + /* Get latest entry point after lock is acquired */ + entryPoint = HnswGetEntryPoint(index); + } + + /* Repair connections */ + RepairGraphElement(vacuumstate, element, entryPoint); + + /* + * Update metapage if needed. Should only happen if entry point + * was replaced and highest point was outdated. + */ + if (entryPoint == NULL || element->level > entryPoint->level) + HnswUpdateMetaPage(index, HNSW_UPDATE_ENTRY_GREATER, element, InvalidBlockNumber, MAIN_FORKNUM); + + /* Release lock */ + UnlockPage(index, HNSW_UPDATE_LOCK, lockmode); + } + + /* Reset memory context */ + MemoryContextSwitchTo(oldCtx); + MemoryContextReset(vacuumstate->tmpCtx); + } +} + +/* + * Mark items as deleted + */ +static void +MarkDeleted(HnswVacuumState * vacuumstate) +{ + BlockNumber blkno = HNSW_HEAD_BLKNO; + BlockNumber insertPage = InvalidBlockNumber; + Relation index = vacuumstate->index; + BufferAccessStrategy bas = vacuumstate->bas; + + /* + * Wait for index scans to complete. Scans before this point may contain + * tuples about to be deleted. Scans after this point will not, since the + * graph has been repaired. + */ + LockPage(index, HNSW_SCAN_LOCK, ExclusiveLock); + UnlockPage(index, HNSW_SCAN_LOCK, ExclusiveLock); + + while (BlockNumberIsValid(blkno)) + { + Buffer buf; + Page page; + GenericXLogState *state; + OffsetNumber offno; + OffsetNumber maxoffno; + + vacuum_delay_point(); + + buf = ReadBufferExtended(index, MAIN_FORKNUM, blkno, RBM_NORMAL, bas); + + /* + * ambulkdelete cannot delete entries from pages that are pinned by + * other backends + * + * https://www.postgresql.org/docs/current/index-locking.html + */ + LockBufferForCleanup(buf); + + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, 0); + maxoffno = PageGetMaxOffsetNumber(page); + + /* Update element and neighbors together */ + for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) + { + HnswElementTuple etup = (HnswElementTuple) PageGetItem(page, PageGetItemId(page, offno)); + HnswNeighborTuple ntup; + Size etupSize; + Size ntupSize; + Buffer nbuf; + Page npage; + BlockNumber neighborPage; + OffsetNumber neighborOffno; + + /* Skip neighbor tuples */ + if (!HnswIsElementTuple(etup)) + continue; + + /* Skip deleted tuples */ + if (etup->deleted) + { + /* Set to first free page */ + if (!BlockNumberIsValid(insertPage)) + insertPage = blkno; + + continue; + } + + /* Skip live tuples */ + if (ItemPointerIsValid(&etup->heaptids[0])) + continue; + + /* Calculate sizes */ + etupSize = HNSW_ELEMENT_TUPLE_SIZE(etup->vec.dim); + ntupSize = HNSW_NEIGHBOR_TUPLE_SIZE(etup->level, vacuumstate->m); + + /* Get neighbor page */ + neighborPage = ItemPointerGetBlockNumber(&etup->neighbortid); + neighborOffno = ItemPointerGetOffsetNumber(&etup->neighbortid); + + if (neighborPage == blkno) + { + nbuf = buf; + npage = page; + } + else + { + nbuf = ReadBufferExtended(index, MAIN_FORKNUM, neighborPage, RBM_NORMAL, bas); + LockBuffer(nbuf, BUFFER_LOCK_EXCLUSIVE); + npage = GenericXLogRegisterBuffer(state, nbuf, 0); + } + + ntup = (HnswNeighborTuple) PageGetItem(npage, PageGetItemId(npage, neighborOffno)); + + /* Overwrite element */ + etup->deleted = 1; + MemSet(&etup->vec.x, 0, etup->vec.dim * sizeof(float)); + + /* Overwrite neighbors */ + for (int i = 0; i < ntup->count; i++) + ItemPointerSetInvalid(&ntup->indextids[i]); + + /* Overwrite element tuple */ + if (!PageIndexTupleOverwrite(page, offno, (Item) etup, etupSize)) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + + /* Overwrite neighbor tuple */ + if (!PageIndexTupleOverwrite(npage, neighborOffno, (Item) ntup, ntupSize)) + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); + + /* Commit */ + GenericXLogFinish(state); + if (nbuf != buf) + UnlockReleaseBuffer(nbuf); + + /* Set to first free page */ + if (!BlockNumberIsValid(insertPage)) + insertPage = blkno; + + /* Prepare new xlog */ + state = GenericXLogStart(index); + page = GenericXLogRegisterBuffer(state, buf, 0); + } + + blkno = HnswPageGetOpaque(page)->nextblkno; + + GenericXLogAbort(state); + UnlockReleaseBuffer(buf); + } + + /* Update insert page last, after everything has been marked as deleted */ + HnswUpdateMetaPage(index, 0, NULL, insertPage, MAIN_FORKNUM); +} + +/* + * Initialize the vacuum state + */ +static void +InitVacuumState(HnswVacuumState * vacuumstate, IndexVacuumInfo *info, IndexBulkDeleteResult *stats, IndexBulkDeleteCallback callback, void *callback_state) +{ + Relation index = info->index; + HASHCTL hash_ctl; + + if (stats == NULL) + stats = (IndexBulkDeleteResult *) palloc0(sizeof(IndexBulkDeleteResult)); + + vacuumstate->index = index; + vacuumstate->stats = stats; + vacuumstate->callback = callback; + vacuumstate->callback_state = callback_state; + vacuumstate->efConstruction = HnswGetEfConstruction(index); + vacuumstate->bas = GetAccessStrategy(BAS_BULKREAD); + vacuumstate->procinfo = index_getprocinfo(index, 1, HNSW_DISTANCE_PROC); + vacuumstate->collation = index->rd_indcollation[0]; + vacuumstate->ntup = palloc0(BLCKSZ); + vacuumstate->tmpCtx = AllocSetContextCreate(CurrentMemoryContext, + "Hnsw vacuum temporary context", + ALLOCSET_DEFAULT_SIZES); + + /* Get m from metapage */ + HnswGetMetaPageInfo(index, &vacuumstate->m, NULL); + + /* Create hash table */ + hash_ctl.keysize = sizeof(ItemPointerData); + hash_ctl.entrysize = sizeof(ItemPointerData); + hash_ctl.hcxt = CurrentMemoryContext; + vacuumstate->deleted = hash_create("hnswbulkdelete indextids", 256, &hash_ctl, HASH_ELEM | HASH_BLOBS | HASH_CONTEXT); +} + +/* + * Free resources + */ +static void +FreeVacuumState(HnswVacuumState * vacuumstate) +{ + hash_destroy(vacuumstate->deleted); + FreeAccessStrategy(vacuumstate->bas); + pfree(vacuumstate->ntup); + MemoryContextDelete(vacuumstate->tmpCtx); +} + +/* + * Bulk delete tuples from the index + */ +IndexBulkDeleteResult * +hnswbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats, + IndexBulkDeleteCallback callback, void *callback_state) +{ + HnswVacuumState vacuumstate; + + InitVacuumState(&vacuumstate, info, stats, callback, callback_state); + + /* Pass 1: Remove heap TIDs */ + RemoveHeapTids(&vacuumstate); + + /* Pass 2: Repair graph */ + RepairGraph(&vacuumstate); + + /* Pass 3: Mark as deleted */ + MarkDeleted(&vacuumstate); + + FreeVacuumState(&vacuumstate); + + return vacuumstate.stats; +} + +/* + * Clean up after a VACUUM operation + */ +IndexBulkDeleteResult * +hnswvacuumcleanup(IndexVacuumInfo *info, IndexBulkDeleteResult *stats) +{ + Relation rel = info->index; + + if (info->analyze_only) + return stats; + + /* stats is NULL if ambulkdelete not called */ + /* OK to return NULL if index not changed */ + if (stats == NULL) + return NULL; + + stats->num_pages = RelationGetNumberOfBlocks(rel); + + return stats; +} diff --git a/src/ivfbuild.c b/src/ivfbuild.c index a96581f..3915177 100644 --- a/src/ivfbuild.c +++ b/src/ivfbuild.c @@ -2,11 +2,16 @@ #include +#include "access/parallel.h" +#include "access/xact.h" #include "catalog/index.h" #include "cdb/cdbvars.h" +#include "catalog/pg_operator_d.h" +#include "catalog/pg_type_d.h" #include "ivfflat.h" #include "miscadmin.h" #include "storage/bufmgr.h" +#include "tcop/tcopprot.h" #include "utils/memutils.h" #if PG_VERSION_NUM >= 140000 @@ -24,9 +29,6 @@ #define PROGRESS_CREATEIDX_TUPLES_DONE 0 #endif -#include "catalog/pg_operator_d.h" -#include "catalog/pg_type_d.h" - #if PG_VERSION_NUM >= 130000 #define CALLBACK_ITEM_POINTER ItemPointer tid #else @@ -39,6 +41,25 @@ #define UpdateProgress(index, val) ((void)val) #endif +#if PG_VERSION_NUM >= 140000 +#include "utils/backend_status.h" +#include "utils/wait_event.h" +#endif + +#if PG_VERSION_NUM >= 120000 +#include "access/table.h" +#include "optimizer/optimizer.h" +#else +#include "access/heapam.h" +#include "optimizer/planner.h" +#include "pgstat.h" +#endif + +#define PARALLEL_KEY_IVFFLAT_SHARED UINT64CONST(0xA000000000000001) +#define PARALLEL_KEY_TUPLESORT UINT64CONST(0xA000000000000002) +#define PARALLEL_KEY_IVFFLAT_CENTERS UINT64CONST(0xA000000000000003) +#define PARALLEL_KEY_QUERY_TEXT UINT64CONST(0xA000000000000004) + /* * Add sample */ @@ -151,7 +172,6 @@ AddTupleToSort(Relation index, ItemPointer tid, Datum *values, IvfflatBuildState int closestCenter = 0; VectorArray centers = buildstate->centers; TupleTableSlot *slot = buildstate->slot; - int i; /* Detoast once for all calls */ Datum value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); @@ -164,7 +184,7 @@ AddTupleToSort(Relation index, ItemPointer tid, Datum *values, IvfflatBuildState } /* Find the list that minimizes the distance */ - for (i = 0; i < centers->length; i++) + for (int i = 0; i < centers->length; i++) { distance = DatumGetFloat8(FunctionCall2Coll(buildstate->procinfo, buildstate->collation, value, PointerGetDatum(VectorArrayGet(centers, i)))); @@ -259,15 +279,8 @@ GetNextTuple(Tuplesortstate *sortstate, TupleDesc tupdesc, TupleTableSlot *slot, static void InsertTuples(Relation index, IvfflatBuildState * buildstate, ForkNumber forkNum) { - Buffer buf; - Page page; - GenericXLogState *state; int list; IndexTuple itup = NULL; /* silence compiler warning */ - BlockNumber startPage; - BlockNumber insertPage; - Size itemsz; - int i; int64 inserted = 0; #if PG_VERSION_NUM >= 120000 @@ -283,8 +296,14 @@ InsertTuples(Relation index, IvfflatBuildState * buildstate, ForkNumber forkNum) GetNextTuple(buildstate->sortstate, tupdesc, slot, &itup, &list); - for (i = 0; i < buildstate->centers->length; i++) + for (int i = 0; i < buildstate->centers->length; i++) { + Buffer buf; + Page page; + GenericXLogState *state; + BlockNumber startPage; + BlockNumber insertPage; + /* Can take a while, so ensure we can interrupt */ /* Needs to be called when no buffer locks are held */ CHECK_FOR_INTERRUPTS(); @@ -298,7 +317,8 @@ InsertTuples(Relation index, IvfflatBuildState * buildstate, ForkNumber forkNum) while (list == i) { /* Check for free space */ - itemsz = MAXALIGN(IndexTupleSize(itup)); + Size itemsz = MAXALIGN(IndexTupleSize(itup)); + if (PageGetFreeSpace(page) < itemsz) IvfflatAppendPage(index, &buf, &page, &state, forkNum); @@ -318,7 +338,7 @@ InsertTuples(Relation index, IvfflatBuildState * buildstate, ForkNumber forkNum) IvfflatCommitBuffer(buf, state); /* Set the start and insert pages */ - IvfflatUpdateList(index, state, buildstate->listInfo[i], insertPage, InvalidBlockNumber, startPage, forkNum); + IvfflatUpdateList(index, buildstate->listInfo[i], insertPage, InvalidBlockNumber, startPage, forkNum); } } @@ -352,9 +372,7 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In buildstate->collation = index->rd_indcollation[0]; /* Require more than one dimension for spherical k-means */ - /* Lists check for backwards compatibility */ - /* TODO Remove lists check in 0.3.0 */ - if (buildstate->kmeansnormprocinfo != NULL && buildstate->dimensions == 1 && buildstate->lists > 1) + if (buildstate->kmeansnormprocinfo != NULL && buildstate->dimensions == 1) elog(ERROR, "dimensions must be greater than one for this opclass"); /* Create tuple description for sorting */ @@ -388,6 +406,8 @@ InitBuildState(IvfflatBuildState * buildstate, Relation heap, Relation index, In buildstate->listSums = palloc0(sizeof(double) * buildstate->lists); buildstate->listCounts = palloc0(sizeof(int) * buildstate->lists); #endif + + buildstate->ivfleader = NULL; } /* @@ -484,33 +504,33 @@ static void CreateListPages(Relation index, VectorArray centers, int dimensions, int lists, ForkNumber forkNum, ListInfo * *listInfo) { - int i; Buffer buf; Page page; GenericXLogState *state; - OffsetNumber offno; - Size itemsz; + Size listSize; IvfflatList list; - itemsz = MAXALIGN(IVFFLAT_LIST_SIZE(dimensions)); - list = palloc(itemsz); + listSize = MAXALIGN(IVFFLAT_LIST_SIZE(dimensions)); + list = palloc(listSize); buf = IvfflatNewBuffer(index, forkNum); IvfflatInitRegisterPage(index, &buf, &page, &state); - for (i = 0; i < lists; i++) + for (int i = 0; i < lists; i++) { + OffsetNumber offno; + /* Load list */ list->startPage = InvalidBlockNumber; list->insertPage = InvalidBlockNumber; memcpy(&list->center, VectorArrayGet(centers, i), VECTOR_SIZE(dimensions)); /* Ensure free space */ - if (PageGetFreeSpace(page) < itemsz) + if (PageGetFreeSpace(page) < listSize) IvfflatAppendPage(index, &buf, &page, &state, forkNum); /* Add the item */ - offno = PageAddItem(page, (Item) list, itemsz, InvalidOffsetNumber, false, false); + offno = PageAddItem(page, (Item) list, listSize, InvalidOffsetNumber, false, false); if (offno == InvalidOffsetNumber) elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); @@ -534,7 +554,7 @@ PrintKmeansMetrics(IvfflatBuildState * buildstate) elog(INFO, "inertia: %.3e", buildstate->inertia); /* Calculate Davies-Bouldin index */ - if (buildstate->lists > 1) + if (buildstate->lists > 1 && !buildstate->ivfleader) { double db = 0.0; @@ -570,49 +590,477 @@ PrintKmeansMetrics(IvfflatBuildState * buildstate) #endif /* - * Scan table for tuples to index + * Within leader, wait for end of heap scan + */ +static double +ParallelHeapScan(IvfflatBuildState * buildstate) +{ + IvfflatShared *ivfshared = buildstate->ivfleader->ivfshared; + int nparticipanttuplesorts; + double reltuples; + + nparticipanttuplesorts = buildstate->ivfleader->nparticipanttuplesorts; + for (;;) + { + SpinLockAcquire(&ivfshared->mutex); + if (ivfshared->nparticipantsdone == nparticipanttuplesorts) + { + buildstate->indtuples = ivfshared->indtuples; + reltuples = ivfshared->reltuples; +#ifdef IVFFLAT_KMEANS_DEBUG + buildstate->inertia = ivfshared->inertia; +#endif + SpinLockRelease(&ivfshared->mutex); + break; + } + SpinLockRelease(&ivfshared->mutex); + + ConditionVariableSleep(&ivfshared->workersdonecv, + WAIT_EVENT_PARALLEL_CREATE_INDEX_SCAN); + } + + ConditionVariableCancelSleep(); + + return reltuples; +} + +/* + * Perform a worker's portion of a parallel sort */ static void -ScanTable(IvfflatBuildState * buildstate) +IvfflatParallelScanAndSort(IvfflatSpool * ivfspool, IvfflatShared * ivfshared, Sharedsort *sharedsort, Vector * ivfcenters, int sortmem, bool progress) { + SortCoordinate coordinate; + IvfflatBuildState buildstate; #if PG_VERSION_NUM >= 120000 - buildstate->reltuples = table_index_build_scan(buildstate->heap, buildstate->index, buildstate->indexInfo, - true, true, BuildCallback, (void *) buildstate, NULL); + TableScanDesc scan; #else - buildstate->reltuples = IndexBuildHeapScan(buildstate->heap, buildstate->index, buildstate->indexInfo, - true, BuildCallback, (void *) buildstate, NULL); + HeapScanDesc scan; #endif + double reltuples; + IndexInfo *indexInfo; + + /* Sort options, which must match AssignTuples */ + AttrNumber attNums[] = {1}; + Oid sortOperators[] = {Int4LessOperator}; + Oid sortCollations[] = {InvalidOid}; + bool nullsFirstFlags[] = {false}; + + /* Initialize local tuplesort coordination state */ + coordinate = palloc0(sizeof(SortCoordinateData)); + coordinate->isWorker = true; + coordinate->nParticipants = -1; + coordinate->sharedsort = sharedsort; + + /* Join parallel scan */ + indexInfo = BuildIndexInfo(ivfspool->index); + indexInfo->ii_Concurrent = ivfshared->isconcurrent; + InitBuildState(&buildstate, ivfspool->heap, ivfspool->index, indexInfo); + memcpy(buildstate.centers->items, ivfcenters, VECTOR_SIZE(buildstate.centers->dim) * buildstate.centers->maxlen); + buildstate.centers->length = buildstate.centers->maxlen; + ivfspool->sortstate = tuplesort_begin_heap(buildstate.tupdesc, 1, attNums, sortOperators, sortCollations, nullsFirstFlags, sortmem, coordinate, false); + buildstate.sortstate = ivfspool->sortstate; +#if PG_VERSION_NUM >= 120000 + scan = table_beginscan_parallel(ivfspool->heap, + ParallelTableScanFromIvfflatShared(ivfshared)); + reltuples = table_index_build_scan(ivfspool->heap, ivfspool->index, indexInfo, + true, progress, BuildCallback, + (void *) &buildstate, scan); +#else + scan = heap_beginscan_parallel(ivfspool->heap, &ivfshared->heapdesc); + reltuples = IndexBuildHeapScan(ivfspool->heap, ivfspool->index, indexInfo, + true, BuildCallback, + (void *) &buildstate, scan); +#endif + + /* Execute this worker's part of the sort */ + tuplesort_performsort(ivfspool->sortstate); + + /* Record statistics */ + SpinLockAcquire(&ivfshared->mutex); + ivfshared->nparticipantsdone++; + ivfshared->reltuples += reltuples; + ivfshared->indtuples += buildstate.indtuples; +#ifdef IVFFLAT_KMEANS_DEBUG + ivfshared->inertia += buildstate.inertia; +#endif + SpinLockRelease(&ivfshared->mutex); + + /* Log statistics */ + if (progress) + ereport(DEBUG1, (errmsg("leader processed " INT64_FORMAT " tuples", (int64) reltuples))); + else + ereport(DEBUG1, (errmsg("worker processed " INT64_FORMAT " tuples", (int64) reltuples))); + + /* Notify leader */ + ConditionVariableSignal(&ivfshared->workersdonecv); + + /* We can end tuplesorts immediately */ + tuplesort_end(ivfspool->sortstate); + + FreeBuildState(&buildstate); } /* - * Create entry pages + * Perform work within a launched parallel process + */ +void +IvfflatParallelBuildMain(dsm_segment *seg, shm_toc *toc) +{ + char *sharedquery; + IvfflatSpool *ivfspool; + IvfflatShared *ivfshared; + Sharedsort *sharedsort; + Vector *ivfcenters; + Relation heapRel; + Relation indexRel; + LOCKMODE heapLockmode; + LOCKMODE indexLockmode; + int sortmem; + + /* Set debug_query_string for individual workers first */ + sharedquery = shm_toc_lookup(toc, PARALLEL_KEY_QUERY_TEXT, true); + debug_query_string = sharedquery; + + /* Report the query string from leader */ + pgstat_report_activity(STATE_RUNNING, debug_query_string); + + /* Look up shared state */ + ivfshared = shm_toc_lookup(toc, PARALLEL_KEY_IVFFLAT_SHARED, false); + + /* Open relations using lock modes known to be obtained by index.c */ + if (!ivfshared->isconcurrent) + { + heapLockmode = ShareLock; + indexLockmode = AccessExclusiveLock; + } + else + { + heapLockmode = ShareUpdateExclusiveLock; + indexLockmode = RowExclusiveLock; + } + + /* Open relations within worker */ +#if PG_VERSION_NUM >= 120000 + heapRel = table_open(ivfshared->heaprelid, heapLockmode); +#else + heapRel = heap_open(ivfshared->heaprelid, heapLockmode); +#endif + indexRel = index_open(ivfshared->indexrelid, indexLockmode); + + /* Initialize worker's own spool */ + ivfspool = (IvfflatSpool *) palloc0(sizeof(IvfflatSpool)); + ivfspool->heap = heapRel; + ivfspool->index = indexRel; + + /* Look up shared state private to tuplesort.c */ + sharedsort = shm_toc_lookup(toc, PARALLEL_KEY_TUPLESORT, false); + tuplesort_attach_shared(sharedsort, seg); + + ivfcenters = shm_toc_lookup(toc, PARALLEL_KEY_IVFFLAT_CENTERS, false); + + /* Perform sorting */ + sortmem = maintenance_work_mem / ivfshared->scantuplesortstates; + IvfflatParallelScanAndSort(ivfspool, ivfshared, sharedsort, ivfcenters, sortmem, false); + + /* Close relations within worker */ + index_close(indexRel, indexLockmode); +#if PG_VERSION_NUM >= 120000 + table_close(heapRel, heapLockmode); +#else + heap_close(heapRel, heapLockmode); +#endif +} + +/* + * End parallel build */ static void -CreateEntryPages(IvfflatBuildState * buildstate, ForkNumber forkNum) +IvfflatEndParallel(IvfflatLeader * ivfleader) +{ + /* Shutdown worker processes */ + WaitForParallelWorkersToFinish(ivfleader->pcxt); + + /* Free last reference to MVCC snapshot, if one was used */ + if (IsMVCCSnapshot(ivfleader->snapshot)) + UnregisterSnapshot(ivfleader->snapshot); + DestroyParallelContext(ivfleader->pcxt); + ExitParallelMode(); +} + +/* + * Return size of shared memory required for parallel index build + */ +static Size +ParallelEstimateShared(Relation heap, Snapshot snapshot) +{ +#if PG_VERSION_NUM >= 120000 + return add_size(BUFFERALIGN(sizeof(IvfflatShared)), table_parallelscan_estimate(heap, snapshot)); +#else + if (!IsMVCCSnapshot(snapshot)) + { + Assert(snapshot == SnapshotAny); + return sizeof(IvfflatShared); + } + + return add_size(offsetof(IvfflatShared, heapdesc) + + offsetof(ParallelHeapScanDescData, phs_snapshot_data), + EstimateSnapshotSpace(snapshot)); +#endif +} + +/* + * Within leader, participate as a parallel worker + */ +static void +IvfflatLeaderParticipateAsWorker(IvfflatBuildState * buildstate) { + IvfflatLeader *ivfleader = buildstate->ivfleader; + IvfflatSpool *leaderworker; + int sortmem; + + /* Allocate memory and initialize private spool */ + leaderworker = (IvfflatSpool *) palloc0(sizeof(IvfflatSpool)); + leaderworker->heap = buildstate->heap; + leaderworker->index = buildstate->index; + + /* Perform work common to all participants */ + sortmem = maintenance_work_mem / ivfleader->nparticipanttuplesorts; + IvfflatParallelScanAndSort(leaderworker, ivfleader->ivfshared, + ivfleader->sharedsort, ivfleader->ivfcenters, + sortmem, true); +} + +/* + * Begin parallel build + */ +static void +IvfflatBeginParallel(IvfflatBuildState * buildstate, bool isconcurrent, int request) +{ + ParallelContext *pcxt; + int scantuplesortstates; + Snapshot snapshot; + Size estivfshared; + Size estsort; + Size estcenters; + IvfflatShared *ivfshared; + Sharedsort *sharedsort; + Vector *ivfcenters; + IvfflatLeader *ivfleader = (IvfflatLeader *) palloc0(sizeof(IvfflatLeader)); + bool leaderparticipates = true; + int querylen; + +#ifdef DISABLE_LEADER_PARTICIPATION + leaderparticipates = false; +#endif + + /* Enter parallel mode and create context */ + EnterParallelMode(); + Assert(request > 0); +#if PG_VERSION_NUM >= 120000 + pcxt = CreateParallelContext("vector", "IvfflatParallelBuildMain", request); +#else + pcxt = CreateParallelContext("vector", "IvfflatParallelBuildMain", request, true); +#endif + + scantuplesortstates = leaderparticipates ? request + 1 : request; + + /* Get snapshot for table scan */ + if (!isconcurrent) + snapshot = SnapshotAny; + else + snapshot = RegisterSnapshot(GetTransactionSnapshot()); + + /* Estimate size of workspaces */ + estivfshared = ParallelEstimateShared(buildstate->heap, snapshot); + shm_toc_estimate_chunk(&pcxt->estimator, estivfshared); + estsort = tuplesort_estimate_shared(scantuplesortstates); + shm_toc_estimate_chunk(&pcxt->estimator, estsort); + estcenters = VECTOR_SIZE(buildstate->dimensions) * buildstate->lists; + shm_toc_estimate_chunk(&pcxt->estimator, estcenters); + shm_toc_estimate_keys(&pcxt->estimator, 3); + + /* Finally, estimate PARALLEL_KEY_QUERY_TEXT space */ + if (debug_query_string) + { + querylen = strlen(debug_query_string); + shm_toc_estimate_chunk(&pcxt->estimator, querylen + 1); + shm_toc_estimate_keys(&pcxt->estimator, 1); + } + else + querylen = 0; /* keep compiler quiet */ + + /* Everyone's had a chance to ask for space, so now create the DSM */ + InitializeParallelDSM(pcxt); + + /* If no DSM segment was available, back out (do serial build) */ + if (pcxt->seg == NULL) + { + if (IsMVCCSnapshot(snapshot)) + UnregisterSnapshot(snapshot); + DestroyParallelContext(pcxt); + ExitParallelMode(); + return; + } + + /* Store shared build state, for which we reserved space */ + ivfshared = (IvfflatShared *) shm_toc_allocate(pcxt->toc, estivfshared); + /* Initialize immutable state */ + ivfshared->heaprelid = RelationGetRelid(buildstate->heap); + ivfshared->indexrelid = RelationGetRelid(buildstate->index); + ivfshared->isconcurrent = isconcurrent; + ivfshared->scantuplesortstates = scantuplesortstates; + ConditionVariableInit(&ivfshared->workersdonecv); + SpinLockInit(&ivfshared->mutex); + /* Initialize mutable state */ + ivfshared->nparticipantsdone = 0; + ivfshared->reltuples = 0; + ivfshared->indtuples = 0; +#ifdef IVFFLAT_KMEANS_DEBUG + ivfshared->inertia = 0; +#endif +#if PG_VERSION_NUM >= 120000 + table_parallelscan_initialize(buildstate->heap, + ParallelTableScanFromIvfflatShared(ivfshared), + snapshot); +#else + heap_parallelscan_initialize(&ivfshared->heapdesc, buildstate->heap, snapshot); +#endif + + /* Store shared tuplesort-private state, for which we reserved space */ + sharedsort = (Sharedsort *) shm_toc_allocate(pcxt->toc, estsort); + tuplesort_initialize_shared(sharedsort, scantuplesortstates, + pcxt->seg); + + ivfcenters = (Vector *) shm_toc_allocate(pcxt->toc, estcenters); + memcpy(ivfcenters, buildstate->centers->items, estcenters); + + shm_toc_insert(pcxt->toc, PARALLEL_KEY_IVFFLAT_SHARED, ivfshared); + shm_toc_insert(pcxt->toc, PARALLEL_KEY_TUPLESORT, sharedsort); + shm_toc_insert(pcxt->toc, PARALLEL_KEY_IVFFLAT_CENTERS, ivfcenters); + + /* Store query string for workers */ + if (debug_query_string) + { + char *sharedquery; + + sharedquery = (char *) shm_toc_allocate(pcxt->toc, querylen + 1); + memcpy(sharedquery, debug_query_string, querylen + 1); + shm_toc_insert(pcxt->toc, PARALLEL_KEY_QUERY_TEXT, sharedquery); + } + + /* Launch workers, saving status for leader/caller */ + LaunchParallelWorkers(pcxt); + ivfleader->pcxt = pcxt; + ivfleader->nparticipanttuplesorts = pcxt->nworkers_launched; + if (leaderparticipates) + ivfleader->nparticipanttuplesorts++; + ivfleader->ivfshared = ivfshared; + ivfleader->sharedsort = sharedsort; + ivfleader->snapshot = snapshot; + ivfleader->ivfcenters = ivfcenters; + + /* If no workers were successfully launched, back out (do serial build) */ + if (pcxt->nworkers_launched == 0) + { + IvfflatEndParallel(ivfleader); + return; + } + + /* Log participants */ + ereport(DEBUG1, (errmsg("using %d parallel workers", pcxt->nworkers_launched))); + + /* Save leader state now that it's clear build will be parallel */ + buildstate->ivfleader = ivfleader; + + /* Join heap scan ourselves */ + if (leaderparticipates) + IvfflatLeaderParticipateAsWorker(buildstate); + + /* Wait for all launched workers */ + WaitForParallelWorkersToAttach(pcxt); +} + +/* + * Scan table for tuples to index + */ +static void +AssignTuples(IvfflatBuildState * buildstate) +{ + int parallel_workers = 0; + SortCoordinate coordinate = NULL; + + /* Sort options, which must match IvfflatParallelScanAndSort */ AttrNumber attNums[] = {1}; Oid sortOperators[] = {Int4LessOperator}; Oid sortCollations[] = {InvalidOid}; bool nullsFirstFlags[] = {false}; - UpdateProgress(PROGRESS_CREATEIDX_SUBPHASE, PROGRESS_IVFFLAT_PHASE_SORT); + UpdateProgress(PROGRESS_CREATEIDX_SUBPHASE, PROGRESS_IVFFLAT_PHASE_ASSIGN); + + /* Calculate parallel workers */ + if (buildstate->heap != NULL) + parallel_workers = plan_create_index_workers(RelationGetRelid(buildstate->heap), RelationGetRelid(buildstate->index)); - buildstate->sortstate = tuplesort_begin_heap(buildstate->tupdesc, 1, attNums, sortOperators, sortCollations, nullsFirstFlags, maintenance_work_mem, NULL, false); + /* Attempt to launch parallel worker scan when required */ + if (parallel_workers > 0) + IvfflatBeginParallel(buildstate, buildstate->indexInfo->ii_Concurrent, parallel_workers); + + /* Set up coordination state if at least one worker launched */ + if (buildstate->ivfleader) + { + coordinate = (SortCoordinate) palloc0(sizeof(SortCoordinateData)); + coordinate->isWorker = false; + coordinate->nParticipants = buildstate->ivfleader->nparticipanttuplesorts; + coordinate->sharedsort = buildstate->ivfleader->sharedsort; + } + + /* Begin serial/leader tuplesort */ + buildstate->sortstate = tuplesort_begin_heap(buildstate->tupdesc, 1, attNums, sortOperators, sortCollations, nullsFirstFlags, maintenance_work_mem, coordinate, false); /* Add tuples to sort */ if (buildstate->heap != NULL) - IvfflatBench("assign tuples", ScanTable(buildstate)); - - /* Sort */ - IvfflatBench("sort tuples", tuplesort_performsort(buildstate->sortstate)); + { + if (buildstate->ivfleader) + buildstate->reltuples = ParallelHeapScan(buildstate); + else + { +#if PG_VERSION_NUM >= 120000 + buildstate->reltuples = table_index_build_scan(buildstate->heap, buildstate->index, buildstate->indexInfo, + true, true, BuildCallback, (void *) buildstate, NULL); +#else + buildstate->reltuples = IndexBuildHeapScan(buildstate->heap, buildstate->index, buildstate->indexInfo, + true, BuildCallback, (void *) buildstate, NULL); +#endif + } #ifdef IVFFLAT_KMEANS_DEBUG - PrintKmeansMetrics(buildstate); + PrintKmeansMetrics(buildstate); #endif + } +} + +/* + * Create entry pages + */ +static void +CreateEntryPages(IvfflatBuildState * buildstate, ForkNumber forkNum) +{ + /* Assign */ + IvfflatBench("assign tuples", AssignTuples(buildstate)); - /* Insert */ + /* Sort */ + IvfflatBench("sort tuples", tuplesort_performsort(buildstate->sortstate)); + + /* Load */ IvfflatBench("load tuples", InsertTuples(buildstate->index, buildstate, forkNum)); + + /* End sort */ tuplesort_end(buildstate->sortstate); + + /* End parallel build */ + if (buildstate->ivfleader) + IvfflatEndParallel(buildstate->ivfleader); } /* diff --git a/src/ivfflat.c b/src/ivfflat.c index ef926a8..d6383f4 100644 --- a/src/ivfflat.c +++ b/src/ivfflat.c @@ -20,11 +20,11 @@ static relopt_kind ivfflat_relopt_kind; * Initialize index options and variables */ void -_PG_init(void) +IvfflatInit(void) { ivfflat_relopt_kind = add_reloption_kind(); add_int_reloption(ivfflat_relopt_kind, "lists", "Number of inverted lists", - IVFFLAT_DEFAULT_LISTS, 1, IVFFLAT_MAX_LISTS + IVFFLAT_DEFAULT_LISTS, IVFFLAT_MIN_LISTS, IVFFLAT_MAX_LISTS #if PG_VERSION_NUM >= 130000 ,AccessExclusiveLock #endif @@ -32,7 +32,7 @@ _PG_init(void) DefineCustomIntVariable("ivfflat.probes", "Sets the number of probes", "Valid range is 1..lists.", &ivfflat_probes, - 1, 1, IVFFLAT_MAX_LISTS, PGC_USERSET, 0, NULL, NULL, NULL); + IVFFLAT_DEFAULT_PROBES, IVFFLAT_MIN_LISTS, IVFFLAT_MAX_LISTS, PGC_USERSET, 0, NULL, NULL, NULL); } /* @@ -48,8 +48,8 @@ ivfflatbuildphasename(int64 phasenum) return "initializing"; case PROGRESS_IVFFLAT_PHASE_KMEANS: return "performing k-means"; - case PROGRESS_IVFFLAT_PHASE_SORT: - return "sorting tuples"; + case PROGRESS_IVFFLAT_PHASE_ASSIGN: + return "assigning tuples"; case PROGRESS_IVFFLAT_PHASE_LOAD: return "loading tuples"; default: @@ -71,7 +71,7 @@ ivfflatcostestimate(PlannerInfo *root, IndexPath *path, double loop_count, int lists; double ratio; double spc_seq_page_cost; - Relation indexRel; + Relation index; #if PG_VERSION_NUM < 120000 List *qinfos; #endif @@ -89,9 +89,9 @@ ivfflatcostestimate(PlannerInfo *root, IndexPath *path, double loop_count, MemSet(&costs, 0, sizeof(costs)); - indexRel = index_open(path->indexinfo->indexoid, NoLock); - lists = IvfflatGetLists(indexRel); - index_close(indexRel, NoLock); + index = index_open(path->indexinfo->indexoid, NoLock); + IvfflatGetMetaPageInfo(index, &lists, NULL); + index_close(index, NoLock); /* Get the ratio of lists that we need to visit */ ratio = ((double) ivfflat_probes) / lists; diff --git a/src/ivfflat.h b/src/ivfflat.h index 5bd7622..1eb35b0 100644 --- a/src/ivfflat.h +++ b/src/ivfflat.h @@ -3,14 +3,11 @@ #include "postgres.h" -#if PG_VERSION_NUM < 110000 -#error "Requires PostgreSQL 11+" -#endif - #include "access/generic_xlog.h" +#include "access/parallel.h" #include "access/reloptions.h" #include "nodes/execnodes.h" -#include "port.h" /* for strtof() and random() */ +#include "port.h" /* for random() */ #include "utils/sampling.h" #include "utils/tuplesort.h" #include "vector.h" @@ -19,6 +16,10 @@ #include "common/pg_prng.h" #endif +#if PG_VERSION_NUM < 120000 +#include "access/relscan.h" +#endif + #ifdef IVFFLAT_BENCH #include "portability/instr_time.h" #endif @@ -39,13 +40,16 @@ #define IVFFLAT_METAPAGE_BLKNO 0 #define IVFFLAT_HEAD_BLKNO 1 /* first list page */ +/* IVFFlat parameters */ #define IVFFLAT_DEFAULT_LISTS 100 +#define IVFFLAT_MIN_LISTS 1 #define IVFFLAT_MAX_LISTS 32768 +#define IVFFLAT_DEFAULT_PROBES 1 /* Build phases */ /* PROGRESS_CREATEIDX_SUBPHASE_INITIALIZE is 1 */ #define PROGRESS_IVFFLAT_PHASE_KMEANS 2 -#define PROGRESS_IVFFLAT_PHASE_SORT 3 +#define PROGRESS_IVFFLAT_PHASE_ASSIGN 3 #define PROGRESS_IVFFLAT_PHASE_LOAD 4 #define IVFFLAT_LIST_SIZE(_dim) (offsetof(IvfflatListData, center) + VECTOR_SIZE(_dim)) @@ -79,9 +83,6 @@ /* Variables */ extern int ivfflat_probes; -/* Exported functions */ -PGDLLEXPORT void _PG_init(void); - typedef struct VectorArrayData { int length; @@ -105,6 +106,56 @@ typedef struct IvfflatOptions int lists; /* number of lists */ } IvfflatOptions; +typedef struct IvfflatSpool +{ + Tuplesortstate *sortstate; + Relation heap; + Relation index; +} IvfflatSpool; + +typedef struct IvfflatShared +{ + /* Immutable state */ + Oid heaprelid; + Oid indexrelid; + bool isconcurrent; + int scantuplesortstates; + + /* Worker progress */ + ConditionVariable workersdonecv; + + /* Mutex for mutable state */ + slock_t mutex; + + /* Mutable state */ + int nparticipantsdone; + double reltuples; + double indtuples; + +#ifdef IVFFLAT_KMEANS_DEBUG + double inertia; +#endif + +#if PG_VERSION_NUM < 120000 + ParallelHeapScanDescData heapdesc; /* must come last */ +#endif +} IvfflatShared; + +#if PG_VERSION_NUM >= 120000 +#define ParallelTableScanFromIvfflatShared(shared) \ + (ParallelTableScanDesc) ((char *) (shared) + BUFFERALIGN(sizeof(IvfflatShared))) +#endif + +typedef struct IvfflatLeader +{ + ParallelContext *pcxt; + int nparticipanttuplesorts; + IvfflatShared *ivfshared; + Sharedsort *sharedsort; + Snapshot snapshot; + Vector *ivfcenters; +} IvfflatLeader; + typedef struct IvfflatBuildState { /* Info */ @@ -150,6 +201,9 @@ typedef struct IvfflatBuildState /* Memory */ MemoryContext tmpCtx; + + /* Parallel builds */ + IvfflatLeader *ivfleader; } IvfflatBuildState; typedef struct IvfflatMetaPageData @@ -190,8 +244,8 @@ typedef struct IvfflatScanList typedef struct IvfflatScanOpaqueData { int probes; + int dimensions; bool first; - Buffer buf; /* Sorting */ Tuplesortstate *sortstate; @@ -221,15 +275,18 @@ VectorArray VectorArrayInit(int maxlen, int dimensions); void VectorArrayFree(VectorArray arr); void PrintVectorArray(char *msg, VectorArray arr); void IvfflatKmeans(Relation index, VectorArray samples, VectorArray centers); -FmgrInfo *IvfflatOptionalProcInfo(Relation rel, uint16 procnum); +FmgrInfo *IvfflatOptionalProcInfo(Relation index, uint16 procnum); bool IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result); int IvfflatGetLists(Relation index); -void IvfflatUpdateList(Relation index, GenericXLogState *state, ListInfo listInfo, BlockNumber insertPage, BlockNumber originalInsertPage, BlockNumber startPage, ForkNumber forkNum); +void IvfflatGetMetaPageInfo(Relation index, int *lists, int *dimensions); +void IvfflatUpdateList(Relation index, ListInfo listInfo, BlockNumber insertPage, BlockNumber originalInsertPage, BlockNumber startPage, ForkNumber forkNum); void IvfflatCommitBuffer(Buffer buf, GenericXLogState *state); void IvfflatAppendPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state, ForkNumber forkNum); Buffer IvfflatNewBuffer(Relation index, ForkNumber forkNum); void IvfflatInitPage(Buffer buf, Page page); void IvfflatInitRegisterPage(Relation index, Buffer *buf, Page *page, GenericXLogState **state); +void IvfflatInit(void); +PGDLLEXPORT void IvfflatParallelBuildMain(dsm_segment *seg, shm_toc *toc); /* Index access methods */ IndexBuildResult *ivfflatbuild(Relation heap, Relation index, IndexInfo *indexInfo); diff --git a/src/ivfinsert.c b/src/ivfinsert.c index 8761f6a..103fe49 100644 --- a/src/ivfinsert.c +++ b/src/ivfinsert.c @@ -4,42 +4,44 @@ #include "ivfflat.h" #include "storage/bufmgr.h" +#include "storage/lmgr.h" #include "utils/memutils.h" /* * Find the list that minimizes the distance function */ static void -FindInsertPage(Relation rel, Datum *values, BlockNumber *insertPage, ListInfo * listInfo) +FindInsertPage(Relation index, Datum *values, BlockNumber *insertPage, ListInfo * listInfo) { - Buffer cbuf; - Page cpage; - IvfflatList list; - double distance; double minDistance = DBL_MAX; BlockNumber nextblkno = IVFFLAT_HEAD_BLKNO; FmgrInfo *procinfo; Oid collation; - OffsetNumber offno; - OffsetNumber maxoffno; /* Avoid compiler warning */ listInfo->blkno = nextblkno; listInfo->offno = FirstOffsetNumber; - procinfo = index_getprocinfo(rel, 1, IVFFLAT_DISTANCE_PROC); - collation = rel->rd_indcollation[0]; + procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC); + collation = index->rd_indcollation[0]; /* Search all list pages */ while (BlockNumberIsValid(nextblkno)) { - cbuf = ReadBuffer(rel, nextblkno); + Buffer cbuf; + Page cpage; + OffsetNumber maxoffno; + + cbuf = ReadBuffer(index, nextblkno); LockBuffer(cbuf, BUFFER_LOCK_SHARE); cpage = BufferGetPage(cbuf); maxoffno = PageGetMaxOffsetNumber(cpage); - for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) + for (OffsetNumber offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) { + IvfflatList list; + double distance; + list = (IvfflatList) PageGetItem(cpage, PageGetItemId(cpage, offno)); distance = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, values[0], PointerGetDatum(&list->center))); @@ -62,7 +64,7 @@ FindInsertPage(Relation rel, Datum *values, BlockNumber *insertPage, ListInfo * * Insert a tuple into the index */ static void -InsertTuple(Relation rel, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heapRel) +InsertTuple(Relation index, Datum *values, bool *isnull, ItemPointer heap_tid, Relation heapRel) { IndexTuple itup; Datum value; @@ -79,33 +81,33 @@ InsertTuple(Relation rel, Datum *values, bool *isnull, ItemPointer heap_tid, Rel value = PointerGetDatum(PG_DETOAST_DATUM(values[0])); /* Normalize if needed */ - normprocinfo = IvfflatOptionalProcInfo(rel, IVFFLAT_NORM_PROC); + normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC); if (normprocinfo != NULL) { - if (!IvfflatNormValue(normprocinfo, rel->rd_indcollation[0], &value, NULL)) + if (!IvfflatNormValue(normprocinfo, index->rd_indcollation[0], &value, NULL)) return; } /* Find the insert page - sets the page and list info */ - FindInsertPage(rel, values, &insertPage, &listInfo); + FindInsertPage(index, values, &insertPage, &listInfo); Assert(BlockNumberIsValid(insertPage)); originalInsertPage = insertPage; /* Form tuple */ - itup = index_form_tuple(RelationGetDescr(rel), &value, isnull); + itup = index_form_tuple(RelationGetDescr(index), &value, isnull); itup->t_tid = *heap_tid; /* Get tuple size */ itemsz = MAXALIGN(IndexTupleSize(itup)); - Assert(itemsz <= BLCKSZ - MAXALIGN(SizeOfPageHeaderData) - MAXALIGN(sizeof(IvfflatPageOpaqueData))); + Assert(itemsz <= BLCKSZ - MAXALIGN(SizeOfPageHeaderData) - MAXALIGN(sizeof(IvfflatPageOpaqueData)) - sizeof(ItemIdData)); /* Find a page to insert the item */ for (;;) { - buf = ReadBuffer(rel, insertPage); + buf = ReadBuffer(index, insertPage); LockBuffer(buf, BUFFER_LOCK_EXCLUSIVE); - state = GenericXLogStart(rel); + state = GenericXLogStart(index); page = GenericXLogRegisterBuffer(state, buf, 0); if (PageGetFreeSpace(page) >= itemsz) @@ -121,23 +123,16 @@ InsertTuple(Relation rel, Datum *values, bool *isnull, ItemPointer heap_tid, Rel } else { - Buffer metabuf; Buffer newbuf; Page newpage; - /* - * From ReadBufferExtended: Caller is responsible for ensuring - * that only one backend tries to extend a relation at the same - * time! - */ - metabuf = ReadBuffer(rel, IVFFLAT_METAPAGE_BLKNO); - LockBuffer(metabuf, BUFFER_LOCK_EXCLUSIVE); - /* Add a new page */ - newbuf = IvfflatNewBuffer(rel, MAIN_FORKNUM); - newpage = GenericXLogRegisterBuffer(state, newbuf, GENERIC_XLOG_FULL_IMAGE); + LockRelationForExtension(index, ExclusiveLock); + newbuf = IvfflatNewBuffer(index, MAIN_FORKNUM); + UnlockRelationForExtension(index, ExclusiveLock); /* Init new page */ + newpage = GenericXLogRegisterBuffer(state, newbuf, GENERIC_XLOG_FULL_IMAGE); IvfflatInitPage(newbuf, newpage); /* Update insert page */ @@ -147,18 +142,13 @@ InsertTuple(Relation rel, Datum *values, bool *isnull, ItemPointer heap_tid, Rel IvfflatPageGetOpaque(page)->nextblkno = insertPage; /* Commit */ - MarkBufferDirty(newbuf); - MarkBufferDirty(buf); GenericXLogFinish(state); - /* Unlock extend relation lock as early as possible */ - UnlockReleaseBuffer(metabuf); - /* Unlock previous buffer */ UnlockReleaseBuffer(buf); /* Prepare new buffer */ - state = GenericXLogStart(rel); + state = GenericXLogStart(index); buf = newbuf; page = GenericXLogRegisterBuffer(state, buf, 0); break; @@ -167,13 +157,13 @@ InsertTuple(Relation rel, Datum *values, bool *isnull, ItemPointer heap_tid, Rel /* Add to next offset */ if (PageAddItem(page, (Item) itup, itemsz, InvalidOffsetNumber, false, false) == InvalidOffsetNumber) - elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(rel)); + elog(ERROR, "failed to add index item to \"%s\"", RelationGetRelationName(index)); IvfflatCommitBuffer(buf, state); /* Update the insert page */ if (insertPage != originalInsertPage) - IvfflatUpdateList(rel, state, listInfo, insertPage, originalInsertPage, InvalidBlockNumber, MAIN_FORKNUM); + IvfflatUpdateList(index, listInfo, insertPage, originalInsertPage, InvalidBlockNumber, MAIN_FORKNUM); } /* diff --git a/src/ivfkmeans.c b/src/ivfkmeans.c index eb94de0..a87edcb 100644 --- a/src/ivfkmeans.c +++ b/src/ivfkmeans.c @@ -16,12 +16,7 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float *low { FmgrInfo *procinfo; Oid collation; - int i; int64 j; - double distance; - double sum; - double choice; - Vector *vec; float *weight = palloc(samples->length * sizeof(float)); int numCenters = centers->maxlen; int numSamples = samples->length; @@ -34,17 +29,21 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float *low centers->length++; for (j = 0; j < numSamples; j++) - weight[j] = DBL_MAX; + weight[j] = FLT_MAX; - for (i = 0; i < numCenters; i++) + for (int i = 0; i < numCenters; i++) { + double sum; + double choice; + CHECK_FOR_INTERRUPTS(); sum = 0.0; for (j = 0; j < numSamples; j++) { - vec = VectorArrayGet(samples, j); + Vector *vec = VectorArrayGet(samples, j); + double distance; /* Only need to compute distance for new center */ /* TODO Use triangle inequality to reduce distance calculations */ @@ -88,13 +87,12 @@ InitCenters(Relation index, VectorArray samples, VectorArray centers, float *low static inline void ApplyNorm(FmgrInfo *normprocinfo, Oid collation, Vector * vec) { - int i; double norm = DatumGetFloat8(FunctionCall1Coll(normprocinfo, collation, PointerGetDatum(vec))); /* TODO Handle zero norm */ if (norm > 0) { - for (i = 0; i < vec->dim; i++) + for (int i = 0; i < vec->dim; i++) vec->x[i] /= norm; } } @@ -114,9 +112,6 @@ CompareVectors(const void *a, const void *b) static void QuickCenters(Relation index, VectorArray samples, VectorArray centers) { - int i; - int j; - Vector *vec; int dimensions = centers->dim; Oid collation = index->rd_indcollation[0]; FmgrInfo *normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_KMEANS_NORM_PROC); @@ -125,9 +120,9 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers) if (samples->length > 0) { qsort(samples->items, samples->length, VECTOR_SIZE(samples->dim), CompareVectors); - for (i = 0; i < samples->length; i++) + for (int i = 0; i < samples->length; i++) { - vec = VectorArrayGet(samples, i); + Vector *vec = VectorArrayGet(samples, i); if (i == 0 || CompareVectors(vec, VectorArrayGet(samples, i - 1)) != 0) { @@ -140,12 +135,12 @@ QuickCenters(Relation index, VectorArray samples, VectorArray centers) /* Fill remaining with random data */ while (centers->length < centers->maxlen) { - vec = VectorArrayGet(centers, centers->length); + Vector *vec = VectorArrayGet(centers, centers->length); SET_VARSIZE(vec, VECTOR_SIZE(dimensions)); vec->dim = dimensions; - for (j = 0; j < dimensions; j++) + for (int j = 0; j < dimensions; j++) vec->x[j] = RandomDouble(); /* Normalize if needed (only needed for random centers) */ @@ -172,7 +167,6 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) Oid collation; Vector *vec; Vector *newCenter; - int iteration; int64 j; int64 k; int dimensions = centers->dim; @@ -186,14 +180,6 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) float *s; float *halfcdist; float *newcdist; - int changes; - double minDistance; - int closestCenter; - double distance; - bool rj; - bool rjreset; - double dxcx; - double dxc; /* Calculate allocation sizes */ Size samplesSize = VECTOR_ARRAY_SIZE(samples->maxlen, samples->dim); @@ -251,14 +237,14 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) /* Assign each x to its closest initial center c(x) = argmin d(x,c) */ for (j = 0; j < numSamples; j++) { - minDistance = DBL_MAX; - closestCenter = 0; + float minDistance = FLT_MAX; + int closestCenter = 0; /* Find closest center */ for (k = 0; k < numCenters; k++) { /* TODO Use Lemma 1 in k-means++ initialization */ - distance = lowerBound[j * numCenters + k]; + float distance = lowerBound[j * numCenters + k]; if (distance < minDistance) { @@ -272,13 +258,14 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) } /* Give 500 iterations to converge */ - for (iteration = 0; iteration < 500; iteration++) + for (int iteration = 0; iteration < 500; iteration++) { + int changes = 0; + bool rjreset; + /* Can take a while, so ensure we can interrupt */ CHECK_FOR_INTERRUPTS(); - changes = 0; - /* Step 1: For all centers, compute distance */ for (j = 0; j < numCenters; j++) { @@ -286,7 +273,8 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) for (k = j + 1; k < numCenters; k++) { - distance = 0.5 * DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, k)))); + float distance = 0.5 * DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, k)))); + halfcdist[j * numCenters + k] = distance; halfcdist[k * numCenters + j] = distance; } @@ -295,10 +283,12 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) /* For all centers c, compute s(c) */ for (j = 0; j < numCenters; j++) { - minDistance = DBL_MAX; + float minDistance = FLT_MAX; for (k = 0; k < numCenters; k++) { + float distance; + if (j == k) continue; @@ -314,6 +304,8 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) for (j = 0; j < numSamples; j++) { + bool rj; + /* Step 2: Identify all points x such that u(x) <= s(c(x)) */ if (upperBound[j] <= s[closestCenters[j]]) continue; @@ -322,6 +314,8 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) for (k = 0; k < numCenters; k++) { + float dxcx; + /* Step 3: For all remaining points x and centers c */ if (k == closestCenters[j]) continue; @@ -351,7 +345,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) /* Step 3b */ if (dxcx > lowerBound[j * numCenters + k] || dxcx > halfcdist[closestCenters[j] * numCenters + k]) { - dxc = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, k)))); + float dxc = DatumGetFloat8(FunctionCall2Coll(procinfo, collation, PointerGetDatum(vec), PointerGetDatum(VectorArrayGet(centers, k)))); /* d(x,c) calculated */ lowerBound[j * numCenters + k] = dxc; @@ -365,7 +359,6 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) changes++; } - } } } @@ -382,6 +375,8 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) for (j = 0; j < numSamples; j++) { + int closestCenter; + vec = VectorArrayGet(samples, j); closestCenter = closestCenters[j]; @@ -430,7 +425,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) { for (k = 0; k < numCenters; k++) { - distance = lowerBound[j * numCenters + k] - newcdist[k]; + float distance = lowerBound[j * numCenters + k] - newcdist[k]; if (distance < 0) distance = 0; @@ -446,7 +441,7 @@ ElkanKmeans(Relation index, VectorArray samples, VectorArray centers) /* Step 7 */ for (j = 0; j < numCenters; j++) - memcpy(VectorArrayGet(centers, j), VectorArrayGet(newCenters, j), VECTOR_SIZE(dimensions)); + VectorArraySet(centers, j, VectorArrayGet(newCenters, j)); if (changes == 0 && iteration != 0) break; @@ -469,21 +464,16 @@ static void CheckCenters(Relation index, VectorArray centers) { FmgrInfo *normprocinfo; - Oid collation; - Vector *vec; - int i; - int j; - double norm; if (centers->length != centers->maxlen) elog(ERROR, "Not enough centers. Please report a bug."); /* Ensure no NaN or infinite values */ - for (i = 0; i < centers->length; i++) + for (int i = 0; i < centers->length; i++) { - vec = VectorArrayGet(centers, i); + Vector *vec = VectorArrayGet(centers, i); - for (j = 0; j < vec->dim; j++) + for (int j = 0; j < vec->dim; j++) { if (isnan(vec->x[j])) elog(ERROR, "NaN detected. Please report a bug."); @@ -496,7 +486,7 @@ CheckCenters(Relation index, VectorArray centers) /* Ensure no duplicate centers */ /* Fine to sort in-place */ qsort(centers->items, centers->length, VECTOR_SIZE(centers->dim), CompareVectors); - for (i = 1; i < centers->length; i++) + for (int i = 1; i < centers->length; i++) { if (CompareVectors(VectorArrayGet(centers, i), VectorArrayGet(centers, i - 1)) == 0) elog(ERROR, "Duplicate centers detected. Please report a bug."); @@ -507,11 +497,12 @@ CheckCenters(Relation index, VectorArray centers) normprocinfo = IvfflatOptionalProcInfo(index, IVFFLAT_NORM_PROC); if (normprocinfo != NULL) { - collation = index->rd_indcollation[0]; + Oid collation = index->rd_indcollation[0]; - for (i = 0; i < centers->length; i++) + for (int i = 0; i < centers->length; i++) { - norm = DatumGetFloat8(FunctionCall1Coll(normprocinfo, collation, PointerGetDatum(VectorArrayGet(centers, i)))); + double norm = DatumGetFloat8(FunctionCall1Coll(normprocinfo, collation, PointerGetDatum(VectorArrayGet(centers, i)))); + if (norm == 0) elog(ERROR, "Zero norm detected. Please report a bug."); } diff --git a/src/ivfscan.c b/src/ivfscan.c index fa3961b..e6a96bb 100644 --- a/src/ivfscan.c +++ b/src/ivfscan.c @@ -3,14 +3,13 @@ #include #include "access/relscan.h" +#include "catalog/pg_operator_d.h" +#include "catalog/pg_type_d.h" #include "ivfflat.h" #include "miscadmin.h" #include "pgstat.h" #include "storage/bufmgr.h" -#include "catalog/pg_operator_d.h" -#include "catalog/pg_type_d.h" - /* * Compare list distances */ @@ -32,36 +31,36 @@ CompareLists(const pairingheap_node *a, const pairingheap_node *b, void *arg) static void GetScanLists(IndexScanDesc scan, Datum value) { - Buffer cbuf; - Page cpage; - IvfflatList list; - OffsetNumber offno; - OffsetNumber maxoffno; + IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque; BlockNumber nextblkno = IVFFLAT_HEAD_BLKNO; int listCount = 0; - IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque; - double distance; - IvfflatScanList *scanlist; double maxDistance = DBL_MAX; /* Search all list pages */ while (BlockNumberIsValid(nextblkno)) { + Buffer cbuf; + Page cpage; + OffsetNumber maxoffno; + cbuf = ReadBuffer(scan->indexRelation, nextblkno); LockBuffer(cbuf, BUFFER_LOCK_SHARE); cpage = BufferGetPage(cbuf); maxoffno = PageGetMaxOffsetNumber(cpage); - for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) + for (OffsetNumber offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) { - list = (IvfflatList) PageGetItem(cpage, PageGetItemId(cpage, offno)); + IvfflatList list = (IvfflatList) PageGetItem(cpage, PageGetItemId(cpage, offno)); + double distance; /* Use procinfo from the index instead of scan key for performance */ distance = DatumGetFloat8(FunctionCall2Coll(so->procinfo, so->collation, PointerGetDatum(&list->center), value)); if (listCount < so->probes) { + IvfflatScanList *scanlist; + scanlist = &so->lists[listCount]; scanlist->startPage = list->startPage; scanlist->distance = distance; @@ -76,6 +75,8 @@ GetScanLists(IndexScanDesc scan, Datum value) } else if (distance < maxDistance) { + IvfflatScanList *scanlist; + /* Remove */ scanlist = (IvfflatScanList *) pairingheap_remove_first(so->listQueue); @@ -102,14 +103,6 @@ static void GetScanItems(IndexScanDesc scan, Datum value) { IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque; - Buffer buf; - Page page; - IndexTuple itup; - BlockNumber searchPage; - OffsetNumber offno; - OffsetNumber maxoffno; - Datum datum; - bool isnull; TupleDesc tupdesc = RelationGetDescr(scan->indexRelation); double tuples = 0; @@ -129,19 +122,28 @@ GetScanItems(IndexScanDesc scan, Datum value) /* Search closest probes lists */ while (!pairingheap_is_empty(so->listQueue)) { - searchPage = ((IvfflatScanList *) pairingheap_remove_first(so->listQueue))->startPage; + BlockNumber searchPage = ((IvfflatScanList *) pairingheap_remove_first(so->listQueue))->startPage; /* Search all entry pages for list */ while (BlockNumberIsValid(searchPage)) { + Buffer buf; + Page page; + OffsetNumber maxoffno; + buf = ReadBufferExtended(scan->indexRelation, MAIN_FORKNUM, searchPage, RBM_NORMAL, bas); LockBuffer(buf, BUFFER_LOCK_SHARE); page = BufferGetPage(buf); maxoffno = PageGetMaxOffsetNumber(page); - for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) + for (OffsetNumber offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) { - itup = (IndexTuple) PageGetItem(page, PageGetItemId(page, offno)); + IndexTuple itup; + Datum datum; + bool isnull; + ItemId itemid = PageGetItemId(page, offno); + + itup = (IndexTuple) PageGetItem(page, itemid); datum = index_getattr(itup, 1, tupdesc, &isnull); /* @@ -155,8 +157,6 @@ GetScanItems(IndexScanDesc scan, Datum value) slot->tts_isnull[0] = false; slot->tts_values[1] = PointerGetDatum(&itup->t_tid); slot->tts_isnull[1] = false; - slot->tts_values[2] = Int32GetDatum((int) searchPage); - slot->tts_isnull[2] = false; ExecStoreVirtualTuple(slot); tuplesort_puttupleslot(so->sortstate, slot); @@ -172,7 +172,6 @@ GetScanItems(IndexScanDesc scan, Datum value) FreeAccessStrategy(bas); - /* TODO Scan more lists */ if (tuples < 100) ereport(DEBUG1, (errmsg("index scan found few tuples"), @@ -191,6 +190,7 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys) IndexScanDesc scan; IvfflatScanOpaque so; int lists; + int dimensions; AttrNumber attNums[] = {1}; Oid sortOperators[] = {Float8LessOperator}; Oid sortCollations[] = {InvalidOid}; @@ -198,15 +198,17 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys) int probes = ivfflat_probes; scan = RelationGetIndexScan(index, nkeys, norderbys); - lists = IvfflatGetLists(scan->indexRelation); + + /* Get lists and dimensions from metapage */ + IvfflatGetMetaPageInfo(index, &lists, &dimensions); if (probes > lists) probes = lists; so = (IvfflatScanOpaque) palloc(offsetof(IvfflatScanOpaqueData, lists) + probes * sizeof(IvfflatScanList)); - so->buf = InvalidBuffer; so->first = true; so->probes = probes; + so->dimensions = dimensions; /* Set support functions */ so->procinfo = index_getprocinfo(index, 1, IVFFLAT_DISTANCE_PROC); @@ -215,13 +217,12 @@ ivfflatbeginscan(Relation index, int nkeys, int norderbys) /* Create tuple description for sorting */ #if PG_VERSION_NUM >= 120000 - so->tupdesc = CreateTemplateTupleDesc(3); + so->tupdesc = CreateTemplateTupleDesc(2); #else - so->tupdesc = CreateTemplateTupleDesc(3, false); + so->tupdesc = CreateTemplateTupleDesc(2, false); #endif TupleDescInitEntry(so->tupdesc, (AttrNumber) 1, "distance", FLOAT8OID, -1, 0); - TupleDescInitEntry(so->tupdesc, (AttrNumber) 2, "tid", TIDOID, -1, 0); - TupleDescInitEntry(so->tupdesc, (AttrNumber) 3, "indexblkno", INT4OID, -1, 0); + TupleDescInitEntry(so->tupdesc, (AttrNumber) 2, "heaptid", TIDOID, -1, 0); /* Prep sort */ so->sortstate = tuplesort_begin_heap(so->tupdesc, 1, attNums, sortOperators, sortCollations, nullsFirstFlags, work_mem, NULL, false); @@ -287,21 +288,24 @@ ivfflatgettuple(IndexScanDesc scan, ScanDirection dir) if (scan->orderByData == NULL) elog(ERROR, "cannot scan ivfflat index without order"); - /* No items will match if null */ - if (scan->orderByData->sk_flags & SK_ISNULL) - return false; + /* Requires MVCC-compliant snapshot as not able to pin during sorting */ + /* https://www.postgresql.org/docs/current/index-locking.html */ + if (!IsMVCCSnapshot(scan->xs_snapshot)) + elog(ERROR, "non-MVCC snapshots are not supported with ivfflat"); - value = scan->orderByData->sk_argument; + if (scan->orderByData->sk_flags & SK_ISNULL) + value = PointerGetDatum(InitVector(so->dimensions)); + else + { + value = scan->orderByData->sk_argument; - /* Value should not be compressed or toasted */ - Assert(!VARATT_IS_COMPRESSED(DatumGetPointer(value))); - Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value))); + /* Value should not be compressed or toasted */ + Assert(!VARATT_IS_COMPRESSED(DatumGetPointer(value))); + Assert(!VARATT_IS_EXTENDED(DatumGetPointer(value))); - if (so->normprocinfo != NULL) - { - /* No items will match if normalization fails */ - if (!IvfflatNormValue(so->normprocinfo, so->collation, &value, NULL)) - return false; + /* Fine if normalization fails */ + if (so->normprocinfo != NULL) + IvfflatNormValue(so->normprocinfo, so->collation, &value, NULL); } IvfflatBench("GetScanLists", GetScanLists(scan, value)); @@ -315,26 +319,14 @@ ivfflatgettuple(IndexScanDesc scan, ScanDirection dir) if (tuplesort_gettupleslot(so->sortstate, true, false, so->slot, NULL)) { - ItemPointer tid = (ItemPointer) DatumGetPointer(slot_getattr(so->slot, 2, &so->isnull)); - BlockNumber indexblkno = DatumGetInt32(slot_getattr(so->slot, 3, &so->isnull)); + ItemPointer heaptid = (ItemPointer) DatumGetPointer(slot_getattr(so->slot, 2, &so->isnull)); #if PG_VERSION_NUM >= 120000 - scan->xs_heaptid = *tid; + scan->xs_heaptid = *heaptid; #else - scan->xs_ctup.t_self = *tid; + scan->xs_ctup.t_self = *heaptid; #endif - if (BufferIsValid(so->buf)) - ReleaseBuffer(so->buf); - - /* - * An index scan must maintain a pin on the index page holding the - * item last returned by amgettuple - * - * https://www.postgresql.org/docs/current/index-locking.html - */ - so->buf = ReadBuffer(scan->indexRelation, indexblkno); - scan->xs_recheckorderby = false; return true; } @@ -350,10 +342,6 @@ ivfflatendscan(IndexScanDesc scan) { IvfflatScanOpaque so = (IvfflatScanOpaque) scan->opaque; - /* Release pin */ - if (BufferIsValid(so->buf)) - ReleaseBuffer(so->buf); - pairingheap_free(so->listQueue); tuplesort_end(so->sortstate); diff --git a/src/ivfutils.c b/src/ivfutils.c index 7cf6fe6..7959a17 100644 --- a/src/ivfutils.c +++ b/src/ivfutils.c @@ -35,9 +35,7 @@ VectorArrayFree(VectorArray arr) void PrintVectorArray(char *msg, VectorArray arr) { - int i; - - for (i = 0; i < arr->length; i++) + for (int i = 0; i < arr->length; i++) PrintVector(msg, VectorArrayGet(arr, i)); } @@ -59,12 +57,12 @@ IvfflatGetLists(Relation index) * Get proc */ FmgrInfo * -IvfflatOptionalProcInfo(Relation rel, uint16 procnum) +IvfflatOptionalProcInfo(Relation index, uint16 procnum) { - if (!OidIsValid(index_getprocid(rel, 1, procnum))) + if (!OidIsValid(index_getprocid(index, 1, procnum))) return NULL; - return index_getprocinfo(rel, 1, procnum); + return index_getprocinfo(index, 1, procnum); } /* @@ -78,20 +76,16 @@ IvfflatOptionalProcInfo(Relation rel, uint16 procnum) bool IvfflatNormValue(FmgrInfo *procinfo, Oid collation, Datum *value, Vector * result) { - Vector *v; - int i; - double norm; - - norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value)); + double norm = DatumGetFloat8(FunctionCall1Coll(procinfo, collation, *value)); if (norm > 0) { - v = DatumGetVector(*value); + Vector *v = DatumGetVector(*value); if (result == NULL) result = InitVector(v->dim); - for (i = 0; i < v->dim; i++) + for (int i = 0; i < v->dim; i++) result->x[i] = v->x[i] / norm; *value = PointerGetDatum(result); @@ -142,7 +136,6 @@ IvfflatInitRegisterPage(Relation index, Buffer *buf, Page *page, GenericXLogStat void IvfflatCommitBuffer(Buffer buf, GenericXLogState *state) { - MarkBufferDirty(buf); GenericXLogFinish(state); UnlockReleaseBuffer(buf); } @@ -166,8 +159,6 @@ IvfflatAppendPage(Relation index, Buffer *buf, Page *page, GenericXLogState **st IvfflatInitPage(newbuf, newpage); /* Commit */ - MarkBufferDirty(*buf); - MarkBufferDirty(newbuf); GenericXLogFinish(*state); /* Unlock */ @@ -178,16 +169,40 @@ IvfflatAppendPage(Relation index, Buffer *buf, Page *page, GenericXLogState **st *buf = newbuf; } +/* + * Get the metapage info + */ +void +IvfflatGetMetaPageInfo(Relation index, int *lists, int *dimensions) +{ + Buffer buf; + Page page; + IvfflatMetaPage metap; + + buf = ReadBuffer(index, IVFFLAT_METAPAGE_BLKNO); + LockBuffer(buf, BUFFER_LOCK_SHARE); + page = BufferGetPage(buf); + metap = IvfflatPageGetMeta(page); + + *lists = metap->lists; + + if (dimensions != NULL) + *dimensions = metap->dimensions; + + UnlockReleaseBuffer(buf); +} + /* * Update the start or insert page of a list */ void -IvfflatUpdateList(Relation index, GenericXLogState *state, ListInfo listInfo, +IvfflatUpdateList(Relation index, ListInfo listInfo, BlockNumber insertPage, BlockNumber originalInsertPage, BlockNumber startPage, ForkNumber forkNum) { Buffer buf; Page page; + GenericXLogState *state; IvfflatList list; bool changed = false; diff --git a/src/ivfvacuum.c b/src/ivfvacuum.c index f9725f7..b548af1 100644 --- a/src/ivfvacuum.c +++ b/src/ivfvacuum.c @@ -12,34 +12,23 @@ ivfflatbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats, IndexBulkDeleteCallback callback, void *callback_state) { Relation index = info->index; - Buffer cbuf; - Page cpage; - Buffer buf; - Page page; - IvfflatList list; - IndexTuple itup; - ItemPointer htup; - OffsetNumber deletable[MaxOffsetNumber]; - int ndeletable; - BlockNumber startPages[MaxOffsetNumber]; - BlockNumber nextblkno = IVFFLAT_HEAD_BLKNO; - BlockNumber searchPage; - BlockNumber insertPage; - GenericXLogState *state; - OffsetNumber coffno; - OffsetNumber cmaxoffno; - OffsetNumber offno; - OffsetNumber maxoffno; - ListInfo listInfo; + BlockNumber blkno = IVFFLAT_HEAD_BLKNO; BufferAccessStrategy bas = GetAccessStrategy(BAS_BULKREAD); if (stats == NULL) stats = (IndexBulkDeleteResult *) palloc0(sizeof(IndexBulkDeleteResult)); /* Iterate over list pages */ - while (BlockNumberIsValid(nextblkno)) + while (BlockNumberIsValid(blkno)) { - cbuf = ReadBuffer(index, nextblkno); + Buffer cbuf; + Page cpage; + OffsetNumber coffno; + OffsetNumber cmaxoffno; + BlockNumber startPages[MaxOffsetNumber]; + ListInfo listInfo; + + cbuf = ReadBuffer(index, blkno); LockBuffer(cbuf, BUFFER_LOCK_SHARE); cpage = BufferGetPage(cbuf); @@ -48,23 +37,32 @@ ivfflatbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats, /* Iterate over lists */ for (coffno = FirstOffsetNumber; coffno <= cmaxoffno; coffno = OffsetNumberNext(coffno)) { - list = (IvfflatList) PageGetItem(cpage, PageGetItemId(cpage, coffno)); + IvfflatList list = (IvfflatList) PageGetItem(cpage, PageGetItemId(cpage, coffno)); + startPages[coffno - FirstOffsetNumber] = list->startPage; } - listInfo.blkno = nextblkno; - nextblkno = IvfflatPageGetOpaque(cpage)->nextblkno; + listInfo.blkno = blkno; + blkno = IvfflatPageGetOpaque(cpage)->nextblkno; UnlockReleaseBuffer(cbuf); for (coffno = FirstOffsetNumber; coffno <= cmaxoffno; coffno = OffsetNumberNext(coffno)) { - searchPage = startPages[coffno - FirstOffsetNumber]; - insertPage = InvalidBlockNumber; + BlockNumber searchPage = startPages[coffno - FirstOffsetNumber]; + BlockNumber insertPage = InvalidBlockNumber; /* Iterate over entry pages */ while (BlockNumberIsValid(searchPage)) { + Buffer buf; + Page page; + GenericXLogState *state; + OffsetNumber offno; + OffsetNumber maxoffno; + OffsetNumber deletable[MaxOffsetNumber]; + int ndeletable; + vacuum_delay_point(); buf = ReadBufferExtended(index, MAIN_FORKNUM, searchPage, RBM_NORMAL, bas); @@ -86,8 +84,8 @@ ivfflatbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats, /* Find deleted tuples */ for (offno = FirstOffsetNumber; offno <= maxoffno; offno = OffsetNumberNext(offno)) { - itup = (IndexTuple) PageGetItem(page, PageGetItemId(page, offno)); - htup = &(itup->t_tid); + IndexTuple itup = (IndexTuple) PageGetItem(page, PageGetItemId(page, offno)); + ItemPointer htup = &(itup->t_tid); if (callback(htup, callback_state)) { @@ -109,7 +107,6 @@ ivfflatbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats, { /* Delete tuples */ PageIndexMultiDelete(page, deletable, ndeletable); - MarkBufferDirty(buf); GenericXLogFinish(state); } else @@ -127,7 +124,7 @@ ivfflatbulkdelete(IndexVacuumInfo *info, IndexBulkDeleteResult *stats, if (BlockNumberIsValid(insertPage)) { listInfo.offno = coffno; - IvfflatUpdateList(index, state, listInfo, insertPage, InvalidBlockNumber, InvalidBlockNumber, MAIN_FORKNUM); + IvfflatUpdateList(index, listInfo, insertPage, InvalidBlockNumber, InvalidBlockNumber, MAIN_FORKNUM); } } } diff --git a/src/vector.c b/src/vector.c index 394a478..d3ebedb 100644 --- a/src/vector.c +++ b/src/vector.c @@ -2,15 +2,22 @@ #include -#include "vector.h" -#include "fmgr.h" #include "catalog/pg_type.h" +#include "fmgr.h" +#include "hnsw.h" +#include "ivfflat.h" #include "lib/stringinfo.h" #include "libpq/pqformat.h" +#include "port.h" /* for strtof() */ #include "utils/array.h" #include "utils/builtins.h" #include "utils/lsyscache.h" #include "utils/numeric.h" +#include "vector.h" + +#if PG_VERSION_NUM >= 160000 +#include "varatt.h" +#endif #if PG_VERSION_NUM >= 120000 #include "common/shortest_dec.h" @@ -29,6 +36,17 @@ PG_MODULE_MAGIC; +/* + * Initialize index options and variables + */ +PGDLLEXPORT void _PG_init(void); +void +_PG_init(void) +{ + HnswInit(); + IvfflatInit(); +} + /* * Ensure same dimensions */ @@ -87,6 +105,23 @@ CheckElement(float value) errmsg("infinite value not allowed in vector"))); } +/* + * Allocate and initialize a new vector + */ +Vector * +InitVector(int dim) +{ + Vector *result; + int size; + + size = VECTOR_SIZE(dim); + result = (Vector *) palloc0(size); + SET_VARSIZE(result, size); + result->dim = dim; + + return result; +} + /* * Check for whitespace, since array_isspace() is static */ @@ -125,6 +160,14 @@ float_overflow_error(void) (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE), errmsg("value out of range: overflow"))); } + +static pg_noinline void +float_underflow_error(void) +{ + ereport(ERROR, + (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE), + errmsg("value out of range: underflow"))); +} #endif /* @@ -136,7 +179,6 @@ vector_in(PG_FUNCTION_ARGS) { char *str = PG_GETARG_CSTRING(0); int32 typmod = PG_GETARG_INT32(2); - int i; float x[VECTOR_MAX_DIM]; int dim = 0; char *pt; @@ -231,7 +273,7 @@ vector_in(PG_FUNCTION_ARGS) CheckExpectedDim(typmod, dim); result = InitVector(dim); - for (i = 0; i < dim; i++) + for (int i = 0; i < dim; i++) result->x[i] = x[i]; PG_RETURN_POINTER(result); @@ -248,7 +290,6 @@ vector_out(PG_FUNCTION_ARGS) int dim = vector->dim; char *buf; char *ptr; - int i; int n; #if PG_VERSION_NUM < 120000 @@ -275,7 +316,7 @@ vector_out(PG_FUNCTION_ARGS) *ptr = '['; ptr++; - for (i = 0; i < dim; i++) + for (int i = 0; i < dim; i++) { if (i > 0) { @@ -353,7 +394,6 @@ vector_recv(PG_FUNCTION_ARGS) Vector *result; int16 dim; int16 unused; - int i; dim = pq_getmsgint(buf, sizeof(int16)); unused = pq_getmsgint(buf, sizeof(int16)); @@ -367,7 +407,7 @@ vector_recv(PG_FUNCTION_ARGS) errmsg("expected unused to be 0, not %d", unused))); result = InitVector(dim); - for (i = 0; i < dim; i++) + for (int i = 0; i < dim; i++) { result->x[i] = pq_getmsgfloat4(buf); CheckElement(result->x[i]); @@ -385,12 +425,11 @@ vector_send(PG_FUNCTION_ARGS) { Vector *vec = PG_GETARG_VECTOR_P(0); StringInfoData buf; - int i; pq_begintypsend(&buf); pq_sendint(&buf, vec->dim, sizeof(int16)); pq_sendint(&buf, vec->unused, sizeof(int16)); - for (i = 0; i < vec->dim; i++) + for (int i = 0; i < vec->dim; i++) pq_sendfloat4(&buf, vec->x[i]); PG_RETURN_BYTEA_P(pq_endtypsend(&buf)); @@ -420,7 +459,6 @@ array_to_vector(PG_FUNCTION_ARGS) { ArrayType *array = PG_GETARG_ARRAYTYPE_P(0); int32 typmod = PG_GETARG_INT32(1); - int i; Vector *result; int16 typlen; bool typbyval; @@ -434,6 +472,11 @@ array_to_vector(PG_FUNCTION_ARGS) (errcode(ERRCODE_DATA_EXCEPTION), errmsg("array must be 1-D"))); + if (ARR_HASNULL(array) && array_contains_nulls(array)) + ereport(ERROR, + (errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED), + errmsg("array must not contain nulls"))); + get_typlenbyvalalign(ARR_ELEMTYPE(array), &typlen, &typbyval, &typalign); deconstruct_array(array, ARR_ELEMTYPE(array), typlen, typbyval, typalign, &elemsp, &nullsp, &nelemsp); @@ -441,29 +484,37 @@ array_to_vector(PG_FUNCTION_ARGS) CheckExpectedDim(typmod, nelemsp); result = InitVector(nelemsp); - for (i = 0; i < nelemsp; i++) - { - if (nullsp[i]) - ereport(ERROR, - (errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED), - errmsg("array must not containing NULLs"))); - /* TODO Move outside loop in 0.5.0 */ - if (ARR_ELEMTYPE(array) == INT4OID) + if (ARR_ELEMTYPE(array) == INT4OID) + { + for (int i = 0; i < nelemsp; i++) result->x[i] = DatumGetInt32(elemsp[i]); - else if (ARR_ELEMTYPE(array) == FLOAT8OID) + } + else if (ARR_ELEMTYPE(array) == FLOAT8OID) + { + for (int i = 0; i < nelemsp; i++) result->x[i] = DatumGetFloat8(elemsp[i]); - else if (ARR_ELEMTYPE(array) == FLOAT4OID) + } + else if (ARR_ELEMTYPE(array) == FLOAT4OID) + { + for (int i = 0; i < nelemsp; i++) result->x[i] = DatumGetFloat4(elemsp[i]); - else if (ARR_ELEMTYPE(array) == NUMERICOID) + } + else if (ARR_ELEMTYPE(array) == NUMERICOID) + { + for (int i = 0; i < nelemsp; i++) result->x[i] = DatumGetFloat4(DirectFunctionCall1(numeric_float4, elemsp[i])); - else - ereport(ERROR, - (errcode(ERRCODE_DATA_EXCEPTION), - errmsg("unsupported array type"))); + } + else + { + ereport(ERROR, + (errcode(ERRCODE_DATA_EXCEPTION), + errmsg("unsupported array type"))); + } + /* Check elements */ + for (int i = 0; i < result->dim; i++) CheckElement(result->x[i]); - } PG_RETURN_POINTER(result); } @@ -478,11 +529,10 @@ vector_to_float4(PG_FUNCTION_ARGS) Vector *vec = PG_GETARG_VECTOR_P(0); Datum *datums; ArrayType *result; - int i; datums = (Datum *) palloc(sizeof(Datum) * vec->dim); - for (i = 0; i < vec->dim; i++) + for (int i = 0; i < vec->dim; i++) datums[i] = Float4GetDatum(vec->x[i]); /* Use TYPALIGN_INT for float4 */ @@ -504,8 +554,8 @@ l2_distance(PG_FUNCTION_ARGS) Vector *b = PG_GETARG_VECTOR_P(1); float *ax = a->x; float *bx = b->x; - double distance = 0.0; - double diff; + float distance = 0.0; + float diff; CheckDims(a, b); @@ -516,7 +566,7 @@ l2_distance(PG_FUNCTION_ARGS) distance += diff * diff; } - PG_RETURN_FLOAT8(sqrt(distance)); + PG_RETURN_FLOAT8(sqrt((double) distance)); } /* @@ -531,8 +581,8 @@ vector_l2_squared_distance(PG_FUNCTION_ARGS) Vector *b = PG_GETARG_VECTOR_P(1); float *ax = a->x; float *bx = b->x; - double distance = 0.0; - double diff; + float distance = 0.0; + float diff; CheckDims(a, b); @@ -543,7 +593,7 @@ vector_l2_squared_distance(PG_FUNCTION_ARGS) distance += diff * diff; } - PG_RETURN_FLOAT8(distance); + PG_RETURN_FLOAT8((double) distance); } /* @@ -557,7 +607,7 @@ inner_product(PG_FUNCTION_ARGS) Vector *b = PG_GETARG_VECTOR_P(1); float *ax = a->x; float *bx = b->x; - double distance = 0.0; + float distance = 0.0; CheckDims(a, b); @@ -565,7 +615,7 @@ inner_product(PG_FUNCTION_ARGS) for (int i = 0; i < a->dim; i++) distance += ax[i] * bx[i]; - PG_RETURN_FLOAT8(distance); + PG_RETURN_FLOAT8((double) distance); } /* @@ -579,7 +629,7 @@ vector_negative_inner_product(PG_FUNCTION_ARGS) Vector *b = PG_GETARG_VECTOR_P(1); float *ax = a->x; float *bx = b->x; - double distance = 0.0; + float distance = 0.0; CheckDims(a, b); @@ -587,7 +637,7 @@ vector_negative_inner_product(PG_FUNCTION_ARGS) for (int i = 0; i < a->dim; i++) distance += ax[i] * bx[i]; - PG_RETURN_FLOAT8(distance * -1); + PG_RETURN_FLOAT8((double) distance * -1); } /* @@ -601,9 +651,10 @@ cosine_distance(PG_FUNCTION_ARGS) Vector *b = PG_GETARG_VECTOR_P(1); float *ax = a->x; float *bx = b->x; - double distance = 0.0; - double norma = 0.0; - double normb = 0.0; + float distance = 0.0; + float norma = 0.0; + float normb = 0.0; + double similarity; CheckDims(a, b); @@ -616,7 +667,21 @@ cosine_distance(PG_FUNCTION_ARGS) } /* Use sqrt(a * b) over sqrt(a) * sqrt(b) */ - PG_RETURN_FLOAT8(1 - (distance / sqrt(norma * normb))); + similarity = (double) distance / sqrt((double) norma * (double) normb); + +#ifdef _MSC_VER + /* /fp:fast may not propagate NaN */ + if (isnan(similarity)) + PG_RETURN_FLOAT8(NAN); +#endif + + /* Keep in range */ + if (similarity > 1) + similarity = 1.0; + else if (similarity < -1) + similarity = -1.0; + + PG_RETURN_FLOAT8(1.0 - similarity); } /* @@ -630,13 +695,18 @@ vector_spherical_distance(PG_FUNCTION_ARGS) { Vector *a = PG_GETARG_VECTOR_P(0); Vector *b = PG_GETARG_VECTOR_P(1); - double distance = 0.0; + float *ax = a->x; + float *bx = b->x; + float dp = 0.0; + double distance; CheckDims(a, b); /* Auto-vectorized */ for (int i = 0; i < a->dim; i++) - distance += a->x[i] * b->x[i]; + dp += ax[i] * bx[i]; + + distance = (double) dp; /* Prevent NaN with acos with loss of precision */ if (distance > 1) @@ -647,6 +717,28 @@ vector_spherical_distance(PG_FUNCTION_ARGS) PG_RETURN_FLOAT8(acos(distance) / M_PI); } +/* + * Get the L1 distance between vectors + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(l1_distance); +Datum +l1_distance(PG_FUNCTION_ARGS) +{ + Vector *a = PG_GETARG_VECTOR_P(0); + Vector *b = PG_GETARG_VECTOR_P(1); + float *ax = a->x; + float *bx = b->x; + float distance = 0.0; + + CheckDims(a, b); + + /* Auto-vectorized */ + for (int i = 0; i < a->dim; i++) + distance += fabsf(ax[i] - bx[i]); + + PG_RETURN_FLOAT8((double) distance); +} + /* * Get the dimensions of a vector */ @@ -672,7 +764,7 @@ vector_norm(PG_FUNCTION_ARGS) /* Auto-vectorized */ for (int i = 0; i < a->dim; i++) - norm += ax[i] * ax[i]; + norm += (double) ax[i] * (double) ax[i]; PG_RETURN_FLOAT8(sqrt(norm)); } @@ -743,17 +835,51 @@ vector_sub(PG_FUNCTION_ARGS) PG_RETURN_POINTER(result); } +/* + * Multiply vectors + */ +PGDLLEXPORT PG_FUNCTION_INFO_V1(vector_mul); +Datum +vector_mul(PG_FUNCTION_ARGS) +{ + Vector *a = PG_GETARG_VECTOR_P(0); + Vector *b = PG_GETARG_VECTOR_P(1); + float *ax = a->x; + float *bx = b->x; + Vector *result; + float *rx; + + CheckDims(a, b); + + result = InitVector(a->dim); + rx = result->x; + + /* Auto-vectorized */ + for (int i = 0, imax = a->dim; i < imax; i++) + rx[i] = ax[i] * bx[i]; + + /* Check for overflow and underflow */ + for (int i = 0, imax = a->dim; i < imax; i++) + { + if (isinf(rx[i])) + float_overflow_error(); + + if (rx[i] == 0 && !(ax[i] == 0 || bx[i] == 0)) + float_underflow_error(); + } + + PG_RETURN_POINTER(result); +} + /* * Internal helper to compare vectors */ int vector_cmp_internal(Vector * a, Vector * b) { - int i; - CheckDims(a, b); - for (i = 0; i < a->dim; i++) + for (int i = 0; i < a->dim; i++) { if (a->x[i] < b->x[i]) return -1; diff --git a/src/vector.h b/src/vector.h index 93aeb6a..e649471 100644 --- a/src/vector.h +++ b/src/vector.h @@ -1,12 +1,6 @@ #ifndef VECTOR_H #define VECTOR_H -#include "postgres.h" - -#if PG_VERSION_NUM >= 160000 -#include "varatt.h" -#endif - #define VECTOR_MAX_DIM 16000 #define VECTOR_SIZE(_dim) (offsetof(Vector, x) + sizeof(float)*(_dim)) @@ -22,24 +16,8 @@ typedef struct Vector float x[FLEXIBLE_ARRAY_MEMBER]; } Vector; +Vector *InitVector(int dim); void PrintVector(char *msg, Vector * vector); int vector_cmp_internal(Vector * a, Vector * b); -/* - * Allocate and initialize a new vector - */ -static inline Vector * -InitVector(int dim) -{ - Vector *result; - int size; - - size = VECTOR_SIZE(dim); - result = (Vector *) palloc0(size); - SET_VARSIZE(result, size); - result->dim = dim; - - return result; -} - #endif diff --git a/test/expected/cast.out b/test/expected/cast.out index 37614d9..4824261 100644 --- a/test/expected/cast.out +++ b/test/expected/cast.out @@ -29,7 +29,7 @@ SELECT ARRAY[1,2,3]::numeric[]::vector; (1 row) SELECT '{NULL}'::real[]::vector; -ERROR: array must not containing NULLs +ERROR: array must not contain nulls SELECT '{NaN}'::real[]::vector; ERROR: NaN not allowed in vector SELECT '{Infinity}'::real[]::vector; @@ -38,6 +38,8 @@ SELECT '{-Infinity}'::real[]::vector; ERROR: infinite value not allowed in vector SELECT '{}'::real[]::vector; ERROR: vector must have at least 1 dimension +SELECT '{{1}}'::real[]::vector; +ERROR: array must be 1-D SELECT '[1,2,3]'::vector::real[]; float4 --------- diff --git a/test/expected/functions.out b/test/expected/functions.out index 0272282..2840688 100644 --- a/test/expected/functions.out +++ b/test/expected/functions.out @@ -14,6 +14,16 @@ SELECT '[1,2,3]'::vector - '[4,5,6]'; SELECT '[-3e38]'::vector - '[3e38]'; ERROR: value out of range: overflow +SELECT '[1,2,3]'::vector * '[4,5,6]'; + ?column? +----------- + [4,10,18] +(1 row) + +SELECT '[1e37]'::vector * '[1e37]'; +ERROR: value out of range: overflow +SELECT '[1e-37]'::vector * '[1e-37]'; +ERROR: value out of range: underflow SELECT vector_dims('[1,2,3]'); vector_dims ------------- @@ -38,6 +48,12 @@ SELECT vector_norm('[0,1]'); 1 (1 row) +SELECT vector_norm('[3e37,4e37]')::real; + vector_norm +------------- + 5e+37 +(1 row) + SELECT l2_distance('[0,0]', '[3,4]'); l2_distance ------------- @@ -52,6 +68,12 @@ SELECT l2_distance('[0,0]', '[0,1]'); SELECT l2_distance('[1,2]', '[3]'); ERROR: different vector dimensions 2 and 1 +SELECT l2_distance('[3e38]', '[-3e38]'); + l2_distance +------------- + Infinity +(1 row) + SELECT inner_product('[1,2]', '[3,4]'); inner_product --------------- @@ -60,6 +82,12 @@ SELECT inner_product('[1,2]', '[3,4]'); SELECT inner_product('[1,2]', '[3]'); ERROR: different vector dimensions 2 and 1 +SELECT inner_product('[3e38]', '[3e38]'); + inner_product +--------------- + Infinity +(1 row) + SELECT cosine_distance('[1,2]', '[2,4]'); cosine_distance ----------------- @@ -78,6 +106,12 @@ SELECT cosine_distance('[1,1]', '[1,1]'); 0 (1 row) +SELECT cosine_distance('[1,0]', '[0,2]'); + cosine_distance +----------------- + 1 +(1 row) + SELECT cosine_distance('[1,1]', '[-1,-1]'); cosine_distance ----------------- @@ -86,6 +120,44 @@ SELECT cosine_distance('[1,1]', '[-1,-1]'); SELECT cosine_distance('[1,2]', '[3]'); ERROR: different vector dimensions 2 and 1 +SELECT cosine_distance('[1,1]', '[1.1,1.1]'); + cosine_distance +----------------- + 0 +(1 row) + +SELECT cosine_distance('[1,1]', '[-1.1,-1.1]'); + cosine_distance +----------------- + 2 +(1 row) + +SELECT cosine_distance('[3e38]', '[3e38]'); + cosine_distance +----------------- + NaN +(1 row) + +SELECT l1_distance('[0,0]', '[3,4]'); + l1_distance +------------- + 7 +(1 row) + +SELECT l1_distance('[0,0]', '[0,1]'); + l1_distance +------------- + 1 +(1 row) + +SELECT l1_distance('[1,2]', '[3]'); +ERROR: different vector dimensions 2 and 1 +SELECT l1_distance('[3e38]', '[-3e38]'); + l1_distance +------------- + Infinity +(1 row) + SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; avg ----------- @@ -106,5 +178,33 @@ SELECT avg(v) FROM unnest(ARRAY[]::vector[]) v; SELECT avg(v) FROM unnest(ARRAY['[1,2]'::vector, '[3]']) v; ERROR: expected 2 dimensions, not 1 +SELECT avg(v) FROM unnest(ARRAY['[3e38]'::vector, '[3e38]']) v; + avg +--------- + [3e+38] +(1 row) + SELECT vector_avg(array_agg(n)) FROM generate_series(1, 16002) n; ERROR: vector cannot have more than 16000 dimensions +SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; + sum +---------- + [4,7,10] +(1 row) + +SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v; + sum +---------- + [4,7,10] +(1 row) + +SELECT sum(v) FROM unnest(ARRAY[]::vector[]) v; + sum +----- + +(1 row) + +SELECT sum(v) FROM unnest(ARRAY['[1,2]'::vector, '[3]']) v; +ERROR: different vector dimensions 2 and 1 +SELECT sum(v) FROM unnest(ARRAY['[3e38]'::vector, '[3e38]']) v; +ERROR: value out of range: overflow diff --git a/test/expected/hnsw_cosine.out b/test/expected/hnsw_cosine.out new file mode 100644 index 0000000..e23894c --- /dev/null +++ b/test/expected/hnsw_cosine.out @@ -0,0 +1,27 @@ +SET enable_seqscan = off; +CREATE TABLE t (val vector(3)); +NOTICE: Table doesn't have 'DISTRIBUTED BY' clause, and no column type is suitable for a distribution key. Creating a NULL policy entry. +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val vector_cosine_ops); +INSERT INTO t (val) VALUES ('[1,2,4]'); +SELECT * FROM t ORDER BY val <=> '[3,3,3]'; + val +--------- + [1,1,1] + [1,2,3] + [1,2,4] +(3 rows) + +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> '[0,0,0]') t2; + count +------- + 3 +(1 row) + +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector)) t2; + count +------- + 3 +(1 row) + +DROP TABLE t; diff --git a/test/expected/hnsw_ip.out b/test/expected/hnsw_ip.out new file mode 100644 index 0000000..2255c98 --- /dev/null +++ b/test/expected/hnsw_ip.out @@ -0,0 +1,22 @@ +SET enable_seqscan = off; +CREATE TABLE t (val vector(3)); +NOTICE: Table doesn't have 'DISTRIBUTED BY' clause, and no column type is suitable for a distribution key. Creating a NULL policy entry. +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val vector_ip_ops); +INSERT INTO t (val) VALUES ('[1,2,4]'); +SELECT * FROM t ORDER BY val <#> '[3,3,3]'; + val +--------- + [1,2,4] + [1,2,3] + [1,1,1] + [0,0,0] +(4 rows) + +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <#> (SELECT NULL::vector)) t2; + count +------- + 4 +(1 row) + +DROP TABLE t; diff --git a/test/expected/hnsw_l2.out b/test/expected/hnsw_l2.out new file mode 100644 index 0000000..9085469 --- /dev/null +++ b/test/expected/hnsw_l2.out @@ -0,0 +1,40 @@ +SET enable_seqscan = off; +CREATE TABLE t (val vector(3)); +NOTICE: Table doesn't have 'DISTRIBUTED BY' clause, and no column type is suitable for a distribution key. Creating a NULL policy entry. +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val vector_l2_ops); +INSERT INTO t (val) VALUES ('[1,2,4]'); +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + val +--------- + [1,2,3] + [1,2,4] + [1,1,1] + [0,0,0] +(4 rows) + +-- this sql will convert to ‘order by ctid’ clause, but the result is not stable on MPP architecture. +-- SELECT * FROM t ORDER BY val <-> (SELECT NULL::vector); +SELECT * FROM t ORDER BY val; + val +--------- + [0,0,0] + [1,1,1] + [1,2,3] + [1,2,4] + +(5 rows) + +SELECT COUNT(*) FROM t; + count +------- + 5 +(1 row) + +TRUNCATE t; +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + val +----- +(0 rows) + +DROP TABLE t; diff --git a/test/expected/hnsw_options.out b/test/expected/hnsw_options.out new file mode 100644 index 0000000..f6e2a1f --- /dev/null +++ b/test/expected/hnsw_options.out @@ -0,0 +1,27 @@ +CREATE TABLE t (val vector(3)); +NOTICE: Table doesn't have 'DISTRIBUTED BY' clause, and no column type is suitable for a distribution key. Creating a NULL policy entry. +CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (m = 1); +ERROR: value 1 out of bounds for option "m" +DETAIL: Valid values are between "2" and "100". +CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (m = 101); +ERROR: value 101 out of bounds for option "m" +DETAIL: Valid values are between "2" and "100". +CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (ef_construction = 3); +ERROR: value 3 out of bounds for option "ef_construction" +DETAIL: Valid values are between "4" and "1000". +CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (ef_construction = 1001); +ERROR: value 1001 out of bounds for option "ef_construction" +DETAIL: Valid values are between "4" and "1000". +CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (m = 16, ef_construction = 31); +ERROR: ef_construction must be greater than or equal to 2 * m (hnswbuild.c:425) +SHOW hnsw.ef_search; + hnsw.ef_search +---------------- + 40 +(1 row) + +SET hnsw.ef_search = 0; +ERROR: 0 is outside the valid range for parameter "hnsw.ef_search" (1 .. 1000) +SET hnsw.ef_search = 1001; +ERROR: 1001 is outside the valid range for parameter "hnsw.ef_search" (1 .. 1000) +DROP TABLE t; diff --git a/test/expected/hnsw_unlogged.out b/test/expected/hnsw_unlogged.out new file mode 100644 index 0000000..c6a5410 --- /dev/null +++ b/test/expected/hnsw_unlogged.out @@ -0,0 +1,14 @@ +SET enable_seqscan = off; +CREATE UNLOGGED TABLE t (val vector(3)); +NOTICE: Table doesn't have 'DISTRIBUTED BY' clause, and no column type is suitable for a distribution key. Creating a NULL policy entry. +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val vector_l2_ops); +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + val +--------- + [1,2,3] + [1,1,1] + [0,0,0] +(3 rows) + +DROP TABLE t; diff --git a/test/expected/input.out b/test/expected/input.out index 19ef74d..102ca51 100644 --- a/test/expected/input.out +++ b/test/expected/input.out @@ -81,6 +81,11 @@ ERROR: malformed vector literal: "1,2,3" LINE 1: SELECT '1,2,3'::vector; ^ DETAIL: Vector contents must start with "[". +SELECT ''::vector; +ERROR: malformed vector literal: "" +LINE 1: SELECT ''::vector; + ^ +DETAIL: Vector contents must start with "[". SELECT '['::vector; ERROR: malformed vector literal: "[" LINE 1: SELECT '['::vector; diff --git a/test/expected/ivfflat_cosine.out b/test/expected/ivfflat_cosine.out index 208ba84..0c855da 100644 --- a/test/expected/ivfflat_cosine.out +++ b/test/expected/ivfflat_cosine.out @@ -1,5 +1,4 @@ SET enable_seqscan = off; -SET optimizer = off; CREATE TABLE t (val vector(3)); NOTICE: Table doesn't have 'DISTRIBUTED BY' clause, and no column type is suitable for a distribution key. Creating a NULL policy entry. INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); @@ -15,9 +14,16 @@ SELECT * FROM t ORDER BY val <=> '[3,3,3]'; [1,2,4] (3 rows) -SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector); - val ------ -(0 rows) +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> '[0,0,0]') t2; + count +------- + 3 +(1 row) + +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector)) t2; + count +------- + 3 +(1 row) DROP TABLE t; diff --git a/test/expected/ivfflat_ip.out b/test/expected/ivfflat_ip.out index e64e7ab..1617c32 100644 --- a/test/expected/ivfflat_ip.out +++ b/test/expected/ivfflat_ip.out @@ -1,13 +1,9 @@ SET enable_seqscan = off; -SET optimizer = off; CREATE TABLE t (val vector(3)); NOTICE: Table doesn't have 'DISTRIBUTED BY' clause, and no column type is suitable for a distribution key. Creating a NULL policy entry. INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); -- start_ignore CREATE INDEX ON t USING ivfflat (val vector_ip_ops) WITH (lists = 1); -NOTICE: ivfflat index created with little data (seg1 127.0.1.1:7003 pid=424029) -DETAIL: This will cause low recall. -HINT: Drop the index until the table has more data. -- end_ignore INSERT INTO t (val) VALUES ('[1,2,4]'); SELECT * FROM t ORDER BY val <#> '[3,3,3]'; @@ -19,9 +15,10 @@ SELECT * FROM t ORDER BY val <#> '[3,3,3]'; [0,0,0] (4 rows) -SELECT * FROM t ORDER BY val <#> (SELECT NULL::vector); - val ------ -(0 rows) +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <#> (SELECT NULL::vector)) t2; + count +------- + 4 +(1 row) DROP TABLE t; diff --git a/test/expected/ivfflat_l2.out b/test/expected/ivfflat_l2.out index c203a68..a8abff4 100644 --- a/test/expected/ivfflat_l2.out +++ b/test/expected/ivfflat_l2.out @@ -1,10 +1,9 @@ SET enable_seqscan = off; -SET optimizer = off; CREATE TABLE t (val vector(3)); NOTICE: Table doesn't have 'DISTRIBUTED BY' clause, and no column type is suitable for a distribution key. Creating a NULL policy entry. INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); -- start_ignore -CREATE INDEX ON t USING ivfflat (val) WITH (lists = 1); +CREATE INDEX ON t USING ivfflat (val vector_l2_ops) WITH (lists = 1); -- end_ignore INSERT INTO t (val) VALUES ('[1,2,4]'); SELECT * FROM t ORDER BY val <-> '[3,3,3]'; @@ -16,10 +15,17 @@ SELECT * FROM t ORDER BY val <-> '[3,3,3]'; [0,0,0] (4 rows) -SELECT * FROM t ORDER BY val <-> (SELECT NULL::vector); - val ------ -(0 rows) +-- this sql will convert to ‘order by ctid’ clause, but the result is not stable on MPP architecture. +-- SELECT * FROM t ORDER BY val <-> (SELECT NULL::vector); +SELECT * FROM t ORDER BY val; + val +--------- + [0,0,0] + [1,1,1] + [1,2,3] + [1,2,4] + +(5 rows) SELECT COUNT(*) FROM t; count @@ -27,4 +33,15 @@ SELECT COUNT(*) FROM t; 5 (1 row) +-- start_ignore +TRUNCATE t; +NOTICE: ivfflat index created with little data +DETAIL: This will cause low recall. +HINT: Drop the index until the table has more data. +-- end_ignore +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + val +----- +(0 rows) + DROP TABLE t; diff --git a/test/expected/ivfflat_options.out b/test/expected/ivfflat_options.out index 8d1b216..99d64d2 100644 --- a/test/expected/ivfflat_options.out +++ b/test/expected/ivfflat_options.out @@ -1,10 +1,9 @@ -SET enable_seqscan = off; CREATE TABLE t (val vector(3)); NOTICE: Table doesn't have 'DISTRIBUTED BY' clause, and no column type is suitable for a distribution key. Creating a NULL policy entry. -CREATE INDEX ON t USING ivfflat (val) WITH (lists = 0); +CREATE INDEX ON t USING ivfflat (val vector_l2_ops) WITH (lists = 0); ERROR: value 0 out of bounds for option "lists" DETAIL: Valid values are between "1" and "32768". -CREATE INDEX ON t USING ivfflat (val) WITH (lists = 32769); +CREATE INDEX ON t USING ivfflat (val vector_l2_ops) WITH (lists = 32769); ERROR: value 32769 out of bounds for option "lists" DETAIL: Valid values are between "1" and "32768". SHOW ivfflat.probes; diff --git a/test/expected/ivfflat_unlogged.out b/test/expected/ivfflat_unlogged.out index 313609b..05bc7e8 100644 --- a/test/expected/ivfflat_unlogged.out +++ b/test/expected/ivfflat_unlogged.out @@ -1,10 +1,9 @@ SET enable_seqscan = off; -SET optimizer = off; CREATE UNLOGGED TABLE t (val vector(3)); NOTICE: Table doesn't have 'DISTRIBUTED BY' clause, and no column type is suitable for a distribution key. Creating a NULL policy entry. INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); -- start_ignore -CREATE INDEX ON t USING ivfflat (val) WITH (lists = 1); +CREATE INDEX ON t USING ivfflat (val vector_l2_ops) WITH (lists = 1); -- end_ignore SELECT * FROM t ORDER BY val <-> '[3,3,3]'; val diff --git a/test/sql/cast.sql b/test/sql/cast.sql index cb5c880..c73ab07 100644 --- a/test/sql/cast.sql +++ b/test/sql/cast.sql @@ -8,6 +8,7 @@ SELECT '{NaN}'::real[]::vector; SELECT '{Infinity}'::real[]::vector; SELECT '{-Infinity}'::real[]::vector; SELECT '{}'::real[]::vector; +SELECT '{{1}}'::real[]::vector; SELECT '[1,2,3]'::vector::real[]; SELECT array_agg(n)::vector FROM generate_series(1, 16001) n; SELECT array_to_vector(array_agg(n), 16001, false) FROM generate_series(1, 16001) n; diff --git a/test/sql/functions.sql b/test/sql/functions.sql index e4d3317..914df36 100644 --- a/test/sql/functions.sql +++ b/test/sql/functions.sql @@ -2,28 +2,50 @@ SELECT '[1,2,3]'::vector + '[4,5,6]'; SELECT '[3e38]'::vector + '[3e38]'; SELECT '[1,2,3]'::vector - '[4,5,6]'; SELECT '[-3e38]'::vector - '[3e38]'; +SELECT '[1,2,3]'::vector * '[4,5,6]'; +SELECT '[1e37]'::vector * '[1e37]'; +SELECT '[1e-37]'::vector * '[1e-37]'; SELECT vector_dims('[1,2,3]'); SELECT round(vector_norm('[1,1]')::numeric, 5); SELECT vector_norm('[3,4]'); SELECT vector_norm('[0,1]'); +SELECT vector_norm('[3e37,4e37]')::real; SELECT l2_distance('[0,0]', '[3,4]'); SELECT l2_distance('[0,0]', '[0,1]'); SELECT l2_distance('[1,2]', '[3]'); +SELECT l2_distance('[3e38]', '[-3e38]'); SELECT inner_product('[1,2]', '[3,4]'); SELECT inner_product('[1,2]', '[3]'); +SELECT inner_product('[3e38]', '[3e38]'); SELECT cosine_distance('[1,2]', '[2,4]'); SELECT cosine_distance('[1,2]', '[0,0]'); SELECT cosine_distance('[1,1]', '[1,1]'); +SELECT cosine_distance('[1,0]', '[0,2]'); SELECT cosine_distance('[1,1]', '[-1,-1]'); SELECT cosine_distance('[1,2]', '[3]'); +SELECT cosine_distance('[1,1]', '[1.1,1.1]'); +SELECT cosine_distance('[1,1]', '[-1.1,-1.1]'); +SELECT cosine_distance('[3e38]', '[3e38]'); + +SELECT l1_distance('[0,0]', '[3,4]'); +SELECT l1_distance('[0,0]', '[0,1]'); +SELECT l1_distance('[1,2]', '[3]'); +SELECT l1_distance('[3e38]', '[-3e38]'); SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; SELECT avg(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v; SELECT avg(v) FROM unnest(ARRAY[]::vector[]) v; SELECT avg(v) FROM unnest(ARRAY['[1,2]'::vector, '[3]']) v; +SELECT avg(v) FROM unnest(ARRAY['[3e38]'::vector, '[3e38]']) v; SELECT vector_avg(array_agg(n)) FROM generate_series(1, 16002) n; + +SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]']) v; +SELECT sum(v) FROM unnest(ARRAY['[1,2,3]'::vector, '[3,5,7]', NULL]) v; +SELECT sum(v) FROM unnest(ARRAY[]::vector[]) v; +SELECT sum(v) FROM unnest(ARRAY['[1,2]'::vector, '[3]']) v; +SELECT sum(v) FROM unnest(ARRAY['[3e38]'::vector, '[3e38]']) v; diff --git a/test/sql/hnsw_cosine.sql b/test/sql/hnsw_cosine.sql new file mode 100644 index 0000000..d23f4f3 --- /dev/null +++ b/test/sql/hnsw_cosine.sql @@ -0,0 +1,13 @@ +SET enable_seqscan = off; + +CREATE TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val vector_cosine_ops); + +INSERT INTO t (val) VALUES ('[1,2,4]'); + +SELECT * FROM t ORDER BY val <=> '[3,3,3]'; +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> '[0,0,0]') t2; +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector)) t2; + +DROP TABLE t; diff --git a/test/sql/hnsw_ip.sql b/test/sql/hnsw_ip.sql new file mode 100644 index 0000000..5a616a1 --- /dev/null +++ b/test/sql/hnsw_ip.sql @@ -0,0 +1,12 @@ +SET enable_seqscan = off; + +CREATE TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val vector_ip_ops); + +INSERT INTO t (val) VALUES ('[1,2,4]'); + +SELECT * FROM t ORDER BY val <#> '[3,3,3]'; +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <#> (SELECT NULL::vector)) t2; + +DROP TABLE t; diff --git a/test/sql/hnsw_l2.sql b/test/sql/hnsw_l2.sql new file mode 100644 index 0000000..97795cb --- /dev/null +++ b/test/sql/hnsw_l2.sql @@ -0,0 +1,18 @@ +SET enable_seqscan = off; + +CREATE TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val vector_l2_ops); + +INSERT INTO t (val) VALUES ('[1,2,4]'); + +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; +-- this sql will convert to ‘order by ctid’ clause, but the result is not stable on MPP architecture. +-- SELECT * FROM t ORDER BY val <-> (SELECT NULL::vector); +SELECT * FROM t ORDER BY val; +SELECT COUNT(*) FROM t; + +TRUNCATE t; +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + +DROP TABLE t; diff --git a/test/sql/hnsw_options.sql b/test/sql/hnsw_options.sql new file mode 100644 index 0000000..7b9662f --- /dev/null +++ b/test/sql/hnsw_options.sql @@ -0,0 +1,13 @@ +CREATE TABLE t (val vector(3)); +CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (m = 1); +CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (m = 101); +CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (ef_construction = 3); +CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (ef_construction = 1001); +CREATE INDEX ON t USING hnsw (val vector_l2_ops) WITH (m = 16, ef_construction = 31); + +SHOW hnsw.ef_search; + +SET hnsw.ef_search = 0; +SET hnsw.ef_search = 1001; + +DROP TABLE t; diff --git a/test/sql/hnsw_unlogged.sql b/test/sql/hnsw_unlogged.sql new file mode 100644 index 0000000..2efcc95 --- /dev/null +++ b/test/sql/hnsw_unlogged.sql @@ -0,0 +1,9 @@ +SET enable_seqscan = off; + +CREATE UNLOGGED TABLE t (val vector(3)); +INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); +CREATE INDEX ON t USING hnsw (val vector_l2_ops); + +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + +DROP TABLE t; diff --git a/test/sql/input.sql b/test/sql/input.sql index a4ad08d..9f5809c 100644 --- a/test/sql/input.sql +++ b/test/sql/input.sql @@ -14,6 +14,7 @@ SELECT '[4e38,1]'::vector; SELECT '[1,2,3'::vector; SELECT '[1,2,3]9'::vector; SELECT '1,2,3'::vector; +SELECT ''::vector; SELECT '['::vector; SELECT '[,'::vector; SELECT '[]'::vector; diff --git a/test/sql/ivfflat_cosine.sql b/test/sql/ivfflat_cosine.sql index 6705d0f..8422f7a 100644 --- a/test/sql/ivfflat_cosine.sql +++ b/test/sql/ivfflat_cosine.sql @@ -1,5 +1,4 @@ SET enable_seqscan = off; -SET optimizer = off; CREATE TABLE t (val vector(3)); INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); @@ -10,6 +9,7 @@ CREATE INDEX ON t USING ivfflat (val vector_cosine_ops) WITH (lists = 1); INSERT INTO t (val) VALUES ('[1,2,4]'); SELECT * FROM t ORDER BY val <=> '[3,3,3]'; -SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector); +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> '[0,0,0]') t2; +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <=> (SELECT NULL::vector)) t2; DROP TABLE t; diff --git a/test/sql/ivfflat_ip.sql b/test/sql/ivfflat_ip.sql index d6a8742..1f39b90 100644 --- a/test/sql/ivfflat_ip.sql +++ b/test/sql/ivfflat_ip.sql @@ -1,5 +1,4 @@ SET enable_seqscan = off; -SET optimizer = off; CREATE TABLE t (val vector(3)); INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); @@ -10,6 +9,6 @@ CREATE INDEX ON t USING ivfflat (val vector_ip_ops) WITH (lists = 1); INSERT INTO t (val) VALUES ('[1,2,4]'); SELECT * FROM t ORDER BY val <#> '[3,3,3]'; -SELECT * FROM t ORDER BY val <#> (SELECT NULL::vector); +SELECT COUNT(*) FROM (SELECT * FROM t ORDER BY val <#> (SELECT NULL::vector)) t2; DROP TABLE t; diff --git a/test/sql/ivfflat_l2.sql b/test/sql/ivfflat_l2.sql index e9742b3..08ac3a9 100644 --- a/test/sql/ivfflat_l2.sql +++ b/test/sql/ivfflat_l2.sql @@ -1,16 +1,22 @@ SET enable_seqscan = off; -SET optimizer = off; CREATE TABLE t (val vector(3)); INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); -- start_ignore -CREATE INDEX ON t USING ivfflat (val) WITH (lists = 1); +CREATE INDEX ON t USING ivfflat (val vector_l2_ops) WITH (lists = 1); -- end_ignore INSERT INTO t (val) VALUES ('[1,2,4]'); SELECT * FROM t ORDER BY val <-> '[3,3,3]'; -SELECT * FROM t ORDER BY val <-> (SELECT NULL::vector); +-- this sql will convert to ‘order by ctid’ clause, but the result is not stable on MPP architecture. +-- SELECT * FROM t ORDER BY val <-> (SELECT NULL::vector); +SELECT * FROM t ORDER BY val; SELECT COUNT(*) FROM t; +-- start_ignore +TRUNCATE t; +-- end_ignore +SELECT * FROM t ORDER BY val <-> '[3,3,3]'; + DROP TABLE t; diff --git a/test/sql/ivfflat_options.sql b/test/sql/ivfflat_options.sql index d8dc45c..aa909a5 100644 --- a/test/sql/ivfflat_options.sql +++ b/test/sql/ivfflat_options.sql @@ -1,8 +1,6 @@ -SET enable_seqscan = off; - CREATE TABLE t (val vector(3)); -CREATE INDEX ON t USING ivfflat (val) WITH (lists = 0); -CREATE INDEX ON t USING ivfflat (val) WITH (lists = 32769); +CREATE INDEX ON t USING ivfflat (val vector_l2_ops) WITH (lists = 0); +CREATE INDEX ON t USING ivfflat (val vector_l2_ops) WITH (lists = 32769); SHOW ivfflat.probes; diff --git a/test/sql/ivfflat_unlogged.sql b/test/sql/ivfflat_unlogged.sql index 33d84c6..d4d6c5f 100644 --- a/test/sql/ivfflat_unlogged.sql +++ b/test/sql/ivfflat_unlogged.sql @@ -1,11 +1,11 @@ SET enable_seqscan = off; -SET optimizer = off; CREATE UNLOGGED TABLE t (val vector(3)); INSERT INTO t (val) VALUES ('[0,0,0]'), ('[1,2,3]'), ('[1,1,1]'), (NULL); -- start_ignore -CREATE INDEX ON t USING ivfflat (val) WITH (lists = 1); +CREATE INDEX ON t USING ivfflat (val vector_l2_ops) WITH (lists = 1); -- end_ignore + SELECT * FROM t ORDER BY val <-> '[3,3,3]'; DROP TABLE t; diff --git a/test/t/001_wal.pl b/test/t/001_ivfflat_wal.pl similarity index 92% rename from test/t/001_wal.pl rename to test/t/001_ivfflat_wal.pl index d56f131..b19eb40 100644 --- a/test/t/001_wal.pl +++ b/test/t/001_ivfflat_wal.pl @@ -19,14 +19,13 @@ sub test_index_replay # Wait for replica to catch up my $applname = $node_replica->name; - - my $server_version_num = $node_primary->safe_psql("postgres", "SHOW server_version_num"); my $caughtup_query = "SELECT pg_current_wal_lsn() <= replay_lsn FROM pg_stat_replication WHERE application_name = '$applname';"; $node_primary->poll_query_until('postgres', $caughtup_query) or die "Timed out while waiting for replica 1 to catch up"; my @r = (); - for (1 .. $dim) { + for (1 .. $dim) + { push(@r, rand()); } my $sql = join(",", @r); @@ -52,11 +51,13 @@ sub test_index_replay # Initialize primary node $node_primary = get_new_node('primary'); $node_primary->init(allows_streaming => 1); -if ($dim > 32) { +if ($dim > 32) +{ # TODO use wal_keep_segments for Postgres < 13 $node_primary->append_conf('postgresql.conf', qq(wal_keep_size = 1GB)); } -if ($dim > 1500) { +if ($dim > 1500) +{ $node_primary->append_conf('postgresql.conf', qq(maintenance_work_mem = 128MB)); } $node_primary->start; @@ -67,8 +68,7 @@ sub test_index_replay # Create streaming replica linking to primary $node_replica = get_new_node('replica'); -$node_replica->init_from_backup($node_primary, $backup_name, - has_streaming => 1); +$node_replica->init_from_backup($node_primary, $backup_name, has_streaming => 1); $node_replica->start; # Create ivfflat index on primary @@ -77,7 +77,7 @@ sub test_index_replay $node_primary->safe_psql("postgres", "INSERT INTO tst SELECT i % 10, ARRAY[$array_sql] FROM generate_series(1, 100000) i;" ); -$node_primary->safe_psql("postgres", "CREATE INDEX ON tst USING ivfflat (v);"); +$node_primary->safe_psql("postgres", "CREATE INDEX ON tst USING ivfflat (v vector_l2_ops);"); # Test that queries give same result test_index_replay('initial'); diff --git a/test/t/002_vacuum.pl b/test/t/002_ivfflat_vacuum.pl similarity index 96% rename from test/t/002_vacuum.pl rename to test/t/002_ivfflat_vacuum.pl index 16ac0d6..d930444 100644 --- a/test/t/002_vacuum.pl +++ b/test/t/002_ivfflat_vacuum.pl @@ -7,7 +7,8 @@ my $dim = 3; my @r = (); -for (1 .. $dim) { +for (1 .. $dim) +{ my $v = int(rand(1000)) + 1; push(@r, "i % $v"); } @@ -24,7 +25,7 @@ $node->safe_psql("postgres", "INSERT INTO tst SELECT i % 10, ARRAY[$array_sql] FROM generate_series(1, 100000) i;" ); -$node->safe_psql("postgres", "CREATE INDEX ON tst USING ivfflat (v);"); +$node->safe_psql("postgres", "CREATE INDEX ON tst USING ivfflat (v vector_l2_ops);"); # Get size my $size = $node->safe_psql("postgres", "SELECT pg_total_relation_size('tst_v_idx');"); diff --git a/test/t/003_ivfflat_build_recall.pl b/test/t/003_ivfflat_build_recall.pl new file mode 100644 index 0000000..de96093 --- /dev/null +++ b/test/t/003_ivfflat_build_recall.pl @@ -0,0 +1,128 @@ +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More; + +my $node; +my @queries = (); +my @expected; +my $limit = 20; + +sub test_recall +{ + my ($probes, $min, $operator) = @_; + my $correct = 0; + my $total = 0; + + my $explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET ivfflat.probes = $probes; + EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v $operator '$queries[0]' LIMIT $limit; + )); + like($explain, qr/Index Scan using idx on tst/); + + for my $i (0 .. $#queries) + { + my $actual = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET ivfflat.probes = $probes; + SELECT i FROM tst ORDER BY v $operator '$queries[$i]' LIMIT $limit; + )); + my @actual_ids = split("\n", $actual); + my %actual_set = map { $_ => 1 } @actual_ids; + + my @expected_ids = split("\n", $expected[$i]); + + foreach (@expected_ids) + { + if (exists($actual_set{$_})) + { + $correct++; + } + $total++; + } + } + + cmp_ok($correct / $total, ">=", $min, $operator); +} + +# Initialize node +$node = get_new_node('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector(3));"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[random(), random(), random()] FROM generate_series(1, 100000) i;" +); + +# Generate queries +for (1 .. 20) +{ + my $r1 = rand(); + my $r2 = rand(); + my $r3 = rand(); + push(@queries, "[$r1,$r2,$r3]"); +} + +# Check each index type +my @operators = ("<->", "<#>", "<=>"); +my @opclasses = ("vector_l2_ops", "vector_ip_ops", "vector_cosine_ops"); + +for my $i (0 .. $#operators) +{ + my $operator = $operators[$i]; + my $opclass = $opclasses[$i]; + + # Get exact results + @expected = (); + foreach (@queries) + { + my $res = $node->safe_psql("postgres", "SELECT i FROM tst ORDER BY v $operator '$_' LIMIT $limit;"); + push(@expected, $res); + } + + # Build index serially + $node->safe_psql("postgres", qq( + SET max_parallel_maintenance_workers = 0; + CREATE INDEX idx ON tst USING ivfflat (v $opclass); + )); + + # Test approximate results + if ($operator ne "<#>") + { + # TODO Fix test (uniform random vectors all have similar inner product) + test_recall(1, 0.71, $operator); + test_recall(10, 0.95, $operator); + } + # Account for equal distances + test_recall(100, 0.9925, $operator); + + $node->safe_psql("postgres", "DROP INDEX idx;"); + + # Build index in parallel + my ($ret, $stdout, $stderr) = $node->psql("postgres", qq( + SET client_min_messages = DEBUG; + SET min_parallel_table_scan_size = 1; + CREATE INDEX idx ON tst USING ivfflat (v $opclass); + )); + is($ret, 0, $stderr); + like($stderr, qr/using \d+ parallel workers/); + + # Test approximate results + if ($operator ne "<#>") + { + # TODO Fix test (uniform random vectors all have similar inner product) + test_recall(1, 0.71, $operator); + test_recall(10, 0.95, $operator); + } + # Account for equal distances + test_recall(100, 0.9925, $operator); + + $node->safe_psql("postgres", "DROP INDEX idx;"); +} + +done_testing(); diff --git a/test/t/004_centers.pl b/test/t/004_ivfflat_centers.pl similarity index 91% rename from test/t/004_centers.pl rename to test/t/004_ivfflat_centers.pl index 47939ff..4c125dd 100644 --- a/test/t/004_centers.pl +++ b/test/t/004_ivfflat_centers.pl @@ -20,7 +20,7 @@ sub test_centers { my ($lists, $min) = @_; - my ($ret, $stdout, $stderr) = $node->psql("postgres", "CREATE INDEX ON tst USING ivfflat (v) WITH (lists = $lists);"); + my ($ret, $stdout, $stderr) = $node->psql("postgres", "CREATE INDEX ON tst USING ivfflat (v vector_l2_ops) WITH (lists = $lists);"); is($ret, 0, $stderr); } diff --git a/test/t/005_query_recall.pl b/test/t/005_ivfflat_query_recall.pl similarity index 75% rename from test/t/005_query_recall.pl rename to test/t/005_ivfflat_query_recall.pl index 50bbb56..1edebb3 100644 --- a/test/t/005_query_recall.pl +++ b/test/t/005_ivfflat_query_recall.pl @@ -18,24 +18,21 @@ # Check each index type my @operators = ("<->", "<#>", "<=>"); -foreach (@operators) { - my $operator = $_; +my @opclasses = ("vector_l2_ops", "vector_ip_ops", "vector_cosine_ops"); + +for my $i (0 .. $#operators) +{ + my $operator = $operators[$i]; + my $opclass = $opclasses[$i]; # Add index - my $opclass; - if ($operator eq "<->") { - $opclass = "vector_l2_ops"; - } elsif ($operator eq "<#>") { - $opclass = "vector_ip_ops"; - } else { - $opclass = "vector_cosine_ops"; - } $node->safe_psql("postgres", "CREATE INDEX ON tst USING ivfflat (v $opclass);"); # Test 100% recall - for (1..20) { - my $i = int(rand() * 100000); - my $query = $node->safe_psql("postgres", "SELECT v FROM tst WHERE i = $i;"); + for (1 .. 20) + { + my $id = int(rand() * 100000); + my $query = $node->safe_psql("postgres", "SELECT v FROM tst WHERE i = $id;"); my $res = $node->safe_psql("postgres", qq( SET enable_seqscan = off; SELECT v FROM tst ORDER BY v <-> '$query' LIMIT 1; diff --git a/test/t/006_lists.pl b/test/t/006_ivfflat_lists.pl similarity index 82% rename from test/t/006_lists.pl rename to test/t/006_ivfflat_lists.pl index 302c9b3..9812f50 100644 --- a/test/t/006_lists.pl +++ b/test/t/006_ivfflat_lists.pl @@ -16,8 +16,8 @@ "INSERT INTO tst SELECT ARRAY[random(), random(), random()] FROM generate_series(1, 100000) i;" ); -$node->safe_psql("postgres", "CREATE INDEX lists50 ON tst USING ivfflat (v) WITH (lists = 50);"); -$node->safe_psql("postgres", "CREATE INDEX lists100 ON tst USING ivfflat (v) WITH (lists = 100);"); +$node->safe_psql("postgres", "CREATE INDEX lists50 ON tst USING ivfflat (v vector_l2_ops) WITH (lists = 50);"); +$node->safe_psql("postgres", "CREATE INDEX lists100 ON tst USING ivfflat (v vector_l2_ops) WITH (lists = 100);"); # Test prefers more lists my $res = $node->safe_psql("postgres", "EXPLAIN SELECT v FROM tst ORDER BY v <-> '[0.5,0.5,0.5]' LIMIT 10;"); @@ -26,7 +26,7 @@ # Test errors with too much memory my ($ret, $stdout, $stderr) = $node->psql("postgres", - "CREATE INDEX lists10000 ON tst USING ivfflat (v) WITH (lists = 10000);" + "CREATE INDEX lists10000 ON tst USING ivfflat (v vector_l2_ops) WITH (lists = 10000);" ); like($stderr, qr/memory required is/); diff --git a/test/t/007_inserts.pl b/test/t/007_ivfflat_inserts.pl similarity index 91% rename from test/t/007_inserts.pl rename to test/t/007_ivfflat_inserts.pl index 0cf087f..dd7a95d 100644 --- a/test/t/007_inserts.pl +++ b/test/t/007_ivfflat_inserts.pl @@ -19,7 +19,7 @@ $node->safe_psql("postgres", "INSERT INTO tst SELECT ARRAY[$array_sql] FROM generate_series(1, 10000) i;" ); -$node->safe_psql("postgres", "CREATE INDEX ON tst USING ivfflat (v);"); +$node->safe_psql("postgres", "CREATE INDEX ON tst USING ivfflat (v vector_l2_ops);"); $node->pgbench( "--no-vacuum --client=5 --transactions=100", @@ -28,7 +28,7 @@ [qr{^$}], "concurrent INSERTs", { - "007_inserts" => "INSERT INTO tst SELECT ARRAY[$array_sql] FROM generate_series(1, 10) i;" + "007_ivfflat_inserts" => "INSERT INTO tst SELECT ARRAY[$array_sql] FROM generate_series(1, 10) i;" } ); diff --git a/test/t/008_aggregates.pl b/test/t/008_aggregates.pl new file mode 100644 index 0000000..0465890 --- /dev/null +++ b/test/t/008_aggregates.pl @@ -0,0 +1,49 @@ +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More; + +# Initialize node +my $node = get_new_node('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (r1 real, r2 real, r3 real, v vector(3));"); +$node->safe_psql("postgres", qq( + INSERT INTO tst SELECT r1, r2, r3, ARRAY[r1, r2, r3] FROM ( + SELECT random() + 1.01 AS r1, random() + 2.01 AS r2, random() + 3.01 AS r3 FROM generate_series(1, 1000000) t + ) i; +)); + +sub test_aggregate +{ + my ($agg) = @_; + + # Test value + my $res = $node->safe_psql("postgres", "SELECT $agg(v) FROM tst;"); + like($res, qr/\[1\.5/); + like($res, qr/,2\.5/); + like($res, qr/,3\.5/); + + # Test matches real for avg + # Cannot test sum since sum(real) varies between calls + if ($agg eq 'avg') + { + my $r1 = $node->safe_psql("postgres", "SELECT $agg(r1)::float4 FROM tst;"); + my $r2 = $node->safe_psql("postgres", "SELECT $agg(r2)::float4 FROM tst;"); + my $r3 = $node->safe_psql("postgres", "SELECT $agg(r3)::float4 FROM tst;"); + is($res, "[$r1,$r2,$r3]"); + } + + # Test explain + my $explain = $node->safe_psql("postgres", "EXPLAIN SELECT $agg(v) FROM tst;"); + like($explain, qr/Partial Aggregate/); +} + +test_aggregate('avg'); +test_aggregate('sum'); + +done_testing(); diff --git a/test/t/008_avg.pl b/test/t/008_avg.pl deleted file mode 100644 index b036678..0000000 --- a/test/t/008_avg.pl +++ /dev/null @@ -1,37 +0,0 @@ -use strict; -use warnings; -use PostgresNode; -use TestLib; -use Test::More; - -# Initialize node -my $node = get_new_node('node'); -$node->init; -$node->start; - -# Create table -$node->safe_psql("postgres", "CREATE EXTENSION vector;"); -$node->safe_psql("postgres", "CREATE TABLE tst (r1 real, r2 real, r3 real, v vector(3));"); -$node->safe_psql("postgres", qq( - INSERT INTO tst SELECT r1, r2, r3, ARRAY[r1, r2, r3] FROM ( - SELECT random() + 1.01 AS r1, random() + 2.01 AS r2, random() + 3.01 AS r3 FROM generate_series(1, 1000000) t - ) i; -)); - -# Test avg -my $avg = $node->safe_psql("postgres", "SELECT AVG(v) FROM tst;"); -like($avg, qr/\[1\.5/); -like($avg, qr/,2\.5/); -like($avg, qr/,3\.5/); - -# Test matches real -my $r1 = $node->safe_psql("postgres", "SELECT AVG(r1)::float4 FROM tst;"); -my $r2 = $node->safe_psql("postgres", "SELECT AVG(r2)::float4 FROM tst;"); -my $r3 = $node->safe_psql("postgres", "SELECT AVG(r3)::float4 FROM tst;"); -is($avg, "[$r1,$r2,$r3]"); - -# Test explain -my $explain = $node->safe_psql("postgres", "EXPLAIN SELECT AVG(v) FROM tst;"); -like($explain, qr/Partial Aggregate/); - -done_testing(); diff --git a/test/t/010_hnsw_wal.pl b/test/t/010_hnsw_wal.pl new file mode 100644 index 0000000..36c0dc5 --- /dev/null +++ b/test/t/010_hnsw_wal.pl @@ -0,0 +1,99 @@ +# Based on postgres/contrib/bloom/t/001_wal.pl + +# Test generic xlog record work for hnsw index replication. +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More; + +my $dim = 32; + +my $node_primary; +my $node_replica; + +# Run few queries on both primary and replica and check their results match. +sub test_index_replay +{ + my ($test_name) = @_; + + # Wait for replica to catch up + my $applname = $node_replica->name; + my $caughtup_query = "SELECT pg_current_wal_lsn() <= replay_lsn FROM pg_stat_replication WHERE application_name = '$applname';"; + $node_primary->poll_query_until('postgres', $caughtup_query) + or die "Timed out while waiting for replica 1 to catch up"; + + my @r = (); + for (1 .. $dim) + { + push(@r, rand()); + } + my $sql = join(",", @r); + + my $queries = qq( + SET enable_seqscan = off; + SELECT * FROM tst ORDER BY v <-> '[$sql]' LIMIT 10; + ); + + # Run test queries and compare their result + my $primary_result = $node_primary->safe_psql("postgres", $queries); + my $replica_result = $node_replica->safe_psql("postgres", $queries); + + is($primary_result, $replica_result, "$test_name: query result matches"); + return; +} + +# Use ARRAY[random(), random(), random(), ...] over +# SELECT array_agg(random()) FROM generate_series(1, $dim) +# to generate different values for each row +my $array_sql = join(",", ('random()') x $dim); + +# Initialize primary node +$node_primary = get_new_node('primary'); +$node_primary->init(allows_streaming => 1); +if ($dim > 32) +{ + # TODO use wal_keep_segments for Postgres < 13 + $node_primary->append_conf('postgresql.conf', qq(wal_keep_size = 1GB)); +} +if ($dim > 1500) +{ + $node_primary->append_conf('postgresql.conf', qq(maintenance_work_mem = 128MB)); +} +$node_primary->start; +my $backup_name = 'my_backup'; + +# Take backup +$node_primary->backup($backup_name); + +# Create streaming replica linking to primary +$node_replica = get_new_node('replica'); +$node_replica->init_from_backup($node_primary, $backup_name, has_streaming => 1); +$node_replica->start; + +# Create hnsw index on primary +$node_primary->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node_primary->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim));"); +$node_primary->safe_psql("postgres", + "INSERT INTO tst SELECT i % 10, ARRAY[$array_sql] FROM generate_series(1, 1000) i;" +); +$node_primary->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops);"); + +# Test that queries give same result +test_index_replay('initial'); + +# Run 10 cycles of table modification. Run test queries after each modification. +for my $i (1 .. 10) +{ + $node_primary->safe_psql("postgres", "DELETE FROM tst WHERE i = $i;"); + test_index_replay("delete $i"); + $node_primary->safe_psql("postgres", "VACUUM tst;"); + test_index_replay("vacuum $i"); + my ($start, $end) = (1001 + ($i - 1) * 100, 1000 + $i * 100); + $node_primary->safe_psql("postgres", + "INSERT INTO tst SELECT i % 10, ARRAY[$array_sql] FROM generate_series($start, $end) i;" + ); + test_index_replay("insert $i"); +} + +done_testing(); diff --git a/test/t/011_hnsw_vacuum.pl b/test/t/011_hnsw_vacuum.pl new file mode 100644 index 0000000..10c301f --- /dev/null +++ b/test/t/011_hnsw_vacuum.pl @@ -0,0 +1,54 @@ +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More; + +my $dim = 3; + +my @r = (); +for (1 .. $dim) +{ + my $v = int(rand(1000)) + 1; + push(@r, "i % $v"); +} +my $array_sql = join(", ", @r); + +# Initialize node +my $node = get_new_node('node'); +$node->init; +$node->start; + +# Create table and index +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector($dim));"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(1, 10000) i;" +); +$node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops);"); + +# Get size +my $size = $node->safe_psql("postgres", "SELECT pg_total_relation_size('tst_v_idx');"); + +# Delete all, vacuum, and insert same data +$node->safe_psql("postgres", "DELETE FROM tst;"); +$node->safe_psql("postgres", "VACUUM tst;"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[$array_sql] FROM generate_series(1, 10000) i;" +); + +# Check size +# May increase some due to different levels +my $new_size = $node->safe_psql("postgres", "SELECT pg_total_relation_size('tst_v_idx');"); +cmp_ok($new_size, "<=", $size * 1.02, "size does not increase too much"); + +# Delete all but one +$node->safe_psql("postgres", "DELETE FROM tst WHERE i != 123;"); +$node->safe_psql("postgres", "VACUUM tst;"); +my $res = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SELECT i FROM tst ORDER BY v <-> '[0,0,0]' LIMIT 10; +)); +is($res, 123); + +done_testing(); diff --git a/test/t/003_recall.pl b/test/t/012_hnsw_build_recall.pl similarity index 65% rename from test/t/003_recall.pl rename to test/t/012_hnsw_build_recall.pl index 8e7042a..e9074c6 100644 --- a/test/t/003_recall.pl +++ b/test/t/012_hnsw_build_recall.pl @@ -11,21 +11,20 @@ sub test_recall { - my ($probes, $min, $operator) = @_; + my ($min, $operator) = @_; my $correct = 0; my $total = 0; my $explain = $node->safe_psql("postgres", qq( SET enable_seqscan = off; - SET ivfflat.probes = $probes; EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v $operator '$queries[0]' LIMIT $limit; )); like($explain, qr/Index Scan/); - for my $i (0 .. $#queries) { + for my $i (0 .. $#queries) + { my $actual = $node->safe_psql("postgres", qq( SET enable_seqscan = off; - SET ivfflat.probes = $probes; SELECT i FROM tst ORDER BY v $operator '$queries[$i]' LIMIT $limit; )); my @actual_ids = split("\n", $actual); @@ -33,8 +32,10 @@ sub test_recall my @expected_ids = split("\n", $expected[$i]); - foreach (@expected_ids) { - if (exists($actual_set{$_})) { + foreach (@expected_ids) + { + if (exists($actual_set{$_})) + { $correct++; } $total++; @@ -53,11 +54,12 @@ sub test_recall $node->safe_psql("postgres", "CREATE EXTENSION vector;"); $node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector(3));"); $node->safe_psql("postgres", - "INSERT INTO tst SELECT i, ARRAY[random(), random(), random()] FROM generate_series(1, 100000) i;" + "INSERT INTO tst SELECT i, ARRAY[random(), random(), random()] FROM generate_series(1, 10000) i;" ); # Generate queries -for (1..20) { +for (1 .. 20) +{ my $r1 = rand(); my $r2 = rand(); my $r3 = rand(); @@ -66,36 +68,26 @@ sub test_recall # Check each index type my @operators = ("<->", "<#>", "<=>"); +my @opclasses = ("vector_l2_ops", "vector_ip_ops", "vector_cosine_ops"); -foreach (@operators) { - my $operator = $_; +for my $i (0 .. $#operators) +{ + my $operator = $operators[$i]; + my $opclass = $opclasses[$i]; # Get exact results @expected = (); - foreach (@queries) { + foreach (@queries) + { my $res = $node->safe_psql("postgres", "SELECT i FROM tst ORDER BY v $operator '$_' LIMIT $limit;"); push(@expected, $res); } # Add index - my $opclass; - if ($operator eq "<->") { - $opclass = "vector_l2_ops"; - } elsif ($operator eq "<#>") { - $opclass = "vector_ip_ops"; - } else { - $opclass = "vector_cosine_ops"; - } - $node->safe_psql("postgres", "CREATE INDEX ON tst USING ivfflat (v $opclass);"); + $node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v $opclass);"); - # Test approximate results - if ($operator ne "<#>") { - # TODO fix test - test_recall(1, 0.75, $operator); - test_recall(10, 0.95, $operator); - } - # Account for equal distances - test_recall(100, 0.9975, $operator); + my $min = $operator eq "<#>" ? 0.80 : 0.99; + test_recall($min, $operator); } done_testing(); diff --git a/test/t/013_hnsw_insert_recall.pl b/test/t/013_hnsw_insert_recall.pl new file mode 100644 index 0000000..d0c24f8 --- /dev/null +++ b/test/t/013_hnsw_insert_recall.pl @@ -0,0 +1,108 @@ +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More; + +my $node; +my @queries = (); +my @expected; +my $limit = 20; + +sub test_recall +{ + my ($min, $operator) = @_; + my $correct = 0; + my $total = 0; + + my $explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v $operator '$queries[0]' LIMIT $limit; + )); + like($explain, qr/Index Scan/); + + for my $i (0 .. $#queries) + { + my $actual = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SELECT i FROM tst ORDER BY v $operator '$queries[$i]' LIMIT $limit; + )); + my @actual_ids = split("\n", $actual); + my %actual_set = map { $_ => 1 } @actual_ids; + + my @expected_ids = split("\n", $expected[$i]); + + foreach (@expected_ids) + { + if (exists($actual_set{$_})) + { + $correct++; + } + $total++; + } + } + + cmp_ok($correct / $total, ">=", $min, $operator); +} + +# Initialize node +$node = get_new_node('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i serial, v vector(3));"); + +# Generate queries +for (1 .. 20) +{ + my $r1 = rand(); + my $r2 = rand(); + my $r3 = rand(); + push(@queries, "[$r1,$r2,$r3]"); +} + +# Check each index type +my @operators = ("<->", "<#>", "<=>"); +my @opclasses = ("vector_l2_ops", "vector_ip_ops", "vector_cosine_ops"); + +for my $i (0 .. $#operators) +{ + my $operator = $operators[$i]; + my $opclass = $opclasses[$i]; + + # Add index + $node->safe_psql("postgres", "CREATE INDEX idx ON tst USING hnsw (v $opclass);"); + + # Use concurrent inserts + $node->pgbench( + "--no-vacuum --client=10 --transactions=1000", + 0, + [qr{actually processed}], + [qr{^$}], + "concurrent INSERTs", + { + "013_hnsw_insert_recall_$opclass" => "INSERT INTO tst (v) VALUES (ARRAY[random(), random(), random()]);" + } + ); + + # Get exact results + @expected = (); + foreach (@queries) + { + my $res = $node->safe_psql("postgres", qq( + SET enable_indexscan = off; + SELECT i FROM tst ORDER BY v $operator '$_' LIMIT $limit; + )); + push(@expected, $res); + } + + my $min = $operator eq "<#>" ? 0.80 : 0.99; + test_recall($min, $operator); + + $node->safe_psql("postgres", "DROP INDEX idx;"); + $node->safe_psql("postgres", "TRUNCATE tst;"); +} + +done_testing(); diff --git a/test/t/014_hnsw_inserts.pl b/test/t/014_hnsw_inserts.pl new file mode 100644 index 0000000..f69bcd6 --- /dev/null +++ b/test/t/014_hnsw_inserts.pl @@ -0,0 +1,74 @@ +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More; + +# Ensures elements and neighbors on both same and different pages +my $dim = 1900; + +my $array_sql = join(",", ('random()') x $dim); + +# Initialize node +my $node = get_new_node('node'); +$node->init; +$node->start; + +# Create table and index +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (v vector($dim));"); +$node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops);"); + +sub idx_scan +{ + # Stats do not update instantaneously + # https://www.postgresql.org/docs/current/monitoring-stats.html#MONITORING-STATS-VIEWS + sleep(1); + $node->safe_psql("postgres", "SELECT idx_scan FROM pg_stat_user_indexes WHERE indexrelid = 'tst_v_idx'::regclass;"); +} + +for my $i (1 .. 20) +{ + $node->pgbench( + "--no-vacuum --client=10 --transactions=1", + 0, + [qr{actually processed}], + [qr{^$}], + "concurrent INSERTs", + { + "014_hnsw_inserts_$i" => "INSERT INTO tst VALUES (ARRAY[$array_sql]);" + } + ); + + my $count = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SELECT COUNT(*) FROM (SELECT v FROM tst ORDER BY v <-> (SELECT v FROM tst LIMIT 1)) t; + )); + is($count, 10); + + $node->safe_psql("postgres", "TRUNCATE tst;"); +} + +$node->pgbench( + "--no-vacuum --client=20 --transactions=5", + 0, + [qr{actually processed}], + [qr{^$}], + "concurrent INSERTs", + { + "014_hnsw_inserts" => "INSERT INTO tst SELECT ARRAY[$array_sql] FROM generate_series(1, 10) i;" + } +); + +my $count = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET hnsw.ef_search = 1000; + SELECT COUNT(*) FROM (SELECT v FROM tst ORDER BY v <-> (SELECT v FROM tst LIMIT 1)) t; +)); +# Elements may lose all incoming connections with the HNSW algorithm +# Vacuuming can fix this if one of the elements neighbors is deleted +cmp_ok($count, ">=", 997); + +is(idx_scan(), 21); + +done_testing(); diff --git a/test/t/015_hnsw_duplicates.pl b/test/t/015_hnsw_duplicates.pl new file mode 100644 index 0000000..7e11dee --- /dev/null +++ b/test/t/015_hnsw_duplicates.pl @@ -0,0 +1,58 @@ +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More; + +# Initialize node +my $node = get_new_node('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (v vector(3));"); + +sub insert_vectors +{ + for my $i (1 .. 20) + { + $node->safe_psql("postgres", "INSERT INTO tst VALUES ('[1,1,1]');"); + } +} + +sub test_duplicates +{ + my $res = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET hnsw.ef_search = 1; + SELECT COUNT(*) FROM (SELECT * FROM tst ORDER BY v <-> '[1,1,1]') t; + )); + is($res, 10); +} + +# Test duplicates with build +insert_vectors(); +$node->safe_psql("postgres", "CREATE INDEX idx ON tst USING hnsw (v vector_l2_ops);"); +test_duplicates(); + +# Reset +$node->safe_psql("postgres", "TRUNCATE tst;"); + +# Test duplicates with inserts +insert_vectors(); +test_duplicates(); + +# Test fallback path for inserts +$node->pgbench( + "--no-vacuum --client=5 --transactions=100", + 0, + [qr{actually processed}], + [qr{^$}], + "concurrent INSERTs", + { + "015_hnsw_duplicates" => "INSERT INTO tst VALUES ('[1,1,1]');" + } +); + +done_testing(); diff --git a/test/t/016_hnsw_vacuum_recall.pl b/test/t/016_hnsw_vacuum_recall.pl new file mode 100644 index 0000000..1cc267d --- /dev/null +++ b/test/t/016_hnsw_vacuum_recall.pl @@ -0,0 +1,97 @@ +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More; + +my $node; +my @queries = (); +my @expected; +my $limit = 20; + +sub test_recall +{ + my ($min, $ef_search, $test_name) = @_; + my $correct = 0; + my $total = 0; + + my $explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET hnsw.ef_search = $ef_search; + EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v <-> '$queries[0]' LIMIT $limit; + )); + like($explain, qr/Index Scan/); + + for my $i (0 .. $#queries) + { + my $actual = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET hnsw.ef_search = $ef_search; + SELECT i FROM tst ORDER BY v <-> '$queries[$i]' LIMIT $limit; + )); + my @actual_ids = split("\n", $actual); + my %actual_set = map { $_ => 1 } @actual_ids; + + my @expected_ids = split("\n", $expected[$i]); + + foreach (@expected_ids) + { + if (exists($actual_set{$_})) + { + $correct++; + } + $total++; + } + } + + cmp_ok($correct / $total, ">=", $min, $test_name); +} + +# Initialize node +$node = get_new_node('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i int4, v vector(3));"); +$node->safe_psql("postgres", "ALTER TABLE tst SET (autovacuum_enabled = false);"); +$node->safe_psql("postgres", + "INSERT INTO tst SELECT i, ARRAY[random(), random(), random()] FROM generate_series(1, 10000) i;" +); + +# Add index +$node->safe_psql("postgres", "CREATE INDEX ON tst USING hnsw (v vector_l2_ops) WITH (m = 4, ef_construction = 8);"); + +# Delete data +$node->safe_psql("postgres", "DELETE FROM tst WHERE i > 2500;"); + +# Generate queries +for (1 .. 20) +{ + my $r1 = rand(); + my $r2 = rand(); + my $r3 = rand(); + push(@queries, "[$r1,$r2,$r3]"); +} + +# Get exact results +@expected = (); +foreach (@queries) +{ + my $res = $node->safe_psql("postgres", qq( + SET enable_indexscan = off; + SELECT i FROM tst ORDER BY v <-> '$_' LIMIT $limit; + )); + push(@expected, $res); +} + +test_recall(0.20, $limit, "before vacuum"); +test_recall(0.95, 100, "before vacuum"); + +# TODO Test concurrent inserts with vacuum +$node->safe_psql("postgres", "VACUUM tst;"); + +test_recall(0.95, $limit, "after vacuum"); + +done_testing(); diff --git a/test/t/017_ivfflat_insert_recall.pl b/test/t/017_ivfflat_insert_recall.pl new file mode 100644 index 0000000..c2e320c --- /dev/null +++ b/test/t/017_ivfflat_insert_recall.pl @@ -0,0 +1,117 @@ +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More; + +my $node; +my @queries = (); +my @expected; +my $limit = 20; + +sub test_recall +{ + my ($probes, $min, $operator) = @_; + my $correct = 0; + my $total = 0; + + my $explain = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET ivfflat.probes = $probes; + EXPLAIN ANALYZE SELECT i FROM tst ORDER BY v $operator '$queries[0]' LIMIT $limit; + )); + like($explain, qr/Index Scan using idx on tst/); + + for my $i (0 .. $#queries) + { + my $actual = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET ivfflat.probes = $probes; + SELECT i FROM tst ORDER BY v $operator '$queries[$i]' LIMIT $limit; + )); + my @actual_ids = split("\n", $actual); + my %actual_set = map { $_ => 1 } @actual_ids; + + my @expected_ids = split("\n", $expected[$i]); + + foreach (@expected_ids) + { + if (exists($actual_set{$_})) + { + $correct++; + } + $total++; + } + } + + cmp_ok($correct / $total, ">=", $min, $operator); +} + +# Initialize node +$node = get_new_node('node'); +$node->init; +$node->start; + +# Create table +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i serial, v vector(3));"); + +# Generate queries +for (1 .. 20) +{ + my $r1 = rand(); + my $r2 = rand(); + my $r3 = rand(); + push(@queries, "[$r1,$r2,$r3]"); +} + +# Check each index type +my @operators = ("<->", "<#>", "<=>"); +my @opclasses = ("vector_l2_ops", "vector_ip_ops", "vector_cosine_ops"); + +for my $i (0 .. $#operators) +{ + my $operator = $operators[$i]; + my $opclass = $opclasses[$i]; + + # Add index + $node->safe_psql("postgres", "CREATE INDEX idx ON tst USING ivfflat (v $opclass);"); + + # Use concurrent inserts + $node->pgbench( + "--no-vacuum --client=10 --transactions=1000", + 0, + [qr{actually processed}], + [qr{^$}], + "concurrent INSERTs", + { + "017_ivfflat_insert_recall_$opclass" => "INSERT INTO tst (v) SELECT ARRAY[random(), random(), random()] FROM generate_series(1, 10) i;" + } + ); + + # Get exact results + @expected = (); + foreach (@queries) + { + my $res = $node->safe_psql("postgres", qq( + SET enable_indexscan = off; + SELECT i FROM tst ORDER BY v $operator '$_' LIMIT $limit; + )); + push(@expected, $res); + } + + # Test approximate results + if ($operator ne "<#>") + { + # TODO Fix test (uniform random vectors all have similar inner product) + test_recall(1, 0.71, $operator); + test_recall(10, 0.95, $operator); + } + # Account for equal distances + test_recall(100, 0.9925, $operator); + + $node->safe_psql("postgres", "DROP INDEX idx;"); + $node->safe_psql("postgres", "TRUNCATE tst;"); +} + +done_testing(); diff --git a/test/t/018_ivfflat_deletes.pl b/test/t/018_ivfflat_deletes.pl new file mode 100644 index 0000000..a0ea0e6 --- /dev/null +++ b/test/t/018_ivfflat_deletes.pl @@ -0,0 +1,43 @@ +use strict; +use warnings; +use PostgresNode; +use TestLib; +use Test::More; + +my $dim = 3; + +my $array_sql = join(",", ('random()') x $dim); + +# Initialize node +my $node = get_new_node('node'); +$node->init; +$node->start; + +# Create table and index +$node->safe_psql("postgres", "CREATE EXTENSION vector;"); +$node->safe_psql("postgres", "CREATE TABLE tst (i serial, v vector($dim));"); +$node->safe_psql("postgres", + "INSERT INTO tst (v) SELECT ARRAY[$array_sql] FROM generate_series(1, 10000) i;" +); +$node->safe_psql("postgres", "CREATE INDEX ON tst USING ivfflat (v vector_l2_ops);"); + +# Delete data +$node->safe_psql("postgres", "DELETE FROM tst WHERE i % 100 != 0;"); + +my $exp = $node->safe_psql("postgres", qq( + SET enable_indexscan = off; + SELECT i FROM tst ORDER BY v <-> '[0,0,0]'; +)); + +# Run twice to make sure correct tuples marked as dead +for (1 .. 2) +{ + my $res = $node->safe_psql("postgres", qq( + SET enable_seqscan = off; + SET ivfflat.probes = 100; + SELECT i FROM tst ORDER BY v <-> '[0,0,0]'; + )); + is($res, $exp); +} + +done_testing(); diff --git a/vector.control b/vector.control index fe1f94e..7091703 100644 --- a/vector.control +++ b/vector.control @@ -1,4 +1,4 @@ -comment = 'vector data type and ivfflat access method' -default_version = '0.4.4' +comment = 'vector data type and ivfflat and hnsw access methods' +default_version = '0.5.1' module_pathname = '$libdir/vector' relocatable = true