Skip to content

Commit d910a7a

Browse files
committed
feat: support client roots feature
1 parent dbddb2e commit d910a7a

16 files changed

Lines changed: 263 additions & 37 deletions

src/McpContext.ts

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import fs from 'node:fs/promises';
88
import path from 'node:path';
9+
import {fileURLToPath} from 'node:url';
910

1011
import type {TargetUniverse} from './DevtoolsUtils.js';
1112
import {UniverseManager} from './DevtoolsUtils.js';
@@ -18,21 +19,22 @@ import {
1819
type ListenerMap,
1920
type UncaughtError,
2021
} from './PageCollector.js';
21-
import type {
22-
Browser,
23-
BrowserContext,
24-
ConsoleMessage,
25-
Debugger,
26-
HTTPRequest,
27-
Page,
28-
ScreenRecorder,
29-
Viewport,
30-
Target,
31-
Extension,
22+
import {
23+
Locator,
24+
PredefinedNetworkConditions,
25+
type Browser,
26+
type BrowserContext,
27+
type ConsoleMessage,
28+
type Debugger,
29+
type HTTPRequest,
30+
type Page,
31+
type ScreenRecorder,
32+
type Viewport,
33+
type Target,
34+
type Extension,
35+
type Root,
36+
type DevTools,
3237
} from './third_party/index.js';
33-
import type {DevTools} from './third_party/index.js';
34-
import {Locator} from './third_party/index.js';
35-
import {PredefinedNetworkConditions} from './third_party/index.js';
3638
import {listPages} from './tools/pages.js';
3739
import {CLOSE_PAGE_ERROR} from './tools/ToolDefinition.js';
3840
import type {Context, SupportedExtensions} from './tools/ToolDefinition.js';
@@ -42,7 +44,7 @@ import type {
4244
GeolocationOptions,
4345
ExtensionServiceWorker,
4446
} from './types.js';
45-
import {ensureExtension, saveTemporaryFile} from './utils/files.js';
47+
import {ensureExtension, getTempFilePath} from './utils/files.js';
4648
import {getNetworkMultiplierFromString} from './WaitForHelper.js';
4749

4850
interface McpContextOptions {
@@ -90,6 +92,7 @@ export class McpContext implements Context {
9092
#locatorClass: typeof Locator;
9193
#options: McpContextOptions;
9294
#heapSnapshotManager = new HeapSnapshotManager();
95+
#roots: Root[] | undefined = undefined;
9396

9497
private constructor(
9598
browser: Browser,
@@ -154,6 +157,34 @@ export class McpContext implements Context {
154157
return context;
155158
}
156159

160+
roots(): Root[] | undefined {
161+
return this.#roots;
162+
}
163+
164+
setRoots(roots: Root[] | undefined): void {
165+
this.#roots = roots;
166+
}
167+
168+
validatePath(filePath: string): void {
169+
const roots = this.roots();
170+
if (roots === undefined) {
171+
return;
172+
}
173+
const absolutePath = path.resolve(filePath);
174+
for (const root of roots) {
175+
const rootPath = path.resolve(fileURLToPath(root.uri));
176+
if (
177+
absolutePath === rootPath ||
178+
absolutePath.startsWith(rootPath + path.sep)
179+
) {
180+
return;
181+
}
182+
}
183+
throw new Error(
184+
`Access denied: path ${filePath} is not within any of the workspace roots ${JSON.stringify(roots)}.`,
185+
);
186+
}
187+
157188
resolveCdpRequestId(page: McpPage, cdpRequestId: string): number | undefined {
158189
if (!cdpRequestId) {
159190
this.logger('no network request');
@@ -643,13 +674,18 @@ export class McpContext implements Context {
643674
data: Uint8Array<ArrayBufferLike>,
644675
filename: string,
645676
): Promise<{filepath: string}> {
646-
return await saveTemporaryFile(data, filename);
677+
const filepath = await getTempFilePath(filename);
678+
this.validatePath(filepath);
679+
await fs.writeFile(filepath, data);
680+
return {filepath};
647681
}
682+
648683
async saveFile(
649684
data: Uint8Array<ArrayBufferLike>,
650685
clientProvidedFilePath: string,
651686
extension: SupportedExtensions,
652687
): Promise<{filename: string}> {
688+
this.validatePath(clientProvidedFilePath);
653689
try {
654690
const filePath = ensureExtension(
655691
path.resolve(clientProvidedFilePath),
@@ -721,6 +757,7 @@ export class McpContext implements Context {
721757
}
722758

723759
async installExtension(extensionPath: string): Promise<string> {
760+
this.validatePath(extensionPath);
724761
const id = await this.browser.installExtension(extensionPath);
725762
return id;
726763
}
@@ -751,25 +788,29 @@ export class McpContext implements Context {
751788
async getHeapSnapshotAggregates(
752789
filePath: string,
753790
): Promise<Record<string, AggregatedInfoWithUid>> {
791+
this.validatePath(filePath);
754792
return await this.#heapSnapshotManager.getAggregates(filePath);
755793
}
756794

757795
async getHeapSnapshotStats(
758796
filePath: string,
759797
): Promise<DevTools.HeapSnapshotModel.HeapSnapshotModel.Statistics> {
798+
this.validatePath(filePath);
760799
return await this.#heapSnapshotManager.getStats(filePath);
761800
}
762801

763802
async getHeapSnapshotStaticData(
764803
filePath: string,
765804
): Promise<DevTools.HeapSnapshotModel.HeapSnapshotModel.StaticData | null> {
805+
this.validatePath(filePath);
766806
return await this.#heapSnapshotManager.getStaticData(filePath);
767807
}
768808

769809
async getHeapSnapshotNodesByUid(
770810
filePath: string,
771811
uid: number,
772812
): Promise<DevTools.HeapSnapshotModel.HeapSnapshotModel.ItemsRange> {
813+
this.validatePath(filePath);
773814
return await this.#heapSnapshotManager.getNodesByUid(filePath, uid);
774815
}
775816
}

src/daemon/client.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import net from 'node:net';
1111
import {logger} from '../logger.js';
1212
import type {CallToolResult} from '../third_party/index.js';
1313
import {PipeTransport} from '../third_party/index.js';
14-
import {saveTemporaryFile} from '../utils/files.js';
14+
import {getTempFilePath} from '../utils/files.js';
1515

1616
import type {DaemonMessage, DaemonResponse} from './types.js';
1717
import {
@@ -179,7 +179,8 @@ export async function handleResponse(
179179
}
180180
const data = Buffer.from(imageData, 'base64');
181181
const name = crypto.randomUUID();
182-
const {filepath} = await saveTemporaryFile(data, `${name}${extension}`);
182+
const filepath = await getTempFilePath(`${name}${extension}`);
183+
fs.writeFileSync(filepath, data);
183184
chunks.push(`Saved to ${filepath}.`);
184185
} else {
185186
throw new Error('Not supported response content type');

src/index.ts

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import {
2121
McpServer,
2222
type CallToolResult,
2323
SetLevelRequestSchema,
24+
ListRootsResultSchema,
25+
RootsListChangedNotificationSchema,
2426
} from './third_party/index.js';
2527
import {ToolCategory} from './tools/categories.js';
2628
import type {DefinedPageTool, ToolDefinition} from './tools/ToolDefinition.js';
@@ -57,11 +59,35 @@ export async function createMcpServer(
5759
return {};
5860
});
5961

62+
const updateRoots = async () => {
63+
if (!server.server.getClientCapabilities()?.roots) {
64+
return;
65+
}
66+
try {
67+
const roots = await server.server.request(
68+
{method: 'roots/list'},
69+
ListRootsResultSchema,
70+
);
71+
context?.setRoots(roots.roots);
72+
} catch (e) {
73+
logger('Failed to list roots', e);
74+
}
75+
};
76+
6077
server.server.oninitialized = () => {
6178
const clientName = server.server.getClientVersion()?.name;
6279
if (clientName) {
6380
clearcutLogger?.setClientName(clientName);
6481
}
82+
if (server.server.getClientCapabilities()?.roots) {
83+
void updateRoots();
84+
server.server.setNotificationHandler(
85+
RootsListChangedNotificationSchema,
86+
() => {
87+
void updateRoots();
88+
},
89+
);
90+
}
6591
};
6692

6793
let context: McpContext;
@@ -109,6 +135,7 @@ export async function createMcpServer(
109135
experimentalIncludeAllPages: serverArgs.experimentalIncludeAllPages,
110136
performanceCrux: serverArgs.performanceCrux,
111137
});
138+
await updateRoots();
112139
}
113140
return context;
114141
}

src/third_party/index.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ export {
3030
SetLevelRequestSchema,
3131
type ImageContent,
3232
type TextContent,
33+
type Root,
34+
ListRootsRequestSchema,
35+
RootsListChangedNotificationSchema,
36+
ListRootsResultSchema,
3337
} from '@modelcontextprotocol/sdk/types.js';
3438
export {z as zod} from 'zod';
3539
export {default as ajv} from 'ajv';

src/tools/ToolDefinition.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,10 @@ export type SupportedExtensions =
167167
| '.json.gz';
168168

169169
/**
170-
* Only add methods required by tools/*.
170+
* Only add methods used by tools/*.
171171
*/
172172
export type Context = Readonly<{
173+
validatePath(filePath: string): void;
173174
isRunningPerformanceTrace(): boolean;
174175
setIsRunningPerformanceTrace(x: boolean): void;
175176
isCruxEnabled(): boolean;
@@ -244,6 +245,9 @@ export type Context = Readonly<{
244245
): Promise<DevTools.HeapSnapshotModel.HeapSnapshotModel.ItemsRange>;
245246
}>;
246247

248+
/**
249+
* Only add methods used by tools/*.
250+
*/
247251
export type ContextPage = Readonly<{
248252
readonly pptrPage: Page;
249253
getAXNodeByUid(uid: string): TextSnapshotNode | undefined;

src/tools/input.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,8 +365,9 @@ export const uploadFile = definePageTool({
365365
filePath: zod.string().describe('The local path of the file to upload'),
366366
includeSnapshot: includeSnapshotSchema,
367367
},
368-
handler: async (request, response) => {
368+
handler: async (request, response, context) => {
369369
const {uid, filePath} = request.params;
370+
context.validatePath(filePath);
370371
const handle = (await request.page.getElementByUid(
371372
uid,
372373
)) as ElementHandle<HTMLInputElement>;

src/tools/lighthouse.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ export const lighthouseAudit = definePageTool({
5353
outputDirPath,
5454
} = request.params;
5555

56+
if (outputDirPath) {
57+
context.validatePath(outputDirPath);
58+
}
59+
5660
const flags: Flags = {
5761
onlyCategories: categories,
5862
output: formats,

src/tools/memory.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ export const takeMemorySnapshot = definePageTool({
2222
.string()
2323
.describe('A path to a .heapsnapshot file to save the heapsnapshot to.'),
2424
},
25-
handler: async (request, response, _context) => {
25+
handler: async (request, response, context) => {
2626
const page = request.page;
27+
context.validatePath(request.params.filePath);
2728

2829
await page.pptrPage.captureHeapSnapshot({
2930
path: ensureExtension(request.params.filePath, '.heapsnapshot'),
@@ -48,6 +49,7 @@ export const exploreMemorySnapshot = defineTool({
4849
filePath: zod.string().describe('A path to a .heapsnapshot file to read.'),
4950
},
5051
handler: async (request, response, context) => {
52+
context.validatePath(request.params.filePath);
5153
const stats = await context.getHeapSnapshotStats(request.params.filePath);
5254
const staticData = await context.getHeapSnapshotStaticData(
5355
request.params.filePath,
@@ -78,6 +80,7 @@ export const getMemorySnapshotDetails = defineTool({
7880
.describe('The page size for pagination of aggregates.'),
7981
},
8082
handler: async (request, response, context) => {
83+
context.validatePath(request.params.filePath);
8184
const aggregates = await context.getHeapSnapshotAggregates(
8285
request.params.filePath,
8386
);
@@ -109,6 +112,7 @@ export const getNodesByClass = defineTool({
109112
pageSize: zod.number().optional().describe('The page size for pagination.'),
110113
},
111114
handler: async (request, response, context) => {
115+
context.validatePath(request.params.filePath);
112116
const nodes = await context.getHeapSnapshotNodesByUid(
113117
request.params.filePath,
114118
request.params.uid,

src/tools/network.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,12 @@ export const getNetworkRequest = definePageTool({
114114
),
115115
},
116116
handler: async (request, response, context) => {
117+
if (request.params.requestFilePath) {
118+
context.validatePath(request.params.requestFilePath);
119+
}
120+
if (request.params.responseFilePath) {
121+
context.validatePath(request.params.responseFilePath);
122+
}
117123
if (request.params.reqid) {
118124
response.attachNetworkRequest(request.params.reqid, {
119125
requestFilePath: request.params.requestFilePath,

src/tools/performance.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ export const startTrace = definePageTool({
4949
filePath: filePathSchema,
5050
},
5151
handler: async (request, response, context) => {
52+
if (request.params.filePath) {
53+
context.validatePath(request.params.filePath);
54+
}
5255
if (context.isRunningPerformanceTrace()) {
5356
response.appendResponseLine(
5457
'Error: a performance trace is already running. Use performance_stop_trace to stop it. Only one trace can be running at any given time.',
@@ -126,6 +129,9 @@ export const stopTrace = definePageTool({
126129
filePath: filePathSchema,
127130
},
128131
handler: async (request, response, context) => {
132+
if (request.params.filePath) {
133+
context.validatePath(request.params.filePath);
134+
}
129135
if (!context.isRunningPerformanceTrace()) {
130136
return;
131137
}

0 commit comments

Comments
 (0)