From 4ffe83c1441ae266bb8749539caac7852a54b0e5 Mon Sep 17 00:00:00 2001 From: Kevin Eger Date: Wed, 12 Nov 2025 01:45:54 +0000 Subject: [PATCH] feat: colab language client for an Colab server The language client maintains ownership of the `vscode-languageclient/node` `LanguageClient`, binding the correct middleware and `pipe`-ing the stream to be LSP compliant. Here we need to add `ws` and `@types/ws` to support the `WebSocket` communication. Minor update to `Updated src/test/helpers/uri.ts` ensuring we correctly parse URIs with port numbers. --- package-lock.json | 3 +- package.json | 2 + src/lsp/language-client.ts | 183 ++++++++++++++++ src/lsp/language-client.unit.test.ts | 304 +++++++++++++++++++++++++++ src/test/helpers/uri.ts | 2 +- 5 files changed, 492 insertions(+), 2 deletions(-) create mode 100644 src/lsp/language-client.ts create mode 100644 src/lsp/language-client.unit.test.ts diff --git a/package-lock.json b/package-lock.json index 0eb78739..b24902e3 100644 --- a/package-lock.json +++ b/package-lock.json @@ -14,6 +14,7 @@ "semver": "^7.7.3", "uuid": "^11.0.3", "vscode-languageclient": "^10.0.0-next.17", + "ws": "^8.18.3", "zod": "^4.0.17" }, "devDependencies": { @@ -30,6 +31,7 @@ "@types/sinon": "^17.0.3", "@types/uuid": "^10.0.0", "@types/vscode": "^1.99.3", + "@types/ws": "^8.18.1", "@vscode/jupyter-extension": "1.0.93", "@vscode/test-cli": "^0.0.10", "@vscode/test-electron": "^2.4.1", @@ -11877,7 +11879,6 @@ "version": "8.18.3", "resolved": "https://registry.npmjs.org/ws/-/ws-8.18.3.tgz", "integrity": "sha512-PEIGCY5tSlUt50cqyMXfCzX+oOPqN0vuGqWzbcJ2xvnkzkq46oOpz7dQaTDBdfICb4N14+GARUDw2XV2N4tvzg==", - "dev": true, "license": "MIT", "engines": { "node": ">=10.0.0" diff --git a/package.json b/package.json index 6045c331..c656c9c5 100644 --- a/package.json +++ b/package.json @@ -159,6 +159,7 @@ "@types/sinon": "^17.0.3", "@types/uuid": "^10.0.0", "@types/vscode": "^1.99.3", + "@types/ws": "^8.18.1", "@vscode/jupyter-extension": "1.0.93", "@vscode/test-cli": "^0.0.10", "@vscode/test-electron": "^2.4.1", @@ -193,6 +194,7 @@ "semver": "^7.7.3", "uuid": "^11.0.3", "vscode-languageclient": "^10.0.0-next.17", + "ws": "^8.18.3", "zod": "^4.0.17" } } diff --git a/src/lsp/language-client.ts b/src/lsp/language-client.ts new file mode 100644 index 00000000..45c7d59d --- /dev/null +++ b/src/lsp/language-client.ts @@ -0,0 +1,183 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { Duplex } from "stream"; +import vscode, { Disposable } from "vscode"; +import type { + LanguageClientOptions, + ServerOptions, + LanguageClient, + DocumentSelector, + StreamInfo, +} from "vscode-languageclient/node"; +import { WebSocket, createWebSocketStream } from "ws"; +import { log } from "../common/logging"; +import { ColabAssignedServer } from "../jupyter/servers"; +import { ContentLengthTransformer } from "./content-length-transformer"; +import { + filterNonIPythonDiagnostics as filterDiags, + filterNonIPythonWorkspaceDiagnostics as filterWorkspaceDiags, +} from "./middleware"; + +/** + * The document selector for Python notebook cells. + */ +const PYTHON_NOTEBOOK: DocumentSelector = [ + { + scheme: "vscode-notebook-cell", + language: "python", + }, +]; + +/** + * Factory function for creating new {@link LanguageClient}s. + */ +export type LanguageClientFactory = ( + id: string, + name: string, + serverOptions: ServerOptions, + clientOptions: LanguageClientOptions, +) => LanguageClient; + +type WebSocketFactory = (url: string) => WebSocket; +type StreamFactory = (socket: WebSocket) => Duplex; + +/** + * A language client to the configured server. + * + * Must call {@link ColabLanguageClient.start | start} to begin receiving + * diagnostics. Callers should then call + * {@link ColabLanguageClient.dispose | dispose} when they no longer need the + * client. + */ +export class ColabLanguageClient implements Disposable { + private languageClient?: LanguageClient; + + constructor( + private vs: typeof vscode, + private readonly server: ColabAssignedServer, + private readonly createClient: LanguageClientFactory, + private readonly createSocket: WebSocketFactory = (url) => + new WebSocket(url), + private readonly createStream: StreamFactory = (socket) => + createWebSocketStream(socket), + ) { + this.languageClient = this.buildClient(); + } + + /** + * Starts the language client if it needs starting. + * + * Cannot be started if {@link ColabLanguageClient.dispose | dispose} has been + * called. + */ + async start(): Promise { + if (!this.languageClient) { + throw new Error("Cannot start after being disposed"); + } + + if (!this.languageClient.needsStart()) { + return; + } + + await this.languageClient.start(); + log.info(`Started a Colab Language Client for ${this.server.label}`); + } + + async dispose(): Promise { + if (!this.languageClient) { + return; + } + await this.languageClient.dispose(); + this.languageClient = undefined; + log.info(`Removed the Colab Language Client for ${this.server.label}`); + } + + private buildClient(): LanguageClient { + const serverOptions = this.getServerOptions(); + const clientOptions = this.getClientOptions(); + + return this.createClient( + "colabLanguageServer", + "Colab Language Server", + serverOptions, + clientOptions, + ); + } + + private getServerOptions(): ServerOptions { + return async () => { + const url = this.buildLanguageServerUrl(); + const socket = this.createSocket(url.toString()); + socket.binaryType = "arraybuffer"; + return this.createSocketConnection(socket); + }; + } + + private getClientOptions(): LanguageClientOptions { + return { + documentSelector: PYTHON_NOTEBOOK, + middleware: { + provideDiagnostics: (d, p, t, n) => { + return filterDiags(this.vs, d, p, t, n); + }, + provideWorkspaceDiagnostics: (r, t, p, n) => { + return filterWorkspaceDiags(this.vs, r, t, p, n); + }, + }, + }; + } + + private buildLanguageServerUrl(): URL { + const c = this.server.connectionInformation; + const url = new URL(c.baseUrl.toString()); + url.protocol = "wss"; + url.pathname = "/colab/lsp"; + url.search = `?colab-runtime-proxy-token=${c.token}`; + return url; + } + + /** + * Creates the websocket connection. Pipes the stream to transform messages to + * the required/expected format and logs relevant events. + */ + private createSocketConnection(socket: WebSocket): Promise { + return new Promise<{ + writer: NodeJS.WritableStream; + reader: NodeJS.ReadableStream; + }>((resolve, reject) => { + socket.onopen = () => { + log.debug("Language server socket opened."); + const stream = this.createStream(socket); + const reader = stream.pipe(new ContentLengthTransformer()); + const writer = stream; + + stream.on("error", (err) => { + log.error("Language server stream error", err); + }); + stream.on("close", () => { + log.debug("Language server stream closed"); + }); + resolve({ + writer, + reader, + }); + }; + socket.onerror = (err) => { + log.error("Language server socket error", err); + const e = + err.error instanceof Error + ? err.error + : new Error(`Socket error: ${err.message}`); + reject(e); + }; + socket.onclose = (event) => { + log.info("Language server socket closed", event); + reject(new Error("Language server socket closed unexpectedly")); + }; + }); + } +} diff --git a/src/lsp/language-client.unit.test.ts b/src/lsp/language-client.unit.test.ts new file mode 100644 index 00000000..8bca3651 --- /dev/null +++ b/src/lsp/language-client.unit.test.ts @@ -0,0 +1,304 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { randomUUID } from "crypto"; +import { Duplex, EventEmitter } from "stream"; +import { assert, expect } from "chai"; +import * as sinon from "sinon"; +import { TextDocument } from "vscode"; +import { + vsdiag, + type LanguageClient, + type LanguageClientOptions, +} from "vscode-languageclient/node"; +import { WebSocket } from "ws"; +import { Variant } from "../colab/api"; +import { + COLAB_CLIENT_AGENT_HEADER, + COLAB_RUNTIME_PROXY_TOKEN_HEADER, +} from "../colab/headers"; +import { LogLevel } from "../common/logging"; +import { ColabAssignedServer } from "../jupyter/servers"; +import { TestCancellationToken } from "../test/helpers/cancellation"; +import { ColabLogWatcher } from "../test/helpers/logging"; +import { TestUri } from "../test/helpers/uri"; +import { newVsCodeStub, VsCodeStub } from "../test/helpers/vscode"; +import { ContentLengthTransformer } from "./content-length-transformer"; +import { ColabLanguageClient, LanguageClientFactory } from "./language-client"; + +const DEFAULT_SERVER: ColabAssignedServer = { + id: randomUUID(), + label: "Colab GPU A100", + variant: Variant.GPU, + accelerator: "A100", + endpoint: "m-s-foo", + connectionInformation: { + baseUrl: TestUri.parse("https://example.com"), + token: "123", + tokenExpiry: new Date(Date.now() + 1000 * 60 * 60), + headers: { + [COLAB_RUNTIME_PROXY_TOKEN_HEADER.key]: "123", + [COLAB_CLIENT_AGENT_HEADER.key]: COLAB_CLIENT_AGENT_HEADER.value, + }, + }, + dateAssigned: new Date(), +}; + +type LanguageClientStub = sinon.SinonStubbedInstance; + +function newLanguageClientStub(): LanguageClientStub { + return { + needsStart: sinon.stub<[], boolean>(), + start: sinon.stub<[], Promise>(), + dispose: sinon.stub<[], Promise>(), + } as unknown as LanguageClientStub; +} + +type WebSocketStub = sinon.SinonStubbedInstance; + +function newWebSocketStub(): WebSocketStub { + const partial = new EventEmitter() as Partial; + partial.binaryType = "arraybuffer"; + return partial as WebSocketStub; +} + +type DuplexStub = sinon.SinonStubbedInstance; + +function newDuplexStub(): DuplexStub { + const stub = sinon.createStubInstance(Duplex); + stub.pipe.returns(stub); + + return stub; +} + +describe("ColabLanguageClient", () => { + let vs: VsCodeStub; + let logs: ColabLogWatcher; + let lsClient: LanguageClientStub; + let socket: WebSocketStub; + let stream: DuplexStub; + let client: ColabLanguageClient; + let factory: sinon.SinonStub< + Parameters, + ReturnType + >; + let createSocket: sinon.SinonStub<[string], WebSocket>; + + beforeEach(() => { + vs = newVsCodeStub(); + logs = new ColabLogWatcher(vs, LogLevel.Error); + lsClient = newLanguageClientStub(); + socket = newWebSocketStub(); + stream = newDuplexStub(); + + factory = sinon.stub(); + factory.returns(lsClient); + + createSocket = sinon.stub<[string], WebSocket>().returns(socket); + + client = new ColabLanguageClient( + vs.asVsCode(), + DEFAULT_SERVER, + factory, + createSocket, + () => stream, + ); + }); + + afterEach(async () => { + await client.dispose(); + logs.dispose(); + sinon.restore(); + }); + + describe("lifecycle", () => { + it("throws when started after being disposed", async () => { + await client.dispose(); + + try { + await client.start(); + expect.fail("Should have thrown"); + } catch (e) { + expect((e as Error).message).to.equal( + "Cannot start after being disposed", + ); + } + }); + + it("disposes the supporting language client", async () => { + await client.dispose(); + + expect(lsClient.dispose.callCount).to.equal(1); + }); + + it("no-ops on repeat dispose calls", async () => { + await client.dispose(); + await client.dispose(); + + expect(lsClient.dispose.callCount).to.equal(1); + }); + }); + + describe("configuration", () => { + let clientOptions: LanguageClientOptions; + + beforeEach(() => { + expect(factory.callCount).to.equal(1); + const call = factory.getCall(0); + clientOptions = call.args[3]; + }); + + it("initializes the client with correct arguments", () => { + const call = factory.getCall(0); + const [id, name, serverOptions] = call.args; + expect(id).to.equal("colabLanguageServer"); + expect(name).to.equal("Colab Language Server"); + expect(serverOptions).to.be.a("function"); + }); + + it("includes the expected document selector", () => { + expect(clientOptions.documentSelector).to.deep.equal([ + { + scheme: "vscode-notebook-cell", + language: "python", + }, + ]); + }); + + it("binds the diagnostics middleware", async () => { + const middleware = clientOptions.middleware; + assert(middleware, "middleware is undefined"); + + const provideDiagnostics = middleware.provideDiagnostics; + assert(provideDiagnostics, "provideDiagnostics is undefined"); + + const docUri = TestUri.parse("file:///test.ipynb"); + const doc = { + uri: docUri, + getText: sinon.stub().returns("!"), + }; + vs.workspace.textDocuments = [ + doc as Pick as TextDocument, + ]; + const token = new TestCancellationToken(new vs.EventEmitter()); + const next = sinon.stub().resolves({ + kind: "full", + items: [ + { + range: { + start: { line: 0, character: 0 }, + end: { line: 0, character: 1 }, + }, + message: "bash command", + }, + ], + }); + + const result = await provideDiagnostics(docUri, undefined, token, next); + + assert( + result?.kind.toString() === "full", + "Expected full diagnostic report", + ); + expect((result as vsdiag.FullDocumentDiagnosticReport).items).to.be.empty; + expect(doc.getText.called).to.be.true; + }); + }); + + describe("connection", () => { + it("connects to the correct URL", async () => { + lsClient.needsStart.returns(true); + await client.start(); + + const call = factory.getCall(0); + const serverOptions = call.args[2]; + const promise = (serverOptions as () => Promise)(); + + expect(createSocket.calledOnce).to.be.true; + const urlString = createSocket.firstCall.args[0]; + const url = new URL(urlString); + + expect(url.protocol).to.equal("wss:"); + expect(url.hostname).to.equal("example.com"); + expect(url.pathname).to.equal("/colab/lsp"); + expect(url.searchParams.get("colab-runtime-proxy-token")).to.equal("123"); + + assert(socket.onopen); + socket.onopen({ type: "open", target: socket }); + await promise; + }); + + it("rejects if socket closes before opening", async () => { + lsClient.needsStart.returns(true); + await client.start(); + + const call = factory.getCall(0); + const serverOptions = call.args[2]; + const promise = (serverOptions as () => Promise)(); + + if (!socket.onclose) { + expect.fail("onclose was not assigned"); + } + socket.onclose({ + code: 1006, // Abnormal closure + reason: "connection refused", + wasClean: false, + type: "close", + target: socket, + }); + + await expect(promise).to.be.rejectedWith( + "Language server socket closed unexpectedly", + ); + }); + }); + + describe("when started", () => { + beforeEach(async () => { + lsClient.needsStart.returns(true); + await client.start(); + const call = factory.getCall(0); + const serverOptions = call.args[2]; + const promise = (serverOptions as () => Promise)(); + assert(socket.onopen); + socket.onopen({ type: "open", target: socket }); + await promise; + }); + + it("pipes the stream with the content-length header", () => { + expect(stream.pipe.callCount).to.equal(1); + const arg = stream.pipe.getCall(0).args[0]; + expect(arg).to.be.instanceOf(ContentLengthTransformer); + }); + + it("logs piped stream errors", () => { + const streamCall = stream.on + .getCalls() + .find((c) => c.args[0] === "error"); + assert(streamCall, "no error listener registered"); + const listener = streamCall.args[1]; + + listener(new Error("stream error")); + + const output = logs.output; + expect(output).to.match(/stream/); + }); + + it("logs socket errors", () => { + if (!socket.onerror) { + expect.fail("onerror was not assigned"); + } + socket.onerror({ + error: new Error("socket error"), + message: "socket error", + type: "error", + target: socket, + }); + const output = logs.output; + expect(output).to.match(/socket/); + }); + }); +}); diff --git a/src/test/helpers/uri.ts b/src/test/helpers/uri.ts index a50ed68f..10a1ddd1 100644 --- a/src/test/helpers/uri.ts +++ b/src/test/helpers/uri.ts @@ -24,7 +24,7 @@ export class TestUri implements vscode.Uri { const url = new URL(stringUri); return new TestUri( url.protocol.replace(/:$/, ""), - url.hostname, + url.hostname + (url.port.length > 0 ? `:${url.port}` : ""), url.pathname, url.search.replace(/^\?/, ""), url.hash.replace(/^#/, ""),