fix: read-after-write issues (#215)
* fix: read-after-write issues * fix: coderabbit comments * fix: clear cache on invite * fix: use primary after a read
This commit is contained in:
committed by
GitHub
parent
abacf66155
commit
f454449365
@@ -9,7 +9,7 @@ import {
|
|||||||
} from '@/utils/ai-tools';
|
} from '@/utils/ai-tools';
|
||||||
import { HttpError } from '@/utils/errors';
|
import { HttpError } from '@/utils/errors';
|
||||||
import { db, getOrganizationByProjectIdCached } from '@openpanel/db';
|
import { db, getOrganizationByProjectIdCached } from '@openpanel/db';
|
||||||
import { getProjectAccessCached } from '@openpanel/trpc/src/access';
|
import { getProjectAccess } from '@openpanel/trpc/src/access';
|
||||||
import { type Message, appendResponseMessages, streamText } from 'ai';
|
import { type Message, appendResponseMessages, streamText } from 'ai';
|
||||||
import type { FastifyReply, FastifyRequest } from 'fastify';
|
import type { FastifyReply, FastifyRequest } from 'fastify';
|
||||||
|
|
||||||
@@ -37,7 +37,7 @@ export async function chat(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const organization = await getOrganizationByProjectIdCached(projectId);
|
const organization = await getOrganizationByProjectIdCached(projectId);
|
||||||
const access = await getProjectAccessCached({
|
const access = await getProjectAccess({
|
||||||
projectId,
|
projectId,
|
||||||
userId: session.userId,
|
userId: session.userId,
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -113,6 +113,17 @@ export async function slackWebhook(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function clearOrganizationCache(organizationId: string) {
|
||||||
|
const projects = await db.project.findMany({
|
||||||
|
where: {
|
||||||
|
organizationId,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
for (const project of projects) {
|
||||||
|
await getOrganizationByProjectIdCached.clear(project.id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export async function polarWebhook(
|
export async function polarWebhook(
|
||||||
request: FastifyRequest<{
|
request: FastifyRequest<{
|
||||||
Querystring: unknown;
|
Querystring: unknown;
|
||||||
@@ -141,8 +152,11 @@ export async function polarWebhook(
|
|||||||
},
|
},
|
||||||
data: {
|
data: {
|
||||||
subscriptionPeriodEventsCount: 0,
|
subscriptionPeriodEventsCount: 0,
|
||||||
|
subscriptionPeriodEventsCountExceededAt: null,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
await clearOrganizationCache(metadata.organizationId);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -205,15 +219,7 @@ export async function polarWebhook(
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
const projects = await db.project.findMany({
|
await clearOrganizationCache(metadata.organizationId);
|
||||||
where: {
|
|
||||||
organizationId: metadata.organizationId,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
for (const project of projects) {
|
|
||||||
await getOrganizationByProjectIdCached.clear(project.id);
|
|
||||||
}
|
|
||||||
|
|
||||||
await publishEvent('organization', 'subscription_updated', {
|
await publishEvent('organization', 'subscription_updated', {
|
||||||
organizationId: metadata.organizationId,
|
organizationId: metadata.organizationId,
|
||||||
|
|||||||
@@ -8,14 +8,18 @@ import Fastify from 'fastify';
|
|||||||
import metricsPlugin from 'fastify-metrics';
|
import metricsPlugin from 'fastify-metrics';
|
||||||
|
|
||||||
import { generateId } from '@openpanel/common';
|
import { generateId } from '@openpanel/common';
|
||||||
import type { IServiceClientWithProject } from '@openpanel/db';
|
import {
|
||||||
import { getRedisPub } from '@openpanel/redis';
|
type IServiceClientWithProject,
|
||||||
|
runWithAlsSession,
|
||||||
|
} from '@openpanel/db';
|
||||||
|
import { getCache, getRedisPub } from '@openpanel/redis';
|
||||||
import type { AppRouter } from '@openpanel/trpc';
|
import type { AppRouter } from '@openpanel/trpc';
|
||||||
import { appRouter, createContext } from '@openpanel/trpc';
|
import { appRouter, createContext } from '@openpanel/trpc';
|
||||||
|
|
||||||
import {
|
import {
|
||||||
EMPTY_SESSION,
|
EMPTY_SESSION,
|
||||||
type SessionValidationResult,
|
type SessionValidationResult,
|
||||||
|
decodeSessionToken,
|
||||||
validateSessionToken,
|
validateSessionToken,
|
||||||
} from '@openpanel/auth';
|
} from '@openpanel/auth';
|
||||||
import sourceMapSupport from 'source-map-support';
|
import sourceMapSupport from 'source-map-support';
|
||||||
@@ -140,7 +144,14 @@ const startServer = async () => {
|
|||||||
instance.addHook('onRequest', async (req) => {
|
instance.addHook('onRequest', async (req) => {
|
||||||
if (req.cookies?.session) {
|
if (req.cookies?.session) {
|
||||||
try {
|
try {
|
||||||
const session = await validateSessionToken(req.cookies.session);
|
const sessionId = decodeSessionToken(req.cookies.session);
|
||||||
|
const session = await runWithAlsSession(sessionId, () =>
|
||||||
|
sessionId
|
||||||
|
? getCache(`validateSession:${sessionId}`, 60 * 5, async () =>
|
||||||
|
validateSessionToken(req.cookies.session),
|
||||||
|
)
|
||||||
|
: validateSessionToken(req.cookies.session),
|
||||||
|
);
|
||||||
if (session.session) {
|
if (session.session) {
|
||||||
req.session = session;
|
req.session = session;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,9 +24,11 @@ export function useSessionExtension() {
|
|||||||
1000 * 60 * 5,
|
1000 * 60 * 5,
|
||||||
);
|
);
|
||||||
|
|
||||||
extendSessionFn();
|
// Delay initial call a bit to prioritize other requests
|
||||||
|
const timer = setTimeout(() => extendSessionFn(), 5000);
|
||||||
|
|
||||||
return () => {
|
return () => {
|
||||||
|
clearTimeout(timer);
|
||||||
if (intervalRef.current) {
|
if (intervalRef.current) {
|
||||||
clearInterval(intervalRef.current);
|
clearInterval(intervalRef.current);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -59,8 +59,14 @@ export async function createDemoSession(
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const decodeSessionToken = (token: string): string | null => {
|
||||||
|
return token
|
||||||
|
? encodeHexLowerCase(sha256(new TextEncoder().encode(token)))
|
||||||
|
: null;
|
||||||
|
};
|
||||||
|
|
||||||
export async function validateSessionToken(
|
export async function validateSessionToken(
|
||||||
token: string | null,
|
token: string | null | undefined,
|
||||||
): Promise<SessionValidationResult> {
|
): Promise<SessionValidationResult> {
|
||||||
if (process.env.DEMO_USER_ID) {
|
if (process.env.DEMO_USER_ID) {
|
||||||
return createDemoSession(process.env.DEMO_USER_ID);
|
return createDemoSession(process.env.DEMO_USER_ID);
|
||||||
@@ -69,7 +75,10 @@ export async function validateSessionToken(
|
|||||||
if (!token) {
|
if (!token) {
|
||||||
return EMPTY_SESSION;
|
return EMPTY_SESSION;
|
||||||
}
|
}
|
||||||
const sessionId = encodeHexLowerCase(sha256(new TextEncoder().encode(token)));
|
const sessionId = decodeSessionToken(token);
|
||||||
|
if (!sessionId) {
|
||||||
|
return EMPTY_SESSION;
|
||||||
|
}
|
||||||
const result = await db.$primary().session.findUnique({
|
const result = await db.$primary().session.findUnique({
|
||||||
where: {
|
where: {
|
||||||
id: sessionId,
|
id: sessionId,
|
||||||
|
|||||||
@@ -19,7 +19,9 @@ export * from './src/services/reference.service';
|
|||||||
export * from './src/services/id.service';
|
export * from './src/services/id.service';
|
||||||
export * from './src/services/retention.service';
|
export * from './src/services/retention.service';
|
||||||
export * from './src/services/notification.service';
|
export * from './src/services/notification.service';
|
||||||
|
export * from './src/services/access.service';
|
||||||
export * from './src/buffers';
|
export * from './src/buffers';
|
||||||
export * from './src/types';
|
export * from './src/types';
|
||||||
export * from './src/clickhouse/query-builder';
|
export * from './src/clickhouse/query-builder';
|
||||||
export * from './src/services/overview.service';
|
export * from './src/services/overview.service';
|
||||||
|
export * from './src/session-context';
|
||||||
|
|||||||
3
packages/db/src/logger.ts
Normal file
3
packages/db/src/logger.ts
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
import { createLogger } from '@openpanel/logger';
|
||||||
|
|
||||||
|
export const logger = createLogger({ name: 'db:prisma' });
|
||||||
@@ -1,11 +1,15 @@
|
|||||||
import { createLogger } from '@openpanel/logger';
|
import { createLogger } from '@openpanel/logger';
|
||||||
import { readReplicas } from '@prisma/extension-read-replicas';
|
import { readReplicas } from '@prisma/extension-read-replicas';
|
||||||
import { type Organization, PrismaClient } from './generated/prisma/client';
|
import {
|
||||||
|
type Organization,
|
||||||
|
Prisma,
|
||||||
|
PrismaClient,
|
||||||
|
} from './generated/prisma/client';
|
||||||
|
import { logger } from './logger';
|
||||||
|
import { sessionConsistency } from './session-consistency';
|
||||||
|
|
||||||
export * from './generated/prisma/client';
|
export * from './generated/prisma/client';
|
||||||
|
|
||||||
const logger = createLogger({ name: 'db' });
|
|
||||||
|
|
||||||
const isWillBeCanceled = (
|
const isWillBeCanceled = (
|
||||||
organization: Pick<
|
organization: Pick<
|
||||||
Organization,
|
Organization,
|
||||||
@@ -30,11 +34,6 @@ const getPrismaClient = () => {
|
|||||||
const prisma = new PrismaClient({
|
const prisma = new PrismaClient({
|
||||||
log: ['error'],
|
log: ['error'],
|
||||||
})
|
})
|
||||||
.$extends(
|
|
||||||
readReplicas({
|
|
||||||
url: process.env.DATABASE_URL_REPLICA ?? process.env.DATABASE_URL!,
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
.$extends({
|
.$extends({
|
||||||
query: {
|
query: {
|
||||||
async $allOperations({ operation, model, args, query }) {
|
async $allOperations({ operation, model, args, query }) {
|
||||||
@@ -53,6 +52,8 @@ const getPrismaClient = () => {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
.$extends(sessionConsistency())
|
||||||
.$extends({
|
.$extends({
|
||||||
result: {
|
result: {
|
||||||
organization: {
|
organization: {
|
||||||
@@ -258,7 +259,12 @@ const getPrismaClient = () => {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
});
|
})
|
||||||
|
.$extends(
|
||||||
|
readReplicas({
|
||||||
|
url: process.env.DATABASE_URL_REPLICA ?? process.env.DATABASE_URL!,
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
return prisma;
|
return prisma;
|
||||||
};
|
};
|
||||||
|
|||||||
96
packages/db/src/services/access.service.ts
Normal file
96
packages/db/src/services/access.service.ts
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
import { cacheable } from '@openpanel/redis';
|
||||||
|
import { db } from '../prisma-client';
|
||||||
|
import { getProjectById } from './project.service';
|
||||||
|
|
||||||
|
export const getProjectAccess = cacheable(
|
||||||
|
'getProjectAccess',
|
||||||
|
async ({
|
||||||
|
userId,
|
||||||
|
projectId,
|
||||||
|
}: {
|
||||||
|
userId: string;
|
||||||
|
projectId: string;
|
||||||
|
}) => {
|
||||||
|
try {
|
||||||
|
// Check if user has access to the project
|
||||||
|
const project = await getProjectById(projectId);
|
||||||
|
if (!project?.organizationId) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const [projectAccess, member] = await Promise.all([
|
||||||
|
db.$primary().projectAccess.findMany({
|
||||||
|
where: {
|
||||||
|
userId,
|
||||||
|
organizationId: project.organizationId,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
db.$primary().member.findFirst({
|
||||||
|
where: {
|
||||||
|
organizationId: project.organizationId,
|
||||||
|
userId,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
]);
|
||||||
|
|
||||||
|
if (projectAccess.length === 0 && member) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return projectAccess.find((item) => item.projectId === projectId);
|
||||||
|
} catch (err) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
60 * 5,
|
||||||
|
);
|
||||||
|
|
||||||
|
export const getOrganizationAccess = cacheable(
|
||||||
|
'getOrganizationAccess',
|
||||||
|
async ({
|
||||||
|
userId,
|
||||||
|
organizationId,
|
||||||
|
}: {
|
||||||
|
userId: string;
|
||||||
|
organizationId: string;
|
||||||
|
}) => {
|
||||||
|
return db.$primary().member.findFirst({
|
||||||
|
where: {
|
||||||
|
userId,
|
||||||
|
organizationId,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
},
|
||||||
|
60 * 5,
|
||||||
|
);
|
||||||
|
|
||||||
|
export async function getClientAccess({
|
||||||
|
userId,
|
||||||
|
clientId,
|
||||||
|
}: {
|
||||||
|
userId: string;
|
||||||
|
clientId: string;
|
||||||
|
}) {
|
||||||
|
const client = await db.client.findFirst({
|
||||||
|
where: {
|
||||||
|
id: clientId,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!client) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (client.projectId) {
|
||||||
|
return getProjectAccess({ userId, projectId: client.projectId });
|
||||||
|
}
|
||||||
|
|
||||||
|
if (client.organizationId) {
|
||||||
|
return getOrganizationAccess({
|
||||||
|
userId,
|
||||||
|
organizationId: client.organizationId,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
@@ -5,7 +5,8 @@ import { chQuery, formatClickhouseDate } from '../clickhouse/client';
|
|||||||
import type { Invite, Prisma, ProjectAccess, User } from '../prisma-client';
|
import type { Invite, Prisma, ProjectAccess, User } from '../prisma-client';
|
||||||
import { db } from '../prisma-client';
|
import { db } from '../prisma-client';
|
||||||
import { createSqlBuilder } from '../sql-builder';
|
import { createSqlBuilder } from '../sql-builder';
|
||||||
import type { IServiceProject } from './project.service';
|
import { getOrganizationAccess, getProjectAccess } from './access.service';
|
||||||
|
import { type IServiceProject, getProjectById } from './project.service';
|
||||||
export type IServiceOrganization = Awaited<
|
export type IServiceOrganization = Awaited<
|
||||||
ReturnType<typeof db.organization.findUniqueOrThrow>
|
ReturnType<typeof db.organization.findUniqueOrThrow>
|
||||||
>;
|
>;
|
||||||
@@ -61,7 +62,7 @@ export async function getOrganizationByProjectId(projectId: string) {
|
|||||||
|
|
||||||
export const getOrganizationByProjectIdCached = cacheable(
|
export const getOrganizationByProjectIdCached = cacheable(
|
||||||
getOrganizationByProjectId,
|
getOrganizationByProjectId,
|
||||||
60 * 60 * 24,
|
60 * 5,
|
||||||
);
|
);
|
||||||
|
|
||||||
export async function getInvites(organizationId: string) {
|
export async function getInvites(organizationId: string) {
|
||||||
@@ -168,8 +169,14 @@ export async function connectUserToOrganization({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
await getOrganizationAccess.clear({
|
||||||
|
userId: user.id,
|
||||||
|
organizationId: invite.organizationId,
|
||||||
|
});
|
||||||
|
|
||||||
if (invite.projectAccess.length > 0) {
|
if (invite.projectAccess.length > 0) {
|
||||||
for (const projectId of invite.projectAccess) {
|
for (const projectId of invite.projectAccess) {
|
||||||
|
await getProjectAccess.clear({ userId: user.id, projectId });
|
||||||
await db.projectAccess.create({
|
await db.projectAccess.create({
|
||||||
data: {
|
data: {
|
||||||
projectId,
|
projectId,
|
||||||
|
|||||||
225
packages/db/src/session-consistency.ts
Normal file
225
packages/db/src/session-consistency.ts
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
import { getRedisCache } from '@openpanel/redis';
|
||||||
|
import type { Operation } from '@prisma/client/runtime/client';
|
||||||
|
import { Prisma, type PrismaClient } from './generated/prisma/client';
|
||||||
|
import { logger } from './logger';
|
||||||
|
import { getAlsSessionId } from './session-context';
|
||||||
|
|
||||||
|
type BarePrismaClient = {
|
||||||
|
$queryRaw: <T>(query: TemplateStringsArray, ...args: unknown[]) => Promise<T>;
|
||||||
|
};
|
||||||
|
|
||||||
|
// WAL LSN tracking for read-after-write consistency
|
||||||
|
const LSN_CACHE_PREFIX = 'db:wal_lsn:';
|
||||||
|
const LSN_CACHE_TTL = 5;
|
||||||
|
const MAX_RETRY_ATTEMPTS = 5;
|
||||||
|
const INITIAL_RETRY_DELAY_MS = 10;
|
||||||
|
|
||||||
|
const READ_OPERATIONS: Operation[] = [
|
||||||
|
'findUnique',
|
||||||
|
'findUniqueOrThrow',
|
||||||
|
'findFirst',
|
||||||
|
'findFirstOrThrow',
|
||||||
|
'findMany',
|
||||||
|
'aggregate',
|
||||||
|
'groupBy',
|
||||||
|
'count',
|
||||||
|
];
|
||||||
|
|
||||||
|
const WRITE_OPERATIONS: Operation[] = [
|
||||||
|
'create',
|
||||||
|
'update',
|
||||||
|
'delete',
|
||||||
|
'createMany',
|
||||||
|
'createManyAndReturn',
|
||||||
|
'updateMany',
|
||||||
|
'deleteMany',
|
||||||
|
'upsert',
|
||||||
|
];
|
||||||
|
|
||||||
|
const isWriteOperation = (operation: string) =>
|
||||||
|
WRITE_OPERATIONS.includes(operation as Operation);
|
||||||
|
|
||||||
|
const isReadOperation = (operation: string) =>
|
||||||
|
READ_OPERATIONS.includes(operation as Operation);
|
||||||
|
|
||||||
|
async function getCurrentWalLsn(
|
||||||
|
prismaClient: BarePrismaClient,
|
||||||
|
): Promise<string | null> {
|
||||||
|
try {
|
||||||
|
const result = await prismaClient.$queryRaw<[{ lsn: string }]>`
|
||||||
|
SELECT pg_current_wal_lsn()::text AS lsn
|
||||||
|
`;
|
||||||
|
return result[0]?.lsn || null;
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Failed to get WAL LSN', { error });
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function cacheWalLsnForSession(
|
||||||
|
sessionId: string,
|
||||||
|
lsn: string,
|
||||||
|
): Promise<void> {
|
||||||
|
try {
|
||||||
|
const redis = getRedisCache();
|
||||||
|
await redis.setex(`${LSN_CACHE_PREFIX}${sessionId}`, LSN_CACHE_TTL, lsn);
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Failed to cache WAL LSN', { error, sessionId });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function getCachedWalLsn(sessionId: string): Promise<string | null> {
|
||||||
|
try {
|
||||||
|
const redis = getRedisCache();
|
||||||
|
return await redis.get(`${LSN_CACHE_PREFIX}${sessionId}`);
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Failed to get cached WAL LSN', { error, sessionId });
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function compareWalLsn(lsn1: string, lsn2: string): number {
|
||||||
|
const [x1, y1] = lsn1.split('/').map((x) => BigInt(`0x${x}`));
|
||||||
|
const [x2, y2] = lsn2.split('/').map((x) => BigInt(`0x${x}`));
|
||||||
|
|
||||||
|
const v1 = ((x1 ?? 0n) << 32n) + (y1 ?? 0n);
|
||||||
|
const v2 = ((x2 ?? 0n) << 32n) + (y2 ?? 0n);
|
||||||
|
|
||||||
|
if (v1 < v2) return -1;
|
||||||
|
if (v1 > v2) return 1;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function sleep(ms: number): Promise<void> {
|
||||||
|
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Method not used for now,
|
||||||
|
// Need a way to check LSN on the actual replica that will be used for the read.
|
||||||
|
async function waitForReplicaCatchup(
|
||||||
|
prismaClient: BarePrismaClient,
|
||||||
|
sessionId: string,
|
||||||
|
): Promise<boolean> {
|
||||||
|
const expectedLsn = await getCachedWalLsn(sessionId);
|
||||||
|
|
||||||
|
if (!expectedLsn) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (let attempt = 0; attempt < MAX_RETRY_ATTEMPTS; attempt++) {
|
||||||
|
const currentLsn = await getCurrentWalLsn(prismaClient);
|
||||||
|
if (!currentLsn) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if replica has caught up (current >= expected)
|
||||||
|
if (compareWalLsn(currentLsn, expectedLsn) >= 0) {
|
||||||
|
logger.debug('Replica caught up', {
|
||||||
|
attempt: attempt + 1,
|
||||||
|
currentLsn,
|
||||||
|
expectedLsn,
|
||||||
|
sessionId,
|
||||||
|
});
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exponential backoff
|
||||||
|
if (attempt < MAX_RETRY_ATTEMPTS - 1) {
|
||||||
|
const delayMs = INITIAL_RETRY_DELAY_MS * 2 ** attempt;
|
||||||
|
logger.debug('Waiting for replica to catch up', {
|
||||||
|
attempt: attempt + 1,
|
||||||
|
delayMs,
|
||||||
|
currentLsn,
|
||||||
|
expectedLsn,
|
||||||
|
sessionId,
|
||||||
|
});
|
||||||
|
await sleep(delayMs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.warn(
|
||||||
|
'Replica did not catch up after max retries, falling back to primary',
|
||||||
|
{
|
||||||
|
sessionId,
|
||||||
|
expectedLsn,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Prisma extension for session-based read-after-write consistency.
|
||||||
|
*
|
||||||
|
* This extension tracks WAL LSN positions after writes and ensures that
|
||||||
|
* subsequent reads within the same session see those writes, even when
|
||||||
|
* using read replicas.
|
||||||
|
*
|
||||||
|
* How it works:
|
||||||
|
* 1. After any write operation with a session ID, it captures the WAL LSN
|
||||||
|
* 2. Before read operations with a session ID, it checks if the replica has caught up
|
||||||
|
* 3. If the replica hasn't caught up after retries, it forces the read to the primary
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
export function sessionConsistency() {
|
||||||
|
return Prisma.defineExtension((client) =>
|
||||||
|
client.$extends({
|
||||||
|
name: 'session-consistency',
|
||||||
|
query: {
|
||||||
|
$allOperations: async ({
|
||||||
|
operation,
|
||||||
|
model,
|
||||||
|
args,
|
||||||
|
query,
|
||||||
|
// This is a hack to force reads to primary when replica hasn't caught up.
|
||||||
|
// The readReplicas extension routes queries to primary when in a transaction,
|
||||||
|
// so we set __internalParams.transaction = true to achieve this.
|
||||||
|
// @ts-expect-error - __internalParams is not in the types
|
||||||
|
__internalParams,
|
||||||
|
}) => {
|
||||||
|
const sessionId = getAlsSessionId();
|
||||||
|
|
||||||
|
// For write operations with session: cache WAL LSN after write
|
||||||
|
if (isWriteOperation(operation)) {
|
||||||
|
logger.info('Prisma operation', {
|
||||||
|
operation,
|
||||||
|
args,
|
||||||
|
model,
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = await query(args);
|
||||||
|
|
||||||
|
if (sessionId) {
|
||||||
|
// Get current WAL LSN and cache it for this session
|
||||||
|
const lsn = await getCurrentWalLsn(client);
|
||||||
|
if (lsn) {
|
||||||
|
await cacheWalLsnForSession(sessionId, lsn);
|
||||||
|
logger.debug('Cached WAL LSN after write', {
|
||||||
|
sessionId,
|
||||||
|
lsn,
|
||||||
|
operation,
|
||||||
|
model,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// For now, we just force the read to the primary without checking the replica
|
||||||
|
// Since the check probably goes to the primary anyways it will always be true,
|
||||||
|
// Not sure how to check LSN on the actual replica that will be used for the read.
|
||||||
|
if (
|
||||||
|
isReadOperation(operation) &&
|
||||||
|
sessionId &&
|
||||||
|
(await getCachedWalLsn(sessionId))
|
||||||
|
) {
|
||||||
|
// This will force readReplicas extension to use primary
|
||||||
|
__internalParams.transaction = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return query(args);
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
}
|
||||||
12
packages/db/src/session-context.ts
Normal file
12
packages/db/src/session-context.ts
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
import { AsyncLocalStorage } from 'node:async_hooks';
|
||||||
|
|
||||||
|
type Ctx = { sessionId: string | null };
|
||||||
|
|
||||||
|
export const als = new AsyncLocalStorage<Ctx>();
|
||||||
|
|
||||||
|
export const runWithAlsSession = <T>(
|
||||||
|
sid: string | null | undefined,
|
||||||
|
fn: () => Promise<T>,
|
||||||
|
) => als.run({ sessionId: sid || null }, fn);
|
||||||
|
|
||||||
|
export const getAlsSessionId = () => als.getStore()?.sessionId ?? null;
|
||||||
@@ -1,5 +1,9 @@
|
|||||||
import { getRedisCache } from './redis';
|
import { getRedisCache } from './redis';
|
||||||
|
|
||||||
|
export const deleteCache = async (key: string) => {
|
||||||
|
return getRedisCache().del(key);
|
||||||
|
};
|
||||||
|
|
||||||
export async function getCache<T>(
|
export async function getCache<T>(
|
||||||
key: string,
|
key: string,
|
||||||
expireInSec: number,
|
expireInSec: number,
|
||||||
|
|||||||
@@ -1,93 +1,5 @@
|
|||||||
import { db, getProjectById } from '@openpanel/db';
|
export {
|
||||||
import { cacheable } from '@openpanel/redis';
|
|
||||||
|
|
||||||
export const getProjectAccessCached = cacheable(getProjectAccess, 60 * 5);
|
|
||||||
export async function getProjectAccess({
|
|
||||||
userId,
|
|
||||||
projectId,
|
|
||||||
}: {
|
|
||||||
userId: string;
|
|
||||||
projectId: string;
|
|
||||||
}) {
|
|
||||||
try {
|
|
||||||
// Check if user has access to the project
|
|
||||||
const project = await getProjectById(projectId);
|
|
||||||
if (!project?.organizationId) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
const [projectAccess, member] = await Promise.all([
|
|
||||||
db.projectAccess.findMany({
|
|
||||||
where: {
|
|
||||||
userId,
|
|
||||||
organizationId: project.organizationId,
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
db.member.findFirst({
|
|
||||||
where: {
|
|
||||||
organizationId: project.organizationId,
|
|
||||||
userId,
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
]);
|
|
||||||
|
|
||||||
if (projectAccess.length === 0 && member) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return projectAccess.find((item) => item.projectId === projectId);
|
|
||||||
} catch (err) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export const getOrganizationAccessCached = cacheable(
|
|
||||||
getOrganizationAccess,
|
getOrganizationAccess,
|
||||||
60 * 5,
|
getProjectAccess,
|
||||||
);
|
getClientAccess,
|
||||||
export async function getOrganizationAccess({
|
} from '@openpanel/db';
|
||||||
userId,
|
|
||||||
organizationId,
|
|
||||||
}: {
|
|
||||||
userId: string;
|
|
||||||
organizationId: string;
|
|
||||||
}) {
|
|
||||||
return db.member.findFirst({
|
|
||||||
where: {
|
|
||||||
userId,
|
|
||||||
organizationId,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
export const getClientAccessCached = cacheable(getClientAccess, 60 * 5);
|
|
||||||
export async function getClientAccess({
|
|
||||||
userId,
|
|
||||||
clientId,
|
|
||||||
}: {
|
|
||||||
userId: string;
|
|
||||||
clientId: string;
|
|
||||||
}) {
|
|
||||||
const client = await db.client.findFirst({
|
|
||||||
where: {
|
|
||||||
id: clientId,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!client) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (client.projectId) {
|
|
||||||
return getProjectAccess({ userId, projectId: client.projectId });
|
|
||||||
}
|
|
||||||
|
|
||||||
if (client.organizationId) {
|
|
||||||
return getOrganizationAccess({
|
|
||||||
userId,
|
|
||||||
organizationId: client.organizationId,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import {
|
|||||||
getUserAccount,
|
getUserAccount,
|
||||||
} from '@openpanel/db';
|
} from '@openpanel/db';
|
||||||
import { sendEmail } from '@openpanel/email';
|
import { sendEmail } from '@openpanel/email';
|
||||||
|
import { deleteCache } from '@openpanel/redis';
|
||||||
import {
|
import {
|
||||||
zRequestResetPassword,
|
zRequestResetPassword,
|
||||||
zResetPassword,
|
zResetPassword,
|
||||||
@@ -74,6 +75,7 @@ export const authRouter = createTRPCRouter({
|
|||||||
deleteSessionTokenCookie(ctx.setCookie);
|
deleteSessionTokenCookie(ctx.setCookie);
|
||||||
if (ctx.session?.session?.id) {
|
if (ctx.session?.session?.id) {
|
||||||
await invalidateSession(ctx.session.session.id);
|
await invalidateSession(ctx.session.session.id);
|
||||||
|
await deleteCache(`validateSession:${ctx.session.session.id}`);
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
signInOAuth: publicProcedure
|
signInOAuth: publicProcedure
|
||||||
@@ -333,6 +335,7 @@ export const authRouter = createTRPCRouter({
|
|||||||
const session = await validateSessionToken(token);
|
const session = await validateSessionToken(token);
|
||||||
|
|
||||||
if (session.session) {
|
if (session.session) {
|
||||||
|
await deleteCache(`validateSession:${session.session.id}`);
|
||||||
// Re-set the cookie with updated expiration
|
// Re-set the cookie with updated expiration
|
||||||
setSessionTokenCookie(ctx.setCookie, token, session.session.expiresAt);
|
setSessionTokenCookie(ctx.setCookie, token, session.session.expiresAt);
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ import {
|
|||||||
differenceInWeeks,
|
differenceInWeeks,
|
||||||
formatISO,
|
formatISO,
|
||||||
} from 'date-fns';
|
} from 'date-fns';
|
||||||
import { getProjectAccessCached } from '../access';
|
import { getProjectAccess } from '../access';
|
||||||
import { TRPCAccessError } from '../errors';
|
import { TRPCAccessError } from '../errors';
|
||||||
import {
|
import {
|
||||||
cacheMiddleware,
|
cacheMiddleware,
|
||||||
@@ -367,7 +367,7 @@ export const chartRouter = createTRPCRouter({
|
|||||||
.input(zChartInput)
|
.input(zChartInput)
|
||||||
.query(async ({ input, ctx }) => {
|
.query(async ({ input, ctx }) => {
|
||||||
if (ctx.session.userId) {
|
if (ctx.session.userId) {
|
||||||
const access = await getProjectAccessCached({
|
const access = await getProjectAccess({
|
||||||
projectId: input.projectId,
|
projectId: input.projectId,
|
||||||
userId: ctx.session.userId,
|
userId: ctx.session.userId,
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ import {
|
|||||||
} from '@openpanel/validation';
|
} from '@openpanel/validation';
|
||||||
|
|
||||||
import { clone } from 'ramda';
|
import { clone } from 'ramda';
|
||||||
import { getProjectAccessCached } from '../access';
|
import { getProjectAccess } from '../access';
|
||||||
import { TRPCAccessError } from '../errors';
|
import { TRPCAccessError } from '../errors';
|
||||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from '../trpc';
|
import { createTRPCRouter, protectedProcedure, publicProcedure } from '../trpc';
|
||||||
|
|
||||||
@@ -266,7 +266,7 @@ export const eventRouter = createTRPCRouter({
|
|||||||
)
|
)
|
||||||
.query(async ({ input: { projectId, cursor, limit }, ctx }) => {
|
.query(async ({ input: { projectId, cursor, limit }, ctx }) => {
|
||||||
if (ctx.session.userId) {
|
if (ctx.session.userId) {
|
||||||
const access = await getProjectAccessCached({
|
const access = await getProjectAccess({
|
||||||
projectId,
|
projectId,
|
||||||
userId: ctx.session.userId,
|
userId: ctx.session.userId,
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import {
|
|||||||
zCreateSlackIntegration,
|
zCreateSlackIntegration,
|
||||||
zCreateWebhookIntegration,
|
zCreateWebhookIntegration,
|
||||||
} from '@openpanel/validation';
|
} from '@openpanel/validation';
|
||||||
import { getOrganizationAccessCached } from '../access';
|
import { getOrganizationAccess } from '../access';
|
||||||
import { TRPCAccessError } from '../errors';
|
import { TRPCAccessError } from '../errors';
|
||||||
import { createTRPCRouter, protectedProcedure } from '../trpc';
|
import { createTRPCRouter, protectedProcedure } from '../trpc';
|
||||||
|
|
||||||
@@ -23,7 +23,7 @@ export const integrationRouter = createTRPCRouter({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
const access = await getOrganizationAccessCached({
|
const access = await getOrganizationAccess({
|
||||||
userId: ctx.session.userId,
|
userId: ctx.session.userId,
|
||||||
organizationId: integration.organizationId,
|
organizationId: integration.organizationId,
|
||||||
});
|
});
|
||||||
@@ -122,7 +122,7 @@ export const integrationRouter = createTRPCRouter({
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
const access = await getOrganizationAccessCached({
|
const access = await getOrganizationAccess({
|
||||||
userId: ctx.session.userId,
|
userId: ctx.session.userId,
|
||||||
organizationId: integration.organizationId,
|
organizationId: integration.organizationId,
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -4,18 +4,15 @@ import { has } from 'ramda';
|
|||||||
import superjson from 'superjson';
|
import superjson from 'superjson';
|
||||||
import { ZodError } from 'zod';
|
import { ZodError } from 'zod';
|
||||||
|
|
||||||
import {
|
import { COOKIE_OPTIONS, type SessionValidationResult } from '@openpanel/auth';
|
||||||
COOKIE_OPTIONS,
|
import { runWithAlsSession } from '@openpanel/db';
|
||||||
EMPTY_SESSION,
|
import { getRedisCache } from '@openpanel/redis';
|
||||||
validateSessionToken,
|
|
||||||
} from '@openpanel/auth';
|
|
||||||
import { getCache, getRedisCache } from '@openpanel/redis';
|
|
||||||
import type { ISetCookie } from '@openpanel/validation';
|
import type { ISetCookie } from '@openpanel/validation';
|
||||||
import {
|
import {
|
||||||
createTrpcRedisLimiter,
|
createTrpcRedisLimiter,
|
||||||
defaultFingerPrint,
|
defaultFingerPrint,
|
||||||
} from '@trpc-limiter/redis';
|
} from '@trpc-limiter/redis';
|
||||||
import { getOrganizationAccessCached, getProjectAccessCached } from './access';
|
import { getOrganizationAccess, getProjectAccess } from './access';
|
||||||
import { TRPCAccessError } from './errors';
|
import { TRPCAccessError } from './errors';
|
||||||
|
|
||||||
export const rateLimitMiddleware = ({
|
export const rateLimitMiddleware = ({
|
||||||
@@ -44,10 +41,6 @@ export async function createContext({ req, res }: CreateFastifyContextOptions) {
|
|||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
const session = cookies?.session
|
|
||||||
? await validateSessionToken(cookies.session!)
|
|
||||||
: EMPTY_SESSION;
|
|
||||||
|
|
||||||
if (process.env.NODE_ENV !== 'production') {
|
if (process.env.NODE_ENV !== 'production') {
|
||||||
await new Promise((res) =>
|
await new Promise((res) =>
|
||||||
setTimeout(() => res(1), Math.min(Math.random() * 500, 200)),
|
setTimeout(() => res(1), Math.min(Math.random() * 500, 200)),
|
||||||
@@ -57,7 +50,7 @@ export async function createContext({ req, res }: CreateFastifyContextOptions) {
|
|||||||
return {
|
return {
|
||||||
req,
|
req,
|
||||||
res,
|
res,
|
||||||
session,
|
session: (req as any).session as SessionValidationResult,
|
||||||
// we do not get types for `setCookie` from fastify
|
// we do not get types for `setCookie` from fastify
|
||||||
// so define it here and be safe in routers
|
// so define it here and be safe in routers
|
||||||
setCookie,
|
setCookie,
|
||||||
@@ -102,6 +95,7 @@ const enforceUserIsAuthed = t.middleware(async ({ ctx, next }) => {
|
|||||||
|
|
||||||
// Only used on protected routes
|
// Only used on protected routes
|
||||||
const enforceAccess = t.middleware(async ({ ctx, next, type, getRawInput }) => {
|
const enforceAccess = t.middleware(async ({ ctx, next, type, getRawInput }) => {
|
||||||
|
return runWithAlsSession(ctx.session.session?.id, async () => {
|
||||||
const rawInput = await getRawInput();
|
const rawInput = await getRawInput();
|
||||||
if (type === 'mutation' && process.env.DEMO_USER_ID) {
|
if (type === 'mutation' && process.env.DEMO_USER_ID) {
|
||||||
throw new TRPCError({
|
throw new TRPCError({
|
||||||
@@ -111,7 +105,7 @@ const enforceAccess = t.middleware(async ({ ctx, next, type, getRawInput }) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (has('projectId', rawInput)) {
|
if (has('projectId', rawInput)) {
|
||||||
const access = await getProjectAccessCached({
|
const access = await getProjectAccess({
|
||||||
userId: ctx.session.userId!,
|
userId: ctx.session.userId!,
|
||||||
projectId: rawInput.projectId as string,
|
projectId: rawInput.projectId as string,
|
||||||
});
|
});
|
||||||
@@ -122,7 +116,7 @@ const enforceAccess = t.middleware(async ({ ctx, next, type, getRawInput }) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (has('organizationId', rawInput)) {
|
if (has('organizationId', rawInput)) {
|
||||||
const access = await getOrganizationAccessCached({
|
const access = await getOrganizationAccess({
|
||||||
userId: ctx.session.userId!,
|
userId: ctx.session.userId!,
|
||||||
organizationId: rawInput.organizationId as string,
|
organizationId: rawInput.organizationId as string,
|
||||||
});
|
});
|
||||||
@@ -134,6 +128,7 @@ const enforceAccess = t.middleware(async ({ ctx, next, type, getRawInput }) => {
|
|||||||
|
|
||||||
return next();
|
return next();
|
||||||
});
|
});
|
||||||
|
});
|
||||||
|
|
||||||
export const createTRPCRouter = t.router;
|
export const createTRPCRouter = t.router;
|
||||||
|
|
||||||
@@ -157,11 +152,21 @@ const loggerMiddleware = t.middleware(
|
|||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
export const publicProcedure = t.procedure.use(loggerMiddleware);
|
const sessionScopeMiddleware = t.middleware(async ({ ctx, next }) => {
|
||||||
|
const sessionId = ctx.session.session?.id ?? null;
|
||||||
|
return runWithAlsSession(sessionId, async () => {
|
||||||
|
return next();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
export const publicProcedure = t.procedure
|
||||||
|
.use(loggerMiddleware)
|
||||||
|
.use(sessionScopeMiddleware);
|
||||||
export const protectedProcedure = t.procedure
|
export const protectedProcedure = t.procedure
|
||||||
.use(enforceUserIsAuthed)
|
.use(enforceUserIsAuthed)
|
||||||
.use(enforceAccess)
|
.use(enforceAccess)
|
||||||
.use(loggerMiddleware);
|
.use(loggerMiddleware)
|
||||||
|
.use(sessionScopeMiddleware);
|
||||||
|
|
||||||
const middlewareMarker = 'middlewareMarker' as 'middlewareMarker' & {
|
const middlewareMarker = 'middlewareMarker' as 'middlewareMarker' & {
|
||||||
__brand: 'middlewareMarker';
|
__brand: 'middlewareMarker';
|
||||||
|
|||||||
Reference in New Issue
Block a user