Skip to content

Commit

Permalink
Use Pyodide native impl in module auto load (#1176)
Browse files Browse the repository at this point in the history
  • Loading branch information
whitphx authored Oct 19, 2024
1 parent a6543de commit b3385f2
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 116 deletions.
63 changes: 0 additions & 63 deletions packages/kernel/src/module-auto-load.spec.ts

This file was deleted.

59 changes: 6 additions & 53 deletions packages/kernel/src/module-auto-load.ts
Original file line number Diff line number Diff line change
@@ -1,67 +1,20 @@
import type { PyodideInterface } from "pyodide";
import type { PyProxy } from "pyodide/ffi";
import type { ModuleAutoLoadMessage } from "./types";
import type { PostMessageFn } from "./worker-runtime";

let findImportsPyFn: ((source: string) => PyProxy) | undefined;
export function findImports(
pyodide: PyodideInterface,
source: string,
): Set<string> {
if (!findImportsPyFn) {
// Ref: https://github.com/pyodide/pyodide/blob/10b484cfe427e076c929a55dc35cfff01ea8d3bc/src/py/_pyodide/_base.py#L586
const pyCode = `
import ast
from textwrap import dedent
def find_imports(source: str) -> list[str]:
source = dedent(source)
try:
mod = ast.parse(source)
except SyntaxError:
return []
imports = set()
for node in mod.body:
if isinstance(node, ast.Import):
for name in node.names:
node_name = name.name
imports.add(node_name.split(".")[0])
elif isinstance(node, ast.ImportFrom):
module_name = node.module
if module_name is None:
continue
imports.add(module_name.split(".")[0])
return imports
`;
pyodide.runPython(pyCode);
findImportsPyFn = pyodide.globals.get("find_imports") as (
source: string,
) => PyProxy;
}
return findImportsPyFn(source).toJs();
}

export function unionSets<T>(sets: Set<T>[]): Set<T> {
const union = new Set<T>();
for (const set of sets) {
for (const item of set) {
union.add(item);
}
}
return union;
}

export async function tryModuleAutoLoad(
pyodide: PyodideInterface,
postMessage: PostMessageFn,
sources: string[],
): Promise<void> {
// Ref: `pyodide.loadPackagesFromImports` (https://github.com/pyodide/pyodide/blob/0.26.0/src/js/api.ts#L191)

const importsArr = sources.map((source) => findImports(pyodide, source));
const importsSet = unionSets(importsArr);
const imports = Array.from(importsSet);
const pyodidePy = pyodide.pyimport("pyodide");
const findImports = (source: string): string[] =>
pyodidePy.code.find_imports(source).toJs();

const importsArr = sources.map((source) => findImports(source));
const imports = Array.from(new Set(importsArr.flat()));

const notFoundImports = imports.filter(
(name) =>
Expand Down

0 comments on commit b3385f2

Please sign in to comment.