Teach local AI from user category corrections

- Add MerchantCorrection model: upsert by merchantName, Category enum
- Check corrections DB first in suggestCategoryForMerchant (source: "learned",
  no confirmation required); falls through to rules then Ollama if no match
- Inject recent corrections as few-shot examples in the Ollama prompt so the
  model improves even for merchants not yet explicitly corrected
- Add POST /categories/correct route to persist corrections
- Detect category override on form save (suggestedCategory !== chosen category)
  and silently fire a correction — no extra UX required
- Fix test isolation: beforeEach re-applies vi.fn() defaults after restoreAllMocks

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-23 17:28:26 -04:00
parent 1015e24e69
commit 3e6231b654
7 changed files with 152 additions and 8 deletions

View File

@@ -0,0 +1,25 @@
import { Category } from "@prisma/client";
import { NextResponse } from "next/server";
import { z } from "zod";
import { saveMerchantCorrection } from "@/lib/merchant-corrections";
const correctionSchema = z.object({
merchantName: z.string().trim().min(1).max(80),
category: z.nativeEnum(Category, { message: "Choose a valid category." }),
});
export async function POST(request: Request) {
const payload = await request.json().catch(() => null);
const parsed = correctionSchema.safeParse(payload);
if (!parsed.success) {
return NextResponse.json(
{ error: parsed.error.issues[0]?.message ?? "Invalid correction." },
{ status: 400 },
);
}
await saveMerchantCorrection(parsed.data.merchantName, parsed.data.category);
return NextResponse.json({ ok: true });
}

View File

@@ -10,7 +10,7 @@ type SuggestionResponse = {
message: string;
merchantName: string;
requiresConfirmation: boolean;
source: "rule" | "model" | "unavailable";
source: "rule" | "model" | "unavailable" | "learned";
};
type ExpenseRecord = {
@@ -49,6 +49,7 @@ export function ExpenseWorkspace({ categoryOptions }: Props) {
const [suggestionMessage, setSuggestionMessage] = useState<string | null>(null);
const [needsSuggestionConfirmation, setNeedsSuggestionConfirmation] = useState(false);
const [lastSuggestedMerchant, setLastSuggestedMerchant] = useState("");
const [suggestedCategory, setSuggestedCategory] = useState<CategoryValue | null>(null);
useEffect(() => {
async function loadExpenses() {
@@ -90,8 +91,8 @@ export function ExpenseWorkspace({ categoryOptions }: Props) {
setSuggestionMessage(suggestion.message);
if (suggestion.category) {
const suggestedCategory = suggestion.category;
setFormState((current) => ({ ...current, category: suggestedCategory }));
setFormState((current) => ({ ...current, category: suggestion.category! }));
setSuggestedCategory(suggestion.category);
}
setNeedsSuggestionConfirmation(suggestion.requiresConfirmation);
@@ -108,6 +109,7 @@ export function ExpenseWorkspace({ categoryOptions }: Props) {
setSuggestionMessage(null);
setNeedsSuggestionConfirmation(false);
setLastSuggestedMerchant("");
setSuggestedCategory(null);
setError(null);
}
@@ -122,6 +124,7 @@ export function ExpenseWorkspace({ categoryOptions }: Props) {
setSuggestionMessage(null);
setNeedsSuggestionConfirmation(false);
setLastSuggestedMerchant("");
setSuggestedCategory(null);
setError(null);
}
@@ -133,6 +136,20 @@ export function ExpenseWorkspace({ categoryOptions }: Props) {
return;
}
// If the AI (model or learned) suggested a category and the user changed it,
// silently record the correction so future suggestions improve.
if (
lastSuggestedMerchant &&
suggestedCategory !== null &&
formState.category !== suggestedCategory
) {
void fetch("/categories/correct", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ merchantName: lastSuggestedMerchant, category: formState.category }),
});
}
setBusy(true);
setError(null);
@@ -166,6 +183,7 @@ export function ExpenseWorkspace({ categoryOptions }: Props) {
setSuggestionMessage(null);
setNeedsSuggestionConfirmation(false);
setLastSuggestedMerchant("");
setSuggestedCategory(null);
}
async function handleDelete(id: string) {

View File

@@ -1,5 +1,11 @@
import { afterEach, describe, expect, it, vi } from "vitest";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
vi.mock("@/lib/merchant-corrections", () => ({
getMerchantCorrection: vi.fn(),
getRecentCorrections: vi.fn(),
}));
import { getMerchantCorrection, getRecentCorrections } from "@/lib/merchant-corrections";
import { getMerchantRuleCategory, suggestCategoryForMerchant } from "@/lib/category-suggestion";
describe("getMerchantRuleCategory", () => {
@@ -10,10 +16,26 @@ describe("getMerchantRuleCategory", () => {
});
describe("suggestCategoryForMerchant", () => {
beforeEach(() => {
vi.mocked(getMerchantCorrection).mockResolvedValue(null);
vi.mocked(getRecentCorrections).mockResolvedValue([]);
});
afterEach(() => {
vi.restoreAllMocks();
});
it("returns learned category without confirmation when a correction exists", async () => {
const { getMerchantCorrection } = await import("@/lib/merchant-corrections");
vi.mocked(getMerchantCorrection).mockResolvedValueOnce({ merchantName: "Blue Tokai", category: "FOOD" });
const suggestion = await suggestCategoryForMerchant("Blue Tokai");
expect(suggestion.category).toBe("FOOD");
expect(suggestion.source).toBe("learned");
expect(suggestion.requiresConfirmation).toBe(false);
});
it("uses the local model for unknown merchants", async () => {
vi.spyOn(globalThis, "fetch").mockResolvedValue({
ok: true,

View File

@@ -1,7 +1,8 @@
import { CATEGORY_VALUES, type CategoryValue } from "@/lib/categories";
import { getMerchantCorrection, getRecentCorrections } from "@/lib/merchant-corrections";
import { generateOllamaJson } from "@/lib/ollama";
type SuggestionSource = "rule" | "model" | "unavailable";
type SuggestionSource = "rule" | "model" | "unavailable" | "learned";
export type CategorySuggestion = {
category: CategoryValue | null;
@@ -51,6 +52,24 @@ function parseSuggestedCategory(raw: unknown): CategoryValue | null {
return CATEGORY_VALUES.includes(normalized as CategoryValue) ? (normalized as CategoryValue) : null;
}
function buildOllamaPrompt(merchantName: string, examples: Array<{ merchantName: string; category: string }>) {
const lines = [
"You categorize personal expense merchants.",
"Return JSON with one key named category.",
"Allowed values only: RENT, FOOD, TRANSPORT, BILLS, SHOPPING, HEALTH, ENTERTAINMENT, MISC.",
];
if (examples.length > 0) {
lines.push("Use these corrections the user has made previously as guidance:");
for (const ex of examples) {
lines.push(` "${ex.merchantName}" → ${ex.category}`);
}
}
lines.push(`Merchant: ${merchantName}`);
return lines.join("\n");
}
export async function suggestCategoryForMerchant(merchantName: string): Promise<CategorySuggestion> {
const normalized = normalizeMerchantName(merchantName);
@@ -64,6 +83,19 @@ export async function suggestCategoryForMerchant(merchantName: string): Promise<
};
}
// 1. Check stored user corrections first — highest priority, no confirmation needed.
const learned = await getMerchantCorrection(normalized);
if (learned) {
return {
category: learned.category as CategoryValue,
message: "Category auto-filled from your previous correction.",
merchantName: normalized,
requiresConfirmation: false,
source: "learned",
};
}
// 2. Hardcoded rules for well-known merchants.
const matchedCategory = getMerchantRuleCategory(normalized);
if (matchedCategory) {
return {
@@ -75,11 +107,11 @@ export async function suggestCategoryForMerchant(merchantName: string): Promise<
};
}
// 3. Ask Ollama, providing recent user corrections as few-shot examples.
try {
const recentCorrections = await getRecentCorrections(20);
const parsed = await generateOllamaJson<{ category?: string }>({
prompt:
"You categorize personal expense merchants. Return JSON with one key named category. Allowed values only: RENT, FOOD, TRANSPORT, BILLS, SHOPPING, HEALTH, ENTERTAINMENT, MISC. Merchant: " +
normalized,
prompt: buildOllamaPrompt(normalized, recentCorrections),
});
const category = parseSuggestedCategory(parsed?.category);

View File

@@ -0,0 +1,28 @@
import type { Category } from "@prisma/client";
import { db } from "@/lib/db";
export type MerchantCorrection = {
merchantName: string;
category: Category;
};
export async function getMerchantCorrection(merchantName: string): Promise<MerchantCorrection | null> {
return db.merchantCorrection.findUnique({ where: { merchantName } });
}
export async function getRecentCorrections(limit = 30): Promise<MerchantCorrection[]> {
return db.merchantCorrection.findMany({
orderBy: { updatedAt: "desc" },
take: limit,
select: { merchantName: true, category: true },
});
}
export async function saveMerchantCorrection(merchantName: string, category: Category): Promise<void> {
await db.merchantCorrection.upsert({
where: { merchantName },
update: { category },
create: { merchantName, category },
});
}