fix: read-after-write issues (#215)
* fix: read-after-write issues * fix: coderabbit comments * fix: clear cache on invite * fix: use primary after a read
This commit is contained in:
committed by
GitHub
parent
abacf66155
commit
f454449365
3
packages/db/src/logger.ts
Normal file
3
packages/db/src/logger.ts
Normal file
@@ -0,0 +1,3 @@
|
||||
import { createLogger } from '@openpanel/logger';
|
||||
|
||||
export const logger = createLogger({ name: 'db:prisma' });
|
||||
@@ -1,11 +1,15 @@
|
||||
import { createLogger } from '@openpanel/logger';
|
||||
import { readReplicas } from '@prisma/extension-read-replicas';
|
||||
import { type Organization, PrismaClient } from './generated/prisma/client';
|
||||
import {
|
||||
type Organization,
|
||||
Prisma,
|
||||
PrismaClient,
|
||||
} from './generated/prisma/client';
|
||||
import { logger } from './logger';
|
||||
import { sessionConsistency } from './session-consistency';
|
||||
|
||||
export * from './generated/prisma/client';
|
||||
|
||||
const logger = createLogger({ name: 'db' });
|
||||
|
||||
const isWillBeCanceled = (
|
||||
organization: Pick<
|
||||
Organization,
|
||||
@@ -30,11 +34,6 @@ const getPrismaClient = () => {
|
||||
const prisma = new PrismaClient({
|
||||
log: ['error'],
|
||||
})
|
||||
.$extends(
|
||||
readReplicas({
|
||||
url: process.env.DATABASE_URL_REPLICA ?? process.env.DATABASE_URL!,
|
||||
}),
|
||||
)
|
||||
.$extends({
|
||||
query: {
|
||||
async $allOperations({ operation, model, args, query }) {
|
||||
@@ -53,6 +52,8 @@ const getPrismaClient = () => {
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
.$extends(sessionConsistency())
|
||||
.$extends({
|
||||
result: {
|
||||
organization: {
|
||||
@@ -258,7 +259,12 @@ const getPrismaClient = () => {
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
})
|
||||
.$extends(
|
||||
readReplicas({
|
||||
url: process.env.DATABASE_URL_REPLICA ?? process.env.DATABASE_URL!,
|
||||
}),
|
||||
);
|
||||
|
||||
return prisma;
|
||||
};
|
||||
|
||||
96
packages/db/src/services/access.service.ts
Normal file
96
packages/db/src/services/access.service.ts
Normal file
@@ -0,0 +1,96 @@
|
||||
import { cacheable } from '@openpanel/redis';
|
||||
import { db } from '../prisma-client';
|
||||
import { getProjectById } from './project.service';
|
||||
|
||||
export const getProjectAccess = cacheable(
|
||||
'getProjectAccess',
|
||||
async ({
|
||||
userId,
|
||||
projectId,
|
||||
}: {
|
||||
userId: string;
|
||||
projectId: string;
|
||||
}) => {
|
||||
try {
|
||||
// Check if user has access to the project
|
||||
const project = await getProjectById(projectId);
|
||||
if (!project?.organizationId) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const [projectAccess, member] = await Promise.all([
|
||||
db.$primary().projectAccess.findMany({
|
||||
where: {
|
||||
userId,
|
||||
organizationId: project.organizationId,
|
||||
},
|
||||
}),
|
||||
db.$primary().member.findFirst({
|
||||
where: {
|
||||
organizationId: project.organizationId,
|
||||
userId,
|
||||
},
|
||||
}),
|
||||
]);
|
||||
|
||||
if (projectAccess.length === 0 && member) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return projectAccess.find((item) => item.projectId === projectId);
|
||||
} catch (err) {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
60 * 5,
|
||||
);
|
||||
|
||||
export const getOrganizationAccess = cacheable(
|
||||
'getOrganizationAccess',
|
||||
async ({
|
||||
userId,
|
||||
organizationId,
|
||||
}: {
|
||||
userId: string;
|
||||
organizationId: string;
|
||||
}) => {
|
||||
return db.$primary().member.findFirst({
|
||||
where: {
|
||||
userId,
|
||||
organizationId,
|
||||
},
|
||||
});
|
||||
},
|
||||
60 * 5,
|
||||
);
|
||||
|
||||
export async function getClientAccess({
|
||||
userId,
|
||||
clientId,
|
||||
}: {
|
||||
userId: string;
|
||||
clientId: string;
|
||||
}) {
|
||||
const client = await db.client.findFirst({
|
||||
where: {
|
||||
id: clientId,
|
||||
},
|
||||
});
|
||||
|
||||
if (!client) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (client.projectId) {
|
||||
return getProjectAccess({ userId, projectId: client.projectId });
|
||||
}
|
||||
|
||||
if (client.organizationId) {
|
||||
return getOrganizationAccess({
|
||||
userId,
|
||||
organizationId: client.organizationId,
|
||||
});
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
@@ -5,7 +5,8 @@ import { chQuery, formatClickhouseDate } from '../clickhouse/client';
|
||||
import type { Invite, Prisma, ProjectAccess, User } from '../prisma-client';
|
||||
import { db } from '../prisma-client';
|
||||
import { createSqlBuilder } from '../sql-builder';
|
||||
import type { IServiceProject } from './project.service';
|
||||
import { getOrganizationAccess, getProjectAccess } from './access.service';
|
||||
import { type IServiceProject, getProjectById } from './project.service';
|
||||
export type IServiceOrganization = Awaited<
|
||||
ReturnType<typeof db.organization.findUniqueOrThrow>
|
||||
>;
|
||||
@@ -61,7 +62,7 @@ export async function getOrganizationByProjectId(projectId: string) {
|
||||
|
||||
export const getOrganizationByProjectIdCached = cacheable(
|
||||
getOrganizationByProjectId,
|
||||
60 * 60 * 24,
|
||||
60 * 5,
|
||||
);
|
||||
|
||||
export async function getInvites(organizationId: string) {
|
||||
@@ -168,8 +169,14 @@ export async function connectUserToOrganization({
|
||||
},
|
||||
});
|
||||
|
||||
await getOrganizationAccess.clear({
|
||||
userId: user.id,
|
||||
organizationId: invite.organizationId,
|
||||
});
|
||||
|
||||
if (invite.projectAccess.length > 0) {
|
||||
for (const projectId of invite.projectAccess) {
|
||||
await getProjectAccess.clear({ userId: user.id, projectId });
|
||||
await db.projectAccess.create({
|
||||
data: {
|
||||
projectId,
|
||||
|
||||
225
packages/db/src/session-consistency.ts
Normal file
225
packages/db/src/session-consistency.ts
Normal file
@@ -0,0 +1,225 @@
|
||||
import { getRedisCache } from '@openpanel/redis';
|
||||
import type { Operation } from '@prisma/client/runtime/client';
|
||||
import { Prisma, type PrismaClient } from './generated/prisma/client';
|
||||
import { logger } from './logger';
|
||||
import { getAlsSessionId } from './session-context';
|
||||
|
||||
type BarePrismaClient = {
|
||||
$queryRaw: <T>(query: TemplateStringsArray, ...args: unknown[]) => Promise<T>;
|
||||
};
|
||||
|
||||
// WAL LSN tracking for read-after-write consistency
|
||||
const LSN_CACHE_PREFIX = 'db:wal_lsn:';
|
||||
const LSN_CACHE_TTL = 5;
|
||||
const MAX_RETRY_ATTEMPTS = 5;
|
||||
const INITIAL_RETRY_DELAY_MS = 10;
|
||||
|
||||
const READ_OPERATIONS: Operation[] = [
|
||||
'findUnique',
|
||||
'findUniqueOrThrow',
|
||||
'findFirst',
|
||||
'findFirstOrThrow',
|
||||
'findMany',
|
||||
'aggregate',
|
||||
'groupBy',
|
||||
'count',
|
||||
];
|
||||
|
||||
const WRITE_OPERATIONS: Operation[] = [
|
||||
'create',
|
||||
'update',
|
||||
'delete',
|
||||
'createMany',
|
||||
'createManyAndReturn',
|
||||
'updateMany',
|
||||
'deleteMany',
|
||||
'upsert',
|
||||
];
|
||||
|
||||
const isWriteOperation = (operation: string) =>
|
||||
WRITE_OPERATIONS.includes(operation as Operation);
|
||||
|
||||
const isReadOperation = (operation: string) =>
|
||||
READ_OPERATIONS.includes(operation as Operation);
|
||||
|
||||
async function getCurrentWalLsn(
|
||||
prismaClient: BarePrismaClient,
|
||||
): Promise<string | null> {
|
||||
try {
|
||||
const result = await prismaClient.$queryRaw<[{ lsn: string }]>`
|
||||
SELECT pg_current_wal_lsn()::text AS lsn
|
||||
`;
|
||||
return result[0]?.lsn || null;
|
||||
} catch (error) {
|
||||
logger.error('Failed to get WAL LSN', { error });
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
async function cacheWalLsnForSession(
|
||||
sessionId: string,
|
||||
lsn: string,
|
||||
): Promise<void> {
|
||||
try {
|
||||
const redis = getRedisCache();
|
||||
await redis.setex(`${LSN_CACHE_PREFIX}${sessionId}`, LSN_CACHE_TTL, lsn);
|
||||
} catch (error) {
|
||||
logger.error('Failed to cache WAL LSN', { error, sessionId });
|
||||
}
|
||||
}
|
||||
|
||||
async function getCachedWalLsn(sessionId: string): Promise<string | null> {
|
||||
try {
|
||||
const redis = getRedisCache();
|
||||
return await redis.get(`${LSN_CACHE_PREFIX}${sessionId}`);
|
||||
} catch (error) {
|
||||
logger.error('Failed to get cached WAL LSN', { error, sessionId });
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function compareWalLsn(lsn1: string, lsn2: string): number {
|
||||
const [x1, y1] = lsn1.split('/').map((x) => BigInt(`0x${x}`));
|
||||
const [x2, y2] = lsn2.split('/').map((x) => BigInt(`0x${x}`));
|
||||
|
||||
const v1 = ((x1 ?? 0n) << 32n) + (y1 ?? 0n);
|
||||
const v2 = ((x2 ?? 0n) << 32n) + (y2 ?? 0n);
|
||||
|
||||
if (v1 < v2) return -1;
|
||||
if (v1 > v2) return 1;
|
||||
return 0;
|
||||
}
|
||||
|
||||
async function sleep(ms: number): Promise<void> {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
// Method not used for now,
|
||||
// Need a way to check LSN on the actual replica that will be used for the read.
|
||||
async function waitForReplicaCatchup(
|
||||
prismaClient: BarePrismaClient,
|
||||
sessionId: string,
|
||||
): Promise<boolean> {
|
||||
const expectedLsn = await getCachedWalLsn(sessionId);
|
||||
|
||||
if (!expectedLsn) {
|
||||
return true;
|
||||
}
|
||||
|
||||
for (let attempt = 0; attempt < MAX_RETRY_ATTEMPTS; attempt++) {
|
||||
const currentLsn = await getCurrentWalLsn(prismaClient);
|
||||
if (!currentLsn) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check if replica has caught up (current >= expected)
|
||||
if (compareWalLsn(currentLsn, expectedLsn) >= 0) {
|
||||
logger.debug('Replica caught up', {
|
||||
attempt: attempt + 1,
|
||||
currentLsn,
|
||||
expectedLsn,
|
||||
sessionId,
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
||||
// Exponential backoff
|
||||
if (attempt < MAX_RETRY_ATTEMPTS - 1) {
|
||||
const delayMs = INITIAL_RETRY_DELAY_MS * 2 ** attempt;
|
||||
logger.debug('Waiting for replica to catch up', {
|
||||
attempt: attempt + 1,
|
||||
delayMs,
|
||||
currentLsn,
|
||||
expectedLsn,
|
||||
sessionId,
|
||||
});
|
||||
await sleep(delayMs);
|
||||
}
|
||||
}
|
||||
|
||||
logger.warn(
|
||||
'Replica did not catch up after max retries, falling back to primary',
|
||||
{
|
||||
sessionId,
|
||||
expectedLsn,
|
||||
},
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Prisma extension for session-based read-after-write consistency.
|
||||
*
|
||||
* This extension tracks WAL LSN positions after writes and ensures that
|
||||
* subsequent reads within the same session see those writes, even when
|
||||
* using read replicas.
|
||||
*
|
||||
* How it works:
|
||||
* 1. After any write operation with a session ID, it captures the WAL LSN
|
||||
* 2. Before read operations with a session ID, it checks if the replica has caught up
|
||||
* 3. If the replica hasn't caught up after retries, it forces the read to the primary
|
||||
*
|
||||
*/
|
||||
export function sessionConsistency() {
|
||||
return Prisma.defineExtension((client) =>
|
||||
client.$extends({
|
||||
name: 'session-consistency',
|
||||
query: {
|
||||
$allOperations: async ({
|
||||
operation,
|
||||
model,
|
||||
args,
|
||||
query,
|
||||
// This is a hack to force reads to primary when replica hasn't caught up.
|
||||
// The readReplicas extension routes queries to primary when in a transaction,
|
||||
// so we set __internalParams.transaction = true to achieve this.
|
||||
// @ts-expect-error - __internalParams is not in the types
|
||||
__internalParams,
|
||||
}) => {
|
||||
const sessionId = getAlsSessionId();
|
||||
|
||||
// For write operations with session: cache WAL LSN after write
|
||||
if (isWriteOperation(operation)) {
|
||||
logger.info('Prisma operation', {
|
||||
operation,
|
||||
args,
|
||||
model,
|
||||
});
|
||||
|
||||
const result = await query(args);
|
||||
|
||||
if (sessionId) {
|
||||
// Get current WAL LSN and cache it for this session
|
||||
const lsn = await getCurrentWalLsn(client);
|
||||
if (lsn) {
|
||||
await cacheWalLsnForSession(sessionId, lsn);
|
||||
logger.debug('Cached WAL LSN after write', {
|
||||
sessionId,
|
||||
lsn,
|
||||
operation,
|
||||
model,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// For now, we just force the read to the primary without checking the replica
|
||||
// Since the check probably goes to the primary anyways it will always be true,
|
||||
// Not sure how to check LSN on the actual replica that will be used for the read.
|
||||
if (
|
||||
isReadOperation(operation) &&
|
||||
sessionId &&
|
||||
(await getCachedWalLsn(sessionId))
|
||||
) {
|
||||
// This will force readReplicas extension to use primary
|
||||
__internalParams.transaction = true;
|
||||
}
|
||||
|
||||
return query(args);
|
||||
},
|
||||
},
|
||||
}),
|
||||
);
|
||||
}
|
||||
12
packages/db/src/session-context.ts
Normal file
12
packages/db/src/session-context.ts
Normal file
@@ -0,0 +1,12 @@
|
||||
import { AsyncLocalStorage } from 'node:async_hooks';
|
||||
|
||||
type Ctx = { sessionId: string | null };
|
||||
|
||||
export const als = new AsyncLocalStorage<Ctx>();
|
||||
|
||||
export const runWithAlsSession = <T>(
|
||||
sid: string | null | undefined,
|
||||
fn: () => Promise<T>,
|
||||
) => als.run({ sessionId: sid || null }, fn);
|
||||
|
||||
export const getAlsSessionId = () => als.getStore()?.sessionId ?? null;
|
||||
Reference in New Issue
Block a user