add better access control
This commit is contained in:
53
packages/trpc/src/access.ts
Normal file
53
packages/trpc/src/access.ts
Normal file
@@ -0,0 +1,53 @@
|
||||
import { clerkClient } from '@clerk/fastify';
|
||||
|
||||
import { getProjectById } from '@openpanel/db';
|
||||
import { cacheable } from '@openpanel/redis';
|
||||
|
||||
export const getProjectAccessCached = cacheable(getProjectAccess, 60 * 60);
|
||||
export async function getProjectAccess({
|
||||
userId,
|
||||
projectId,
|
||||
}: {
|
||||
userId: string;
|
||||
projectId: string;
|
||||
}) {
|
||||
try {
|
||||
// Check if user has access to the project
|
||||
const [project, organizations] = await Promise.all([
|
||||
getProjectById(projectId),
|
||||
clerkClient.users.getOrganizationMembershipList({
|
||||
userId,
|
||||
}),
|
||||
]);
|
||||
|
||||
if (!project) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return !!organizations.data.find(
|
||||
(org) => org.organization.slug === project.organizationSlug
|
||||
);
|
||||
} catch (err) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export const getOrganizationAccessCached = cacheable(
|
||||
getOrganizationAccess,
|
||||
60 * 60
|
||||
);
|
||||
export async function getOrganizationAccess({
|
||||
userId,
|
||||
organizationId,
|
||||
}: {
|
||||
userId: string;
|
||||
organizationId: string;
|
||||
}) {
|
||||
const organizations = await clerkClient.users.getOrganizationMembershipList({
|
||||
userId,
|
||||
});
|
||||
|
||||
return !!organizations.data.find(
|
||||
(org) => org.organization.id === organizationId
|
||||
);
|
||||
}
|
||||
7
packages/trpc/src/errors.ts
Normal file
7
packages/trpc/src/errors.ts
Normal file
@@ -0,0 +1,7 @@
|
||||
import { TRPCError } from '@trpc/server';
|
||||
|
||||
export const TRPCAccessError = (message: string) =>
|
||||
new TRPCError({
|
||||
code: 'UNAUTHORIZED',
|
||||
message,
|
||||
});
|
||||
@@ -3,10 +3,12 @@ import { escape } from 'sqlstring';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { average, max, min, round, slug, sum } from '@openpanel/common';
|
||||
import { chQuery, createSqlBuilder } from '@openpanel/db';
|
||||
import { chQuery, createSqlBuilder, db } from '@openpanel/db';
|
||||
import { zChartInput } from '@openpanel/validation';
|
||||
import type { IChartEvent, IChartInput } from '@openpanel/validation';
|
||||
|
||||
import { getProjectAccessCached } from '../access';
|
||||
import { TRPCAccessError } from '../errors';
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from '../trpc';
|
||||
import {
|
||||
getChartPrevStartEndDate,
|
||||
@@ -111,8 +113,7 @@ export const chartRouter = createTRPCRouter({
|
||||
)(properties);
|
||||
}),
|
||||
|
||||
// TODO: Make this private
|
||||
values: publicProcedure
|
||||
values: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
event: z.string(),
|
||||
@@ -154,7 +155,7 @@ export const chartRouter = createTRPCRouter({
|
||||
};
|
||||
}),
|
||||
|
||||
funnel: publicProcedure.input(zChartInput).query(async ({ input }) => {
|
||||
funnel: protectedProcedure.input(zChartInput).query(async ({ input }) => {
|
||||
const currentPeriod = getChartStartEndDate(input);
|
||||
const previousPeriod = getChartPrevStartEndDate({
|
||||
range: input.range,
|
||||
@@ -172,7 +173,7 @@ export const chartRouter = createTRPCRouter({
|
||||
};
|
||||
}),
|
||||
|
||||
funnelStep: publicProcedure
|
||||
funnelStep: protectedProcedure
|
||||
.input(
|
||||
zChartInput.extend({
|
||||
step: z.number(),
|
||||
@@ -183,8 +184,27 @@ export const chartRouter = createTRPCRouter({
|
||||
return getFunnelStep({ ...input, ...currentPeriod });
|
||||
}),
|
||||
|
||||
// TODO: Make this private
|
||||
chart: publicProcedure.input(zChartInput).query(async ({ input }) => {
|
||||
chart: publicProcedure.input(zChartInput).query(async ({ input, ctx }) => {
|
||||
if (ctx.session.userId) {
|
||||
const access = await getProjectAccessCached({
|
||||
projectId: input.projectId,
|
||||
userId: ctx.session.userId,
|
||||
});
|
||||
if (!access) {
|
||||
throw TRPCAccessError('You do not have access to this project');
|
||||
}
|
||||
} else {
|
||||
const share = await db.shareOverview.findFirst({
|
||||
where: {
|
||||
projectId: input.projectId,
|
||||
},
|
||||
});
|
||||
|
||||
if (!share) {
|
||||
throw TRPCAccessError('You do not have access to this project');
|
||||
}
|
||||
}
|
||||
|
||||
const currentPeriod = getChartStartEndDate(input);
|
||||
const previousPeriod = getChartPrevStartEndDate({
|
||||
range: input.range,
|
||||
|
||||
@@ -3,6 +3,8 @@ import { z } from 'zod';
|
||||
|
||||
import { chQuery, convertClickhouseDateToJs, db } from '@openpanel/db';
|
||||
|
||||
import { getProjectAccessCached } from '../access';
|
||||
import { TRPCAccessError } from '../errors';
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from '../trpc';
|
||||
|
||||
export const eventRouter = createTRPCRouter({
|
||||
@@ -37,7 +39,27 @@ export const eventRouter = createTRPCRouter({
|
||||
limit: z.number().default(8),
|
||||
})
|
||||
)
|
||||
.query(async ({ input: { projectId, cursor, limit } }) => {
|
||||
.query(async ({ input: { projectId, cursor, limit }, ctx }) => {
|
||||
if (ctx.session.userId) {
|
||||
const access = await getProjectAccessCached({
|
||||
projectId,
|
||||
userId: ctx.session.userId,
|
||||
});
|
||||
if (!access) {
|
||||
throw TRPCAccessError('You do not have access to this project');
|
||||
}
|
||||
} else {
|
||||
const share = await db.shareOverview.findFirst({
|
||||
where: {
|
||||
projectId,
|
||||
},
|
||||
});
|
||||
|
||||
if (!share) {
|
||||
throw TRPCAccessError('You do not have access to this project');
|
||||
}
|
||||
}
|
||||
|
||||
const [events, counts] = await Promise.all([
|
||||
chQuery<{
|
||||
id: string;
|
||||
|
||||
@@ -4,7 +4,7 @@ import { z } from 'zod';
|
||||
|
||||
import { chQuery, createSqlBuilder } from '@openpanel/db';
|
||||
|
||||
import { createTRPCRouter, protectedProcedure, publicProcedure } from '../trpc';
|
||||
import { createTRPCRouter, protectedProcedure } from '../trpc';
|
||||
|
||||
export const profileRouter = createTRPCRouter({
|
||||
properties: protectedProcedure
|
||||
@@ -28,7 +28,7 @@ export const profileRouter = createTRPCRouter({
|
||||
)(properties);
|
||||
}),
|
||||
|
||||
values: publicProcedure
|
||||
values: protectedProcedure
|
||||
.input(
|
||||
z.object({
|
||||
property: z.string(),
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
import { getAuth } from '@clerk/fastify';
|
||||
import { initTRPC, TRPCError } from '@trpc/server';
|
||||
import type { CreateFastifyContextOptions } from '@trpc/server/adapters/fastify';
|
||||
import { has } from 'ramda';
|
||||
import superjson from 'superjson';
|
||||
import { ZodError } from 'zod';
|
||||
|
||||
import { getProjectAccessCached } from './access';
|
||||
import { TRPCAccessError } from './errors';
|
||||
|
||||
export function createContext({ req, res }: CreateFastifyContextOptions) {
|
||||
return {
|
||||
req,
|
||||
@@ -41,10 +45,11 @@ const t = initTRPC.context<Context>().create({
|
||||
},
|
||||
});
|
||||
|
||||
const enforceUserIsAuthed = t.middleware(async ({ ctx, next }) => {
|
||||
const enforceUserIsAuthed = t.middleware(async ({ ctx, next, input }) => {
|
||||
if (!ctx.session?.userId) {
|
||||
throw new TRPCError({ code: 'UNAUTHORIZED', message: 'Not authenticated' });
|
||||
}
|
||||
|
||||
try {
|
||||
return next({
|
||||
ctx: {
|
||||
@@ -60,7 +65,25 @@ const enforceUserIsAuthed = t.middleware(async ({ ctx, next }) => {
|
||||
}
|
||||
});
|
||||
|
||||
// Only used on protected routes
|
||||
const enforceProjectAccess = t.middleware(async ({ ctx, next, rawInput }) => {
|
||||
if (has('projectId', rawInput)) {
|
||||
const access = await getProjectAccessCached({
|
||||
userId: ctx.session.userId!,
|
||||
projectId: rawInput.projectId as string,
|
||||
});
|
||||
|
||||
if (!access) {
|
||||
throw TRPCAccessError('You do not have access to this project');
|
||||
}
|
||||
}
|
||||
|
||||
return next();
|
||||
});
|
||||
|
||||
export const createTRPCRouter = t.router;
|
||||
|
||||
export const publicProcedure = t.procedure;
|
||||
export const protectedProcedure = t.procedure.use(enforceUserIsAuthed);
|
||||
export const protectedProcedure = t.procedure
|
||||
.use(enforceUserIsAuthed)
|
||||
.use(enforceProjectAccess);
|
||||
|
||||
Reference in New Issue
Block a user