Skip to content

Commit

Permalink
- Utility methods.
Browse files Browse the repository at this point in the history
  • Loading branch information
jjzazuet committed Jul 10, 2023
1 parent cfff3af commit b962ad0
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 30 deletions.
19 changes: 16 additions & 3 deletions src/main/java/io/vacco/bert/BtContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public BtContext(File modelPath, int nThreads) {
this.ctxPtr = Bt.bertLoadFromFile(modelPath.getAbsolutePath());
this.tokenBuffer = new int[Bt.bertNMaxTokens(ctxPtr)];
this.embedBuffer = new float[Bt.bertNEmbd(ctxPtr)];
this.tokensRead[0] = 0;
}

public float[] eval(String sentence) {
Expand All @@ -36,9 +37,21 @@ public float[] eval(String sentence) {
return embedBuffer;
}

public float[] embeddingBufferCopy() {
var copy = new float[embedBuffer.length];
System.arraycopy(embedBuffer, 0, copy, 0, embedBuffer.length);
public float[] evalCopy(String sentence) {
return copy(eval(sentence));
}

public String[] tokenSymbols() {
var ts = new String[tokensRead[0]];
for (int i = 0; i < tokensRead[0]; i++) {
ts[i] = Bt.bertVocabIdToToken(ctxPtr, tokenBuffer[i]);
}
return ts;
}

public float[] copy(float[] in) {
var copy = new float[in.length];
System.arraycopy(in, 0, copy, 0, in.length);
return copy;
}

Expand Down
3 changes: 1 addition & 2 deletions src/test/java/BtRecord.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ public BtRecord withSimilarity(float similarity) {
public static BtRecord from(float[] embedding, String text) {
var r = new BtRecord();
r.text = Objects.requireNonNull(text);
r.embedding = new float[embedding.length];
System.arraycopy(embedding, 0, r.embedding, 0, embedding.length);
r.embedding = embedding;
return r;
}

Expand Down
62 changes: 37 additions & 25 deletions src/test/java/BtTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import io.vacco.bert.Bt;
import io.vacco.bert.BtContext;
;import io.vacco.bert.BtContext;
import j8spec.annotation.DefinedOrder;
import j8spec.junit.J8SpecRunner;
import org.junit.runner.RunWith;
Expand All @@ -17,6 +16,7 @@
public class BtTest {

public static final File modelPath = new File("/home/jjzazuet/code/bert.cpp/models/multi-qa-MiniLM-L6-cos-v1/ggml-model-f16.bin");
public static BtContext bt = new BtContext(modelPath, 4);

public static float cosineSimilarity(float[] vectorA, float[] vectorB) {
float dotProduct = 0.0f;
Expand All @@ -32,35 +32,47 @@ public static float cosineSimilarity(float[] vectorA, float[] vectorB) {

static {
if (!GraphicsEnvironment.isHeadless()) {
it("Loads/releases a BERT model", () -> {
var ctx = Bt.bertLoadFromFile(modelPath.getAbsolutePath());
Bt.bertFree(ctx);
});
it("Encodes a sentence into an embedding", () -> {
try (var bt = new BtContext(modelPath, 2)) {
var embedding = bt.eval("This is a prompt");
System.out.println(Arrays.toString(embedding));
var embedding = bt.eval("This is a prompt, it should get tokenized.");
System.out.println(Arrays.toString(embedding));
System.out.println(Arrays.toString(bt.tokenSymbols()));
});
it("Computes pair-wise sequence similarity", () -> {
var sentences = new String[] {
"Kittens are cute",
"We want to have a cat recognition system",
"You should use a neural network for this",
"It's better to apply some deep learning techniques"
};
for (var s0 : sentences) {
for (var s1 : sentences) {
var vec0 = bt.evalCopy(s0);
var vec1 = bt.evalCopy(s1);
System.out.printf("[%.8f], %s <---> %s%n", cosineSimilarity(vec0, vec1), s0, s1);
}
}
});
it("Queries embeddings for a search term", () -> {
try (var bt = new BtContext(modelPath, 2)) {
var lines = Files.readAllLines(Paths.get("./src/test/resources/documents.txt"));
var recs = lines.stream()
.map(txt -> BtRecord.from(bt.eval(txt), txt))
.collect(Collectors.toList());
var qText = "Should I get health insurance?";
var query = bt.eval(qText);
var results = recs.stream()
.map(rec -> rec.withSimilarity(cosineSimilarity(query, rec.embedding)))
.sorted(Comparator.comparing(rec -> -rec.similarity))
.limit(25)
.collect(Collectors.toList());
System.out.printf("====> %s <====%n", qText);
for (var rec : results) {
System.out.printf("[%.8f] %s%n", rec.similarity, rec.text);
}
var lines = Files.readAllLines(Paths.get("./src/test/resources/documents.txt"));
var recs = lines.stream()
.map(txt -> BtRecord.from(bt.evalCopy(txt), txt))
.collect(Collectors.toList());
var qText = "Should I get health insurance?";
var query = bt.eval(qText);
var results = recs.stream()
.map(rec -> rec.withSimilarity(cosineSimilarity(query, rec.embedding)))
.sorted(Comparator.comparing(rec -> -rec.similarity))
.limit(25)
.collect(Collectors.toList());
System.out.printf("====> %s <====%n", qText);
for (var rec : results) {
System.out.printf("[%.8f] %s%n", rec.similarity, rec.text);
}
});
it("Closes the BERT context", () -> {
System.out.println("Closing BERT context");
bt.close();
});
} else {
System.out.println("Headless mode, skipping tests");
}
Expand Down

0 comments on commit b962ad0

Please sign in to comment.