import { appendUnique, arrayEquals, domainDescriptions, flatMorph, groupBy, hasKey, isArray, jsTypeOfDescriptions, printable, range, throwParseError, unset } from "@ark/util"; import { compileLiteralPropAccess, compileSerializedValue } from "../shared/compile.js"; import { Disjoint } from "../shared/disjoint.js"; import { implementNode } from "../shared/implement.js"; import { intersectNodesRoot, intersectOrPipeNodes } from "../shared/intersections.js"; import { $ark, registeredReference } from "../shared/registry.js"; import { Traversal } from "../shared/traversal.js"; import { hasArkKind } from "../shared/utils.js"; import { BaseRoot } from "./root.js"; import { defineRightwardIntersections } from "./utils.js"; const implementation = implementNode({ kind: "union", hasAssociatedError: true, collapsibleKey: "branches", keys: { ordered: {}, branches: { child: true, parse: (schema, ctx) => { const branches = []; for (const branchSchema of schema) { const branchNodes = hasArkKind(branchSchema, "root") ? branchSchema.branches : ctx.$.parseSchema(branchSchema).branches; for (const node of branchNodes) { if (node.hasKind("morph")) { const matchingMorphIndex = branches.findIndex(matching => matching.hasKind("morph") && matching.hasEqualMorphs(node)); if (matchingMorphIndex === -1) branches.push(node); else { const matchingMorph = branches[matchingMorphIndex]; branches[matchingMorphIndex] = ctx.$.node("morph", { ...matchingMorph.inner, in: matchingMorph.rawIn.rawOr(node.rawIn) }); } } else branches.push(node); } } if (!ctx.def.ordered) branches.sort((l, r) => (l.hash < r.hash ? -1 : 1)); return branches; } } }, normalize: schema => (isArray(schema) ? { branches: schema } : schema), reduce: (inner, $) => { const reducedBranches = reduceBranches(inner); if (reducedBranches.length === 1) return reducedBranches[0]; if (reducedBranches.length === inner.branches.length) return; return $.node("union", { ...inner, branches: reducedBranches }, { prereduced: true }); }, defaults: { description: node => node.distribute(branch => branch.description, describeBranches), expected: ctx => { const byPath = groupBy(ctx.errors, "propString"); const pathDescriptions = Object.entries(byPath).map(([path, errors]) => { const branchesAtPath = []; for (const errorAtPath of errors) appendUnique(branchesAtPath, errorAtPath.expected); const expected = describeBranches(branchesAtPath); // if there are multiple actual descriptions that differ, // just fall back to printable, which is the most specific const actual = errors.every(e => e.actual === errors[0].actual) ? errors[0].actual : printable(errors[0].data); return `${path && `${path} `}must be ${expected}${actual && ` (was ${actual})`}`; }); return describeBranches(pathDescriptions); }, problem: ctx => ctx.expected, message: ctx => { if (ctx.problem[0] === "[") { // clarify paths like [1], [0][1], and ["key!"] that could be confusing return `value at ${ctx.problem}`; } return ctx.problem; } }, intersections: { union: (l, r, ctx) => { if (l.isNever !== r.isNever) { // if exactly one operand is never, we can use it to discriminate based on presence return Disjoint.init("presence", l, r); } let resultBranches; if (l.ordered) { if (r.ordered) { throwParseError(writeOrderedIntersectionMessage(l.expression, r.expression)); } resultBranches = intersectBranches(r.branches, l.branches, ctx); if (resultBranches instanceof Disjoint) resultBranches.invert(); } else resultBranches = intersectBranches(l.branches, r.branches, ctx); if (resultBranches instanceof Disjoint) return resultBranches; return ctx.$.parseSchema(l.ordered || r.ordered ? { branches: resultBranches, ordered: true } : { branches: resultBranches }); }, ...defineRightwardIntersections("union", (l, r, ctx) => { const branches = intersectBranches(l.branches, [r], ctx); if (branches instanceof Disjoint) return branches; if (branches.length === 1) return branches[0]; return ctx.$.parseSchema(l.ordered ? { branches, ordered: true } : { branches }); }) } }); export class UnionNode extends BaseRoot { isBoolean = this.branches.length === 2 && this.branches[0].hasUnit(false) && this.branches[1].hasUnit(true); get branchGroups() { const branchGroups = []; let firstBooleanIndex = -1; for (const branch of this.branches) { if (branch.hasKind("unit") && branch.domain === "boolean") { if (firstBooleanIndex === -1) { firstBooleanIndex = branchGroups.length; branchGroups.push(branch); } else branchGroups[firstBooleanIndex] = $ark.intrinsic.boolean; continue; } branchGroups.push(branch); } return branchGroups; } unitBranches = this.branches.filter((n) => n.rawIn.hasKind("unit")); discriminant = this.discriminate(); discriminantJson = this.discriminant ? discriminantToJson(this.discriminant) : null; expression = this.distribute(n => n.nestableExpression, expressBranches); createBranchedOptimisticRootApply() { return (data, onFail) => { const optimisticResult = this.traverseOptimistic(data); if (optimisticResult !== unset) return optimisticResult; const ctx = new Traversal(data, this.$.resolvedConfig); this.traverseApply(data, ctx); return ctx.finalize(onFail); }; } get shallowMorphs() { return this.branches.reduce((morphs, branch) => appendUnique(morphs, branch.shallowMorphs), []); } get defaultShortDescription() { return this.distribute(branch => branch.defaultShortDescription, describeBranches); } innerToJsonSchema(ctx) { // special case to simplify { const: true } | { const: false } // to the canonical JSON Schema representation { type: "boolean" } if (this.branchGroups.length === 1 && this.branchGroups[0].equals($ark.intrinsic.boolean)) return { type: "boolean" }; const jsonSchemaBranches = this.branchGroups.map(group => group.toJsonSchemaRecurse(ctx)); if (jsonSchemaBranches.every((branch) => // iff all branches are pure unit values with no metadata, // we can simplify the representation to an enum Object.keys(branch).length === 1 && hasKey(branch, "const"))) { return { enum: jsonSchemaBranches.map(branch => branch.const) }; } return { anyOf: jsonSchemaBranches }; } traverseAllows = (data, ctx) => this.branches.some(b => b.traverseAllows(data, ctx)); traverseApply = (data, ctx) => { const errors = []; for (let i = 0; i < this.branches.length; i++) { ctx.pushBranch(); this.branches[i].traverseApply(data, ctx); if (!ctx.hasError()) { if (this.branches[i].includesTransform) return ctx.queuedMorphs.push(...ctx.popBranch().queuedMorphs); return ctx.popBranch(); } errors.push(ctx.popBranch().error); } ctx.errorFromNodeContext({ code: "union", errors, meta: this.meta }); }; traverseOptimistic = (data) => { for (let i = 0; i < this.branches.length; i++) { const branch = this.branches[i]; if (branch.traverseAllows(data)) { if (branch.contextFreeMorph) return branch.contextFreeMorph(data); // if we're calling this function and the matching branch didn't have // a context-free morph, it shouldn't have morphs at all return data; } } return unset; }; compile(js) { if (!this.discriminant || // if we have a union of two units like `boolean`, the // undiscriminated compilation will be just as fast (this.unitBranches.length === this.branches.length && this.branches.length === 2)) return this.compileIndiscriminable(js); // we need to access the path as optional so we don't throw if it isn't present let condition = this.discriminant.optionallyChainedPropString; if (this.discriminant.kind === "domain") condition = `typeof ${condition} === "object" ? ${condition} === null ? "null" : "object" : typeof ${condition} === "function" ? "object" : typeof ${condition}`; const cases = this.discriminant.cases; const caseKeys = Object.keys(cases); const { optimistic } = js; // only the first layer can be optimistic js.optimistic = false; js.block(`switch(${condition})`, () => { for (const k in cases) { const v = cases[k]; const caseCondition = k === "default" ? k : `case ${k}`; let caseResult; if (v === true) caseResult = optimistic ? "data" : "true"; else if (optimistic) { if (v.rootApplyStrategy === "branchedOptimistic") caseResult = js.invoke(v, { kind: "Optimistic" }); else if (v.contextFreeMorph) caseResult = `${js.invoke(v)} ? ${registeredReference(v.contextFreeMorph)}(data) : "${unset}"`; else caseResult = `${js.invoke(v)} ? data : "${unset}"`; } else caseResult = js.invoke(v); js.line(`${caseCondition}: return ${caseResult}`); } return js; }); if (js.traversalKind === "Allows") { js.return(optimistic ? `"${unset}"` : false); return; } const expected = describeBranches(this.discriminant.kind === "domain" ? caseKeys.map(k => { const jsTypeOf = k.slice(1, -1); return jsTypeOf === "function" ? domainDescriptions.object : domainDescriptions[jsTypeOf]; }) : caseKeys); const serializedPathSegments = this.discriminant.path.map(k => typeof k === "symbol" ? registeredReference(k) : JSON.stringify(k)); const serializedExpected = JSON.stringify(expected); const serializedActual = this.discriminant.kind === "domain" ? `${serializedTypeOfDescriptions}[${condition}]` : `${serializedPrintable}(${condition})`; js.line(`ctx.errorFromNodeContext({ code: "predicate", expected: ${serializedExpected}, actual: ${serializedActual}, relativePath: [${serializedPathSegments}], meta: ${this.compiledMeta} })`); } compileIndiscriminable(js) { if (js.traversalKind === "Apply") { js.const("errors", "[]"); for (const branch of this.branches) { js.line("ctx.pushBranch()") .line(js.invoke(branch)) .if("!ctx.hasError()", () => js.return(branch.includesTransform ? "ctx.queuedMorphs.push(...ctx.popBranch().queuedMorphs)" : "ctx.popBranch()")) .line("errors.push(ctx.popBranch().error)"); } js.line(`ctx.errorFromNodeContext({ code: "union", errors, meta: ${this.compiledMeta} })`); } else { const { optimistic } = js; // only the first layer can be optimistic js.optimistic = false; for (const branch of this.branches) { js.if(`${js.invoke(branch)}`, () => js.return(optimistic ? branch.contextFreeMorph ? `${registeredReference(branch.contextFreeMorph)}(data)` : "data" : true)); } js.return(optimistic ? `"${unset}"` : false); } } get nestableExpression() { // avoid adding unnecessary parentheses around boolean since it's // already collapsed to a single keyword return this.isBoolean ? "boolean" : `(${this.expression})`; } discriminate() { if (this.branches.length < 2 || this.isCyclic) return null; if (this.unitBranches.length === this.branches.length) { const cases = flatMorph(this.unitBranches, (i, n) => [ `${n.rawIn.serializedValue}`, n.hasKind("morph") ? n : true ]); return { kind: "unit", path: [], optionallyChainedPropString: "data", cases }; } const candidates = []; for (let lIndex = 0; lIndex < this.branches.length - 1; lIndex++) { const l = this.branches[lIndex]; for (let rIndex = lIndex + 1; rIndex < this.branches.length; rIndex++) { const r = this.branches[rIndex]; const result = intersectNodesRoot(l.rawIn, r.rawIn, l.$); if (!(result instanceof Disjoint)) continue; for (const entry of result) { if (!entry.kind || entry.optional) continue; let lSerialized; let rSerialized; if (entry.kind === "domain") { const lValue = entry.l; const rValue = entry.r; lSerialized = `"${typeof lValue === "string" ? lValue : lValue.domain}"`; rSerialized = `"${typeof rValue === "string" ? rValue : rValue.domain}"`; } else if (entry.kind === "unit") { lSerialized = entry.l.serializedValue; rSerialized = entry.r.serializedValue; } else continue; const matching = candidates.find(d => arrayEquals(d.path, entry.path) && d.kind === entry.kind); if (!matching) { candidates.push({ kind: entry.kind, cases: { [lSerialized]: { branchIndices: [lIndex], condition: entry.l }, [rSerialized]: { branchIndices: [rIndex], condition: entry.r } }, path: entry.path }); } else { if (matching.cases[lSerialized]) { matching.cases[lSerialized].branchIndices = appendUnique(matching.cases[lSerialized].branchIndices, lIndex); } else { matching.cases[lSerialized] ??= { branchIndices: [lIndex], condition: entry.l }; } if (matching.cases[rSerialized]) { matching.cases[rSerialized].branchIndices = appendUnique(matching.cases[rSerialized].branchIndices, rIndex); } else { matching.cases[rSerialized] ??= { branchIndices: [rIndex], condition: entry.r }; } } } } } const viableCandidates = this.ordered ? viableOrderedCandidates(candidates, this.branches) : candidates; if (!viableCandidates.length) return null; const ctx = createCaseResolutionContext(viableCandidates, this); const cases = {}; for (const k in ctx.best.cases) { const resolution = resolveCase(ctx, k); if (resolution === null) { cases[k] = true; continue; } // if all the branches ended up back in pruned, we'd loop if we continued // so just bail out- nothing left to discriminate if (resolution.length === this.branches.length) return null; if (this.ordered) { // ensure the original order of the pruned branches is preserved resolution.sort((l, r) => l.originalIndex - r.originalIndex); } const branches = resolution.map(entry => entry.branch); const caseNode = branches.length === 1 ? branches[0] : this.$.node("union", this.ordered ? { branches, ordered: true } : branches); Object.assign(this.referencesById, caseNode.referencesById); cases[k] = caseNode; } if (ctx.defaultEntries.length) { // we don't have to worry about order here as it is always preserved // within defaultEntries const branches = ctx.defaultEntries.map(entry => entry.branch); cases.default = this.$.node("union", this.ordered ? { branches, ordered: true } : branches, { prereduced: true }); Object.assign(this.referencesById, cases.default.referencesById); } return Object.assign(ctx.location, { cases }); } } const createCaseResolutionContext = (viableCandidates, node) => { const ordered = viableCandidates.sort((l, r) => l.path.length === r.path.length ? Object.keys(r.cases).length - Object.keys(l.cases).length // prefer shorter paths first : l.path.length - r.path.length); const best = ordered[0]; const location = { kind: best.kind, path: best.path, optionallyChainedPropString: optionallyChainPropString(best.path) }; const defaultEntries = node.branches.map((branch, originalIndex) => ({ originalIndex, branch })); return { best, location, defaultEntries, node }; }; const resolveCase = (ctx, key) => { const caseCtx = ctx.best.cases[key]; const discriminantNode = discriminantCaseToNode(caseCtx.condition, ctx.location.path, ctx.node.$); let resolvedEntries = []; const nextDefaults = []; for (let i = 0; i < ctx.defaultEntries.length; i++) { const entry = ctx.defaultEntries[i]; if (caseCtx.branchIndices.includes(entry.originalIndex)) { const pruned = pruneDiscriminant(ctx.node.branches[entry.originalIndex], ctx.location); if (pruned === null) { // if any branch of the union has no constraints (i.e. is // unknown), the others won't affect the resolution type, but could still // remove additional cases from defaultEntries resolvedEntries = null; } else { resolvedEntries?.push({ originalIndex: entry.originalIndex, branch: pruned }); } } else if ( // we shouldn't need a special case for alias to avoid the below // once alias resolution issues are improved: // https://github.com/arktypeio/arktype/issues/1026 entry.branch.hasKind("alias") && discriminantNode.hasKind("domain") && discriminantNode.domain === "object") resolvedEntries?.push(entry); else { if (entry.branch.rawIn.overlaps(discriminantNode)) { // include cases where an object not including the // discriminant path might have that value present as an undeclared key const overlapping = pruneDiscriminant(entry.branch, ctx.location); resolvedEntries?.push({ originalIndex: entry.originalIndex, branch: overlapping }); } nextDefaults.push(entry); } } ctx.defaultEntries = nextDefaults; return resolvedEntries; }; const viableOrderedCandidates = (candidates, originalBranches) => { const viableCandidates = candidates.filter(candidate => { const caseGroups = Object.values(candidate.cases).map(caseCtx => caseCtx.branchIndices); // compare each group against all subsequent groups. for (let i = 0; i < caseGroups.length - 1; i++) { const currentGroup = caseGroups[i]; for (let j = i + 1; j < caseGroups.length; j++) { const nextGroup = caseGroups[j]; // for each group pair, check for branches whose order was reversed for (const currentIndex of currentGroup) { for (const nextIndex of nextGroup) { if (currentIndex > nextIndex) { if (originalBranches[currentIndex].overlaps(originalBranches[nextIndex])) { // if the order was not preserved and the branches overlap, // this is not a viable discriminant as it cannot guarantee the same behavior return false; } } } } } } // branch groups preserved order for non-disjoint pairs and is viable return true; }); return viableCandidates; }; const discriminantCaseToNode = (caseDiscriminant, path, $) => { let node = caseDiscriminant === "undefined" ? $.node("unit", { unit: undefined }) : caseDiscriminant === "null" ? $.node("unit", { unit: null }) : caseDiscriminant === "boolean" ? $.units([true, false]) : caseDiscriminant; for (let i = path.length - 1; i >= 0; i--) { const key = path[i]; node = $.node("intersection", typeof key === "number" ? { proto: "Array", // create unknown for preceding elements (could be optimized with safe imports) sequence: [...range(key).map(_ => ({})), node] } : { domain: "object", required: [{ key, value: node }] }); } return node; }; const optionallyChainPropString = (path) => path.reduce((acc, k) => acc + compileLiteralPropAccess(k, true), "data"); const serializedTypeOfDescriptions = registeredReference(jsTypeOfDescriptions); const serializedPrintable = registeredReference(printable); export const Union = { implementation, Node: UnionNode }; const discriminantToJson = (discriminant) => ({ kind: discriminant.kind, path: discriminant.path.map(k => typeof k === "string" ? k : compileSerializedValue(k)), cases: flatMorph(discriminant.cases, (k, node) => [ k, node === true ? node : node.hasKind("union") && node.discriminantJson ? node.discriminantJson : node.json ]) }); const describeExpressionOptions = { delimiter: " | ", finalDelimiter: " | " }; const expressBranches = (expressions) => describeBranches(expressions, describeExpressionOptions); export const describeBranches = (descriptions, opts) => { const delimiter = opts?.delimiter ?? ", "; const finalDelimiter = opts?.finalDelimiter ?? " or "; if (descriptions.length === 0) return "never"; if (descriptions.length === 1) return descriptions[0]; if ((descriptions.length === 2 && descriptions[0] === "false" && descriptions[1] === "true") || (descriptions[0] === "true" && descriptions[1] === "false")) return "boolean"; // keep track of seen descriptions to avoid duplication const seen = {}; const unique = descriptions.filter(s => (seen[s] ? false : (seen[s] = true))); const last = unique.pop(); return `${unique.join(delimiter)}${unique.length ? finalDelimiter : ""}${last}`; }; export const intersectBranches = (l, r, ctx) => { // If the corresponding r branch is identified as a subtype of an l branch, the // value at rIndex is set to null so we can avoid including previous/future // intersections in the reduced result. const batchesByR = r.map(() => []); for (let lIndex = 0; lIndex < l.length; lIndex++) { let candidatesByR = {}; for (let rIndex = 0; rIndex < r.length; rIndex++) { if (batchesByR[rIndex] === null) { // rBranch is a subtype of an lBranch and // will not yield any distinct intersection continue; } if (l[lIndex].equals(r[rIndex])) { // Combination of subtype and supertype cases batchesByR[rIndex] = null; candidatesByR = {}; break; } const branchIntersection = intersectOrPipeNodes(l[lIndex], r[rIndex], ctx); if (branchIntersection instanceof Disjoint) { // Doesn't tell us anything useful about their relationships // with other branches continue; } if (branchIntersection.equals(l[lIndex])) { // If the current l branch is a subtype of r, intersections // with previous and remaining branches of r won't lead to // distinct intersections. batchesByR[rIndex].push(l[lIndex]); candidatesByR = {}; break; } if (branchIntersection.equals(r[rIndex])) { // If the current r branch is a subtype of l, set its batch to // null, removing any previous intersections and preventing any // of its remaining intersections from being computed. batchesByR[rIndex] = null; } else { // If neither l nor r is a subtype of the other, add their // intersection as a candidate (could still be removed if it is // determined l or r is a subtype of a remaining branch). candidatesByR[rIndex] = branchIntersection; } } for (const rIndex in candidatesByR) { // batchesByR at rIndex should never be null if it is in candidatesByR batchesByR[rIndex][lIndex] = candidatesByR[rIndex]; } } // Compile the reduced intersection result, including: // 1. Remaining candidates resulting from distinct intersections or strict subtypes of r // 2. Original r branches corresponding to indices with a null batch (subtypes of l) const resultBranches = batchesByR.flatMap( // ensure unions returned from branchable intersections like sequence are flattened (batch, i) => batch?.flatMap(branch => branch.branches) ?? r[i]); return resultBranches.length === 0 ? Disjoint.init("union", l, r) : resultBranches; }; export const reduceBranches = ({ branches, ordered }) => { if (branches.length < 2) return branches; const uniquenessByIndex = branches.map(() => true); for (let i = 0; i < branches.length; i++) { for (let j = i + 1; j < branches.length && uniquenessByIndex[i] && uniquenessByIndex[j]; j++) { if (branches[i].equals(branches[j])) { // if the two branches are equal, only "j" is marked as // redundant so at least one copy could still be included in // the final set of branches. uniquenessByIndex[j] = false; continue; } const intersection = intersectNodesRoot(branches[i].rawIn, branches[j].rawIn, branches[0].$); if (intersection instanceof Disjoint) continue; if (!ordered) assertDeterminateOverlap(branches[i], branches[j]); if (intersection.equals(branches[i].rawIn)) { // preserve ordered branches that are a subtype of a subsequent branch uniquenessByIndex[i] = !!ordered; } else if (intersection.equals(branches[j].rawIn)) uniquenessByIndex[j] = false; } } return branches.filter((_, i) => uniquenessByIndex[i]); }; const assertDeterminateOverlap = (l, r) => { if (!l.includesTransform && !r.includesTransform) return; if (!arrayEquals(l.shallowMorphs, r.shallowMorphs)) { throwParseError(writeIndiscriminableMorphMessage(l.expression, r.expression)); } if (!arrayEquals(l.flatMorphs, r.flatMorphs, { isEqual: (l, r) => l.propString === r.propString && (l.node.hasKind("morph") && r.node.hasKind("morph") ? l.node.hasEqualMorphs(r.node) : l.node.hasKind("intersection") && r.node.hasKind("intersection") ? l.node.structure?.structuralMorphRef === r.node.structure?.structuralMorphRef : false) })) { throwParseError(writeIndiscriminableMorphMessage(l.expression, r.expression)); } }; export const pruneDiscriminant = (discriminantBranch, discriminantCtx) => discriminantBranch.transform((nodeKind, inner) => { if (nodeKind === "domain" || nodeKind === "unit") return null; return inner; }, { shouldTransform: (node, ctx) => { // safe to cast here as index nodes are never discriminants const propString = optionallyChainPropString(ctx.path); if (!discriminantCtx.optionallyChainedPropString.startsWith(propString)) return false; if (node.hasKind("domain") && node.domain === "object") // if we've already checked a path at least as long as the current one, // we don't need to revalidate that we're in an object return true; if ((node.hasKind("domain") || discriminantCtx.kind === "unit") && propString === discriminantCtx.optionallyChainedPropString) // if the discriminant has already checked the domain at the current path // (or a unit literal, implying a domain), we don't need to recheck it return true; // we don't need to recurse into index nodes as they will never // have a required path therefore can't be used to discriminate return node.children.length !== 0 && node.kind !== "index"; } }); export const writeIndiscriminableMorphMessage = (lDescription, rDescription) => `An unordered union of a type including a morph and a type with overlapping input is indeterminate: Left: ${lDescription} Right: ${rDescription}`; export const writeOrderedIntersectionMessage = (lDescription, rDescription) => `The intersection of two ordered unions is indeterminate: Left: ${lDescription} Right: ${rDescription}`;