106 lines
3.3 KiB
TypeScript
106 lines
3.3 KiB
TypeScript
if (process.env.NODE_ENV !== "test")
|
|
require("server-only");
|
|
import getPool from "@/lib/server/db";
|
|
import { apiError } from "@/lib/server/errors";
|
|
|
|
type LimitInput = {
|
|
key: string;
|
|
limit: number;
|
|
windowMs: number;
|
|
scope: string;
|
|
};
|
|
|
|
let ensureTablePromise: Promise<void> | null = null;
|
|
|
|
async function ensureRateLimitsTable() {
|
|
if (!ensureTablePromise) {
|
|
ensureTablePromise = (async () => {
|
|
const pool = getPool();
|
|
await pool.query(`
|
|
create table if not exists rate_limits(
|
|
key text primary key,
|
|
window_start timestamptz not null,
|
|
count integer not null default 0,
|
|
updated_at timestamptz not null default now()
|
|
)
|
|
`);
|
|
await pool.query("create index if not exists rate_limits_updated_at_idx on rate_limits(updated_at)");
|
|
})();
|
|
}
|
|
await ensureTablePromise;
|
|
}
|
|
|
|
function normalizeWindowStart(nowMs: number, windowMs: number) {
|
|
const bucketStart = Math.floor(nowMs / windowMs) * windowMs;
|
|
return new Date(bucketStart);
|
|
}
|
|
|
|
async function consumeRateLimit(input: LimitInput) {
|
|
await ensureRateLimitsTable();
|
|
const now = Date.now();
|
|
const windowStart = normalizeWindowStart(now, input.windowMs);
|
|
const pool = getPool();
|
|
const { rows } = await pool.query<{ count: number }>(
|
|
`insert into rate_limits(key, window_start, count, updated_at)
|
|
values($1, $2, 1, now())
|
|
on conflict (key) do update
|
|
set count = case
|
|
when rate_limits.window_start = excluded.window_start then rate_limits.count + 1
|
|
else 1
|
|
end,
|
|
window_start = case
|
|
when rate_limits.window_start = excluded.window_start then rate_limits.window_start
|
|
else excluded.window_start
|
|
end,
|
|
updated_at = now()
|
|
returning count`,
|
|
[input.key, windowStart]
|
|
);
|
|
const count = Number(rows[0]?.count || 0);
|
|
if (count > input.limit) {
|
|
apiError("RATE_LIMITED", {
|
|
scope: input.scope,
|
|
limit: input.limit,
|
|
windowMs: input.windowMs
|
|
});
|
|
}
|
|
}
|
|
|
|
export async function enforceAuthRateLimit(input: {
|
|
route: "login" | "register";
|
|
ip?: string | null;
|
|
identifier?: string | null;
|
|
ipLimit?: number;
|
|
identifierLimit?: number;
|
|
windowMs?: number;
|
|
}) {
|
|
const scope = `auth:${input.route}`;
|
|
const ip = String(input.ip || "unknown").trim().toLowerCase();
|
|
const windowMs = input.windowMs ?? (15 * 60 * 1000);
|
|
await consumeRateLimit({
|
|
key: `${scope}:ip:${ip}`,
|
|
scope,
|
|
limit: input.ipLimit ?? 20,
|
|
windowMs
|
|
});
|
|
|
|
const identifier = String(input.identifier || "").trim().toLowerCase();
|
|
if (identifier) {
|
|
await consumeRateLimit({
|
|
key: `${scope}:identifier:${identifier}`,
|
|
scope,
|
|
limit: input.identifierLimit ?? 10,
|
|
windowMs
|
|
});
|
|
}
|
|
}
|
|
|
|
export async function enforceUserWriteRateLimit(input: { userId: number; scope: string; limit?: number; windowMs?: number }) {
|
|
await consumeRateLimit({
|
|
key: `write:user:${input.userId}:scope:${input.scope}`,
|
|
scope: input.scope,
|
|
limit: input.limit ?? 120,
|
|
windowMs: input.windowMs ?? (15 * 60 * 1000)
|
|
});
|
|
}
|