126 lines
4.1 KiB
TypeScript
126 lines
4.1 KiB
TypeScript
if (process.env.NODE_ENV !== "test")
|
|
require("server-only");
|
|
import crypto from "node:crypto";
|
|
import getPool from "@/lib/server/db";
|
|
import { apiError } from "@/lib/server/errors";
|
|
|
|
type LimitInput = {
|
|
key: string;
|
|
limit: number;
|
|
windowMs: number;
|
|
scope: string;
|
|
};
|
|
|
|
let ensureTablePromise: Promise<void> | null = null;
|
|
let lastCleanupAtMs = 0;
|
|
|
|
async function ensureRateLimitsTable() {
|
|
if (!ensureTablePromise) {
|
|
ensureTablePromise = (async () => {
|
|
const pool = getPool();
|
|
await pool.query(`
|
|
create table if not exists rate_limits(
|
|
key text primary key,
|
|
window_start timestamptz not null,
|
|
count integer not null default 0,
|
|
updated_at timestamptz not null default now()
|
|
)
|
|
`);
|
|
await pool.query("create index if not exists rate_limits_updated_at_idx on rate_limits(updated_at)");
|
|
})();
|
|
}
|
|
await ensureTablePromise;
|
|
}
|
|
|
|
function normalizeSegment(value: string, fallbackUnknown = true) {
|
|
const trimmed = value.trim().toLowerCase().slice(0, 256);
|
|
if (!trimmed) return fallbackUnknown ? "unknown" : "";
|
|
const safe = trimmed.replace(/[^a-z0-9:._-]/g, "_");
|
|
if (safe.length <= 96) return safe;
|
|
return `sha256:${crypto.createHash("sha256").update(safe).digest("hex")}`;
|
|
}
|
|
|
|
async function cleanupStaleRateLimits() {
|
|
const nowMs = Date.now();
|
|
if (nowMs - lastCleanupAtMs < 10 * 60 * 1000) return;
|
|
lastCleanupAtMs = nowMs;
|
|
const pool = getPool();
|
|
await pool.query("delete from rate_limits where updated_at < now() - interval '2 days'");
|
|
}
|
|
|
|
function normalizeWindowStart(nowMs: number, windowMs: number) {
|
|
const bucketStart = Math.floor(nowMs / windowMs) * windowMs;
|
|
return new Date(bucketStart);
|
|
}
|
|
|
|
async function consumeRateLimit(input: LimitInput) {
|
|
await ensureRateLimitsTable();
|
|
await cleanupStaleRateLimits();
|
|
const now = Date.now();
|
|
const windowStart = normalizeWindowStart(now, input.windowMs);
|
|
const pool = getPool();
|
|
const { rows } = await pool.query<{ count: number }>(
|
|
`insert into rate_limits(key, window_start, count, updated_at)
|
|
values($1, $2, 1, now())
|
|
on conflict (key) do update
|
|
set count = case
|
|
when rate_limits.window_start = excluded.window_start then rate_limits.count + 1
|
|
else 1
|
|
end,
|
|
window_start = case
|
|
when rate_limits.window_start = excluded.window_start then rate_limits.window_start
|
|
else excluded.window_start
|
|
end,
|
|
updated_at = now()
|
|
returning count`,
|
|
[input.key, windowStart]
|
|
);
|
|
const count = Number(rows[0]?.count || 0);
|
|
if (count > input.limit) {
|
|
apiError("RATE_LIMITED", {
|
|
scope: input.scope,
|
|
limit: input.limit,
|
|
windowMs: input.windowMs
|
|
});
|
|
}
|
|
}
|
|
|
|
export async function enforceAuthRateLimit(input: {
|
|
route: "login" | "register";
|
|
ip?: string | null;
|
|
identifier?: string | null;
|
|
ipLimit?: number;
|
|
identifierLimit?: number;
|
|
windowMs?: number;
|
|
}) {
|
|
const scope = normalizeSegment(`auth:${input.route}`);
|
|
const ip = normalizeSegment(String(input.ip || "unknown"));
|
|
const windowMs = input.windowMs ?? (15 * 60 * 1000);
|
|
await consumeRateLimit({
|
|
key: `${scope}:ip:${ip}`,
|
|
scope,
|
|
limit: input.ipLimit ?? 20,
|
|
windowMs
|
|
});
|
|
|
|
const identifier = normalizeSegment(String(input.identifier || ""), false);
|
|
if (identifier) {
|
|
await consumeRateLimit({
|
|
key: `${scope}:identifier:${identifier}`,
|
|
scope,
|
|
limit: input.identifierLimit ?? 10,
|
|
windowMs
|
|
});
|
|
}
|
|
}
|
|
|
|
export async function enforceUserWriteRateLimit(input: { userId: number; scope: string; limit?: number; windowMs?: number }) {
|
|
const scope = normalizeSegment(input.scope);
|
|
await consumeRateLimit({
|
|
key: `write:user:${input.userId}:scope:${scope}`,
|
|
scope,
|
|
limit: input.limit ?? 120,
|
|
windowMs: input.windowMs ?? (15 * 60 * 1000)
|
|
});
|
|
}
|