diff --git a/prisma/migrations/20260323212510_add_merchant_corrections/migration.sql b/prisma/migrations/20260323212510_add_merchant_corrections/migration.sql new file mode 100644 index 0000000..f964976 --- /dev/null +++ b/prisma/migrations/20260323212510_add_merchant_corrections/migration.sql @@ -0,0 +1,11 @@ +-- CreateTable +CREATE TABLE "MerchantCorrection" ( + "id" TEXT NOT NULL PRIMARY KEY, + "merchantName" TEXT NOT NULL, + "category" TEXT NOT NULL, + "createdAt" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updatedAt" DATETIME NOT NULL +); + +-- CreateIndex +CREATE UNIQUE INDEX "MerchantCorrection_merchantName_key" ON "MerchantCorrection"("merchantName"); diff --git a/prisma/schema.prisma b/prisma/schema.prisma index 1cb7601..3be8dd3 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -42,6 +42,14 @@ model PaySchedule { createdAt DateTime @default(now()) } +model MerchantCorrection { + id String @id @default(cuid()) + merchantName String @unique + category Category + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt +} + model MonthlyInsight { id String @id @default(cuid()) month String @unique diff --git a/src/app/categories/correct/route.ts b/src/app/categories/correct/route.ts new file mode 100644 index 0000000..acfecd3 --- /dev/null +++ b/src/app/categories/correct/route.ts @@ -0,0 +1,25 @@ +import { Category } from "@prisma/client"; +import { NextResponse } from "next/server"; +import { z } from "zod"; + +import { saveMerchantCorrection } from "@/lib/merchant-corrections"; + +const correctionSchema = z.object({ + merchantName: z.string().trim().min(1).max(80), + category: z.nativeEnum(Category, { message: "Choose a valid category." }), +}); + +export async function POST(request: Request) { + const payload = await request.json().catch(() => null); + const parsed = correctionSchema.safeParse(payload); + + if (!parsed.success) { + return NextResponse.json( + { error: parsed.error.issues[0]?.message ?? "Invalid correction." }, + { status: 400 }, + ); + } + + await saveMerchantCorrection(parsed.data.merchantName, parsed.data.category); + return NextResponse.json({ ok: true }); +} diff --git a/src/components/expense-workspace.tsx b/src/components/expense-workspace.tsx index 246d23f..b968b63 100644 --- a/src/components/expense-workspace.tsx +++ b/src/components/expense-workspace.tsx @@ -10,7 +10,7 @@ type SuggestionResponse = { message: string; merchantName: string; requiresConfirmation: boolean; - source: "rule" | "model" | "unavailable"; + source: "rule" | "model" | "unavailable" | "learned"; }; type ExpenseRecord = { @@ -49,6 +49,7 @@ export function ExpenseWorkspace({ categoryOptions }: Props) { const [suggestionMessage, setSuggestionMessage] = useState(null); const [needsSuggestionConfirmation, setNeedsSuggestionConfirmation] = useState(false); const [lastSuggestedMerchant, setLastSuggestedMerchant] = useState(""); + const [suggestedCategory, setSuggestedCategory] = useState(null); useEffect(() => { async function loadExpenses() { @@ -90,8 +91,8 @@ export function ExpenseWorkspace({ categoryOptions }: Props) { setSuggestionMessage(suggestion.message); if (suggestion.category) { - const suggestedCategory = suggestion.category; - setFormState((current) => ({ ...current, category: suggestedCategory })); + setFormState((current) => ({ ...current, category: suggestion.category! })); + setSuggestedCategory(suggestion.category); } setNeedsSuggestionConfirmation(suggestion.requiresConfirmation); @@ -108,6 +109,7 @@ export function ExpenseWorkspace({ categoryOptions }: Props) { setSuggestionMessage(null); setNeedsSuggestionConfirmation(false); setLastSuggestedMerchant(""); + setSuggestedCategory(null); setError(null); } @@ -122,6 +124,7 @@ export function ExpenseWorkspace({ categoryOptions }: Props) { setSuggestionMessage(null); setNeedsSuggestionConfirmation(false); setLastSuggestedMerchant(""); + setSuggestedCategory(null); setError(null); } @@ -133,6 +136,20 @@ export function ExpenseWorkspace({ categoryOptions }: Props) { return; } + // If the AI (model or learned) suggested a category and the user changed it, + // silently record the correction so future suggestions improve. + if ( + lastSuggestedMerchant && + suggestedCategory !== null && + formState.category !== suggestedCategory + ) { + void fetch("/categories/correct", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ merchantName: lastSuggestedMerchant, category: formState.category }), + }); + } + setBusy(true); setError(null); @@ -166,6 +183,7 @@ export function ExpenseWorkspace({ categoryOptions }: Props) { setSuggestionMessage(null); setNeedsSuggestionConfirmation(false); setLastSuggestedMerchant(""); + setSuggestedCategory(null); } async function handleDelete(id: string) { diff --git a/src/lib/category-suggestion.test.ts b/src/lib/category-suggestion.test.ts index 4f9887d..65560a8 100644 --- a/src/lib/category-suggestion.test.ts +++ b/src/lib/category-suggestion.test.ts @@ -1,5 +1,11 @@ -import { afterEach, describe, expect, it, vi } from "vitest"; +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +vi.mock("@/lib/merchant-corrections", () => ({ + getMerchantCorrection: vi.fn(), + getRecentCorrections: vi.fn(), +})); + +import { getMerchantCorrection, getRecentCorrections } from "@/lib/merchant-corrections"; import { getMerchantRuleCategory, suggestCategoryForMerchant } from "@/lib/category-suggestion"; describe("getMerchantRuleCategory", () => { @@ -10,10 +16,26 @@ describe("getMerchantRuleCategory", () => { }); describe("suggestCategoryForMerchant", () => { + beforeEach(() => { + vi.mocked(getMerchantCorrection).mockResolvedValue(null); + vi.mocked(getRecentCorrections).mockResolvedValue([]); + }); + afterEach(() => { vi.restoreAllMocks(); }); + it("returns learned category without confirmation when a correction exists", async () => { + const { getMerchantCorrection } = await import("@/lib/merchant-corrections"); + vi.mocked(getMerchantCorrection).mockResolvedValueOnce({ merchantName: "Blue Tokai", category: "FOOD" }); + + const suggestion = await suggestCategoryForMerchant("Blue Tokai"); + + expect(suggestion.category).toBe("FOOD"); + expect(suggestion.source).toBe("learned"); + expect(suggestion.requiresConfirmation).toBe(false); + }); + it("uses the local model for unknown merchants", async () => { vi.spyOn(globalThis, "fetch").mockResolvedValue({ ok: true, diff --git a/src/lib/category-suggestion.ts b/src/lib/category-suggestion.ts index cfe508b..faefff5 100644 --- a/src/lib/category-suggestion.ts +++ b/src/lib/category-suggestion.ts @@ -1,7 +1,8 @@ import { CATEGORY_VALUES, type CategoryValue } from "@/lib/categories"; +import { getMerchantCorrection, getRecentCorrections } from "@/lib/merchant-corrections"; import { generateOllamaJson } from "@/lib/ollama"; -type SuggestionSource = "rule" | "model" | "unavailable"; +type SuggestionSource = "rule" | "model" | "unavailable" | "learned"; export type CategorySuggestion = { category: CategoryValue | null; @@ -51,6 +52,24 @@ function parseSuggestedCategory(raw: unknown): CategoryValue | null { return CATEGORY_VALUES.includes(normalized as CategoryValue) ? (normalized as CategoryValue) : null; } +function buildOllamaPrompt(merchantName: string, examples: Array<{ merchantName: string; category: string }>) { + const lines = [ + "You categorize personal expense merchants.", + "Return JSON with one key named category.", + "Allowed values only: RENT, FOOD, TRANSPORT, BILLS, SHOPPING, HEALTH, ENTERTAINMENT, MISC.", + ]; + + if (examples.length > 0) { + lines.push("Use these corrections the user has made previously as guidance:"); + for (const ex of examples) { + lines.push(` "${ex.merchantName}" → ${ex.category}`); + } + } + + lines.push(`Merchant: ${merchantName}`); + return lines.join("\n"); +} + export async function suggestCategoryForMerchant(merchantName: string): Promise { const normalized = normalizeMerchantName(merchantName); @@ -64,6 +83,19 @@ export async function suggestCategoryForMerchant(merchantName: string): Promise< }; } + // 1. Check stored user corrections first — highest priority, no confirmation needed. + const learned = await getMerchantCorrection(normalized); + if (learned) { + return { + category: learned.category as CategoryValue, + message: "Category auto-filled from your previous correction.", + merchantName: normalized, + requiresConfirmation: false, + source: "learned", + }; + } + + // 2. Hardcoded rules for well-known merchants. const matchedCategory = getMerchantRuleCategory(normalized); if (matchedCategory) { return { @@ -75,11 +107,11 @@ export async function suggestCategoryForMerchant(merchantName: string): Promise< }; } + // 3. Ask Ollama, providing recent user corrections as few-shot examples. try { + const recentCorrections = await getRecentCorrections(20); const parsed = await generateOllamaJson<{ category?: string }>({ - prompt: - "You categorize personal expense merchants. Return JSON with one key named category. Allowed values only: RENT, FOOD, TRANSPORT, BILLS, SHOPPING, HEALTH, ENTERTAINMENT, MISC. Merchant: " + - normalized, + prompt: buildOllamaPrompt(normalized, recentCorrections), }); const category = parseSuggestedCategory(parsed?.category); diff --git a/src/lib/merchant-corrections.ts b/src/lib/merchant-corrections.ts new file mode 100644 index 0000000..fbff3b6 --- /dev/null +++ b/src/lib/merchant-corrections.ts @@ -0,0 +1,28 @@ +import type { Category } from "@prisma/client"; + +import { db } from "@/lib/db"; + +export type MerchantCorrection = { + merchantName: string; + category: Category; +}; + +export async function getMerchantCorrection(merchantName: string): Promise { + return db.merchantCorrection.findUnique({ where: { merchantName } }); +} + +export async function getRecentCorrections(limit = 30): Promise { + return db.merchantCorrection.findMany({ + orderBy: { updatedAt: "desc" }, + take: limit, + select: { merchantName: true, category: true }, + }); +} + +export async function saveMerchantCorrection(merchantName: string, category: Category): Promise { + await db.merchantCorrection.upsert({ + where: { merchantName }, + update: { category }, + create: { merchantName, category }, + }); +}