fix: read-after-write issues (#215)

* fix: read-after-write issues

* fix: coderabbit comments

* fix: clear cache on invite

* fix: use primary after a read
This commit is contained in:
Carl-Gerhard Lindesvärd
2025-10-31 09:56:07 +01:00
committed by GitHub
parent abacf66155
commit f454449365
19 changed files with 470 additions and 167 deletions

View File

@@ -9,7 +9,7 @@ import {
} from '@/utils/ai-tools'; } from '@/utils/ai-tools';
import { HttpError } from '@/utils/errors'; import { HttpError } from '@/utils/errors';
import { db, getOrganizationByProjectIdCached } from '@openpanel/db'; import { db, getOrganizationByProjectIdCached } from '@openpanel/db';
import { getProjectAccessCached } from '@openpanel/trpc/src/access'; import { getProjectAccess } from '@openpanel/trpc/src/access';
import { type Message, appendResponseMessages, streamText } from 'ai'; import { type Message, appendResponseMessages, streamText } from 'ai';
import type { FastifyReply, FastifyRequest } from 'fastify'; import type { FastifyReply, FastifyRequest } from 'fastify';
@@ -37,7 +37,7 @@ export async function chat(
} }
const organization = await getOrganizationByProjectIdCached(projectId); const organization = await getOrganizationByProjectIdCached(projectId);
const access = await getProjectAccessCached({ const access = await getProjectAccess({
projectId, projectId,
userId: session.userId, userId: session.userId,
}); });

View File

@@ -113,6 +113,17 @@ export async function slackWebhook(
} }
} }
async function clearOrganizationCache(organizationId: string) {
const projects = await db.project.findMany({
where: {
organizationId,
},
});
for (const project of projects) {
await getOrganizationByProjectIdCached.clear(project.id);
}
}
export async function polarWebhook( export async function polarWebhook(
request: FastifyRequest<{ request: FastifyRequest<{
Querystring: unknown; Querystring: unknown;
@@ -141,8 +152,11 @@ export async function polarWebhook(
}, },
data: { data: {
subscriptionPeriodEventsCount: 0, subscriptionPeriodEventsCount: 0,
subscriptionPeriodEventsCountExceededAt: null,
}, },
}); });
await clearOrganizationCache(metadata.organizationId);
} }
break; break;
} }
@@ -205,15 +219,7 @@ export async function polarWebhook(
}, },
}); });
const projects = await db.project.findMany({ await clearOrganizationCache(metadata.organizationId);
where: {
organizationId: metadata.organizationId,
},
});
for (const project of projects) {
await getOrganizationByProjectIdCached.clear(project.id);
}
await publishEvent('organization', 'subscription_updated', { await publishEvent('organization', 'subscription_updated', {
organizationId: metadata.organizationId, organizationId: metadata.organizationId,

View File

@@ -8,14 +8,18 @@ import Fastify from 'fastify';
import metricsPlugin from 'fastify-metrics'; import metricsPlugin from 'fastify-metrics';
import { generateId } from '@openpanel/common'; import { generateId } from '@openpanel/common';
import type { IServiceClientWithProject } from '@openpanel/db'; import {
import { getRedisPub } from '@openpanel/redis'; type IServiceClientWithProject,
runWithAlsSession,
} from '@openpanel/db';
import { getCache, getRedisPub } from '@openpanel/redis';
import type { AppRouter } from '@openpanel/trpc'; import type { AppRouter } from '@openpanel/trpc';
import { appRouter, createContext } from '@openpanel/trpc'; import { appRouter, createContext } from '@openpanel/trpc';
import { import {
EMPTY_SESSION, EMPTY_SESSION,
type SessionValidationResult, type SessionValidationResult,
decodeSessionToken,
validateSessionToken, validateSessionToken,
} from '@openpanel/auth'; } from '@openpanel/auth';
import sourceMapSupport from 'source-map-support'; import sourceMapSupport from 'source-map-support';
@@ -140,7 +144,14 @@ const startServer = async () => {
instance.addHook('onRequest', async (req) => { instance.addHook('onRequest', async (req) => {
if (req.cookies?.session) { if (req.cookies?.session) {
try { try {
const session = await validateSessionToken(req.cookies.session); const sessionId = decodeSessionToken(req.cookies.session);
const session = await runWithAlsSession(sessionId, () =>
sessionId
? getCache(`validateSession:${sessionId}`, 60 * 5, async () =>
validateSessionToken(req.cookies.session),
)
: validateSessionToken(req.cookies.session),
);
if (session.session) { if (session.session) {
req.session = session; req.session = session;
} }

View File

@@ -24,9 +24,11 @@ export function useSessionExtension() {
1000 * 60 * 5, 1000 * 60 * 5,
); );
extendSessionFn(); // Delay initial call a bit to prioritize other requests
const timer = setTimeout(() => extendSessionFn(), 5000);
return () => { return () => {
clearTimeout(timer);
if (intervalRef.current) { if (intervalRef.current) {
clearInterval(intervalRef.current); clearInterval(intervalRef.current);
} }

View File

@@ -59,8 +59,14 @@ export async function createDemoSession(
}; };
} }
export const decodeSessionToken = (token: string): string | null => {
return token
? encodeHexLowerCase(sha256(new TextEncoder().encode(token)))
: null;
};
export async function validateSessionToken( export async function validateSessionToken(
token: string | null, token: string | null | undefined,
): Promise<SessionValidationResult> { ): Promise<SessionValidationResult> {
if (process.env.DEMO_USER_ID) { if (process.env.DEMO_USER_ID) {
return createDemoSession(process.env.DEMO_USER_ID); return createDemoSession(process.env.DEMO_USER_ID);
@@ -69,7 +75,10 @@ export async function validateSessionToken(
if (!token) { if (!token) {
return EMPTY_SESSION; return EMPTY_SESSION;
} }
const sessionId = encodeHexLowerCase(sha256(new TextEncoder().encode(token))); const sessionId = decodeSessionToken(token);
if (!sessionId) {
return EMPTY_SESSION;
}
const result = await db.$primary().session.findUnique({ const result = await db.$primary().session.findUnique({
where: { where: {
id: sessionId, id: sessionId,

View File

@@ -19,7 +19,9 @@ export * from './src/services/reference.service';
export * from './src/services/id.service'; export * from './src/services/id.service';
export * from './src/services/retention.service'; export * from './src/services/retention.service';
export * from './src/services/notification.service'; export * from './src/services/notification.service';
export * from './src/services/access.service';
export * from './src/buffers'; export * from './src/buffers';
export * from './src/types'; export * from './src/types';
export * from './src/clickhouse/query-builder'; export * from './src/clickhouse/query-builder';
export * from './src/services/overview.service'; export * from './src/services/overview.service';
export * from './src/session-context';

View File

@@ -0,0 +1,3 @@
import { createLogger } from '@openpanel/logger';
export const logger = createLogger({ name: 'db:prisma' });

View File

@@ -1,11 +1,15 @@
import { createLogger } from '@openpanel/logger'; import { createLogger } from '@openpanel/logger';
import { readReplicas } from '@prisma/extension-read-replicas'; import { readReplicas } from '@prisma/extension-read-replicas';
import { type Organization, PrismaClient } from './generated/prisma/client'; import {
type Organization,
Prisma,
PrismaClient,
} from './generated/prisma/client';
import { logger } from './logger';
import { sessionConsistency } from './session-consistency';
export * from './generated/prisma/client'; export * from './generated/prisma/client';
const logger = createLogger({ name: 'db' });
const isWillBeCanceled = ( const isWillBeCanceled = (
organization: Pick< organization: Pick<
Organization, Organization,
@@ -30,11 +34,6 @@ const getPrismaClient = () => {
const prisma = new PrismaClient({ const prisma = new PrismaClient({
log: ['error'], log: ['error'],
}) })
.$extends(
readReplicas({
url: process.env.DATABASE_URL_REPLICA ?? process.env.DATABASE_URL!,
}),
)
.$extends({ .$extends({
query: { query: {
async $allOperations({ operation, model, args, query }) { async $allOperations({ operation, model, args, query }) {
@@ -53,6 +52,8 @@ const getPrismaClient = () => {
}, },
}, },
}) })
.$extends(sessionConsistency())
.$extends({ .$extends({
result: { result: {
organization: { organization: {
@@ -258,7 +259,12 @@ const getPrismaClient = () => {
}, },
}, },
}, },
}); })
.$extends(
readReplicas({
url: process.env.DATABASE_URL_REPLICA ?? process.env.DATABASE_URL!,
}),
);
return prisma; return prisma;
}; };

View File

@@ -0,0 +1,96 @@
import { cacheable } from '@openpanel/redis';
import { db } from '../prisma-client';
import { getProjectById } from './project.service';
export const getProjectAccess = cacheable(
'getProjectAccess',
async ({
userId,
projectId,
}: {
userId: string;
projectId: string;
}) => {
try {
// Check if user has access to the project
const project = await getProjectById(projectId);
if (!project?.organizationId) {
return false;
}
const [projectAccess, member] = await Promise.all([
db.$primary().projectAccess.findMany({
where: {
userId,
organizationId: project.organizationId,
},
}),
db.$primary().member.findFirst({
where: {
organizationId: project.organizationId,
userId,
},
}),
]);
if (projectAccess.length === 0 && member) {
return true;
}
return projectAccess.find((item) => item.projectId === projectId);
} catch (err) {
return false;
}
},
60 * 5,
);
export const getOrganizationAccess = cacheable(
'getOrganizationAccess',
async ({
userId,
organizationId,
}: {
userId: string;
organizationId: string;
}) => {
return db.$primary().member.findFirst({
where: {
userId,
organizationId,
},
});
},
60 * 5,
);
export async function getClientAccess({
userId,
clientId,
}: {
userId: string;
clientId: string;
}) {
const client = await db.client.findFirst({
where: {
id: clientId,
},
});
if (!client) {
return false;
}
if (client.projectId) {
return getProjectAccess({ userId, projectId: client.projectId });
}
if (client.organizationId) {
return getOrganizationAccess({
userId,
organizationId: client.organizationId,
});
}
return false;
}

View File

@@ -5,7 +5,8 @@ import { chQuery, formatClickhouseDate } from '../clickhouse/client';
import type { Invite, Prisma, ProjectAccess, User } from '../prisma-client'; import type { Invite, Prisma, ProjectAccess, User } from '../prisma-client';
import { db } from '../prisma-client'; import { db } from '../prisma-client';
import { createSqlBuilder } from '../sql-builder'; import { createSqlBuilder } from '../sql-builder';
import type { IServiceProject } from './project.service'; import { getOrganizationAccess, getProjectAccess } from './access.service';
import { type IServiceProject, getProjectById } from './project.service';
export type IServiceOrganization = Awaited< export type IServiceOrganization = Awaited<
ReturnType<typeof db.organization.findUniqueOrThrow> ReturnType<typeof db.organization.findUniqueOrThrow>
>; >;
@@ -61,7 +62,7 @@ export async function getOrganizationByProjectId(projectId: string) {
export const getOrganizationByProjectIdCached = cacheable( export const getOrganizationByProjectIdCached = cacheable(
getOrganizationByProjectId, getOrganizationByProjectId,
60 * 60 * 24, 60 * 5,
); );
export async function getInvites(organizationId: string) { export async function getInvites(organizationId: string) {
@@ -168,8 +169,14 @@ export async function connectUserToOrganization({
}, },
}); });
await getOrganizationAccess.clear({
userId: user.id,
organizationId: invite.organizationId,
});
if (invite.projectAccess.length > 0) { if (invite.projectAccess.length > 0) {
for (const projectId of invite.projectAccess) { for (const projectId of invite.projectAccess) {
await getProjectAccess.clear({ userId: user.id, projectId });
await db.projectAccess.create({ await db.projectAccess.create({
data: { data: {
projectId, projectId,

View File

@@ -0,0 +1,225 @@
import { getRedisCache } from '@openpanel/redis';
import type { Operation } from '@prisma/client/runtime/client';
import { Prisma, type PrismaClient } from './generated/prisma/client';
import { logger } from './logger';
import { getAlsSessionId } from './session-context';
type BarePrismaClient = {
$queryRaw: <T>(query: TemplateStringsArray, ...args: unknown[]) => Promise<T>;
};
// WAL LSN tracking for read-after-write consistency
const LSN_CACHE_PREFIX = 'db:wal_lsn:';
const LSN_CACHE_TTL = 5;
const MAX_RETRY_ATTEMPTS = 5;
const INITIAL_RETRY_DELAY_MS = 10;
const READ_OPERATIONS: Operation[] = [
'findUnique',
'findUniqueOrThrow',
'findFirst',
'findFirstOrThrow',
'findMany',
'aggregate',
'groupBy',
'count',
];
const WRITE_OPERATIONS: Operation[] = [
'create',
'update',
'delete',
'createMany',
'createManyAndReturn',
'updateMany',
'deleteMany',
'upsert',
];
const isWriteOperation = (operation: string) =>
WRITE_OPERATIONS.includes(operation as Operation);
const isReadOperation = (operation: string) =>
READ_OPERATIONS.includes(operation as Operation);
async function getCurrentWalLsn(
prismaClient: BarePrismaClient,
): Promise<string | null> {
try {
const result = await prismaClient.$queryRaw<[{ lsn: string }]>`
SELECT pg_current_wal_lsn()::text AS lsn
`;
return result[0]?.lsn || null;
} catch (error) {
logger.error('Failed to get WAL LSN', { error });
return null;
}
}
async function cacheWalLsnForSession(
sessionId: string,
lsn: string,
): Promise<void> {
try {
const redis = getRedisCache();
await redis.setex(`${LSN_CACHE_PREFIX}${sessionId}`, LSN_CACHE_TTL, lsn);
} catch (error) {
logger.error('Failed to cache WAL LSN', { error, sessionId });
}
}
async function getCachedWalLsn(sessionId: string): Promise<string | null> {
try {
const redis = getRedisCache();
return await redis.get(`${LSN_CACHE_PREFIX}${sessionId}`);
} catch (error) {
logger.error('Failed to get cached WAL LSN', { error, sessionId });
return null;
}
}
function compareWalLsn(lsn1: string, lsn2: string): number {
const [x1, y1] = lsn1.split('/').map((x) => BigInt(`0x${x}`));
const [x2, y2] = lsn2.split('/').map((x) => BigInt(`0x${x}`));
const v1 = ((x1 ?? 0n) << 32n) + (y1 ?? 0n);
const v2 = ((x2 ?? 0n) << 32n) + (y2 ?? 0n);
if (v1 < v2) return -1;
if (v1 > v2) return 1;
return 0;
}
async function sleep(ms: number): Promise<void> {
return new Promise((resolve) => setTimeout(resolve, ms));
}
// Method not used for now,
// Need a way to check LSN on the actual replica that will be used for the read.
async function waitForReplicaCatchup(
prismaClient: BarePrismaClient,
sessionId: string,
): Promise<boolean> {
const expectedLsn = await getCachedWalLsn(sessionId);
if (!expectedLsn) {
return true;
}
for (let attempt = 0; attempt < MAX_RETRY_ATTEMPTS; attempt++) {
const currentLsn = await getCurrentWalLsn(prismaClient);
if (!currentLsn) {
return true;
}
// Check if replica has caught up (current >= expected)
if (compareWalLsn(currentLsn, expectedLsn) >= 0) {
logger.debug('Replica caught up', {
attempt: attempt + 1,
currentLsn,
expectedLsn,
sessionId,
});
return true;
}
// Exponential backoff
if (attempt < MAX_RETRY_ATTEMPTS - 1) {
const delayMs = INITIAL_RETRY_DELAY_MS * 2 ** attempt;
logger.debug('Waiting for replica to catch up', {
attempt: attempt + 1,
delayMs,
currentLsn,
expectedLsn,
sessionId,
});
await sleep(delayMs);
}
}
logger.warn(
'Replica did not catch up after max retries, falling back to primary',
{
sessionId,
expectedLsn,
},
);
return false;
}
/**
* Prisma extension for session-based read-after-write consistency.
*
* This extension tracks WAL LSN positions after writes and ensures that
* subsequent reads within the same session see those writes, even when
* using read replicas.
*
* How it works:
* 1. After any write operation with a session ID, it captures the WAL LSN
* 2. Before read operations with a session ID, it checks if the replica has caught up
* 3. If the replica hasn't caught up after retries, it forces the read to the primary
*
*/
export function sessionConsistency() {
return Prisma.defineExtension((client) =>
client.$extends({
name: 'session-consistency',
query: {
$allOperations: async ({
operation,
model,
args,
query,
// This is a hack to force reads to primary when replica hasn't caught up.
// The readReplicas extension routes queries to primary when in a transaction,
// so we set __internalParams.transaction = true to achieve this.
// @ts-expect-error - __internalParams is not in the types
__internalParams,
}) => {
const sessionId = getAlsSessionId();
// For write operations with session: cache WAL LSN after write
if (isWriteOperation(operation)) {
logger.info('Prisma operation', {
operation,
args,
model,
});
const result = await query(args);
if (sessionId) {
// Get current WAL LSN and cache it for this session
const lsn = await getCurrentWalLsn(client);
if (lsn) {
await cacheWalLsnForSession(sessionId, lsn);
logger.debug('Cached WAL LSN after write', {
sessionId,
lsn,
operation,
model,
});
}
}
return result;
}
// For now, we just force the read to the primary without checking the replica
// Since the check probably goes to the primary anyways it will always be true,
// Not sure how to check LSN on the actual replica that will be used for the read.
if (
isReadOperation(operation) &&
sessionId &&
(await getCachedWalLsn(sessionId))
) {
// This will force readReplicas extension to use primary
__internalParams.transaction = true;
}
return query(args);
},
},
}),
);
}

View File

@@ -0,0 +1,12 @@
import { AsyncLocalStorage } from 'node:async_hooks';
type Ctx = { sessionId: string | null };
export const als = new AsyncLocalStorage<Ctx>();
export const runWithAlsSession = <T>(
sid: string | null | undefined,
fn: () => Promise<T>,
) => als.run({ sessionId: sid || null }, fn);
export const getAlsSessionId = () => als.getStore()?.sessionId ?? null;

View File

@@ -1,5 +1,9 @@
import { getRedisCache } from './redis'; import { getRedisCache } from './redis';
export const deleteCache = async (key: string) => {
return getRedisCache().del(key);
};
export async function getCache<T>( export async function getCache<T>(
key: string, key: string,
expireInSec: number, expireInSec: number,

View File

@@ -1,93 +1,5 @@
import { db, getProjectById } from '@openpanel/db'; export {
import { cacheable } from '@openpanel/redis';
export const getProjectAccessCached = cacheable(getProjectAccess, 60 * 5);
export async function getProjectAccess({
userId,
projectId,
}: {
userId: string;
projectId: string;
}) {
try {
// Check if user has access to the project
const project = await getProjectById(projectId);
if (!project?.organizationId) {
return false;
}
const [projectAccess, member] = await Promise.all([
db.projectAccess.findMany({
where: {
userId,
organizationId: project.organizationId,
},
}),
db.member.findFirst({
where: {
organizationId: project.organizationId,
userId,
},
}),
]);
if (projectAccess.length === 0 && member) {
return true;
}
return projectAccess.find((item) => item.projectId === projectId);
} catch (err) {
return false;
}
}
export const getOrganizationAccessCached = cacheable(
getOrganizationAccess, getOrganizationAccess,
60 * 5, getProjectAccess,
); getClientAccess,
export async function getOrganizationAccess({ } from '@openpanel/db';
userId,
organizationId,
}: {
userId: string;
organizationId: string;
}) {
return db.member.findFirst({
where: {
userId,
organizationId,
},
});
}
export const getClientAccessCached = cacheable(getClientAccess, 60 * 5);
export async function getClientAccess({
userId,
clientId,
}: {
userId: string;
clientId: string;
}) {
const client = await db.client.findFirst({
where: {
id: clientId,
},
});
if (!client) {
return false;
}
if (client.projectId) {
return getProjectAccess({ userId, projectId: client.projectId });
}
if (client.organizationId) {
return getOrganizationAccess({
userId,
organizationId: client.organizationId,
});
}
return false;
}

View File

@@ -20,6 +20,7 @@ import {
getUserAccount, getUserAccount,
} from '@openpanel/db'; } from '@openpanel/db';
import { sendEmail } from '@openpanel/email'; import { sendEmail } from '@openpanel/email';
import { deleteCache } from '@openpanel/redis';
import { import {
zRequestResetPassword, zRequestResetPassword,
zResetPassword, zResetPassword,
@@ -74,6 +75,7 @@ export const authRouter = createTRPCRouter({
deleteSessionTokenCookie(ctx.setCookie); deleteSessionTokenCookie(ctx.setCookie);
if (ctx.session?.session?.id) { if (ctx.session?.session?.id) {
await invalidateSession(ctx.session.session.id); await invalidateSession(ctx.session.session.id);
await deleteCache(`validateSession:${ctx.session.session.id}`);
} }
}), }),
signInOAuth: publicProcedure signInOAuth: publicProcedure
@@ -333,6 +335,7 @@ export const authRouter = createTRPCRouter({
const session = await validateSessionToken(token); const session = await validateSessionToken(token);
if (session.session) { if (session.session) {
await deleteCache(`validateSession:${session.session.id}`);
// Re-set the cookie with updated expiration // Re-set the cookie with updated expiration
setSessionTokenCookie(ctx.setCookie, token, session.session.expiresAt); setSessionTokenCookie(ctx.setCookie, token, session.session.expiresAt);
return { return {

View File

@@ -32,7 +32,7 @@ import {
differenceInWeeks, differenceInWeeks,
formatISO, formatISO,
} from 'date-fns'; } from 'date-fns';
import { getProjectAccessCached } from '../access'; import { getProjectAccess } from '../access';
import { TRPCAccessError } from '../errors'; import { TRPCAccessError } from '../errors';
import { import {
cacheMiddleware, cacheMiddleware,
@@ -367,7 +367,7 @@ export const chartRouter = createTRPCRouter({
.input(zChartInput) .input(zChartInput)
.query(async ({ input, ctx }) => { .query(async ({ input, ctx }) => {
if (ctx.session.userId) { if (ctx.session.userId) {
const access = await getProjectAccessCached({ const access = await getProjectAccess({
projectId: input.projectId, projectId: input.projectId,
userId: ctx.session.userId, userId: ctx.session.userId,
}); });

View File

@@ -27,7 +27,7 @@ import {
} from '@openpanel/validation'; } from '@openpanel/validation';
import { clone } from 'ramda'; import { clone } from 'ramda';
import { getProjectAccessCached } from '../access'; import { getProjectAccess } from '../access';
import { TRPCAccessError } from '../errors'; import { TRPCAccessError } from '../errors';
import { createTRPCRouter, protectedProcedure, publicProcedure } from '../trpc'; import { createTRPCRouter, protectedProcedure, publicProcedure } from '../trpc';
@@ -266,7 +266,7 @@ export const eventRouter = createTRPCRouter({
) )
.query(async ({ input: { projectId, cursor, limit }, ctx }) => { .query(async ({ input: { projectId, cursor, limit }, ctx }) => {
if (ctx.session.userId) { if (ctx.session.userId) {
const access = await getProjectAccessCached({ const access = await getProjectAccess({
projectId, projectId,
userId: ctx.session.userId, userId: ctx.session.userId,
}); });

View File

@@ -9,7 +9,7 @@ import {
zCreateSlackIntegration, zCreateSlackIntegration,
zCreateWebhookIntegration, zCreateWebhookIntegration,
} from '@openpanel/validation'; } from '@openpanel/validation';
import { getOrganizationAccessCached } from '../access'; import { getOrganizationAccess } from '../access';
import { TRPCAccessError } from '../errors'; import { TRPCAccessError } from '../errors';
import { createTRPCRouter, protectedProcedure } from '../trpc'; import { createTRPCRouter, protectedProcedure } from '../trpc';
@@ -23,7 +23,7 @@ export const integrationRouter = createTRPCRouter({
}, },
}); });
const access = await getOrganizationAccessCached({ const access = await getOrganizationAccess({
userId: ctx.session.userId, userId: ctx.session.userId,
organizationId: integration.organizationId, organizationId: integration.organizationId,
}); });
@@ -122,7 +122,7 @@ export const integrationRouter = createTRPCRouter({
}, },
}); });
const access = await getOrganizationAccessCached({ const access = await getOrganizationAccess({
userId: ctx.session.userId, userId: ctx.session.userId,
organizationId: integration.organizationId, organizationId: integration.organizationId,
}); });

View File

@@ -4,18 +4,15 @@ import { has } from 'ramda';
import superjson from 'superjson'; import superjson from 'superjson';
import { ZodError } from 'zod'; import { ZodError } from 'zod';
import { import { COOKIE_OPTIONS, type SessionValidationResult } from '@openpanel/auth';
COOKIE_OPTIONS, import { runWithAlsSession } from '@openpanel/db';
EMPTY_SESSION, import { getRedisCache } from '@openpanel/redis';
validateSessionToken,
} from '@openpanel/auth';
import { getCache, getRedisCache } from '@openpanel/redis';
import type { ISetCookie } from '@openpanel/validation'; import type { ISetCookie } from '@openpanel/validation';
import { import {
createTrpcRedisLimiter, createTrpcRedisLimiter,
defaultFingerPrint, defaultFingerPrint,
} from '@trpc-limiter/redis'; } from '@trpc-limiter/redis';
import { getOrganizationAccessCached, getProjectAccessCached } from './access'; import { getOrganizationAccess, getProjectAccess } from './access';
import { TRPCAccessError } from './errors'; import { TRPCAccessError } from './errors';
export const rateLimitMiddleware = ({ export const rateLimitMiddleware = ({
@@ -44,10 +41,6 @@ export async function createContext({ req, res }: CreateFastifyContextOptions) {
}); });
}; };
const session = cookies?.session
? await validateSessionToken(cookies.session!)
: EMPTY_SESSION;
if (process.env.NODE_ENV !== 'production') { if (process.env.NODE_ENV !== 'production') {
await new Promise((res) => await new Promise((res) =>
setTimeout(() => res(1), Math.min(Math.random() * 500, 200)), setTimeout(() => res(1), Math.min(Math.random() * 500, 200)),
@@ -57,7 +50,7 @@ export async function createContext({ req, res }: CreateFastifyContextOptions) {
return { return {
req, req,
res, res,
session, session: (req as any).session as SessionValidationResult,
// we do not get types for `setCookie` from fastify // we do not get types for `setCookie` from fastify
// so define it here and be safe in routers // so define it here and be safe in routers
setCookie, setCookie,
@@ -102,6 +95,7 @@ const enforceUserIsAuthed = t.middleware(async ({ ctx, next }) => {
// Only used on protected routes // Only used on protected routes
const enforceAccess = t.middleware(async ({ ctx, next, type, getRawInput }) => { const enforceAccess = t.middleware(async ({ ctx, next, type, getRawInput }) => {
return runWithAlsSession(ctx.session.session?.id, async () => {
const rawInput = await getRawInput(); const rawInput = await getRawInput();
if (type === 'mutation' && process.env.DEMO_USER_ID) { if (type === 'mutation' && process.env.DEMO_USER_ID) {
throw new TRPCError({ throw new TRPCError({
@@ -111,7 +105,7 @@ const enforceAccess = t.middleware(async ({ ctx, next, type, getRawInput }) => {
} }
if (has('projectId', rawInput)) { if (has('projectId', rawInput)) {
const access = await getProjectAccessCached({ const access = await getProjectAccess({
userId: ctx.session.userId!, userId: ctx.session.userId!,
projectId: rawInput.projectId as string, projectId: rawInput.projectId as string,
}); });
@@ -122,7 +116,7 @@ const enforceAccess = t.middleware(async ({ ctx, next, type, getRawInput }) => {
} }
if (has('organizationId', rawInput)) { if (has('organizationId', rawInput)) {
const access = await getOrganizationAccessCached({ const access = await getOrganizationAccess({
userId: ctx.session.userId!, userId: ctx.session.userId!,
organizationId: rawInput.organizationId as string, organizationId: rawInput.organizationId as string,
}); });
@@ -134,6 +128,7 @@ const enforceAccess = t.middleware(async ({ ctx, next, type, getRawInput }) => {
return next(); return next();
}); });
});
export const createTRPCRouter = t.router; export const createTRPCRouter = t.router;
@@ -157,11 +152,21 @@ const loggerMiddleware = t.middleware(
}, },
); );
export const publicProcedure = t.procedure.use(loggerMiddleware); const sessionScopeMiddleware = t.middleware(async ({ ctx, next }) => {
const sessionId = ctx.session.session?.id ?? null;
return runWithAlsSession(sessionId, async () => {
return next();
});
});
export const publicProcedure = t.procedure
.use(loggerMiddleware)
.use(sessionScopeMiddleware);
export const protectedProcedure = t.procedure export const protectedProcedure = t.procedure
.use(enforceUserIsAuthed) .use(enforceUserIsAuthed)
.use(enforceAccess) .use(enforceAccess)
.use(loggerMiddleware); .use(loggerMiddleware)
.use(sessionScopeMiddleware);
const middlewareMarker = 'middlewareMarker' as 'middlewareMarker' & { const middlewareMarker = 'middlewareMarker' as 'middlewareMarker' & {
__brand: 'middlewareMarker'; __brand: 'middlewareMarker';