fiddy/apps/web/lib/server/rate-limit.ts

137 lines
4.5 KiB
TypeScript

if (process.env.NODE_ENV !== "test")
require("server-only");
import crypto from "node:crypto";
import getPool from "@/lib/server/db";
import { apiError } from "@/lib/server/errors";
type LimitInput = {
key: string;
limit: number;
windowMs: number;
scope: string;
};
let ensureTablePromise: Promise<void> | null = null;
let lastCleanupAtMs = 0;
async function ensureRateLimitsTable() {
if (!ensureTablePromise) {
ensureTablePromise = (async () => {
const pool = getPool();
await pool.query(`
create table if not exists rate_limits(
key text primary key,
window_start timestamptz not null,
count integer not null default 0,
updated_at timestamptz not null default now()
)
`);
await pool.query("create index if not exists rate_limits_updated_at_idx on rate_limits(updated_at)");
})();
}
await ensureTablePromise;
}
function normalizeSegment(value: string, fallbackUnknown = true) {
const trimmed = value.trim().toLowerCase().slice(0, 256);
if (!trimmed) return fallbackUnknown ? "unknown" : "";
const safe = trimmed.replace(/[^a-z0-9:._-]/g, "_");
if (safe.length <= 96) return safe;
return `sha256:${crypto.createHash("sha256").update(safe).digest("hex")}`;
}
async function cleanupStaleRateLimits() {
const nowMs = Date.now();
if (nowMs - lastCleanupAtMs < 10 * 60 * 1000) return;
lastCleanupAtMs = nowMs;
const pool = getPool();
await pool.query("delete from rate_limits where updated_at < now() - interval '2 days'");
}
function normalizeWindowStart(nowMs: number, windowMs: number) {
const bucketStart = Math.floor(nowMs / windowMs) * windowMs;
return new Date(bucketStart);
}
async function consumeRateLimit(input: LimitInput) {
await ensureRateLimitsTable();
await cleanupStaleRateLimits();
const now = Date.now();
const windowStart = normalizeWindowStart(now, input.windowMs);
const pool = getPool();
const { rows } = await pool.query<{ count: number }>(
`insert into rate_limits(key, window_start, count, updated_at)
values($1, $2, 1, now())
on conflict (key) do update
set count = case
when rate_limits.window_start = excluded.window_start then rate_limits.count + 1
else 1
end,
window_start = case
when rate_limits.window_start = excluded.window_start then rate_limits.window_start
else excluded.window_start
end,
updated_at = now()
returning count`,
[input.key, windowStart]
);
const count = Number(rows[0]?.count || 0);
if (count > input.limit) {
apiError("RATE_LIMITED", {
scope: input.scope,
limit: input.limit,
windowMs: input.windowMs
});
}
}
export async function enforceAuthRateLimit(input: {
route: "login" | "register";
ip?: string | null;
identifier?: string | null;
ipLimit?: number;
identifierLimit?: number;
windowMs?: number;
}) {
const scope = normalizeSegment(`auth:${input.route}`);
const ip = normalizeSegment(String(input.ip || "unknown"));
const windowMs = input.windowMs ?? (15 * 60 * 1000);
await consumeRateLimit({
key: `${scope}:ip:${ip}`,
scope,
limit: input.ipLimit ?? 20,
windowMs
});
const identifier = normalizeSegment(String(input.identifier || ""), false);
if (identifier) {
await consumeRateLimit({
key: `${scope}:identifier:${identifier}`,
scope,
limit: input.identifierLimit ?? 10,
windowMs
});
}
}
export async function enforceUserWriteRateLimit(input: { userId: number; scope: string; limit?: number; windowMs?: number }) {
const scope = normalizeSegment(input.scope);
await consumeRateLimit({
key: `write:user:${input.userId}:scope:${scope}`,
scope,
limit: input.limit ?? 120,
windowMs: input.windowMs ?? (15 * 60 * 1000)
});
}
export async function enforceIpRateLimit(input: { scope: string; ip?: string | null; limit?: number; windowMs?: number }) {
const scope = normalizeSegment(input.scope);
const ip = normalizeSegment(String(input.ip || "unknown"));
await consumeRateLimit({
key: `ip:scope:${scope}:ip:${ip}`,
scope,
limit: input.limit ?? 120,
windowMs: input.windowMs ?? (15 * 60 * 1000)
});
}