fiddy/apps/web/lib/server/rate-limit.ts

106 lines
3.3 KiB
TypeScript

if (process.env.NODE_ENV !== "test")
require("server-only");
import getPool from "@/lib/server/db";
import { apiError } from "@/lib/server/errors";
type LimitInput = {
key: string;
limit: number;
windowMs: number;
scope: string;
};
let ensureTablePromise: Promise<void> | null = null;
async function ensureRateLimitsTable() {
if (!ensureTablePromise) {
ensureTablePromise = (async () => {
const pool = getPool();
await pool.query(`
create table if not exists rate_limits(
key text primary key,
window_start timestamptz not null,
count integer not null default 0,
updated_at timestamptz not null default now()
)
`);
await pool.query("create index if not exists rate_limits_updated_at_idx on rate_limits(updated_at)");
})();
}
await ensureTablePromise;
}
function normalizeWindowStart(nowMs: number, windowMs: number) {
const bucketStart = Math.floor(nowMs / windowMs) * windowMs;
return new Date(bucketStart);
}
async function consumeRateLimit(input: LimitInput) {
await ensureRateLimitsTable();
const now = Date.now();
const windowStart = normalizeWindowStart(now, input.windowMs);
const pool = getPool();
const { rows } = await pool.query<{ count: number }>(
`insert into rate_limits(key, window_start, count, updated_at)
values($1, $2, 1, now())
on conflict (key) do update
set count = case
when rate_limits.window_start = excluded.window_start then rate_limits.count + 1
else 1
end,
window_start = case
when rate_limits.window_start = excluded.window_start then rate_limits.window_start
else excluded.window_start
end,
updated_at = now()
returning count`,
[input.key, windowStart]
);
const count = Number(rows[0]?.count || 0);
if (count > input.limit) {
apiError("RATE_LIMITED", {
scope: input.scope,
limit: input.limit,
windowMs: input.windowMs
});
}
}
export async function enforceAuthRateLimit(input: {
route: "login" | "register";
ip?: string | null;
identifier?: string | null;
ipLimit?: number;
identifierLimit?: number;
windowMs?: number;
}) {
const scope = `auth:${input.route}`;
const ip = String(input.ip || "unknown").trim().toLowerCase();
const windowMs = input.windowMs ?? (15 * 60 * 1000);
await consumeRateLimit({
key: `${scope}:ip:${ip}`,
scope,
limit: input.ipLimit ?? 20,
windowMs
});
const identifier = String(input.identifier || "").trim().toLowerCase();
if (identifier) {
await consumeRateLimit({
key: `${scope}:identifier:${identifier}`,
scope,
limit: input.identifierLimit ?? 10,
windowMs
});
}
}
export async function enforceUserWriteRateLimit(input: { userId: number; scope: string; limit?: number; windowMs?: number }) {
await consumeRateLimit({
key: `write:user:${input.userId}:scope:${input.scope}`,
scope: input.scope,
limit: input.limit ?? 120,
windowMs: input.windowMs ?? (15 * 60 * 1000)
});
}