Teach local AI from user category corrections
- Add MerchantCorrection model: upsert by merchantName, Category enum - Check corrections DB first in suggestCategoryForMerchant (source: "learned", no confirmation required); falls through to rules then Ollama if no match - Inject recent corrections as few-shot examples in the Ollama prompt so the model improves even for merchants not yet explicitly corrected - Add POST /categories/correct route to persist corrections - Detect category override on form save (suggestedCategory !== chosen category) and silently fire a correction — no extra UX required - Fix test isolation: beforeEach re-applies vi.fn() defaults after restoreAllMocks Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,11 @@
|
|||||||
|
-- CreateTable
|
||||||
|
CREATE TABLE "MerchantCorrection" (
|
||||||
|
"id" TEXT NOT NULL PRIMARY KEY,
|
||||||
|
"merchantName" TEXT NOT NULL,
|
||||||
|
"category" TEXT NOT NULL,
|
||||||
|
"createdAt" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"updatedAt" DATETIME NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE UNIQUE INDEX "MerchantCorrection_merchantName_key" ON "MerchantCorrection"("merchantName");
|
||||||
@@ -42,6 +42,14 @@ model PaySchedule {
|
|||||||
createdAt DateTime @default(now())
|
createdAt DateTime @default(now())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
model MerchantCorrection {
|
||||||
|
id String @id @default(cuid())
|
||||||
|
merchantName String @unique
|
||||||
|
category Category
|
||||||
|
createdAt DateTime @default(now())
|
||||||
|
updatedAt DateTime @updatedAt
|
||||||
|
}
|
||||||
|
|
||||||
model MonthlyInsight {
|
model MonthlyInsight {
|
||||||
id String @id @default(cuid())
|
id String @id @default(cuid())
|
||||||
month String @unique
|
month String @unique
|
||||||
|
|||||||
25
src/app/categories/correct/route.ts
Normal file
25
src/app/categories/correct/route.ts
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
import { Category } from "@prisma/client";
|
||||||
|
import { NextResponse } from "next/server";
|
||||||
|
import { z } from "zod";
|
||||||
|
|
||||||
|
import { saveMerchantCorrection } from "@/lib/merchant-corrections";
|
||||||
|
|
||||||
|
const correctionSchema = z.object({
|
||||||
|
merchantName: z.string().trim().min(1).max(80),
|
||||||
|
category: z.nativeEnum(Category, { message: "Choose a valid category." }),
|
||||||
|
});
|
||||||
|
|
||||||
|
export async function POST(request: Request) {
|
||||||
|
const payload = await request.json().catch(() => null);
|
||||||
|
const parsed = correctionSchema.safeParse(payload);
|
||||||
|
|
||||||
|
if (!parsed.success) {
|
||||||
|
return NextResponse.json(
|
||||||
|
{ error: parsed.error.issues[0]?.message ?? "Invalid correction." },
|
||||||
|
{ status: 400 },
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
await saveMerchantCorrection(parsed.data.merchantName, parsed.data.category);
|
||||||
|
return NextResponse.json({ ok: true });
|
||||||
|
}
|
||||||
@@ -10,7 +10,7 @@ type SuggestionResponse = {
|
|||||||
message: string;
|
message: string;
|
||||||
merchantName: string;
|
merchantName: string;
|
||||||
requiresConfirmation: boolean;
|
requiresConfirmation: boolean;
|
||||||
source: "rule" | "model" | "unavailable";
|
source: "rule" | "model" | "unavailable" | "learned";
|
||||||
};
|
};
|
||||||
|
|
||||||
type ExpenseRecord = {
|
type ExpenseRecord = {
|
||||||
@@ -49,6 +49,7 @@ export function ExpenseWorkspace({ categoryOptions }: Props) {
|
|||||||
const [suggestionMessage, setSuggestionMessage] = useState<string | null>(null);
|
const [suggestionMessage, setSuggestionMessage] = useState<string | null>(null);
|
||||||
const [needsSuggestionConfirmation, setNeedsSuggestionConfirmation] = useState(false);
|
const [needsSuggestionConfirmation, setNeedsSuggestionConfirmation] = useState(false);
|
||||||
const [lastSuggestedMerchant, setLastSuggestedMerchant] = useState("");
|
const [lastSuggestedMerchant, setLastSuggestedMerchant] = useState("");
|
||||||
|
const [suggestedCategory, setSuggestedCategory] = useState<CategoryValue | null>(null);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
async function loadExpenses() {
|
async function loadExpenses() {
|
||||||
@@ -90,8 +91,8 @@ export function ExpenseWorkspace({ categoryOptions }: Props) {
|
|||||||
setSuggestionMessage(suggestion.message);
|
setSuggestionMessage(suggestion.message);
|
||||||
|
|
||||||
if (suggestion.category) {
|
if (suggestion.category) {
|
||||||
const suggestedCategory = suggestion.category;
|
setFormState((current) => ({ ...current, category: suggestion.category! }));
|
||||||
setFormState((current) => ({ ...current, category: suggestedCategory }));
|
setSuggestedCategory(suggestion.category);
|
||||||
}
|
}
|
||||||
|
|
||||||
setNeedsSuggestionConfirmation(suggestion.requiresConfirmation);
|
setNeedsSuggestionConfirmation(suggestion.requiresConfirmation);
|
||||||
@@ -108,6 +109,7 @@ export function ExpenseWorkspace({ categoryOptions }: Props) {
|
|||||||
setSuggestionMessage(null);
|
setSuggestionMessage(null);
|
||||||
setNeedsSuggestionConfirmation(false);
|
setNeedsSuggestionConfirmation(false);
|
||||||
setLastSuggestedMerchant("");
|
setLastSuggestedMerchant("");
|
||||||
|
setSuggestedCategory(null);
|
||||||
setError(null);
|
setError(null);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -122,6 +124,7 @@ export function ExpenseWorkspace({ categoryOptions }: Props) {
|
|||||||
setSuggestionMessage(null);
|
setSuggestionMessage(null);
|
||||||
setNeedsSuggestionConfirmation(false);
|
setNeedsSuggestionConfirmation(false);
|
||||||
setLastSuggestedMerchant("");
|
setLastSuggestedMerchant("");
|
||||||
|
setSuggestedCategory(null);
|
||||||
setError(null);
|
setError(null);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -133,6 +136,20 @@ export function ExpenseWorkspace({ categoryOptions }: Props) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If the AI (model or learned) suggested a category and the user changed it,
|
||||||
|
// silently record the correction so future suggestions improve.
|
||||||
|
if (
|
||||||
|
lastSuggestedMerchant &&
|
||||||
|
suggestedCategory !== null &&
|
||||||
|
formState.category !== suggestedCategory
|
||||||
|
) {
|
||||||
|
void fetch("/categories/correct", {
|
||||||
|
method: "POST",
|
||||||
|
headers: { "Content-Type": "application/json" },
|
||||||
|
body: JSON.stringify({ merchantName: lastSuggestedMerchant, category: formState.category }),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
setBusy(true);
|
setBusy(true);
|
||||||
setError(null);
|
setError(null);
|
||||||
|
|
||||||
@@ -166,6 +183,7 @@ export function ExpenseWorkspace({ categoryOptions }: Props) {
|
|||||||
setSuggestionMessage(null);
|
setSuggestionMessage(null);
|
||||||
setNeedsSuggestionConfirmation(false);
|
setNeedsSuggestionConfirmation(false);
|
||||||
setLastSuggestedMerchant("");
|
setLastSuggestedMerchant("");
|
||||||
|
setSuggestedCategory(null);
|
||||||
}
|
}
|
||||||
|
|
||||||
async function handleDelete(id: string) {
|
async function handleDelete(id: string) {
|
||||||
|
|||||||
@@ -1,5 +1,11 @@
|
|||||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||||
|
|
||||||
|
vi.mock("@/lib/merchant-corrections", () => ({
|
||||||
|
getMerchantCorrection: vi.fn(),
|
||||||
|
getRecentCorrections: vi.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
import { getMerchantCorrection, getRecentCorrections } from "@/lib/merchant-corrections";
|
||||||
import { getMerchantRuleCategory, suggestCategoryForMerchant } from "@/lib/category-suggestion";
|
import { getMerchantRuleCategory, suggestCategoryForMerchant } from "@/lib/category-suggestion";
|
||||||
|
|
||||||
describe("getMerchantRuleCategory", () => {
|
describe("getMerchantRuleCategory", () => {
|
||||||
@@ -10,10 +16,26 @@ describe("getMerchantRuleCategory", () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
describe("suggestCategoryForMerchant", () => {
|
describe("suggestCategoryForMerchant", () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.mocked(getMerchantCorrection).mockResolvedValue(null);
|
||||||
|
vi.mocked(getRecentCorrections).mockResolvedValue([]);
|
||||||
|
});
|
||||||
|
|
||||||
afterEach(() => {
|
afterEach(() => {
|
||||||
vi.restoreAllMocks();
|
vi.restoreAllMocks();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("returns learned category without confirmation when a correction exists", async () => {
|
||||||
|
const { getMerchantCorrection } = await import("@/lib/merchant-corrections");
|
||||||
|
vi.mocked(getMerchantCorrection).mockResolvedValueOnce({ merchantName: "Blue Tokai", category: "FOOD" });
|
||||||
|
|
||||||
|
const suggestion = await suggestCategoryForMerchant("Blue Tokai");
|
||||||
|
|
||||||
|
expect(suggestion.category).toBe("FOOD");
|
||||||
|
expect(suggestion.source).toBe("learned");
|
||||||
|
expect(suggestion.requiresConfirmation).toBe(false);
|
||||||
|
});
|
||||||
|
|
||||||
it("uses the local model for unknown merchants", async () => {
|
it("uses the local model for unknown merchants", async () => {
|
||||||
vi.spyOn(globalThis, "fetch").mockResolvedValue({
|
vi.spyOn(globalThis, "fetch").mockResolvedValue({
|
||||||
ok: true,
|
ok: true,
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
import { CATEGORY_VALUES, type CategoryValue } from "@/lib/categories";
|
import { CATEGORY_VALUES, type CategoryValue } from "@/lib/categories";
|
||||||
|
import { getMerchantCorrection, getRecentCorrections } from "@/lib/merchant-corrections";
|
||||||
import { generateOllamaJson } from "@/lib/ollama";
|
import { generateOllamaJson } from "@/lib/ollama";
|
||||||
|
|
||||||
type SuggestionSource = "rule" | "model" | "unavailable";
|
type SuggestionSource = "rule" | "model" | "unavailable" | "learned";
|
||||||
|
|
||||||
export type CategorySuggestion = {
|
export type CategorySuggestion = {
|
||||||
category: CategoryValue | null;
|
category: CategoryValue | null;
|
||||||
@@ -51,6 +52,24 @@ function parseSuggestedCategory(raw: unknown): CategoryValue | null {
|
|||||||
return CATEGORY_VALUES.includes(normalized as CategoryValue) ? (normalized as CategoryValue) : null;
|
return CATEGORY_VALUES.includes(normalized as CategoryValue) ? (normalized as CategoryValue) : null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function buildOllamaPrompt(merchantName: string, examples: Array<{ merchantName: string; category: string }>) {
|
||||||
|
const lines = [
|
||||||
|
"You categorize personal expense merchants.",
|
||||||
|
"Return JSON with one key named category.",
|
||||||
|
"Allowed values only: RENT, FOOD, TRANSPORT, BILLS, SHOPPING, HEALTH, ENTERTAINMENT, MISC.",
|
||||||
|
];
|
||||||
|
|
||||||
|
if (examples.length > 0) {
|
||||||
|
lines.push("Use these corrections the user has made previously as guidance:");
|
||||||
|
for (const ex of examples) {
|
||||||
|
lines.push(` "${ex.merchantName}" → ${ex.category}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
lines.push(`Merchant: ${merchantName}`);
|
||||||
|
return lines.join("\n");
|
||||||
|
}
|
||||||
|
|
||||||
export async function suggestCategoryForMerchant(merchantName: string): Promise<CategorySuggestion> {
|
export async function suggestCategoryForMerchant(merchantName: string): Promise<CategorySuggestion> {
|
||||||
const normalized = normalizeMerchantName(merchantName);
|
const normalized = normalizeMerchantName(merchantName);
|
||||||
|
|
||||||
@@ -64,6 +83,19 @@ export async function suggestCategoryForMerchant(merchantName: string): Promise<
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 1. Check stored user corrections first — highest priority, no confirmation needed.
|
||||||
|
const learned = await getMerchantCorrection(normalized);
|
||||||
|
if (learned) {
|
||||||
|
return {
|
||||||
|
category: learned.category as CategoryValue,
|
||||||
|
message: "Category auto-filled from your previous correction.",
|
||||||
|
merchantName: normalized,
|
||||||
|
requiresConfirmation: false,
|
||||||
|
source: "learned",
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Hardcoded rules for well-known merchants.
|
||||||
const matchedCategory = getMerchantRuleCategory(normalized);
|
const matchedCategory = getMerchantRuleCategory(normalized);
|
||||||
if (matchedCategory) {
|
if (matchedCategory) {
|
||||||
return {
|
return {
|
||||||
@@ -75,11 +107,11 @@ export async function suggestCategoryForMerchant(merchantName: string): Promise<
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 3. Ask Ollama, providing recent user corrections as few-shot examples.
|
||||||
try {
|
try {
|
||||||
|
const recentCorrections = await getRecentCorrections(20);
|
||||||
const parsed = await generateOllamaJson<{ category?: string }>({
|
const parsed = await generateOllamaJson<{ category?: string }>({
|
||||||
prompt:
|
prompt: buildOllamaPrompt(normalized, recentCorrections),
|
||||||
"You categorize personal expense merchants. Return JSON with one key named category. Allowed values only: RENT, FOOD, TRANSPORT, BILLS, SHOPPING, HEALTH, ENTERTAINMENT, MISC. Merchant: " +
|
|
||||||
normalized,
|
|
||||||
});
|
});
|
||||||
const category = parseSuggestedCategory(parsed?.category);
|
const category = parseSuggestedCategory(parsed?.category);
|
||||||
|
|
||||||
|
|||||||
28
src/lib/merchant-corrections.ts
Normal file
28
src/lib/merchant-corrections.ts
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
import type { Category } from "@prisma/client";
|
||||||
|
|
||||||
|
import { db } from "@/lib/db";
|
||||||
|
|
||||||
|
export type MerchantCorrection = {
|
||||||
|
merchantName: string;
|
||||||
|
category: Category;
|
||||||
|
};
|
||||||
|
|
||||||
|
export async function getMerchantCorrection(merchantName: string): Promise<MerchantCorrection | null> {
|
||||||
|
return db.merchantCorrection.findUnique({ where: { merchantName } });
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function getRecentCorrections(limit = 30): Promise<MerchantCorrection[]> {
|
||||||
|
return db.merchantCorrection.findMany({
|
||||||
|
orderBy: { updatedAt: "desc" },
|
||||||
|
take: limit,
|
||||||
|
select: { merchantName: true, category: true },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function saveMerchantCorrection(merchantName: string, category: Category): Promise<void> {
|
||||||
|
await db.merchantCorrection.upsert({
|
||||||
|
where: { merchantName },
|
||||||
|
update: { category },
|
||||||
|
create: { merchantName, category },
|
||||||
|
});
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user