diff --git a/src/McpContext.ts b/src/McpContext.ts index 2ffb1316f..cf390a1a5 100644 --- a/src/McpContext.ts +++ b/src/McpContext.ts @@ -6,6 +6,7 @@ import fs from 'node:fs/promises'; import path from 'node:path'; +import {fileURLToPath} from 'node:url'; import type {TargetUniverse} from './DevtoolsUtils.js'; import {UniverseManager} from './DevtoolsUtils.js'; @@ -18,21 +19,22 @@ import { type ListenerMap, type UncaughtError, } from './PageCollector.js'; -import type { - Browser, - BrowserContext, - ConsoleMessage, - Debugger, - HTTPRequest, - Page, - ScreenRecorder, - Viewport, - Target, - Extension, +import { + Locator, + PredefinedNetworkConditions, + type Browser, + type BrowserContext, + type ConsoleMessage, + type Debugger, + type HTTPRequest, + type Page, + type ScreenRecorder, + type Viewport, + type Target, + type Extension, + type Root, + type DevTools, } from './third_party/index.js'; -import type {DevTools} from './third_party/index.js'; -import {Locator} from './third_party/index.js'; -import {PredefinedNetworkConditions} from './third_party/index.js'; import {listPages} from './tools/pages.js'; import {CLOSE_PAGE_ERROR} from './tools/ToolDefinition.js'; import type {Context, SupportedExtensions} from './tools/ToolDefinition.js'; @@ -42,7 +44,7 @@ import type { GeolocationOptions, ExtensionServiceWorker, } from './types.js'; -import {ensureExtension, saveTemporaryFile} from './utils/files.js'; +import {ensureExtension, getTempFilePath} from './utils/files.js'; import {getNetworkMultiplierFromString} from './WaitForHelper.js'; interface McpContextOptions { @@ -90,6 +92,7 @@ export class McpContext implements Context { #locatorClass: typeof Locator; #options: McpContextOptions; #heapSnapshotManager = new HeapSnapshotManager(); + #roots: Root[] | undefined = undefined; private constructor( browser: Browser, @@ -154,6 +157,34 @@ export class McpContext implements Context { return context; } + roots(): Root[] | undefined { + return this.#roots; + } + + setRoots(roots: Root[] | undefined): void { + this.#roots = roots; + } + + validatePath(filePath: string): void { + const roots = this.roots(); + if (roots === undefined) { + return; + } + const absolutePath = path.resolve(filePath); + for (const root of roots) { + const rootPath = path.resolve(fileURLToPath(root.uri)); + if ( + absolutePath === rootPath || + absolutePath.startsWith(rootPath + path.sep) + ) { + return; + } + } + throw new Error( + `Access denied: path ${filePath} is not within any of the workspace roots ${JSON.stringify(roots)}.`, + ); + } + resolveCdpRequestId(page: McpPage, cdpRequestId: string): number | undefined { if (!cdpRequestId) { this.logger('no network request'); @@ -643,13 +674,18 @@ export class McpContext implements Context { data: Uint8Array, filename: string, ): Promise<{filepath: string}> { - return await saveTemporaryFile(data, filename); + const filepath = await getTempFilePath(filename); + this.validatePath(filepath); + await fs.writeFile(filepath, data); + return {filepath}; } + async saveFile( data: Uint8Array, clientProvidedFilePath: string, extension: SupportedExtensions, ): Promise<{filename: string}> { + this.validatePath(clientProvidedFilePath); try { const filePath = ensureExtension( path.resolve(clientProvidedFilePath), @@ -721,6 +757,7 @@ export class McpContext implements Context { } async installExtension(extensionPath: string): Promise { + this.validatePath(extensionPath); const id = await this.browser.installExtension(extensionPath); return id; } @@ -751,18 +788,21 @@ export class McpContext implements Context { async getHeapSnapshotAggregates( filePath: string, ): Promise> { + this.validatePath(filePath); return await this.#heapSnapshotManager.getAggregates(filePath); } async getHeapSnapshotStats( filePath: string, ): Promise { + this.validatePath(filePath); return await this.#heapSnapshotManager.getStats(filePath); } async getHeapSnapshotStaticData( filePath: string, ): Promise { + this.validatePath(filePath); return await this.#heapSnapshotManager.getStaticData(filePath); } @@ -770,6 +810,7 @@ export class McpContext implements Context { filePath: string, uid: number, ): Promise { + this.validatePath(filePath); return await this.#heapSnapshotManager.getNodesByUid(filePath, uid); } } diff --git a/src/daemon/client.ts b/src/daemon/client.ts index 47862f8e5..61d921b18 100644 --- a/src/daemon/client.ts +++ b/src/daemon/client.ts @@ -11,7 +11,7 @@ import net from 'node:net'; import {logger} from '../logger.js'; import type {CallToolResult} from '../third_party/index.js'; import {PipeTransport} from '../third_party/index.js'; -import {saveTemporaryFile} from '../utils/files.js'; +import {getTempFilePath} from '../utils/files.js'; import type {DaemonMessage, DaemonResponse} from './types.js'; import { @@ -179,7 +179,8 @@ export async function handleResponse( } const data = Buffer.from(imageData, 'base64'); const name = crypto.randomUUID(); - const {filepath} = await saveTemporaryFile(data, `${name}${extension}`); + const filepath = await getTempFilePath(`${name}${extension}`); + fs.writeFileSync(filepath, data); chunks.push(`Saved to ${filepath}.`); } else { throw new Error('Not supported response content type'); diff --git a/src/index.ts b/src/index.ts index 24b7d42a0..9f522f63b 100644 --- a/src/index.ts +++ b/src/index.ts @@ -21,6 +21,8 @@ import { McpServer, type CallToolResult, SetLevelRequestSchema, + ListRootsResultSchema, + RootsListChangedNotificationSchema, } from './third_party/index.js'; import {ToolCategory} from './tools/categories.js'; import type {DefinedPageTool, ToolDefinition} from './tools/ToolDefinition.js'; @@ -57,11 +59,35 @@ export async function createMcpServer( return {}; }); + const updateRoots = async () => { + if (!server.server.getClientCapabilities()?.roots) { + return; + } + try { + const roots = await server.server.request( + {method: 'roots/list'}, + ListRootsResultSchema, + ); + context?.setRoots(roots.roots); + } catch (e) { + logger('Failed to list roots', e); + } + }; + server.server.oninitialized = () => { const clientName = server.server.getClientVersion()?.name; if (clientName) { clearcutLogger?.setClientName(clientName); } + if (server.server.getClientCapabilities()?.roots) { + void updateRoots(); + server.server.setNotificationHandler( + RootsListChangedNotificationSchema, + () => { + void updateRoots(); + }, + ); + } }; let context: McpContext; @@ -109,6 +135,7 @@ export async function createMcpServer( experimentalIncludeAllPages: serverArgs.experimentalIncludeAllPages, performanceCrux: serverArgs.performanceCrux, }); + await updateRoots(); } return context; } diff --git a/src/third_party/index.ts b/src/third_party/index.ts index caf183248..7cf50217c 100644 --- a/src/third_party/index.ts +++ b/src/third_party/index.ts @@ -30,6 +30,10 @@ export { SetLevelRequestSchema, type ImageContent, type TextContent, + type Root, + ListRootsRequestSchema, + RootsListChangedNotificationSchema, + ListRootsResultSchema, } from '@modelcontextprotocol/sdk/types.js'; export {z as zod} from 'zod'; export {default as ajv} from 'ajv'; diff --git a/src/tools/ToolDefinition.ts b/src/tools/ToolDefinition.ts index 4ff48ecf1..7334de3e8 100644 --- a/src/tools/ToolDefinition.ts +++ b/src/tools/ToolDefinition.ts @@ -167,9 +167,10 @@ export type SupportedExtensions = | '.json.gz'; /** - * Only add methods required by tools/*. + * Only add methods used by tools/*. */ export type Context = Readonly<{ + validatePath(filePath: string): void; isRunningPerformanceTrace(): boolean; setIsRunningPerformanceTrace(x: boolean): void; isCruxEnabled(): boolean; @@ -244,6 +245,9 @@ export type Context = Readonly<{ ): Promise; }>; +/** + * Only add methods used by tools/*. + */ export type ContextPage = Readonly<{ readonly pptrPage: Page; getAXNodeByUid(uid: string): TextSnapshotNode | undefined; diff --git a/src/tools/input.ts b/src/tools/input.ts index bce93f73b..588f5308b 100644 --- a/src/tools/input.ts +++ b/src/tools/input.ts @@ -365,8 +365,9 @@ export const uploadFile = definePageTool({ filePath: zod.string().describe('The local path of the file to upload'), includeSnapshot: includeSnapshotSchema, }, - handler: async (request, response) => { + handler: async (request, response, context) => { const {uid, filePath} = request.params; + context.validatePath(filePath); const handle = (await request.page.getElementByUid( uid, )) as ElementHandle; diff --git a/src/tools/lighthouse.ts b/src/tools/lighthouse.ts index 4356f1e62..f92f342eb 100644 --- a/src/tools/lighthouse.ts +++ b/src/tools/lighthouse.ts @@ -53,6 +53,10 @@ export const lighthouseAudit = definePageTool({ outputDirPath, } = request.params; + if (outputDirPath) { + context.validatePath(outputDirPath); + } + const flags: Flags = { onlyCategories: categories, output: formats, diff --git a/src/tools/memory.ts b/src/tools/memory.ts index db61dfb7b..dcc92dad2 100644 --- a/src/tools/memory.ts +++ b/src/tools/memory.ts @@ -22,8 +22,9 @@ export const takeMemorySnapshot = definePageTool({ .string() .describe('A path to a .heapsnapshot file to save the heapsnapshot to.'), }, - handler: async (request, response, _context) => { + handler: async (request, response, context) => { const page = request.page; + context.validatePath(request.params.filePath); await page.pptrPage.captureHeapSnapshot({ path: ensureExtension(request.params.filePath, '.heapsnapshot'), @@ -48,6 +49,7 @@ export const exploreMemorySnapshot = defineTool({ filePath: zod.string().describe('A path to a .heapsnapshot file to read.'), }, handler: async (request, response, context) => { + context.validatePath(request.params.filePath); const stats = await context.getHeapSnapshotStats(request.params.filePath); const staticData = await context.getHeapSnapshotStaticData( request.params.filePath, @@ -78,6 +80,7 @@ export const getMemorySnapshotDetails = defineTool({ .describe('The page size for pagination of aggregates.'), }, handler: async (request, response, context) => { + context.validatePath(request.params.filePath); const aggregates = await context.getHeapSnapshotAggregates( request.params.filePath, ); @@ -109,6 +112,7 @@ export const getNodesByClass = defineTool({ pageSize: zod.number().optional().describe('The page size for pagination.'), }, handler: async (request, response, context) => { + context.validatePath(request.params.filePath); const nodes = await context.getHeapSnapshotNodesByUid( request.params.filePath, request.params.uid, diff --git a/src/tools/network.ts b/src/tools/network.ts index 181a2ed26..05afea232 100644 --- a/src/tools/network.ts +++ b/src/tools/network.ts @@ -114,6 +114,12 @@ export const getNetworkRequest = definePageTool({ ), }, handler: async (request, response, context) => { + if (request.params.requestFilePath) { + context.validatePath(request.params.requestFilePath); + } + if (request.params.responseFilePath) { + context.validatePath(request.params.responseFilePath); + } if (request.params.reqid) { response.attachNetworkRequest(request.params.reqid, { requestFilePath: request.params.requestFilePath, diff --git a/src/tools/performance.ts b/src/tools/performance.ts index acc588655..23f1f2566 100644 --- a/src/tools/performance.ts +++ b/src/tools/performance.ts @@ -49,6 +49,9 @@ export const startTrace = definePageTool({ filePath: filePathSchema, }, handler: async (request, response, context) => { + if (request.params.filePath) { + context.validatePath(request.params.filePath); + } if (context.isRunningPerformanceTrace()) { response.appendResponseLine( 'Error: a performance trace is already running. Use performance_stop_trace to stop it. Only one trace can be running at any given time.', @@ -126,6 +129,9 @@ export const stopTrace = definePageTool({ filePath: filePathSchema, }, handler: async (request, response, context) => { + if (request.params.filePath) { + context.validatePath(request.params.filePath); + } if (!context.isRunningPerformanceTrace()) { return; } diff --git a/src/tools/screencast.ts b/src/tools/screencast.ts index ad7130455..ed548518c 100644 --- a/src/tools/screencast.ts +++ b/src/tools/screencast.ts @@ -40,6 +40,9 @@ export const startScreencast = definePageTool(args => ({ ), }, handler: async (request, response, context) => { + if (request.params.filePath) { + context.validatePath(request.params.filePath); + } if (context.getScreenRecorder() !== null) { response.appendResponseLine( 'Error: a screencast recording is already in progress. Use screencast_stop to stop it before starting a new one.', diff --git a/src/tools/screenshot.ts b/src/tools/screenshot.ts index f740fda4e..2b80021f1 100644 --- a/src/tools/screenshot.ts +++ b/src/tools/screenshot.ts @@ -51,6 +51,9 @@ export const screenshot = definePageTool({ ), }, handler: async (request, response, context) => { + if (request.params.filePath) { + context.validatePath(request.params.filePath); + } if (request.params.uid && request.params.fullPage) { throw new Error('Providing both "uid" and "fullPage" is not allowed.'); } diff --git a/src/tools/snapshot.ts b/src/tools/snapshot.ts index 338bd6794..105ae2d62 100644 --- a/src/tools/snapshot.ts +++ b/src/tools/snapshot.ts @@ -33,7 +33,10 @@ in the DevTools Elements panel (if any).`, 'The absolute path, or a path relative to the current working directory, to save the snapshot to instead of attaching it to the response.', ), }, - handler: async (request, response) => { + handler: async (request, response, context) => { + if (request.params.filePath) { + context.validatePath(request.params.filePath); + } response.includeSnapshot({ verbose: request.params.verbose ?? false, filePath: request.params.filePath, diff --git a/src/utils/files.ts b/src/utils/files.ts index abdba3ed8..00083dfe1 100644 --- a/src/utils/files.ts +++ b/src/utils/files.ts @@ -8,21 +8,11 @@ import fs from 'node:fs/promises'; import os from 'node:os'; import path from 'node:path'; -export async function saveTemporaryFile( - data: Uint8Array, - filename: string, -): Promise<{filepath: string}> { - try { - const dir = await fs.mkdtemp( - path.join(os.tmpdir(), 'chrome-devtools-mcp-'), - ); +export async function getTempFilePath(filename: string) { + const dir = await fs.mkdtemp(path.join(os.tmpdir(), 'chrome-devtools-mcp-')); - const filepath = path.join(dir, filename); - await fs.writeFile(filepath, data); - return {filepath}; - } catch (err) { - throw new Error('Could not save a file', {cause: err}); - } + const filepath = path.join(dir, filename); + return filepath; } export function ensureExtension( diff --git a/tests/McpContext.test.ts b/tests/McpContext.test.ts index a33dc71e5..12d2c640b 100644 --- a/tests/McpContext.test.ts +++ b/tests/McpContext.test.ts @@ -5,7 +5,9 @@ */ import assert from 'node:assert'; +import path from 'node:path'; import {afterEach, describe, it} from 'node:test'; +import {pathToFileURL} from 'node:url'; import sinon from 'sinon'; @@ -213,4 +215,46 @@ describe('McpContext', () => { fromStub.restore(); }); }); + + it('can store and retrieve roots', async () => { + await withMcpContext(async (_response, context) => { + const roots = [{uri: 'file:///test', name: 'test'}]; + context.setRoots(roots); + assert.deepEqual(context.roots(), roots); + }); + }); + + it('validatePath allows paths within roots', async () => { + await withMcpContext(async (_response, context) => { + const workspacePath = path.resolve('/tmp/workspace'); + const roots = [ + {uri: pathToFileURL(workspacePath).href, name: 'workspace'}, + ]; + context.setRoots(roots); + // Valid path within root + context.validatePath(path.join(workspacePath, 'test.txt')); + context.validatePath(workspacePath); + + // Invalid path outside root + const outsidePath = path.resolve('/tmp/outside.txt'); + assert.throws(() => context.validatePath(outsidePath), /Access denied/); + }); + }); + + it('validatePath allows all paths if roots are undefined (legacy)', async () => { + await withMcpContext(async (_response, context) => { + context.setRoots(undefined); + context.validatePath(path.resolve('/tmp/anywhere.txt')); + }); + }); + + it('validatePath denies all paths if roots list is empty', async () => { + await withMcpContext(async (_response, context) => { + context.setRoots([]); + assert.throws( + () => context.validatePath(path.resolve('/tmp/anywhere.txt')), + /Access denied/, + ); + }); + }); }); diff --git a/tests/index.test.ts b/tests/index.test.ts index a1c7765bb..60c4c83d0 100644 --- a/tests/index.test.ts +++ b/tests/index.test.ts @@ -10,6 +10,12 @@ import {describe, it} from 'node:test'; import {Client} from '@modelcontextprotocol/sdk/client/index.js'; import {StdioClientTransport} from '@modelcontextprotocol/sdk/client/stdio.js'; +import { + ListRootsRequestSchema, + RootsListChangedNotificationSchema, + type ClientCapabilities, + type TextContent, +} from '@modelcontextprotocol/sdk/types.js'; import {executablePath} from 'puppeteer'; import type {ToolCategory} from '../src/tools/categories.js'; @@ -20,6 +26,7 @@ describe('e2e', () => { async function withClient( cb: (client: Client) => Promise, extraArgs: string[] = [], + options: {capabilities?: ClientCapabilities} = {}, ) { const transport = new StdioClientTransport({ command: 'node', @@ -38,7 +45,7 @@ describe('e2e', () => { version: '1.0.0', }, { - capabilities: {}, + capabilities: options.capabilities ?? {}, }, ); @@ -160,6 +167,84 @@ describe('e2e', () => { ['--experimental-webmcp'], ); }); + + it('updates roots when client notifies', async () => { + const roots = [{uri: 'file:///test-root', name: 'test-root'}]; + let resolvePromise: () => void; + const promise = new Promise(resolve => { + resolvePromise = resolve; + }); + + await withClient( + async client => { + client.setRequestHandler(ListRootsRequestSchema, () => { + resolvePromise(); + return {roots}; + }); + + await client.notification({ + method: RootsListChangedNotificationSchema.shape.method.value, + }); + + // Wait for the server to process the notification and request roots + await promise; + }, + [], + { + capabilities: { + roots: {listChanged: true}, + }, + }, + ); + }); + + it('denies file access if roots list is empty', async () => { + await withClient( + async client => { + client.setRequestHandler(ListRootsRequestSchema, () => { + return {roots: []}; + }); + + const result = await client.callTool({ + name: 'take_screenshot', + arguments: { + filePath: '/tmp/test.png', + }, + }); + + assert.strictEqual(result.isError, true); + const content = result.content as TextContent[]; + assert.match(content[0].text, /Access denied/); + }, + [], + { + capabilities: { + roots: {listChanged: true}, + }, + }, + ); + }); + + it('allows file access if roots capability is missing', async () => { + await withClient( + async client => { + const result = await client.callTool({ + name: 'take_screenshot', + arguments: { + filePath: '/tmp/test.png', + }, + }); + + assert.strictEqual(result.isError, undefined); + const content = result.content as TextContent[]; + assert.match(content[0].text, /Saved screenshot to/); + }, + [], + { + capabilities: {}, + }, + ); + }); }); async function getToolsWithFilteredCategories(