diff --git a/tools/wasm/src/bytes.ts b/tools/wasm/src/bytes.ts index 5cbca5ffdb1f44..a166958ef7ad69 100644 --- a/tools/wasm/src/bytes.ts +++ b/tools/wasm/src/bytes.ts @@ -10,11 +10,12 @@ export function Struct( ) { let size = 0; - const TheStruct = class { + return class { #dv: DataView; constructor(dv: DataView) { this.#dv = dv; } + static { for ( const [key, type] of Object.entries( @@ -33,41 +34,37 @@ export function Struct( size += type.size; } } - } as { new (dv: DataView): T }; - const type: Type = { - get(dv, offset) { + static get(dv: DataView, offset: number) { if (offset !== 0) dv = new DataView(dv.buffer, dv.byteOffset + offset); - return new TheStruct(dv); - }, - set(dv, offset, value) { + return new this(dv); + } + static set(dv: DataView, offset: number, value: T) { if (offset !== 0) dv = new DataView(dv.buffer, dv.byteOffset + offset); - Object.assign(new TheStruct(dv), value); - }, - size, - }; - - return Object.assign(TheStruct, type); + Object.assign(new this(dv), value); + } + static size = size; + } as { new (dv: DataView): T } & Type; } export function FixedArray( - { get, set, size }: Type, + type: Type, length: number, ): Type { return { get(dv, offset) { const arr = Array(length); for (let i = 0; i < length; i++) { - arr[i] = get(dv, offset + size * i); + arr[i] = type.get(dv, offset + type.size * i); } return arr; }, set(dv, offset, value) { for (let i = 0; i < length; i++) { - set(dv, offset + size * i, value[i]!); + type.set(dv, offset + type.size * i, value[i]!); } }, - size: size * length, + size: type.size * length, }; } diff --git a/tools/wasm/src/worker.ts b/tools/wasm/src/worker.ts index 6592697331f442..3317c38589ada8 100644 --- a/tools/wasm/src/worker.ts +++ b/tools/wasm/src/worker.ts @@ -8,6 +8,7 @@ import { type Unwrap, } from "./bytes.ts"; import { type Imports, type Instance, kernel_imports } from "./wasm.ts"; +import { assert } from "./util.ts"; export interface InitMessage { fn: number; @@ -32,21 +33,13 @@ const VIRTQ_DESC_F_WRITE = 1 << 1; const VIRTQ_DESC_F_INDIRECT = 1 << 2; const VIRTQ_DESC_F_AVAIL = 1 << 7; const VIRTQ_DESC_F_USED = 1 << 15; -// const VirtqDescriptor = Struct({ -// addr: U64LE, -// len: U32LE, -// flags: U16LE, -// next: U16LE, -// }); -// type VirtqDescriptor = Unwrap; - -const VirtqDescriptor = Struct({ + +class VirtqDescriptor extends Struct({ addr: U64LE, len: U32LE, id: U16LE, flags: U16LE, -}); -type VirtqDescriptor = Unwrap; +}) {} const VIRTQ_AVAIL_F_NO_INTERRUPT = 1; const VirtqAvail = (size: number) => @@ -72,76 +65,132 @@ const VirtqUsed = (size: number) => }); type VirtqUsed = Unwrap>; +class Chain { + #mem: DataView; + #queue: Virtqueue; + id: number; + skip: number; + desc: VirtqDescriptor[]; + + constructor( + mem: DataView, + queue: Virtqueue, + id: number, + skip: number, + desc: VirtqDescriptor[], + ) { + this.#mem = mem; + this.#queue = queue; + this.id = id; + this.skip = skip; + this.desc = desc; + } + + release(written: number) { + const queue = this.#queue; + const desc = queue.desc[queue.used_idx]; + assert(desc); + const avail = (desc.flags & VIRTQ_DESC_F_AVAIL) !== 0; + const used = (desc.flags & VIRTQ_DESC_F_USED) !== 0; + if (avail === used || avail !== queue.wrap) throw new Error("ring full"); + + let flags = 0; + if (queue.wrap) flags |= VIRTQ_DESC_F_AVAIL | VIRTQ_DESC_F_USED; + if (written > 0) flags |= VIRTQ_DESC_F_WRITE; + + desc.id = this.id; + desc.len = written; + desc.flags = flags; + + queue.used_idx += this.skip; + if (queue.used_idx >= queue.size) { + queue.used_idx -= queue.size; + queue.wrap = !queue.wrap; + } + } + + *[Symbol.iterator]() { + for (const desc of this.desc) { + console.log(desc.addr, desc.flags, desc.id, desc.len); + yield new Uint8Array(this.#mem.buffer, Number(desc.addr), desc.len); + } + } +} + class Virtqueue { - #dv: DataView; + #mem: DataView; - // ring stuff: size: number; desc: VirtqDescriptor[]; used: VirtqUsed; avail: VirtqAvail; - - // queue stuff: - count_used = 0; - used_wrap_count = true; - last_avail = 0; + wrap = true; + used_idx = 0; + avail_idx = 0; constructor( - dv: DataView, + mem: DataView, size: number, desc_addr: number, used_addr: number, avail_addr: number, ) { - this.#dv = dv; + assert(size !== 0); + assert(mem.byteOffset === 0); + this.#mem = mem; this.size = size; - this.desc = FixedArray(VirtqDescriptor, size).get(dv, desc_addr); - this.used = VirtqUsed(size).get(dv, used_addr); - this.avail = VirtqAvail(size).get(dv, avail_addr); + this.desc = FixedArray(VirtqDescriptor, size).get(mem, desc_addr); + this.used = VirtqUsed(size).get(mem, used_addr); + this.avail = VirtqAvail(size).get(mem, avail_addr); } - take() { - if (this.count_used >= this.size) { - throw new Error("Virtqueue size exceeded"); - } - - const last_seen = this.desc[this.last_avail]!; - console.log("flags:", last_seen.flags.toString(16)); - let id; - let count = 0; - const descs = []; - if (last_seen.flags & VIRTQ_DESC_F_INDIRECT) { - ({ id } = last_seen); - const max = last_seen.len / VirtqDescriptor.size; - for (let i = 0; i < max; i++) descs.push(this.desc[i]!); - count = 1; - } else { - for ( - let i = 0; - this.desc[i] !== undefined && this.desc[i]!.flags & VIRTQ_DESC_F_NEXT; - i = (i + 1) % this.size - ) { - if (++count > this.size) throw new Error("looped"); - descs.push(this.desc[i]!); + pop() { + let i = this.#advance(); + if (i === null) return null; + const head = i; + + let desc = this.desc[i]; + assert(desc); + + const chain = new Chain( + this.#mem, + this, + desc.id, + 1, + this.desc.slice(head, i + 1), + ); + + if (desc.flags & VIRTQ_DESC_F_NEXT) { + do { + i = this.#advance(); + if (i === null) throw new Error("no next descriptor is available"); + desc = this.desc[i]; + assert(desc); + } while (desc.flags & VIRTQ_DESC_F_NEXT); + chain.skip = i - head + 1; + chain.desc = this.desc.slice(head, i + 1); + } else if (desc.flags & VIRTQ_DESC_F_INDIRECT) { + if (desc.len % VirtqDescriptor.size !== 0) { + throw new Error("malformed indirect buffer"); } - if (descs.length) ({ id } = descs.at(-1)!); + chain.desc = FixedArray(VirtqDescriptor, desc.len / VirtqDescriptor.size) + .get(this.#mem, Number(desc.addr)); } - const readable = []; - const writable = []; - for (const desc of descs) { - if (desc.flags & VIRTQ_DESC_F_WRITE) writable.push(desc); - else readable.push(desc); - } + return chain; + } - this.count_used += count; - this.last_avail += count; - if (this.last_avail > this.size) { - this.last_avail -= this.size; - this.used_wrap_count = !this.used_wrap_count; - } + #advance() { + const desc = this.desc[this.avail_idx]; + assert(desc); + + const avail = (desc.flags & VIRTQ_DESC_F_AVAIL) !== 0; + const used = (desc.flags & VIRTQ_DESC_F_USED) !== 0; + if (avail === used || avail !== this.wrap) return null; - return { id, count, readable, writable }; + const index = this.avail_idx; + this.avail_idx = (this.avail_idx + 1) % this.size; + return index; } } @@ -153,8 +202,7 @@ abstract class VirtioDevice { config = new Uint8Array(0); features = VIRTIO_F_VERSION_1 | VIRTIO_F_RING_PACKED | - VIRTIO_F_EVENT_IDX | - VIRTIO_F_INDIRECT_DESC; + VIRTIO_F_EVENT_IDX | VIRTIO_F_INDIRECT_DESC; trigger_interrupt = (kind: "config" | "vring"): void => { // this function is overwritten on device setup @@ -185,9 +233,14 @@ class EntropyDevice extends VirtioDevice { const queue = this.vqs[vq]!; console.log("notify", vq, queue); - const elem = queue.take(); - console.log(elem); - // TODO fill with bytes, enqueue buffer + const chain = queue.pop(); + assert(chain); + let n = 0; + for (const buf of chain) { + crypto.getRandomValues(buf); + n += buf.byteLength; + } + chain.release(n); this.trigger_interrupt("vring"); }