diff --git a/README.md b/README.md index d68f34fe..fff48fbb 100644 --- a/README.md +++ b/README.md @@ -15,9 +15,9 @@ this extension exposes Colab servers directly in VS Code! ## Quick Start 1. Install [VS Code](https://code.visualstudio.com). -1. Install the [Colab - extension](https://marketplace.visualstudio.com/items?itemName=google.colab) - (and Jupyter if not already installed). +1. Install the Colab extension from either the [Visual Studio + Marketplace](https://marketplace.visualstudio.com/items?itemName=google.colab) + or [Open VSX](https://open-vsx.org/extension/Google/colab). 1. Open or create a notebook file. 1. When prompted, sign in. 1. Click `Select Kernel` > `Colab` > `New Colab Server`. diff --git a/docs/assets/hello-world.gif b/docs/assets/hello-world.gif index d1d1c748..0d4e2f94 100644 Binary files a/docs/assets/hello-world.gif and b/docs/assets/hello-world.gif differ diff --git a/package-lock.json b/package-lock.json index 95ad3808..113ab3b3 100644 --- a/package-lock.json +++ b/package-lock.json @@ -14,6 +14,7 @@ "semver": "^7.7.3", "uuid": "^11.0.3", "vscode-languageclient": "^10.0.0-next.17", + "ws": "^8.18.3", "zod": "^4.0.17" }, "devDependencies": { @@ -30,6 +31,7 @@ "@types/sinon": "^17.0.3", "@types/uuid": "^10.0.0", "@types/vscode": "^1.100.0", + "@types/ws": "^8.18.1", "@vscode/jupyter-extension": "1.0.93", "@vscode/test-cli": "^0.0.10", "@vscode/test-electron": "^2.4.1", @@ -11877,7 +11879,6 @@ "version": "8.18.3", "resolved": "https://registry.npmjs.org/ws/-/ws-8.18.3.tgz", "integrity": "sha512-PEIGCY5tSlUt50cqyMXfCzX+oOPqN0vuGqWzbcJ2xvnkzkq46oOpz7dQaTDBdfICb4N14+GARUDw2XV2N4tvzg==", - "dev": true, "license": "MIT", "engines": { "node": ">=10.0.0" diff --git a/package.json b/package.json index c0c90f29..f4928ea2 100644 --- a/package.json +++ b/package.json @@ -43,6 +43,11 @@ "configuration": { "title": "Colab", "properties": { + "colab.codeDiagnostics": { + "type": "boolean", + "default": false, + "description": "Enable code diagnostics powered by a connected Colab server." + }, "colab.logging.level": { "type": "string", "default": "info", @@ -153,6 +158,7 @@ "@types/sinon": "^17.0.3", "@types/uuid": "^10.0.0", "@types/vscode": "^1.99.3", + "@types/ws": "^8.18.1", "@vscode/jupyter-extension": "1.0.93", "@vscode/test-cli": "^0.0.10", "@vscode/test-electron": "^2.4.1", @@ -187,6 +193,7 @@ "semver": "^7.7.3", "uuid": "^11.0.3", "vscode-languageclient": "^10.0.0-next.17", + "ws": "^8.18.3", "zod": "^4.0.17" } } diff --git a/src/lsp/language-client.ts b/src/lsp/language-client.ts new file mode 100644 index 00000000..f5b0b88a --- /dev/null +++ b/src/lsp/language-client.ts @@ -0,0 +1,201 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import vscode, { Disposable } from "vscode"; +import type { + LanguageClientOptions, + ServerOptions, + LanguageClient, +} from "vscode-languageclient/node"; +import { ClientOptions, WebSocket, createWebSocketStream } from "ws"; +import { log } from "../common/logging"; +import { AsyncToggleable } from "../common/toggleable"; +import { AssignmentManager } from "../jupyter/assignments"; +import { ColabAssignedServer } from "../jupyter/servers"; +import { ContentLengthTransformer } from "./content-length-transformer"; +import { getMiddleware } from "./middleware"; + +type VSLanguageClientFactory = ( + id: string, + name: string, + serverOptions: ServerOptions, + clientOptions: LanguageClientOptions, +) => LanguageClient; + +/** + * Manages the lifecycle of a LanguageClient connected to the latest assigned + * Colab server. + */ +export class LanguageClientController extends AsyncToggleable { + private client: ColabLanguageClient | undefined; + private latestServerEndpoint: string; + private abortController = new AbortController(); + + constructor( + private vs: typeof vscode, + private readonly assignments: AssignmentManager, + private readonly vsLanguageClientFactory: VSLanguageClientFactory, + ) { + super(); + } + + override async initialize(signal: AbortSignal): Promise { + // signal will be aborted when the Toggleable is turned off. + signal.onabort = (e) => { + this.abortController.abort(e); + }; + const listenDispose = this.assignments.onDidAssignmentsChange(async (e) => { + if ( + e.added.length || + e.removed.some((s) => { + return s.server.endpoint === this.latestServerEndpoint; + }) + ) { + // Abort any in-flight work from the last call. + this.abortController.abort(); + await this.tearDownClient("Server removed"); + } else { + // Don't care about updated server lists, or servers being + // removed that we weren't connected to. + return; + } + this.abortController = new AbortController(); + await this.connectToLatest(this.abortController.signal); + }); + await this.connectToLatest(this.abortController.signal); + return { + dispose: () => { + listenDispose.dispose(); + void this.tearDownClient("Toggled off"); + }, + }; + } + + private async connectToLatest(signal?: AbortSignal): Promise { + const latestServer = await this.assignments.latestServer(signal); + if (!latestServer) { + await this.tearDownClient("No assigned servers"); + return; + } + // Don't make a new client if the latest runtime has not changed. + if (latestServer.endpoint === this.latestServerEndpoint) { + return; + } + await this.tearDownClient("Newer runtime found"); + this.latestServerEndpoint = latestServer.endpoint; + if (signal?.aborted) { + return; + } + this.latestServerEndpoint = latestServer.endpoint; + this.client = new ColabLanguageClient( + this.vsLanguageClientFactory, + latestServer, + this.vs, + ); + await this.client.start(); + return; + } + + private async tearDownClient(reason: string) { + if (!this.client) { + return; + } + log.info( + `Tearing down LanguageClient for endpoint ${this.latestServerEndpoint}: ${reason}`, + ); + await this.client.dispose(); + this.client = undefined; + this.latestServerEndpoint = ""; + } +} + +class ColabLanguageClient implements Disposable { + private languageClient: LanguageClient; + + constructor( + private readonly createVSLanguageClient: VSLanguageClientFactory, + server: ColabAssignedServer, + private vs: typeof vscode, + ) { + this.languageClient = this.buildVSLanguageClient(server); + } + + async start(): Promise { + if (!this.languageClient.needsStart()) { + return; + } + + await this.languageClient.start(); + } + + async dispose(): Promise { + await this.languageClient.dispose(); + } + + private buildVSLanguageClient(server: ColabAssignedServer): LanguageClient { + const runtimeProxyInfo = server.connectionInformation; + const url = new URL(runtimeProxyInfo.baseUrl.toString()); + const isLocalhost = + url.hostname === "localhost" || url.hostname === "127.0.0.1"; + url.protocol = isLocalhost ? "ws" : "wss"; + url.pathname = "/colab/lsp"; + url.search = `?colab-runtime-proxy-token=${runtimeProxyInfo.token}`; + + log.info( + `Setting up Colab Language Client for endpoint ${server.endpoint}`, + ); + + const socketOptions: ClientOptions = { + rejectUnauthorized: isLocalhost ? false : true, + }; + + const socket = new WebSocket(url.toString(), socketOptions); + socket.binaryType = "arraybuffer"; + const vs = this.vs; + const serverOptions: ServerOptions = async () => { + return new Promise((resolve, reject) => { + socket.onopen = () => { + log.debug("Language server socket opened."); + const stream = createWebSocketStream(socket); + const reader = stream.pipe(new ContentLengthTransformer()); + // The LanguageClient handles framing for outgoing messages. + const writer = stream; + + stream.on("error", (err) => { + log.error("Language server stream error:", err); + }); + stream.on("close", () => { + log.debug("Language server stream closed."); + }); + resolve({ + writer, + reader, + }); + }; + socket.onerror = (err) => { + log.error("Language server socket error:", err); + // eslint-disable-next-line @typescript-eslint/prefer-promise-reject-errors + reject(err); + }; + socket.onclose = (event) => { + log.info("Language server socket closed:", event); + }; + }); + }; + const clientOptions: LanguageClientOptions = { + documentSelector: [ + { scheme: "vscode-notebook-cell", language: "python" }, + ], + middleware: getMiddleware(vs), + }; + return this.createVSLanguageClient( + "colabLanguageServer", + "Colab Language Server", + serverOptions, + clientOptions, + ); + } +} diff --git a/src/lsp/language-client.unit.test.ts b/src/lsp/language-client.unit.test.ts new file mode 100644 index 00000000..78222072 --- /dev/null +++ b/src/lsp/language-client.unit.test.ts @@ -0,0 +1,460 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { randomUUID } from "crypto"; +import { expect } from "chai"; +import sinon, { SinonStubbedInstance } from "sinon"; +import type { + LanguageClientOptions, + LanguageClient, + MessageTransports, + ServerOptions, +} from "vscode-languageclient/node"; +import { WebSocket, AddressInfo, WebSocketServer } from "ws"; +import { Variant } from "../colab/api"; +import { + COLAB_CLIENT_AGENT_HEADER, + COLAB_RUNTIME_PROXY_TOKEN_HEADER, +} from "../colab/headers"; +import { + AssignmentChangeEvent, + AssignmentManager, +} from "../jupyter/assignments"; +import { ColabAssignedServer } from "../jupyter/servers"; +import { TestUri } from "../test/helpers/uri"; +import { newVsCodeStub, VsCodeStub } from "../test/helpers/vscode"; +import { LanguageClientController } from "./language-client"; + +class TestLanguageClient + implements Pick +{ + private connection: Promise; + private sendPingHandle: NodeJS.Timeout; + + constructor( + _id: string, + _name: string, + private readonly serverOptions: ServerOptions, + _clientOptions: LanguageClientOptions, + ) {} + + needsStart(): boolean { + return true; + } + + async ping() { + // The interface calls for passing an object, but the + // test implementation expects a stringified object. + ( + (await this.connection).writer as unknown as { + write: (s: string) => void; + } + ).write(JSON.stringify({ jsonrpc: "{}" })); + } + + start(): Promise { + this.connection = ( + this.serverOptions as () => Promise + )(); + // Periodically send empty ping messages so that tests + // can verify that the connection is live. + this.sendPingHandle = setInterval(() => { + try { + void this.ping(); + } catch (e) { + console.log(e); + } + }, 10); + return Promise.resolve(); + } + + async dispose(): Promise { + (await this.connection).writer.end(); + clearTimeout(this.sendPingHandle); + } +} + +function newTestLanguageClient( + id: string, + name: string, + serverOptions: ServerOptions, + clientOptions: LanguageClientOptions, +): LanguageClient { + return new TestLanguageClient( + id, + name, + serverOptions, + clientOptions, + ) as Partial as LanguageClient; +} + +const REFRESH_MS = 60000; +const emptyFunc = () => { + // Empty intentionally. +}; + +describe("LanguageClientController", () => { + let assignmentStub: SinonStubbedInstance; + let vsStub: VsCodeStub; + let server: WebSocketServer; + let latestServer: ColabAssignedServer; + + beforeEach(async () => { + assignmentStub = sinon.createStubInstance(AssignmentManager); + vsStub = newVsCodeStub(); + Object.defineProperty(assignmentStub, "onDidAssignmentsChange", { + value: sinon.stub(), + }); + assignmentStub.onDidAssignmentsChange.returns({ dispose: emptyFunc }); + server = new WebSocketServer({ port: 9876, host: "127.0.0.1" }); + // Wait for the server to be listening. + await new Promise((resolve) => + server.on("listening", () => { + resolve(); + }), + ); + const addr = server.address() as AddressInfo; + const baseUrl = TestUri.parse( + `ws://${addr.address}:${addr.port.toString()}`, + ); + latestServer = { + id: randomUUID(), + label: "Colab GPU A100", + variant: Variant.GPU, + accelerator: "A100", + endpoint: "m-s-foo", + connectionInformation: { + baseUrl, + token: "123", + tokenExpiry: new Date(Date.now() + REFRESH_MS), + headers: { + [COLAB_RUNTIME_PROXY_TOKEN_HEADER.key]: "123", + [COLAB_CLIENT_AGENT_HEADER.key]: COLAB_CLIENT_AGENT_HEADER.value, + }, + }, + dateAssigned: new Date(), + }; + }); + + afterEach(() => { + server.close(); + }); + + it("sets up a socket to a server", async () => { + assignmentStub.latestServer.returns(Promise.resolve(latestServer)); + // Promise that resolves when the server receives a websocket connection. + const connectionPromise = new Promise((resolve, reject) => { + server.on("connection", (socket) => { + resolve(socket); + }); + // Avoid hanging the test forever. + setTimeout(() => { + reject(new Error("Timeout waiting for connection")); + }, 2000); + }); + const languageClient = new LanguageClientController( + vsStub.asVsCode(), + assignmentStub, + newTestLanguageClient, + ); + // Ensure the client is started so it registers its assignment-change + // handler. + languageClient.on(); + // Await connection after enabling the client. + const socket = await connectionPromise; + + // Promise that resolves when the server disconnects. + const disconnectPromise = new Promise((resolve, reject) => { + socket.on("close", () => { + resolve(); + }); + setTimeout(() => { + reject(new Error("Timeout waiting for close")); + }, 5000); + }); + + // Ensure the client disconnects on disposal. + languageClient.off(); + await disconnectPromise; + }); + + it("disconnects when server is unassigned", async () => { + let connectedCallback = (_: AssignmentChangeEvent) => { + // NoOp + }; + assignmentStub.onDidAssignmentsChange.callsFake( + (listener: (e: AssignmentChangeEvent) => void) => { + connectedCallback = listener; + return { dispose: emptyFunc }; + }, + ); + assignmentStub.latestServer.returns(Promise.resolve(latestServer)); + // Promise that resolves when the server receives a websocket connection. + const connectionPromise = new Promise((resolve, reject) => { + server.on("connection", (socket) => { + resolve(socket); + }); + // Avoid hanging the test forever. + setTimeout(() => { + reject(new Error("Timeout waiting for connection")); + }, 2000); + }); + const languageClient = new LanguageClientController( + vsStub.asVsCode(), + assignmentStub, + newTestLanguageClient, + ); + // Ensure the client is started so it registers its assignment-change + // handler. + languageClient.on(); + // Await connection after enabling the client. + // Await connection after enabling the client. + const socket = await connectionPromise; + + // Promise that resolves when the server disconnects. + const disconnectPromise = new Promise((resolve, reject) => { + socket.on("close", () => { + resolve(); + }); + setTimeout(() => { + reject(new Error("Timeout waiting for close")); + }, 5000); + }); + + // // Ensure the client disconnects on the latest runtime disappearing. + assignmentStub.latestServer.returns(Promise.resolve(undefined)); + connectedCallback({ + added: [], + changed: [], + removed: [{ server: latestServer, userInitiated: true }], + }); + await disconnectPromise; + }); + + it("connects to a newer runtime", async () => { + let assignmentsChangedCallback = (_: AssignmentChangeEvent) => { + // NoOp + }; + assignmentStub.onDidAssignmentsChange.callsFake( + (listener: (e: AssignmentChangeEvent) => void) => { + assignmentsChangedCallback = listener; + return { dispose: emptyFunc }; + }, + ); + assignmentStub.latestServer.returns(Promise.resolve(latestServer)); + + // Promise that resolves when the server receives a websocket connection. + const connectionPromise1 = new Promise((resolve, reject) => { + server.on("connection", (socket) => { + resolve(socket); + }); + // Avoid hanging the test forever. + setTimeout(() => { + reject(new Error("Timeout waiting for connection to server 1")); + }, 2000); + }); + const languageClient = new LanguageClientController( + vsStub.asVsCode(), + assignmentStub, + newTestLanguageClient, + ); + languageClient.on(); + const socket1 = await connectionPromise1; + + // Promise that resolves when the server disconnects. + const disconnectPromise1 = new Promise((resolve, reject) => { + socket1.on("close", () => { + resolve(); + }); + setTimeout(() => { + reject(new Error("Timeout waiting for close from server 1")); + }, 5000); + }); + + // Set up a second server. + const server2 = new WebSocketServer({ port: 9877, host: "127.0.0.1" }); + after(() => { + server2.close(); + }); + await new Promise((resolve) => + server2.on("listening", () => { + resolve(); + }), + ); + const addr2 = server2.address() as AddressInfo; + const baseUrl2 = TestUri.parse( + `ws://${addr2.address}:${addr2.port.toString()}`, + ); + const latestServer2: ColabAssignedServer = { + ...latestServer, + id: randomUUID(), + // Must be a new endpoint to trigger a new connection. + endpoint: "m-s-foo2", + connectionInformation: { + ...latestServer.connectionInformation, + baseUrl: baseUrl2, + }, + }; + + const connectionPromise2 = new Promise((resolve, reject) => { + server2.on("connection", (socket) => { + resolve(socket); + }); + setTimeout(() => { + reject(new Error("Timeout waiting for connection to server 2")); + }, 2000); + }); + + // Switch to the new server. + assignmentStub.latestServer.returns(Promise.resolve(latestServer2)); + assignmentsChangedCallback({ + added: [latestServer2], + changed: [], + removed: [], + }); + + await disconnectPromise1; + const socket2 = await connectionPromise2; + + const disconnectPromise2 = new Promise((resolve, reject) => { + socket2.on("close", () => { + resolve(); + }); + setTimeout(() => { + reject(new Error("Timeout waiting for close from server 2")); + }, 5000); + }); + + languageClient.dispose(); + await disconnectPromise2; + }); + + it("does not reconnect when an older server is removed", async () => { + let assignmentsChangedCallback = (_: AssignmentChangeEvent) => { + // NoOp + }; + assignmentStub.onDidAssignmentsChange.callsFake( + (listener: (e: AssignmentChangeEvent) => void) => { + assignmentsChangedCallback = listener; + return { dispose: emptyFunc }; + }, + ); + assignmentStub.latestServer.returns(Promise.resolve(latestServer)); + + // Promise that resolves when the server receives a websocket connection. + const connectionPromise1 = new Promise((resolve, reject) => { + server.on("connection", (socket) => { + resolve(socket); + }); + // Avoid hanging the test forever. + setTimeout(() => { + reject(new Error("Timeout waiting for connection to server 1")); + }, 2000); + }); + const languageClient = new LanguageClientController( + vsStub.asVsCode(), + assignmentStub, + newTestLanguageClient, + ); + languageClient.on(); + const socket1 = await connectionPromise1; + + let closed = false; + socket1.on("close", () => { + closed = true; + }); + const removedServer: ColabAssignedServer = { + ...latestServer, + endpoint: "not-newest", + }; + // Call the callback, even though latestServer is returning the same + // value. This is expected if the user removes an older runtime. + assignmentsChangedCallback({ + added: [], + changed: [], + removed: [{ server: removedServer, userInitiated: true }], + }); + + // Listen for another message on the client to know that the connection + // is still live. + await new Promise((resolve, reject) => { + socket1.once("message", () => { + resolve(); + }); + setTimeout(() => { + reject(new Error("Did not complete within timeout")); + }, 5000); + }); + expect(closed).to.equal(false); + languageClient.dispose(); + }); + + it("can call the assignment callback multiple times", async () => { + assignmentStub.latestServer + .onFirstCall() + .returns(Promise.resolve(undefined)); + assignmentStub.latestServer + .onSecondCall() + .returns(Promise.reject(new Error("Test error"))); + assignmentStub.latestServer + .onThirdCall() + .returns(Promise.resolve(latestServer)); + let assignmentsChangedCallback = (_: AssignmentChangeEvent) => { + // NoOp + }; + assignmentStub.onDidAssignmentsChange.callsFake( + (listener: (e: AssignmentChangeEvent) => void) => { + assignmentsChangedCallback = listener; + return { dispose: emptyFunc }; + }, + ); + // Promise that resolves when the server receives a websocket connection. + const connectionPromise = new Promise((resolve, reject) => { + server.on("connection", (socket) => { + resolve(socket); + }); + // Avoid hanging the test forever. + setTimeout(() => { + reject(new Error("Timeout waiting for connection")); + }, 2000); + }); + const languageClient = new LanguageClientController( + vsStub.asVsCode(), + assignmentStub, + newTestLanguageClient, + ); + // Ensure the client is started so it registers its assignment-change + // handler. + languageClient.on(); + // Call the callback twice, to check that it recovers from the first error + // returned, and also can handle multiple callbacks happening at once. + assignmentsChangedCallback({ + added: [latestServer], + changed: [], + removed: [], + }); + assignmentsChangedCallback({ + added: [latestServer], + changed: [], + removed: [], + }); + // Await connection after enabling the client. + const socket = await connectionPromise; + // Promise that resolves when the server disconnects. + const disconnectPromise = new Promise((resolve, reject) => { + socket.on("close", () => { + resolve(); + }); + setTimeout(() => { + reject(new Error("Timeout waiting for close")); + }, 5000); + }); + + // Ensure the client disconnects on disposal. + languageClient.off(); + await disconnectPromise; + }); +}); diff --git a/src/test/helpers/uri.ts b/src/test/helpers/uri.ts index a50ed68f..10a1ddd1 100644 --- a/src/test/helpers/uri.ts +++ b/src/test/helpers/uri.ts @@ -24,7 +24,7 @@ export class TestUri implements vscode.Uri { const url = new URL(stringUri); return new TestUri( url.protocol.replace(/:$/, ""), - url.hostname, + url.hostname + (url.port.length > 0 ? `:${url.port}` : ""), url.pathname, url.search.replace(/^\?/, ""), url.hash.replace(/^#/, ""),