Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions frontend/__tests__/test/funbox/funbox-validation.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ describe("funbox-validation", () => {
"58008", //getWord
"wikipedia", //pullSection
"polyglot", //withWords
"llm", //withWords
"zipf", //changesWordsFrequency
].map((funbox) => ({
key: "mode",
Expand Down
211 changes: 211 additions & 0 deletions frontend/__tests__/test/llm/constraint-engine-oracle.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
import { describe, expect, it } from "vitest";
import { ConstraintEngine } from "../../../src/ts/test/llm/constraint-engine";
import {
OracleState,
createConstraintOracle,
} from "../../../src/ts/test/llm/spec-oracle";

type DecodedToken = {
id: number;
text: string;
};

describe("constraint engine oracle", () => {
it("models the spec with a simple oracle", () => {
const oracle = createConstraintOracle(["the", "there", "world"]);

const theState = oracle.getNextState(oracle.initialState, "the");
expect(theState).toEqual({ partialWord: "the" });
expect(oracle.canTerminate(theState as OracleState)).toBe(true);

const worldState = oracle.getNextState(theState as OracleState, " world");
expect(worldState).toEqual({ partialWord: "world" });
expect(oracle.canTerminate(worldState as OracleState)).toBe(true);

expect(oracle.getNextState(oracle.initialState, " world")).toBeNull();
expect(oracle.getNextState(theState as OracleState, "x")).toBeNull();
expect(oracle.getNextState(theState as OracleState, " ")).toBeNull();
});

it("matches the oracle on curated prefix-heavy cases", () => {
const tokenTexts = enumerateTokenTexts(["a", "b", "c", " "], 3);

for (const words of [
["a", "ab", "aba", "b", "ba"],
["the", "there", "world"],
["cat", "car", "card", "dog"],
["go", "good", "goods", "gone"],
]) {
assertEngineMatchesOracle(words, tokenTexts);
}
});

it("matches the oracle across randomized small wordsets", () => {
const rng = createMulberry32(123456);
const tokenTexts = enumerateTokenTexts(["a", "b", "c", " "], 3);
const availableWords = enumerateTokenTexts(["a", "b", "c"], 4).filter(
(word) => word.length > 0,
);

for (let index = 0; index < 75; index++) {
const wordCount = 2 + Math.floor(rng() * 6);
const words = pickUniqueWords(availableWords, wordCount, rng);
assertEngineMatchesOracle(words, tokenTexts);
}
});

it("stays inside the language during random valid walks", () => {
const rng = createMulberry32(98765);
const words = ["a", "ab", "abc", "b", "ba"];
const decodedTokens = buildDecodedTokens(
enumerateTokenTexts(["a", "b", "c", " "], 3),
);
const engine = new ConstraintEngine(words, decodedTokens);
const oracle = createConstraintOracle(words);

for (let walk = 0; walk < 100; walk++) {
let stateId = engine.rootStateId;
let oracleState = oracle.initialState;
let renderedText = "";

for (let step = 0; step < 20; step++) {
const validTokenIds = engine.getValidTokenIds(stateId);
expect(validTokenIds.length).toBeGreaterThan(0);

const tokenId = validTokenIds[
Math.floor(rng() * validTokenIds.length)
] as number;
const tokenText = decodedTokens[tokenId]?.text ?? "";
const nextStateId = engine.getNextState(stateId, tokenId);
const nextOracleState = oracle.getNextState(oracleState, tokenText);

expect(nextStateId).not.toBeNull();
expect(nextOracleState).not.toBeNull();

renderedText += tokenText;
stateId = nextStateId as number;
oracleState = nextOracleState as OracleState;

expect(engine.getStatePrefix(stateId)).toBe(oracleState.partialWord);

const parsedWords = renderedText
.split(" ")
.filter(
(word, index, items) => word.length > 0 && index < items.length - 1,
);

for (const word of parsedWords) {
expect(words).toContain(word);
}

if (engine.canTerminate(stateId) && step > 0 && rng() < 0.35) {
break;
}
}
}
});
});

function assertEngineMatchesOracle(
words: string[],
tokenTexts: string[],
): void {
const decodedTokens = buildDecodedTokens(tokenTexts);
const engine = new ConstraintEngine(words, decodedTokens);
const oracle = createConstraintOracle(words);

for (let stateId = 0; stateId < engine.getStateCount(); stateId++) {
const oracleState = {
partialWord: engine.getStatePrefix(stateId),
};

expect(engine.canTerminate(stateId)).toBe(oracle.canTerminate(oracleState));

const expectedValidTokenIds: number[] = [];

for (const token of decodedTokens) {
const nextOracleState = oracle.getNextState(oracleState, token.text);
const nextEngineStateId = engine.getNextState(stateId, token.id);

if (nextOracleState === null) {
expect(nextEngineStateId).toBeNull();
continue;
}

expectedValidTokenIds.push(token.id);
expect(nextEngineStateId).not.toBeNull();
expect(engine.getStatePrefix(nextEngineStateId as number)).toBe(
nextOracleState.partialWord,
);
}

expect(engine.getValidTokenIds(stateId)).toEqual(expectedValidTokenIds);
}
}

function buildDecodedTokens(tokenTexts: string[]): DecodedToken[] {
return tokenTexts.map((text, id) => ({ id, text }));
}

function enumerateTokenTexts(alphabet: string[], maxLength: number): string[] {
const texts = [""];

for (let length = 1; length <= maxLength; length++) {
buildTextsOfLength("", length, alphabet, texts);
}

return texts;
}

function buildTextsOfLength(
prefix: string,
remainingLength: number,
alphabet: string[],
texts: string[],
): void {
if (remainingLength === 0) {
texts.push(prefix);
return;
}

for (const char of alphabet) {
buildTextsOfLength(
`${prefix}${char}`,
remainingLength - 1,
alphabet,
texts,
);
}
}

function createMulberry32(seed: number): () => number {
let current = seed;

return () => {
current |= 0;
current = (current + 0x6d2b79f5) | 0;
let t = Math.imul(current ^ (current >>> 15), 1 | current);
t = (t + Math.imul(t ^ (t >>> 7), 61 | t)) ^ t;
return ((t ^ (t >>> 14)) >>> 0) / 4294967296;
};
}

function pickUniqueWords(
availableWords: string[],
count: number,
rng: () => number,
): string[] {
const pool = [...availableWords];
const result: string[] = [];

for (let index = 0; index < count; index++) {
const pickIndex = Math.floor(rng() * pool.length);
const [word] = pool.splice(pickIndex, 1);

if (word !== undefined) {
result.push(word);
}
}

return result;
}
79 changes: 79 additions & 0 deletions frontend/__tests__/test/llm/constraint-engine.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import { describe, expect, it } from "vitest";
import {
ConstraintEngine,
decodeTokenizerVocabulary,
} from "../../../src/ts/test/llm/constraint-engine";

describe("constraint engine", () => {
it("precomputes valid token transitions for root and mid-word states", () => {
const engine = new ConstraintEngine(
["the", "there", "world"],
[
{ id: 0, text: "the" },
{ id: 1, text: "there" },
{ id: 2, text: "re" },
{ id: 3, text: " world" },
{ id: 4, text: "w" },
{ id: 5, text: "orld" },
{ id: 6, text: "the world" },
{ id: 7, text: "x" },
{ id: 8, text: "" },
],
);

expect(engine.getValidTokenIds(engine.rootStateId)).toEqual([0, 1, 4, 6]);

const theState = engine.getNextState(engine.rootStateId, 0);
expect(theState).not.toBeNull();
expect(engine.getStatePrefix(theState as number)).toBe("the");
expect(engine.canTerminate(theState as number)).toBe(true);
expect(engine.getValidTokenIds(theState as number)).toEqual([2, 3]);

const worldFromSpan = engine.getNextState(engine.rootStateId, 6);
expect(worldFromSpan).not.toBeNull();
expect(engine.getStatePrefix(worldFromSpan as number)).toBe("world");
expect(engine.canTerminate(worldFromSpan as number)).toBe(true);

expect(engine.getNextState(engine.rootStateId, 3)).toBeNull();
expect(engine.getNextState(engine.rootStateId, 8)).toBeNull();
expect(engine.canTerminate(engine.rootStateId)).toBe(false);
});

it("rejects invalid surface forms in the constructor", () => {
expect(() => {
new ConstraintEngine(["hello", "two words"], [{ id: 0, text: "hello" }]);
}).toThrow("invalid surface form (must not contain spaces)");
});

it("decodes tokenizer vocab through the adapter interface", () => {
const decoded = decodeTokenizerVocabulary({
getVocabSize() {
return 3;
},
decodeToken(tokenId) {
return ["a", "b", "c"][tokenId] as string;
},
});

expect(decoded).toEqual([
{ id: 0, text: "a" },
{ id: 1, text: "b" },
{ id: 2, text: "c" },
]);
});

it("reports basic precompute stats", () => {
const engine = new ConstraintEngine(
["a", "ab"],
[
{ id: 0, text: "a" },
{ id: 1, text: "b" },
{ id: 2, text: " ab" },
],
);

expect(engine.getStats().stateCount).toBe(3);
expect(engine.getStats().tokenCount).toBe(3);
expect(engine.getStats().averageValidTokensPerState).toBeCloseTo(4 / 3);
});
});
37 changes: 37 additions & 0 deletions frontend/__tests__/test/llm/constraint-trie.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import { describe, expect, it } from "vitest";
import { ConstraintTrie } from "../../../src/ts/test/llm/constraint-trie";

describe("constraint trie", () => {
it("supports words that are prefixes of longer words", () => {
const trie = new ConstraintTrie(["the", "there", "world"]);

const tState = trie.consumeChar(trie.rootStateId, "t");
expect(tState).not.toBeNull();

const thState = trie.consumeChar(tState as number, "h");
expect(thState).not.toBeNull();

const theState = trie.consumeChar(thState as number, "e");
expect(theState).not.toBeNull();
expect(trie.getPrefix(theState as number)).toBe("the");
expect(trie.isWordState(theState as number)).toBe(true);

expect(trie.consumeChar(theState as number, " ")).toBe(trie.rootStateId);

const therState = trie.consumeChar(theState as number, "r");
expect(therState).not.toBeNull();
expect(trie.getPrefix(therState as number)).toBe("ther");

expect(trie.consumeChar(theState as number, "x")).toBeNull();
});

it("consumes multi-word token text across boundaries", () => {
const trie = new ConstraintTrie(["hello", "world"]);

const nextState = trie.consumeText(trie.rootStateId, "hello world");

expect(nextState).not.toBeNull();
expect(trie.getPrefix(nextState as number)).toBe("world");
expect(trie.isWordState(nextState as number)).toBe(true);
});
});
47 changes: 47 additions & 0 deletions frontend/__tests__/test/llm/surface-forms.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import { describe, expect, it } from "vitest";
import {
buildSurfaceForms,
isValidSurfaceForm,
} from "../../../src/ts/test/llm/surface-forms";

describe("llm surface forms", () => {
describe("isValidSurfaceForm", () => {
it("accepts plain words", () => {
expect(isValidSurfaceForm("hello")).toBe(true);
expect(isValidSurfaceForm("Bonjour")).toBe(true);
expect(isValidSurfaceForm("cafe")).toBe(true);
});

it("accepts words with non-letter characters", () => {
expect(isValidSurfaceForm("can't")).toBe(true);
expect(isValidSurfaceForm("re-entry")).toBe(true);
expect(isValidSurfaceForm("h3llo")).toBe(true);
expect(isValidSurfaceForm("hello,")).toBe(true);
expect(isValidSurfaceForm("__init__")).toBe(true);
expect(isValidSurfaceForm("123")).toBe(true);
});

it("rejects empty strings and strings containing spaces", () => {
expect(isValidSurfaceForm("")).toBe(false);
expect(isValidSurfaceForm("two words")).toBe(false);
expect(isValidSurfaceForm(" hello")).toBe(false);
expect(isValidSurfaceForm("hello ")).toBe(false);
});
});

describe("buildSurfaceForms", () => {
it("filters space-containing entries and deduplicates", () => {
expect(
buildSurfaceForms([
"hello",
"hello",
"two words",
"world",
"world!",
"123",
"re-entry",
]),
).toEqual(["hello", "world", "world!", "123", "re-entry"]);
});
});
});
Loading
Loading