if (process.env.NODE_ENV !== "test") require("server-only"); import crypto from "node:crypto"; import getPool from "@/lib/server/db"; import { apiError } from "@/lib/server/errors"; type LimitInput = { key: string; limit: number; windowMs: number; scope: string; }; let ensureTablePromise: Promise | null = null; let lastCleanupAtMs = 0; async function ensureRateLimitsTable() { if (!ensureTablePromise) { ensureTablePromise = (async () => { const pool = getPool(); await pool.query(` create table if not exists rate_limits( key text primary key, window_start timestamptz not null, count integer not null default 0, updated_at timestamptz not null default now() ) `); await pool.query("create index if not exists rate_limits_updated_at_idx on rate_limits(updated_at)"); })(); } await ensureTablePromise; } function normalizeSegment(value: string, fallbackUnknown = true) { const trimmed = value.trim().toLowerCase().slice(0, 256); if (!trimmed) return fallbackUnknown ? "unknown" : ""; const safe = trimmed.replace(/[^a-z0-9:._-]/g, "_"); if (safe.length <= 96) return safe; return `sha256:${crypto.createHash("sha256").update(safe).digest("hex")}`; } async function cleanupStaleRateLimits() { const nowMs = Date.now(); if (nowMs - lastCleanupAtMs < 10 * 60 * 1000) return; lastCleanupAtMs = nowMs; const pool = getPool(); await pool.query("delete from rate_limits where updated_at < now() - interval '2 days'"); } function normalizeWindowStart(nowMs: number, windowMs: number) { const bucketStart = Math.floor(nowMs / windowMs) * windowMs; return new Date(bucketStart); } async function consumeRateLimit(input: LimitInput) { await ensureRateLimitsTable(); await cleanupStaleRateLimits(); const now = Date.now(); const windowStart = normalizeWindowStart(now, input.windowMs); const pool = getPool(); const { rows } = await pool.query<{ count: number }>( `insert into rate_limits(key, window_start, count, updated_at) values($1, $2, 1, now()) on conflict (key) do update set count = case when rate_limits.window_start = excluded.window_start then rate_limits.count + 1 else 1 end, window_start = case when rate_limits.window_start = excluded.window_start then rate_limits.window_start else excluded.window_start end, updated_at = now() returning count`, [input.key, windowStart] ); const count = Number(rows[0]?.count || 0); if (count > input.limit) { apiError("RATE_LIMITED", { scope: input.scope, limit: input.limit, windowMs: input.windowMs }); } } export async function enforceAuthRateLimit(input: { route: "login" | "register"; ip?: string | null; identifier?: string | null; ipLimit?: number; identifierLimit?: number; windowMs?: number; }) { const scope = normalizeSegment(`auth:${input.route}`); const ip = normalizeSegment(String(input.ip || "unknown")); const windowMs = input.windowMs ?? (15 * 60 * 1000); await consumeRateLimit({ key: `${scope}:ip:${ip}`, scope, limit: input.ipLimit ?? 20, windowMs }); const identifier = normalizeSegment(String(input.identifier || ""), false); if (identifier) { await consumeRateLimit({ key: `${scope}:identifier:${identifier}`, scope, limit: input.identifierLimit ?? 10, windowMs }); } } export async function enforceUserWriteRateLimit(input: { userId: number; scope: string; limit?: number; windowMs?: number }) { const scope = normalizeSegment(input.scope); await consumeRateLimit({ key: `write:user:${input.userId}:scope:${scope}`, scope, limit: input.limit ?? 120, windowMs: input.windowMs ?? (15 * 60 * 1000) }); }