Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds List.sample and List.sampleN #3398

Merged
merged 2 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/late-onions-tease.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@quri/squiggle-lang": patch
---

Added List.sample and List.sampleN
17 changes: 17 additions & 0 deletions packages/squiggle-lang/__tests__/library/list_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -308,4 +308,21 @@ describe("List functions", () => {
testEvalToBe("List.unzip([[1,3],[2,4],[5,6]])", "[[1,2,5],[3,4,6]]");
testEvalToBe("List.unzip([])", "[[],[]]");
});

describe("sample", () => {
testEvalToBe("List.sample([1, 2, 3, 4, 5]) -> typeOf", '"Number"');
testEvalToBe(
"List.sample([])",
"Error(Argument Error: List must not be empty)"
);
});

describe("sampleN", () => {
testEvalToBe("List.sampleN([1, 2, 3, 4, 5], 3) -> List.length", "3");
testEvalToBe("List.sampleN([1, 2, 3], 0)", "[]");
testEvalToBe(
"List.sampleN([], 2)",
"Error(Argument Error: List must not be empty)"
);
});
});
5 changes: 2 additions & 3 deletions packages/squiggle-lang/src/dists/SampleSetDist/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import * as Discrete from "../../PointSet/Discrete.js";
import { DiscreteShape } from "../../PointSet/Discrete.js";
import { buildMixedShape } from "../../PointSet/Mixed.js";
import { PRNG } from "../../rng/index.js";
import { isEqual } from "../../utility/E_A.js";
import { isEqual, sample as E_A_sample } from "../../utility/E_A.js";
import * as E_A_Floats from "../../utility/E_A_Floats.js";
import * as E_A_Sorted from "../../utility/E_A_Sorted.js";
import * as Result from "../../utility/result.js";
Expand Down Expand Up @@ -125,8 +125,7 @@ export class SampleSetDist extends BaseDist {

// Randomly get one sample from the distribution
sample(rng: PRNG) {
const index = Math.floor(rng() * this.samples.length);
return this.samples[index];
return E_A_sample(this.samples, rng);
}

/*
Expand Down
34 changes: 33 additions & 1 deletion packages/squiggle-lang/src/fr/list.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import { Lambda } from "../reducer/lambda/index.js";
import { Reducer } from "../reducer/Reducer.js";
import { tBool, tNumber } from "../types/TIntrinsic.js";
import { tAny } from "../types/Type.js";
import { shuffle, unzip, zip } from "../utility/E_A.js";
import { sample, sampleN, shuffle, unzip, zip } from "../utility/E_A.js";
import * as E_A_Floats from "../utility/E_A_Floats.js";
import { uniq, uniqBy, Value } from "../value/index.js";
import { vNumber } from "../value/VNumber.js";
Expand Down Expand Up @@ -800,4 +800,36 @@ List.reduceWhile(
),
],
}),
maker.make({
name: "sample",
requiresNamespace: true,
examples: [makeFnExample(`List.sample([1,4,5])`)],
displaySection: "Queries",
definitions: [
makeDefinition(
[frArray(frAny({ genericName: "A" }))],
frAny({ genericName: "A" }),
([array], { rng }) => {
_assertUnemptyArray(array);
return sample(array, rng);
}
),
],
}),
maker.make({
name: "sampleN",
requiresNamespace: true,
examples: [makeFnExample(`List.sampleN([1,4,5], 2)`)],
displaySection: "Queries",
definitions: [
makeDefinition(
[frArray(frAny({ genericName: "A" })), namedInput("n", frNumber)],
frArray(frAny({ genericName: "A" })),
([array, n], { rng }) => {
_assertUnemptyArray(array);
return sampleN(array, n, rng);
}
),
],
}),
];
24 changes: 24 additions & 0 deletions packages/squiggle-lang/src/utility/E_A.ts
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,27 @@ export function isEqual<T>(arr1: readonly T[], arr2: readonly T[]): boolean {

return true;
}

export function sample<T>(array: readonly T[], rng: PRNG): T {
const index = Math.floor(rng() * array.length);
return array[index];
}

export function sampleN<T>(array: readonly T[], n: number, rng: PRNG): T[] {
const size = Math.max(0, Math.floor(n));
if (size === 0) return [];
if (size >= array.length) return shuffle(array, rng);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sampling 10 items from the list with 3 items will return the list of 3 items, is this intentional?

I'd expect it to create duplicates, and that would be consistent with how Dist.sampleN works

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I see that this defaults to sampling without replacement on lower N values, so it makes sense to keep it consistent.

It's a pity that it's inconsistent with how Dist.sampleN works, though. I assume it was intentional, and also the reason why you didn't make un-namespaced sampleN non-polymorphic?

Seems like a potential footgun...

How about this: we rename the version with replacement to pickN, and make sampleN consistent for lists and dists(with replacement)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good points. Agreed it's gnarly.
Another option is to add configs. Maybe something like,
sampleN(samples, {withReplacement: boolean, shuffle: boolean})

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, options could work, but if we want to combine this with polymorphism on dists vs lists (which seems natural and good), and have different replacement defaults, then it's still problematic.


const result: T[] = [];
const indices = new Set<number>();

while (indices.size < size) {
const index = Math.floor(rng() * array.length);
if (!indices.has(index)) {
indices.add(index);
result.push(array[index]);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't look optimal, but I tested it on List.make(n,1) -> List.sampleN(n-1), and even for n = 1M it does only 14M loops, so I guess it's fine.

(asymptote is probably something like O(n*log(n)))

}
}

return result;
}