Teach local AI from user category corrections
- Add MerchantCorrection model: upsert by merchantName, Category enum - Check corrections DB first in suggestCategoryForMerchant (source: "learned", no confirmation required); falls through to rules then Ollama if no match - Inject recent corrections as few-shot examples in the Ollama prompt so the model improves even for merchants not yet explicitly corrected - Add POST /categories/correct route to persist corrections - Detect category override on form save (suggestedCategory !== chosen category) and silently fire a correction — no extra UX required - Fix test isolation: beforeEach re-applies vi.fn() defaults after restoreAllMocks Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,11 @@
|
||||
-- CreateTable
|
||||
CREATE TABLE "MerchantCorrection" (
|
||||
"id" TEXT NOT NULL PRIMARY KEY,
|
||||
"merchantName" TEXT NOT NULL,
|
||||
"category" TEXT NOT NULL,
|
||||
"createdAt" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" DATETIME NOT NULL
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "MerchantCorrection_merchantName_key" ON "MerchantCorrection"("merchantName");
|
||||
@@ -42,6 +42,14 @@ model PaySchedule {
|
||||
createdAt DateTime @default(now())
|
||||
}
|
||||
|
||||
model MerchantCorrection {
|
||||
id String @id @default(cuid())
|
||||
merchantName String @unique
|
||||
category Category
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
}
|
||||
|
||||
model MonthlyInsight {
|
||||
id String @id @default(cuid())
|
||||
month String @unique
|
||||
|
||||
25
src/app/categories/correct/route.ts
Normal file
25
src/app/categories/correct/route.ts
Normal file
@@ -0,0 +1,25 @@
|
||||
import { Category } from "@prisma/client";
|
||||
import { NextResponse } from "next/server";
|
||||
import { z } from "zod";
|
||||
|
||||
import { saveMerchantCorrection } from "@/lib/merchant-corrections";
|
||||
|
||||
const correctionSchema = z.object({
|
||||
merchantName: z.string().trim().min(1).max(80),
|
||||
category: z.nativeEnum(Category, { message: "Choose a valid category." }),
|
||||
});
|
||||
|
||||
export async function POST(request: Request) {
|
||||
const payload = await request.json().catch(() => null);
|
||||
const parsed = correctionSchema.safeParse(payload);
|
||||
|
||||
if (!parsed.success) {
|
||||
return NextResponse.json(
|
||||
{ error: parsed.error.issues[0]?.message ?? "Invalid correction." },
|
||||
{ status: 400 },
|
||||
);
|
||||
}
|
||||
|
||||
await saveMerchantCorrection(parsed.data.merchantName, parsed.data.category);
|
||||
return NextResponse.json({ ok: true });
|
||||
}
|
||||
@@ -10,7 +10,7 @@ type SuggestionResponse = {
|
||||
message: string;
|
||||
merchantName: string;
|
||||
requiresConfirmation: boolean;
|
||||
source: "rule" | "model" | "unavailable";
|
||||
source: "rule" | "model" | "unavailable" | "learned";
|
||||
};
|
||||
|
||||
type ExpenseRecord = {
|
||||
@@ -49,6 +49,7 @@ export function ExpenseWorkspace({ categoryOptions }: Props) {
|
||||
const [suggestionMessage, setSuggestionMessage] = useState<string | null>(null);
|
||||
const [needsSuggestionConfirmation, setNeedsSuggestionConfirmation] = useState(false);
|
||||
const [lastSuggestedMerchant, setLastSuggestedMerchant] = useState("");
|
||||
const [suggestedCategory, setSuggestedCategory] = useState<CategoryValue | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
async function loadExpenses() {
|
||||
@@ -90,8 +91,8 @@ export function ExpenseWorkspace({ categoryOptions }: Props) {
|
||||
setSuggestionMessage(suggestion.message);
|
||||
|
||||
if (suggestion.category) {
|
||||
const suggestedCategory = suggestion.category;
|
||||
setFormState((current) => ({ ...current, category: suggestedCategory }));
|
||||
setFormState((current) => ({ ...current, category: suggestion.category! }));
|
||||
setSuggestedCategory(suggestion.category);
|
||||
}
|
||||
|
||||
setNeedsSuggestionConfirmation(suggestion.requiresConfirmation);
|
||||
@@ -108,6 +109,7 @@ export function ExpenseWorkspace({ categoryOptions }: Props) {
|
||||
setSuggestionMessage(null);
|
||||
setNeedsSuggestionConfirmation(false);
|
||||
setLastSuggestedMerchant("");
|
||||
setSuggestedCategory(null);
|
||||
setError(null);
|
||||
}
|
||||
|
||||
@@ -122,6 +124,7 @@ export function ExpenseWorkspace({ categoryOptions }: Props) {
|
||||
setSuggestionMessage(null);
|
||||
setNeedsSuggestionConfirmation(false);
|
||||
setLastSuggestedMerchant("");
|
||||
setSuggestedCategory(null);
|
||||
setError(null);
|
||||
}
|
||||
|
||||
@@ -133,6 +136,20 @@ export function ExpenseWorkspace({ categoryOptions }: Props) {
|
||||
return;
|
||||
}
|
||||
|
||||
// If the AI (model or learned) suggested a category and the user changed it,
|
||||
// silently record the correction so future suggestions improve.
|
||||
if (
|
||||
lastSuggestedMerchant &&
|
||||
suggestedCategory !== null &&
|
||||
formState.category !== suggestedCategory
|
||||
) {
|
||||
void fetch("/categories/correct", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ merchantName: lastSuggestedMerchant, category: formState.category }),
|
||||
});
|
||||
}
|
||||
|
||||
setBusy(true);
|
||||
setError(null);
|
||||
|
||||
@@ -166,6 +183,7 @@ export function ExpenseWorkspace({ categoryOptions }: Props) {
|
||||
setSuggestionMessage(null);
|
||||
setNeedsSuggestionConfirmation(false);
|
||||
setLastSuggestedMerchant("");
|
||||
setSuggestedCategory(null);
|
||||
}
|
||||
|
||||
async function handleDelete(id: string) {
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
vi.mock("@/lib/merchant-corrections", () => ({
|
||||
getMerchantCorrection: vi.fn(),
|
||||
getRecentCorrections: vi.fn(),
|
||||
}));
|
||||
|
||||
import { getMerchantCorrection, getRecentCorrections } from "@/lib/merchant-corrections";
|
||||
import { getMerchantRuleCategory, suggestCategoryForMerchant } from "@/lib/category-suggestion";
|
||||
|
||||
describe("getMerchantRuleCategory", () => {
|
||||
@@ -10,10 +16,26 @@ describe("getMerchantRuleCategory", () => {
|
||||
});
|
||||
|
||||
describe("suggestCategoryForMerchant", () => {
|
||||
beforeEach(() => {
|
||||
vi.mocked(getMerchantCorrection).mockResolvedValue(null);
|
||||
vi.mocked(getRecentCorrections).mockResolvedValue([]);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it("returns learned category without confirmation when a correction exists", async () => {
|
||||
const { getMerchantCorrection } = await import("@/lib/merchant-corrections");
|
||||
vi.mocked(getMerchantCorrection).mockResolvedValueOnce({ merchantName: "Blue Tokai", category: "FOOD" });
|
||||
|
||||
const suggestion = await suggestCategoryForMerchant("Blue Tokai");
|
||||
|
||||
expect(suggestion.category).toBe("FOOD");
|
||||
expect(suggestion.source).toBe("learned");
|
||||
expect(suggestion.requiresConfirmation).toBe(false);
|
||||
});
|
||||
|
||||
it("uses the local model for unknown merchants", async () => {
|
||||
vi.spyOn(globalThis, "fetch").mockResolvedValue({
|
||||
ok: true,
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import { CATEGORY_VALUES, type CategoryValue } from "@/lib/categories";
|
||||
import { getMerchantCorrection, getRecentCorrections } from "@/lib/merchant-corrections";
|
||||
import { generateOllamaJson } from "@/lib/ollama";
|
||||
|
||||
type SuggestionSource = "rule" | "model" | "unavailable";
|
||||
type SuggestionSource = "rule" | "model" | "unavailable" | "learned";
|
||||
|
||||
export type CategorySuggestion = {
|
||||
category: CategoryValue | null;
|
||||
@@ -51,6 +52,24 @@ function parseSuggestedCategory(raw: unknown): CategoryValue | null {
|
||||
return CATEGORY_VALUES.includes(normalized as CategoryValue) ? (normalized as CategoryValue) : null;
|
||||
}
|
||||
|
||||
function buildOllamaPrompt(merchantName: string, examples: Array<{ merchantName: string; category: string }>) {
|
||||
const lines = [
|
||||
"You categorize personal expense merchants.",
|
||||
"Return JSON with one key named category.",
|
||||
"Allowed values only: RENT, FOOD, TRANSPORT, BILLS, SHOPPING, HEALTH, ENTERTAINMENT, MISC.",
|
||||
];
|
||||
|
||||
if (examples.length > 0) {
|
||||
lines.push("Use these corrections the user has made previously as guidance:");
|
||||
for (const ex of examples) {
|
||||
lines.push(` "${ex.merchantName}" → ${ex.category}`);
|
||||
}
|
||||
}
|
||||
|
||||
lines.push(`Merchant: ${merchantName}`);
|
||||
return lines.join("\n");
|
||||
}
|
||||
|
||||
export async function suggestCategoryForMerchant(merchantName: string): Promise<CategorySuggestion> {
|
||||
const normalized = normalizeMerchantName(merchantName);
|
||||
|
||||
@@ -64,6 +83,19 @@ export async function suggestCategoryForMerchant(merchantName: string): Promise<
|
||||
};
|
||||
}
|
||||
|
||||
// 1. Check stored user corrections first — highest priority, no confirmation needed.
|
||||
const learned = await getMerchantCorrection(normalized);
|
||||
if (learned) {
|
||||
return {
|
||||
category: learned.category as CategoryValue,
|
||||
message: "Category auto-filled from your previous correction.",
|
||||
merchantName: normalized,
|
||||
requiresConfirmation: false,
|
||||
source: "learned",
|
||||
};
|
||||
}
|
||||
|
||||
// 2. Hardcoded rules for well-known merchants.
|
||||
const matchedCategory = getMerchantRuleCategory(normalized);
|
||||
if (matchedCategory) {
|
||||
return {
|
||||
@@ -75,11 +107,11 @@ export async function suggestCategoryForMerchant(merchantName: string): Promise<
|
||||
};
|
||||
}
|
||||
|
||||
// 3. Ask Ollama, providing recent user corrections as few-shot examples.
|
||||
try {
|
||||
const recentCorrections = await getRecentCorrections(20);
|
||||
const parsed = await generateOllamaJson<{ category?: string }>({
|
||||
prompt:
|
||||
"You categorize personal expense merchants. Return JSON with one key named category. Allowed values only: RENT, FOOD, TRANSPORT, BILLS, SHOPPING, HEALTH, ENTERTAINMENT, MISC. Merchant: " +
|
||||
normalized,
|
||||
prompt: buildOllamaPrompt(normalized, recentCorrections),
|
||||
});
|
||||
const category = parseSuggestedCategory(parsed?.category);
|
||||
|
||||
|
||||
28
src/lib/merchant-corrections.ts
Normal file
28
src/lib/merchant-corrections.ts
Normal file
@@ -0,0 +1,28 @@
|
||||
import type { Category } from "@prisma/client";
|
||||
|
||||
import { db } from "@/lib/db";
|
||||
|
||||
export type MerchantCorrection = {
|
||||
merchantName: string;
|
||||
category: Category;
|
||||
};
|
||||
|
||||
export async function getMerchantCorrection(merchantName: string): Promise<MerchantCorrection | null> {
|
||||
return db.merchantCorrection.findUnique({ where: { merchantName } });
|
||||
}
|
||||
|
||||
export async function getRecentCorrections(limit = 30): Promise<MerchantCorrection[]> {
|
||||
return db.merchantCorrection.findMany({
|
||||
orderBy: { updatedAt: "desc" },
|
||||
take: limit,
|
||||
select: { merchantName: true, category: true },
|
||||
});
|
||||
}
|
||||
|
||||
export async function saveMerchantCorrection(merchantName: string, category: Category): Promise<void> {
|
||||
await db.merchantCorrection.upsert({
|
||||
where: { merchantName },
|
||||
update: { category },
|
||||
create: { merchantName, category },
|
||||
});
|
||||
}
|
||||
Reference in New Issue
Block a user