From 5635e3da9aca573c0231775815eed65db47a884c Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Mon, 30 Sep 2024 07:28:40 -0700 Subject: [PATCH 1/2] Adds List.sample and List.sampleN --- .../__tests__/library/list_test.ts | 17 ++++++++++ .../src/dists/SampleSetDist/index.ts | 5 ++- packages/squiggle-lang/src/fr/list.ts | 34 ++++++++++++++++++- packages/squiggle-lang/src/utility/E_A.ts | 24 +++++++++++++ 4 files changed, 76 insertions(+), 4 deletions(-) diff --git a/packages/squiggle-lang/__tests__/library/list_test.ts b/packages/squiggle-lang/__tests__/library/list_test.ts index 803055df5a..ea62d5392b 100644 --- a/packages/squiggle-lang/__tests__/library/list_test.ts +++ b/packages/squiggle-lang/__tests__/library/list_test.ts @@ -308,4 +308,21 @@ describe("List functions", () => { testEvalToBe("List.unzip([[1,3],[2,4],[5,6]])", "[[1,2,5],[3,4,6]]"); testEvalToBe("List.unzip([])", "[[],[]]"); }); + + describe("sample", () => { + testEvalToBe("List.sample([1, 2, 3, 4, 5]) -> typeOf", '"Number"'); + testEvalToBe( + "List.sample([])", + "Error(Argument Error: List must not be empty)" + ); + }); + + describe("sampleN", () => { + testEvalToBe("List.sampleN([1, 2, 3, 4, 5], 3) -> List.length", "3"); + testEvalToBe("List.sampleN([1, 2, 3], 0)", "[]"); + testEvalToBe( + "List.sampleN([], 2)", + "Error(Argument Error: List must not be empty)" + ); + }); }); diff --git a/packages/squiggle-lang/src/dists/SampleSetDist/index.ts b/packages/squiggle-lang/src/dists/SampleSetDist/index.ts index b61d9c9b95..1bfd6b757a 100644 --- a/packages/squiggle-lang/src/dists/SampleSetDist/index.ts +++ b/packages/squiggle-lang/src/dists/SampleSetDist/index.ts @@ -5,7 +5,7 @@ import * as Discrete from "../../PointSet/Discrete.js"; import { DiscreteShape } from "../../PointSet/Discrete.js"; import { buildMixedShape } from "../../PointSet/Mixed.js"; import { PRNG } from "../../rng/index.js"; -import { isEqual } from "../../utility/E_A.js"; +import { isEqual, sample as E_A_sample } from "../../utility/E_A.js"; import * as E_A_Floats from "../../utility/E_A_Floats.js"; import * as E_A_Sorted from "../../utility/E_A_Sorted.js"; import * as Result from "../../utility/result.js"; @@ -125,8 +125,7 @@ export class SampleSetDist extends BaseDist { // Randomly get one sample from the distribution sample(rng: PRNG) { - const index = Math.floor(rng() * this.samples.length); - return this.samples[index]; + return E_A_sample(this.samples, rng); } /* diff --git a/packages/squiggle-lang/src/fr/list.ts b/packages/squiggle-lang/src/fr/list.ts index d2337e2d09..b5ff99abfc 100644 --- a/packages/squiggle-lang/src/fr/list.ts +++ b/packages/squiggle-lang/src/fr/list.ts @@ -30,7 +30,7 @@ import { Lambda } from "../reducer/lambda/index.js"; import { Reducer } from "../reducer/Reducer.js"; import { tBool, tNumber } from "../types/TIntrinsic.js"; import { tAny } from "../types/Type.js"; -import { shuffle, unzip, zip } from "../utility/E_A.js"; +import { sample, sampleN, shuffle, unzip, zip } from "../utility/E_A.js"; import * as E_A_Floats from "../utility/E_A_Floats.js"; import { uniq, uniqBy, Value } from "../value/index.js"; import { vNumber } from "../value/VNumber.js"; @@ -800,4 +800,36 @@ List.reduceWhile( ), ], }), + maker.make({ + name: "sample", + requiresNamespace: true, + examples: [makeFnExample(`List.sample([1,4,5])`)], + displaySection: "Queries", + definitions: [ + makeDefinition( + [frArray(frAny({ genericName: "A" }))], + frAny({ genericName: "A" }), + ([array], { rng }) => { + _assertUnemptyArray(array); + return sample(array, rng); + } + ), + ], + }), + maker.make({ + name: "sampleN", + requiresNamespace: true, + examples: [makeFnExample(`List.sampleN([1,4,5], 2)`)], + displaySection: "Queries", + definitions: [ + makeDefinition( + [frArray(frAny({ genericName: "A" })), namedInput("n", frNumber)], + frArray(frAny({ genericName: "A" })), + ([array, n], { rng }) => { + _assertUnemptyArray(array); + return sampleN(array, n, rng); + } + ), + ], + }), ]; diff --git a/packages/squiggle-lang/src/utility/E_A.ts b/packages/squiggle-lang/src/utility/E_A.ts index 048a9a6f25..60b8296965 100644 --- a/packages/squiggle-lang/src/utility/E_A.ts +++ b/packages/squiggle-lang/src/utility/E_A.ts @@ -119,3 +119,27 @@ export function isEqual(arr1: readonly T[], arr2: readonly T[]): boolean { return true; } + +export function sample(array: readonly T[], rng: PRNG): T { + const index = Math.floor(rng() * array.length); + return array[index]; +} + +export function sampleN(array: readonly T[], n: number, rng: PRNG): T[] { + const size = Math.max(0, Math.floor(n)); + if (size === 0) return []; + if (size >= array.length) return shuffle(array, rng); + + const result: T[] = []; + const indices = new Set(); + + while (indices.size < size) { + const index = Math.floor(rng() * array.length); + if (!indices.has(index)) { + indices.add(index); + result.push(array[index]); + } + } + + return result; +} From a241b8b67ac5757b7a35fcc1bdb2265015abb669 Mon Sep 17 00:00:00 2001 From: Ozzie Gooen Date: Mon, 30 Sep 2024 07:34:20 -0700 Subject: [PATCH 2/2] Added changeset --- .changeset/late-onions-tease.md | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .changeset/late-onions-tease.md diff --git a/.changeset/late-onions-tease.md b/.changeset/late-onions-tease.md new file mode 100644 index 0000000000..4682d1da36 --- /dev/null +++ b/.changeset/late-onions-tease.md @@ -0,0 +1,5 @@ +--- +"@quri/squiggle-lang": patch +--- + +Added List.sample and List.sampleN