Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion cpp/include/cuvs/neighbors/cagra.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -574,6 +574,17 @@ cuvsError_t cuvsCagraSerialize(cuvsResources_t res,
cuvsCagraIndex_t index,
bool include_dataset);

cuvsError_t cuvsCagraSerializeWithWriter(cuvsResources_t res,
void (*writer)(int),
cuvsCagraIndex_t index,
bool include_dataset);

cuvsError_t cuvsCagraSerializeWithBufferedWriter(cuvsResources_t res,
size_t (*writer)(void*, size_t),
size_t buffer_size,
cuvsCagraIndex_t index,
bool include_dataset);

/**
* Save the CAGRA index to file in hnswlib format.
* NOTE: The saved index can only be read by the hnswlib wrapper in cuVS,
Expand Down
122 changes: 122 additions & 0 deletions cpp/src/neighbors/cagra_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,87 @@ void _serialize(cuvsResources_t res,
cuvs::neighbors::cagra::serialize(*res_ptr, std::string(filename), *index_ptr, include_dataset);
}

struct _direct_write_buf : public std::streambuf {
_direct_write_buf(void (*writer)(int)) : _writer(writer) {}

protected:
int_type overflow(int_type ch) override
{
if (ch == EOF) { return 0; }
_writer(ch);
return (char_type)ch;
}

private:
std::function<void(int)> _writer;
};

template <typename T>
void _serialize(cuvsResources_t res,
void (*writer)(int),
cuvsCagraIndex_t index,
bool include_dataset)
{
auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto index_ptr = reinterpret_cast<cuvs::neighbors::cagra::index<T, uint32_t>*>(index->addr);

_direct_write_buf write_buffer(writer);
std::ostream out(&write_buffer);

cuvs::neighbors::cagra::serialize(*res_ptr, out, *index_ptr, include_dataset);
}

struct _write_buf : public std::streambuf {
_write_buf(size_t (*writer)(void*, size_t), size_t buffer_size)
: _writer(writer), _buffer(buffer_size)
{
// save 1 char space for overflows
setp(_buffer.data(), _buffer.data() + buffer_size - 1);
}

protected:
int_type overflow(int_type ch) override
{
int_type eof = traits_type::eof();
size_t bytes_to_write = pptr() - pbase();
if (_writer(pbase(), bytes_to_write) != bytes_to_write) {
ch = eof; // returning eof represents an error
} else if (traits_type::eq_int_type(ch, eof)) {
ch = 0; // we are done
}
setp(_buffer.data(), _buffer.data() + buffer_size - 1);
}
return ch;
}

int sync()
{
size_t bytes_to_write = pptr() - pbase();
if (bytes_to_write == 0 || _writer(pbase(), bytes_to_write) == bytes_to_write) { return 0; }
return -1;
}

private:
std::function<size_t(void*, size_t)> _writer;
std::vector<char_type> _buffer;
};

template <typename T>
void _serialize(cuvsResources_t res,
size_t (*writer)(void*, size_t),
size_t buffer_size,
cuvsCagraIndex_t index,
bool include_dataset)
{
auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto index_ptr = reinterpret_cast<cuvs::neighbors::cagra::index<T, uint32_t>*>(index->addr);

_write_buf write_buffer(writer, buffer_size);
std::ostream out(&write_buffer);

cuvs::neighbors::cagra::serialize(*res_ptr, out, *index_ptr, include_dataset);
}

template <typename T>
void _serialize_to_hnswlib(cuvsResources_t res, const char* filename, cuvsCagraIndex_t index)
{
Expand Down Expand Up @@ -652,6 +733,47 @@ extern "C" cuvsError_t cuvsCagraSerialize(cuvsResources_t res,
});
}

extern "C" cuvsError_t cuvsCagraSerializeWithWriter(cuvsResources_t res,
void (*writer)(int),
cuvsCagraIndex_t index,
bool include_dataset)
{
return cuvs::core::translate_exceptions([=] {
if (index->dtype.code == kDLFloat && index->dtype.bits == 32) {
_serialize<float>(res, writer, index, include_dataset);
} else if (index->dtype.code == kDLFloat && index->dtype.bits == 16) {
_serialize<half>(res, writer, index, include_dataset);
} else if (index->dtype.code == kDLInt && index->dtype.bits == 8) {
_serialize<int8_t>(res, writer, index, include_dataset);
} else if (index->dtype.code == kDLUInt && index->dtype.bits == 8) {
_serialize<uint8_t>(res, writer, index, include_dataset);
} else {
RAFT_FAIL("Unsupported index dtype: %d and bits: %d", index->dtype.code, index->dtype.bits);
}
});
}

extern "C" cuvsError_t cuvsCagraSerializeWithBufferedWriter(cuvsResources_t res,
size_t (*writer)(void*, size_t),
size_t buffer_size,
cuvsCagraIndex_t index,
bool include_dataset)
{
return cuvs::core::translate_exceptions([=] {
if (index->dtype.code == kDLFloat && index->dtype.bits == 32) {
_serialize<float>(res, writer, index, include_dataset);
} else if (index->dtype.code == kDLFloat && index->dtype.bits == 16) {
_serialize<half>(res, writer, index, include_dataset);
} else if (index->dtype.code == kDLInt && index->dtype.bits == 8) {
_serialize<int8_t>(res, writer, index, include_dataset);
} else if (index->dtype.code == kDLUInt && index->dtype.bits == 8) {
_serialize<uint8_t>(res, writer, index, include_dataset);
} else {
RAFT_FAIL("Unsupported index dtype: %d and bits: %d", index->dtype.code, index->dtype.bits);
}
});
}

extern "C" cuvsError_t cuvsCagraSerializeToHnswlib(cuvsResources_t res,
const char* filename,
cuvsCagraIndex_t index)
Expand Down
25 changes: 3 additions & 22 deletions java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java
Original file line number Diff line number Diff line change
Expand Up @@ -72,30 +72,11 @@ public interface CagraIndex {
void serialize(OutputStream outputStream, int bufferLength) throws Throwable;

/**
* A method to persist a CAGRA index using an instance of {@link OutputStream}
* for writing index bytes.
*
* @param outputStream an instance of {@link OutputStream} to write the index
* bytes into
* @param tempFile an intermediate {@link Path} where CAGRA index is written
* temporarily
*/
default void serialize(OutputStream outputStream, Path tempFile) throws Throwable {
serialize(outputStream, tempFile, 1024);
}

/**
* A method to persist a CAGRA index using an instance of {@link OutputStream}
* and path to the intermediate temporary file.
* A method to persist a CAGRA index to a file.
*
* @param outputStream an instance of {@link OutputStream} to write the index
* bytes to
* @param tempFile an intermediate {@link Path} where CAGRA index is written
* temporarily
* @param bufferLength the length of buffer to use for writing bytes. Default
* value is 1024
* @param tempFile a {@link Path} where the CAGRA index is written
*/
void serialize(OutputStream outputStream, Path tempFile, int bufferLength) throws Throwable;
void serialize(Path tempFile) throws Throwable;

/**
* A method to create and persist HNSW index from CAGRA index using an instance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,7 @@
import static com.nvidia.cuvs.internal.common.Util.concatenate;
import static com.nvidia.cuvs.internal.common.Util.cudaMemcpy;
import static com.nvidia.cuvs.internal.common.Util.prepareTensor;
import static com.nvidia.cuvs.internal.panama.headers_h.cudaStream_t;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsCagraBuild;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsCagraDeserialize;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsCagraIndexCreate;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsCagraIndexDestroy;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsCagraIndexGetDims;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsCagraIndex_t;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsCagraMerge;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsCagraMergeParams_t;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsCagraSearch;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsCagraSerialize;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsCagraSerializeToHnswlib;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsRMMAlloc;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsRMMFree;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsStreamGet;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsStreamSync;
import static com.nvidia.cuvs.internal.panama.headers_h.omp_set_num_threads;
import static com.nvidia.cuvs.internal.panama.headers_h.*;

import com.nvidia.cuvs.CagraCompressionParams;
import com.nvidia.cuvs.CagraIndex;
Expand All @@ -66,13 +50,13 @@
import com.nvidia.cuvs.internal.panama.cuvsIvfPqParams;
import com.nvidia.cuvs.internal.panama.cuvsIvfPqSearchParams;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.foreign.Arena;
import java.lang.foreign.MemoryLayout;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.SequenceLayout;
import java.lang.foreign.ValueLayout;
import java.lang.foreign.*;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.BitSet;
Expand Down Expand Up @@ -394,23 +378,118 @@ public SearchResults search(CagraQuery query) throws Throwable {
}
}

private static class BufferedOutputStreamWrapper {
private final byte[] buffer;
private final OutputStream out;
private final MemorySegment bufferPtr;

BufferedOutputStreamWrapper(int bufferSize, OutputStream out) {
buffer = new byte[bufferSize];
this.out = out;
bufferPtr = MemorySegment.ofArray(buffer);
}

long write(MemorySegment buf, long size) throws IOException {
if (size <= buffer.length) {
MemorySegment.copy(buf, 0, bufferPtr, 0, size);
out.write(buffer, 0, (int) size);
return size;
}
return -1;
}

static long exceptionHandler(Throwable t, MemorySegment buf, long size) {
// TODO: log exception or rethrow?
return -1;
}
}

private static final MethodHandle writeSingleChar$mh;
private static final MethodHandle writeSingleCharErrorHandler$mh;

private static final MethodHandle writeBuffered$mh;
private static final MethodHandle writeBufferedErrorHandler$mh;

static {
try {
writeSingleChar$mh =
MethodHandles.lookup()
.findVirtual(
OutputStream.class, "write", MethodType.methodType(void.class, int.class));
writeSingleCharErrorHandler$mh =
MethodHandles.lookup()
.findStatic(
CagraIndexImpl.class,
"exceptionHandler",
MethodType.methodType(void.class, Throwable.class, int.class));
writeBuffered$mh =
MethodHandles.lookup()
.findVirtual(
BufferedOutputStreamWrapper.class,
"write",
MethodType.methodType(long.class, MemorySegment.class, long.class));
writeBufferedErrorHandler$mh =
MethodHandles.lookup()
.findStatic(
BufferedOutputStreamWrapper.class,
"exceptionHandler",
MethodType.methodType(
long.class, Throwable.class, MemorySegment.class, long.class));

} catch (NoSuchMethodException | IllegalAccessException e) {
throw new RuntimeException(e);
}
}

private static void exceptionHandler(Throwable e, int argument) {
// TODO: log exception or rethrow?
}

@Override
public void serialize(OutputStream outputStream) throws Throwable {
Path path =
Files.createTempFile(resources.tempDirectory(), UUID.randomUUID().toString(), ".cag");
serialize(outputStream, path, 1024);
public void serialize(OutputStream outputStream) {
// Create a stub as a native symbol to be passed into native function.
try (var arena = Arena.ofConfined()) {
MethodHandle outWrite =
MethodHandles.catchException(
writeSingleChar$mh.bindTo(outputStream),
Throwable.class,
writeSingleCharErrorHandler$mh);

MemorySegment outWriteNativeSymbol =
Linker.nativeLinker()
.upcallStub(outWrite, FunctionDescriptor.ofVoid(ValueLayout.JAVA_INT), arena);

var index = cagraIndexReference.getMemorySegment();
var returnValue =
cuvsCagraSerializeWithWriter(resources.getHandle(), outWriteNativeSymbol, index, true);
checkCuVSError(returnValue, "cuvsCagraSerializeWithWriter");
}
}

@Override
public void serialize(OutputStream outputStream, int bufferLength) throws Throwable {
Path path =
Files.createTempFile(resources.tempDirectory(), UUID.randomUUID().toString(), ".cag");
serialize(outputStream, path, bufferLength);
public void serialize(OutputStream outputStream, int bufferLength) {
// Create a stub as a native symbol to be passed into native function.
try (var arena = Arena.ofConfined()) {
MethodHandle outWrite =
MethodHandles.catchException(
writeBuffered$mh.bindTo(new BufferedOutputStreamWrapper(bufferLength, outputStream)),
Throwable.class,
writeBufferedErrorHandler$mh);

MemorySegment outWriteNativeSymbol =
Linker.nativeLinker()
.upcallStub(outWrite, FunctionDescriptor.of(C_LONG, C_POINTER, C_LONG), arena);

var index = cagraIndexReference.getMemorySegment();
var returnValue =
cuvsCagraSerializeWithBufferedWriter(
resources.getHandle(), outWriteNativeSymbol, bufferLength, index, true);
checkCuVSError(returnValue, "cuvsCagraSerializeWithWriter");
}
}

@Override
public void serialize(OutputStream outputStream, Path tempFile, int bufferLength)
throws Throwable {
public void serialize(Path tempFile) {
checkNotDestroyed();
tempFile = tempFile.toAbsolutePath();
try (var localArena = Arena.ofConfined()) {
Expand All @@ -423,16 +502,6 @@ public void serialize(OutputStream outputStream, Path tempFile, int bufferLength
cagraIndexReference.getMemorySegment(),
true);
checkCuVSError(returnValue, "cuvsCagraSerialize");

try (var fileInputStream = Files.newInputStream(tempFile)) {
byte[] chunk = new byte[bufferLength];
int chunkLength = 0;
while ((chunkLength = fileInputStream.read(chunk)) != -1) {
outputStream.write(chunk, 0, chunkLength);
}
} finally {
Files.deleteIfExists(tempFile);
}
}
}

Expand Down