From 123cd85d8c0a5d42bc4a5ed2ba50e4cd2eabdb51 Mon Sep 17 00:00:00 2001 From: Patrick Detlefsen Date: Sat, 28 Feb 2026 22:32:05 +0100 Subject: [PATCH 1/4] refactor: add Environment behaviour, SSH backend, and inline SpriteLifecycle - Add Jido.Shell.Environment behaviour (provision/teardown/status) for pluggable infrastructure backends - Add Jido.Shell.Environment.Sprite implementing the behaviour - Inline SpriteLifecycle logic into Environment.Sprite and delete the standalone module - Add SSH backend (Jido.Shell.Backend.SSH) with reconnection support - Extract output limit logic to Jido.Shell.Backend.OutputLimiter - Add :ssh and :public_key to extra_applications - Exclude :ssh_integration tag in test_helper --- lib/jido_shell/backend/output_limiter.ex | 32 ++ lib/jido_shell/backend/sprite.ex | 39 +- lib/jido_shell/backend/ssh.ex | 433 ++++++++++++++++++ lib/jido_shell/backend/ssh/key_callback.ex | 102 +++++ lib/jido_shell/environment.ex | 87 ++++ .../sprite.ex} | 24 +- mix.exs | 9 +- test/jido/shell/backend/ssh_test.exs | 339 ++++++++++++++ test/jido/shell/environment_test.exs | 43 ++ test/test_helper.exs | 2 +- 10 files changed, 1075 insertions(+), 35 deletions(-) create mode 100644 lib/jido_shell/backend/output_limiter.ex create mode 100644 lib/jido_shell/backend/ssh.ex create mode 100644 lib/jido_shell/backend/ssh/key_callback.ex create mode 100644 lib/jido_shell/environment.ex rename lib/jido_shell/{sprite_lifecycle.ex => environment/sprite.ex} (92%) create mode 100644 test/jido/shell/backend/ssh_test.exs create mode 100644 test/jido/shell/environment_test.exs diff --git a/lib/jido_shell/backend/output_limiter.ex b/lib/jido_shell/backend/output_limiter.ex new file mode 100644 index 0000000..a5fb244 --- /dev/null +++ b/lib/jido_shell/backend/output_limiter.ex @@ -0,0 +1,32 @@ +defmodule Jido.Shell.Backend.OutputLimiter do + @moduledoc """ + Shared output-limit logic for SSH and Sprite backends. + + Tracks emitted bytes and aborts when the configured limit is exceeded. + """ + + alias Jido.Shell.Error + + @doc """ + Check whether emitting `chunk_bytes` would exceed `output_limit`. + + Returns `{:ok, updated_bytes}` when under the limit, or + `{:limit_exceeded, %Jido.Shell.Error{}}` when the limit is breached. + """ + @spec check(non_neg_integer(), non_neg_integer(), non_neg_integer() | nil) :: + {:ok, non_neg_integer()} | {:limit_exceeded, Error.t()} + def check(chunk_bytes, emitted_bytes, output_limit) + when is_integer(chunk_bytes) and is_integer(emitted_bytes) do + updated_total = emitted_bytes + chunk_bytes + + if is_integer(output_limit) and output_limit > 0 and updated_total > output_limit do + {:limit_exceeded, + Error.command(:output_limit_exceeded, %{ + emitted_bytes: updated_total, + max_output_bytes: output_limit + })} + else + {:ok, updated_total} + end + end +end diff --git a/lib/jido_shell/backend/sprite.ex b/lib/jido_shell/backend/sprite.ex index d78bd3f..892951a 100644 --- a/lib/jido_shell/backend/sprite.ex +++ b/lib/jido_shell/backend/sprite.ex @@ -11,6 +11,7 @@ defmodule Jido.Shell.Backend.Sprite do @behaviour Jido.Shell.Backend + alias Jido.Shell.Backend.OutputLimiter alias Jido.Shell.Error @default_task_supervisor Jido.Shell.CommandTaskSupervisor @@ -359,22 +360,15 @@ defmodule Jido.Shell.Backend.Sprite do defp emit_stream_chunk(state, cmd_ref, data, output_limit, emitted_bytes) do chunk = IO.iodata_to_binary(data) - chunk_bytes = byte_size(chunk) - updated_total = emitted_bytes + chunk_bytes - cond do - is_integer(output_limit) and output_limit > 0 and updated_total > output_limit -> - _ = close_remote_handle(state, cmd_ref) - - {:error, - Error.command(:output_limit_exceeded, %{ - emitted_bytes: updated_total, - max_output_bytes: output_limit - })} - - true -> + case OutputLimiter.check(byte_size(chunk), emitted_bytes, output_limit) do + {:ok, updated_total} -> send(state.session_pid, {:command_event, {:output, chunk}}) {:ok, updated_total} + + {:limit_exceeded, error} -> + _ = close_remote_handle(state, cmd_ref) + {:error, error} end end @@ -383,17 +377,14 @@ defmodule Jido.Shell.Backend.Sprite do defp maybe_emit_output(session_pid, output, output_limit) do chunk = IO.iodata_to_binary(output) - chunk_bytes = byte_size(chunk) - - if is_integer(output_limit) and output_limit > 0 and chunk_bytes > output_limit do - {:error, - Error.command(:output_limit_exceeded, %{ - emitted_bytes: chunk_bytes, - max_output_bytes: output_limit - })} - else - send(session_pid, {:command_event, {:output, chunk}}) - :ok + + case OutputLimiter.check(byte_size(chunk), 0, output_limit) do + {:ok, _} -> + send(session_pid, {:command_event, {:output, chunk}}) + :ok + + {:limit_exceeded, error} -> + {:error, error} end end diff --git a/lib/jido_shell/backend/ssh.ex b/lib/jido_shell/backend/ssh.ex new file mode 100644 index 0000000..6845536 --- /dev/null +++ b/lib/jido_shell/backend/ssh.ex @@ -0,0 +1,433 @@ +defmodule Jido.Shell.Backend.SSH do + @moduledoc """ + SSH backend implementation for remote command execution on any SSH-accessible machine. + + Uses Erlang's built-in `:ssh` module — zero additional dependencies. + + This backend keeps the same session event contract as other backends: + + - `{:output, chunk}` + - `:command_done` + - `{:error, %Jido.Shell.Error{}}` + + ## Configuration + + %{ + session_pid: pid(), # required (injected by ShellSessionServer) + host: String.t(), # required + port: pos_integer(), # default 22 + user: String.t(), # required + key: binary(), # raw PEM content, OR + key_path: String.t(), # path to key file, OR + password: String.t(), # password auth + cwd: String.t(), # default "/" + env: map(), # default %{} + shell: String.t(), # default "sh" + connect_timeout: pos_integer(), # default 10_000 + ssh_module: module(), # default :ssh (for testing) + ssh_connection_module: module() # default :ssh_connection (for testing) + } + + """ + + @behaviour Jido.Shell.Backend + + alias Jido.Shell.Backend.OutputLimiter + alias Jido.Shell.Error + + @default_task_supervisor Jido.Shell.CommandTaskSupervisor + @default_shell "sh" + @default_port 22 + @default_connect_timeout 10_000 + + @impl true + def init(config) when is_map(config) do + ssh_mod = Map.get(config, :ssh_module, :ssh) + ssh_conn_mod = Map.get(config, :ssh_connection_module, :ssh_connection) + + with {:ok, session_pid} <- fetch_session_pid(config), + {:ok, host} <- fetch_required_string(config, :host), + {:ok, user} <- fetch_required_string(config, :user), + {:ok, auth_opts} <- build_auth_opts(config), + port = Map.get(config, :port, @default_port), + {:ok, conn} <- connect(ssh_mod, host, port, user, auth_opts, config) do + {:ok, + %{ + session_pid: session_pid, + task_supervisor: Map.get(config, :task_supervisor, @default_task_supervisor), + conn: conn, + host: host, + port: port, + user: user, + cwd: Map.get(config, :cwd, "/"), + env: normalize_env(Map.get(config, :env, %{})), + shell: Map.get(config, :shell, @default_shell), + commands_table: :ets.new(:jido_shell_ssh_commands, [:public, :set]), + ssh_module: ssh_mod, + ssh_connection_module: ssh_conn_mod, + connect_params: %{ + host: host, + port: port, + user: user, + auth_opts: auth_opts, + config: config + } + }} + end + end + + @impl true + def execute(state, command, args, exec_opts) when is_binary(command) and is_list(args) and is_list(exec_opts) do + with {:ok, state} <- ensure_connected(state) do + line = command_line(command, args) + cwd = Keyword.get(exec_opts, :dir, state.cwd) + env = Keyword.get(exec_opts, :env, state.env) |> normalize_env() + + timeout = + Keyword.get(exec_opts, :timeout) || + extract_limit(Keyword.get(exec_opts, :execution_context, %{}), :max_runtime_ms) + + output_limit = + Keyword.get(exec_opts, :output_limit) || + extract_limit(Keyword.get(exec_opts, :execution_context, %{}), :max_output_bytes) + + case start_worker(state, line, cwd, env, timeout, output_limit) do + {:ok, worker_pid} -> + {:ok, worker_pid, %{state | cwd: cwd, env: env}} + + {:error, _} = error -> + error + end + end + end + + @impl true + def cancel(state, command_ref) when is_pid(command_ref) do + close_channel(state, command_ref) + + if Process.alive?(command_ref) do + Process.exit(command_ref, :shutdown) + end + + :ok + end + + def cancel(_state, _command_ref), do: {:error, :invalid_command_ref} + + @impl true + def terminate(state) do + _ = safe_close_connection(state.ssh_module, state.conn) + _ = maybe_delete_table(state.commands_table) + :ok + end + + @impl true + def cwd(state), do: {:ok, state.cwd, state} + + @impl true + def cd(state, path) when is_binary(path), do: {:ok, %{state | cwd: path}} + + # -- Private: Connection --------------------------------------------------- + + defp connect(ssh_mod, host, port, user, auth_opts, config) do + timeout = Map.get(config, :connect_timeout, @default_connect_timeout) + + ssh_opts = + [ + {:user, String.to_charlist(user)}, + {:silently_accept_hosts, true}, + {:user_interaction, false} + | auth_opts + ] + + case ssh_mod.connect(String.to_charlist(host), port, ssh_opts, timeout) do + {:ok, conn} -> + {:ok, conn} + + {:error, reason} -> + {:error, Error.command(:start_failed, %{reason: {:ssh_connect, reason}, host: host, port: port})} + end + end + + defp ensure_connected(state) do + if connection_alive?(state) do + {:ok, state} + else + reconnect(state) + end + end + + defp connection_alive?(state) do + Process.alive?(state.conn) + rescue + _ -> false + catch + _, _ -> false + end + + defp reconnect(state) do + %{host: host, port: port, user: user, auth_opts: auth_opts, config: config} = + state.connect_params + + case connect(state.ssh_module, host, port, user, auth_opts, config) do + {:ok, conn} -> + {:ok, %{state | conn: conn}} + + {:error, _} = error -> + error + end + end + + defp build_auth_opts(config) do + cond do + is_binary(Map.get(config, :key)) -> + {:ok, [{:key_cb, {Jido.Shell.Backend.SSH.KeyCallback, key: Map.get(config, :key)}}]} + + is_binary(Map.get(config, :key_path)) -> + path = Path.expand(Map.get(config, :key_path)) + + case File.read(path) do + {:ok, pem} -> + {:ok, [{:key_cb, {Jido.Shell.Backend.SSH.KeyCallback, key: pem}}]} + + {:error, reason} -> + {:error, Error.command(:start_failed, %{reason: {:key_read_failed, reason}, path: path})} + end + + is_binary(Map.get(config, :password)) -> + {:ok, [{:password, String.to_charlist(Map.get(config, :password))}]} + + true -> + # Fall back to default SSH key discovery by the :ssh app + {:ok, []} + end + end + + defp safe_close_connection(ssh_mod, conn) do + ssh_mod.close(conn) + catch + _, _ -> :ok + end + + # -- Private: Worker ------------------------------------------------------- + + defp start_worker(state, line, cwd, env, timeout, output_limit) do + case Task.Supervisor.start_child(state.task_supervisor, fn -> + run_command_worker(state, line, cwd, env, timeout, output_limit) + end) do + {:ok, worker_pid} -> {:ok, worker_pid} + {:error, _} = error -> error + end + end + + defp run_command_worker(state, line, cwd, env, timeout, output_limit) do + ssh_conn_mod = state.ssh_connection_module + + case open_channel_and_exec(state.conn, ssh_conn_mod, line, cwd, env, state.shell) do + {:ok, channel_id} -> + :ets.insert(state.commands_table, {self(), channel_id}) + await_ssh_events(state, channel_id, line, timeout, output_limit, 0) + + {:error, reason} -> + send_finished(state.session_pid, {:error, Error.command(:start_failed, %{reason: reason, line: line})}) + end + + :ets.delete(state.commands_table, self()) + rescue + error -> + send_finished( + state.session_pid, + {:error, Error.command(:crashed, %{line: line, reason: Exception.message(error)})} + ) + end + + defp open_channel_and_exec(conn, ssh_conn_mod, line, cwd, env, shell) do + case ssh_conn_mod.session_channel(conn, :infinity) do + {:ok, channel_id} -> + # Set environment variables (best effort — many SSH servers restrict this) + Enum.each(env, fn {k, v} -> + ssh_conn_mod.setenv(conn, channel_id, String.to_charlist(k), String.to_charlist(v), 5_000) + end) + + wrapped = remote_command(shell, line, cwd, env) + + case ssh_conn_mod.exec(conn, channel_id, String.to_charlist(wrapped), :infinity) do + :success -> + {:ok, channel_id} + + :failure -> + ssh_conn_mod.close(conn, channel_id) + {:error, :exec_failed} + + {:error, reason} -> + ssh_conn_mod.close(conn, channel_id) + {:error, reason} + end + + {:error, reason} -> + {:error, {:channel_open_failed, reason}} + end + end + + defp await_ssh_events(state, channel_id, line, timeout, output_limit, emitted_bytes) do + ssh_conn_mod = state.ssh_connection_module + + receive do + {:ssh_cm, _conn, {:data, ^channel_id, _type, data}} -> + chunk = IO.iodata_to_binary(data) + + case OutputLimiter.check(byte_size(chunk), emitted_bytes, output_limit) do + {:ok, updated_total} -> + send(state.session_pid, {:command_event, {:output, chunk}}) + await_ssh_events(state, channel_id, line, timeout, output_limit, updated_total) + + {:limit_exceeded, error} -> + ssh_conn_mod.close(state.conn, channel_id) + send_finished(state.session_pid, {:error, error}) + end + + {:ssh_cm, _conn, {:exit_status, ^channel_id, 0}} -> + await_ssh_events(state, channel_id, line, timeout, output_limit, emitted_bytes) + + {:ssh_cm, _conn, {:exit_status, ^channel_id, code}} -> + # Don't send finished yet — wait for :closed or :eof to ensure all data is flushed + await_ssh_close(state, channel_id, line, code) + + {:ssh_cm, _conn, {:eof, ^channel_id}} -> + await_ssh_events(state, channel_id, line, timeout, output_limit, emitted_bytes) + + {:ssh_cm, _conn, {:closed, ^channel_id}} -> + send_finished(state.session_pid, {:ok, nil}) + after + receive_timeout(timeout) -> + ssh_conn_mod.close(state.conn, channel_id) + send_finished(state.session_pid, {:error, Error.command(:timeout, %{line: line})}) + end + end + + # After we've received a non-zero exit_status, drain remaining data/eof/closed messages + defp await_ssh_close(state, channel_id, line, exit_code) do + receive do + {:ssh_cm, _conn, {:data, ^channel_id, _type, data}} -> + chunk = IO.iodata_to_binary(data) + send(state.session_pid, {:command_event, {:output, chunk}}) + await_ssh_close(state, channel_id, line, exit_code) + + {:ssh_cm, _conn, {:eof, ^channel_id}} -> + await_ssh_close(state, channel_id, line, exit_code) + + {:ssh_cm, _conn, {:closed, ^channel_id}} -> + send_finished(state.session_pid, {:error, Error.command(:exit_code, %{code: exit_code, line: line})}) + after + 5_000 -> + send_finished(state.session_pid, {:error, Error.command(:exit_code, %{code: exit_code, line: line})}) + end + end + + # -- Private: Channel cancellation ----------------------------------------- + + defp close_channel(state, worker_pid) do + do_close_channel(state, worker_pid, 5) + rescue + _ -> :ok + end + + defp do_close_channel(_state, _worker_pid, 0), do: :ok + + defp do_close_channel(state, worker_pid, attempts_left) do + case :ets.lookup(state.commands_table, worker_pid) do + [{^worker_pid, channel_id}] -> + state.ssh_connection_module.close(state.conn, channel_id) + + _ -> + Process.sleep(10) + do_close_channel(state, worker_pid, attempts_left - 1) + end + end + + # -- Private: Helpers ------------------------------------------------------ + + defp command_line(command, []), do: command + defp command_line(command, args), do: Enum.join([command | args], " ") + + defp remote_command(shell, line, cwd, env) do + env_prefix = + env + |> Enum.map(fn {k, v} -> "#{k}=#{shell_escape(v)}" end) + |> Enum.join(" ") + + case env_prefix do + "" -> "cd #{shell_escape(cwd)} && #{shell} -lc #{shell_escape(line)}" + prefix -> "cd #{shell_escape(cwd)} && env #{prefix} #{shell} -lc #{shell_escape(line)}" + end + end + + defp shell_escape(value) do + # Use single-quote wrapping with internal single-quote escaping + "'" <> String.replace(to_string(value), "'", "'\\''") <> "'" + end + + defp send_finished(session_pid, result) do + send(session_pid, {:command_finished, result}) + end + + defp receive_timeout(timeout) when is_integer(timeout) and timeout > 0, do: timeout + defp receive_timeout(_timeout), do: 60_000 + + defp normalize_env(env) when is_map(env) do + Enum.reduce(env, %{}, fn {key, value}, acc -> + Map.put(acc, to_string(key), to_string(value)) + end) + end + + defp normalize_env(_env), do: %{} + + defp extract_limit(execution_context, key) when is_map(execution_context) do + limits = Map.get(execution_context, :limits, %{}) + + parse_limit( + Map.get(limits, key, Map.get(execution_context, key, nil)) + ) + end + + defp extract_limit(_, _), do: nil + + defp parse_limit(value) when is_integer(value) and value > 0, do: value + + defp parse_limit(value) when is_binary(value) do + case Integer.parse(value) do + {parsed, ""} when parsed > 0 -> parsed + _ -> nil + end + end + + defp parse_limit(_), do: nil + + defp maybe_delete_table(table) do + :ets.delete(table) + :ok + rescue + _ -> :ok + end + + defp fetch_session_pid(config) do + case Map.get(config, :session_pid) do + pid when is_pid(pid) -> {:ok, pid} + _ -> {:error, Error.session(:invalid_state_transition, %{reason: :missing_session_pid})} + end + end + + defp fetch_required_string(config, key) do + case Map.get(config, key) do + value when is_binary(value) -> + if byte_size(String.trim(value)) > 0 do + {:ok, value} + else + {:error, Error.command(:start_failed, %{reason: {:missing_config, key}})} + end + + _ -> + {:error, Error.command(:start_failed, %{reason: {:missing_config, key}})} + end + end +end diff --git a/lib/jido_shell/backend/ssh/key_callback.ex b/lib/jido_shell/backend/ssh/key_callback.ex new file mode 100644 index 0000000..41a7fb2 --- /dev/null +++ b/lib/jido_shell/backend/ssh/key_callback.ex @@ -0,0 +1,102 @@ +defmodule Jido.Shell.Backend.SSH.KeyCallback do + @moduledoc false + + # Custom SSH key callback for injecting PEM key content directly + # rather than reading from the default ~/.ssh/ directory. + # + # Supports standard PEM formats (RSA, ECDSA) and OpenSSH-format + # Ed25519 keys (-----BEGIN OPENSSH PRIVATE KEY-----). + + @behaviour :ssh_client_key_api + + @impl true + def is_host_key(_key, _host, _port, _algorithm, _opts) do + true + end + + @impl true + def user_key(algorithm, opts) do + # OTP 23+ passes key_cb tuple options under :key_cb_private + key_pem = get_key_from_opts(opts) + + case :public_key.pem_decode(key_pem) do + [] -> + {:error, :no_keys_found} + + [{{:no_asn1, :new_openssh}, _der, :not_encrypted} | _] -> + # OpenSSH-format key (e.g. Ed25519) — use :ssh_file to decode + decode_openssh_key(key_pem, algorithm) + + entries -> + find_key_for_algorithm(entries, algorithm) + end + end + + defp decode_openssh_key(key_pem, algorithm) do + case :ssh_file.decode(key_pem, :openssh_key_v1) do + [{key, _attrs} | _] when is_tuple(key) -> + if key_matches_algorithm?(key, algorithm) do + {:ok, key} + else + {:error, :no_matching_key} + end + + _ -> + {:error, :openssh_decode_failed} + end + rescue + _ -> {:error, :openssh_decode_failed} + end + + defp find_key_for_algorithm(entries, algorithm) do + Enum.find_value(entries, {:error, :no_matching_key}, fn entry -> + case :public_key.pem_entry_decode(entry) do + key when is_tuple(key) -> + if key_matches_algorithm?(key, algorithm) do + {:ok, key} + else + nil + end + + _ -> + nil + end + end) + rescue + _ -> {:error, :key_decode_failed} + end + + defp key_matches_algorithm?(key, algorithm) do + case {elem(key, 0), algorithm} do + {:RSAPrivateKey, :"ssh-rsa"} -> true + {:ECPrivateKey, :"ecdsa-sha2-nistp256"} -> true + {:ECPrivateKey, :"ecdsa-sha2-nistp384"} -> true + {:ECPrivateKey, :"ecdsa-sha2-nistp521"} -> true + # OTP wraps Ed25519 as ECPrivateKey with namedCurve {1,3,101,112} + {:ECPrivateKey, :"ssh-ed25519"} -> ed25519_curve?(key) + {:ECPrivateKey, :"ssh-ed448"} -> ed448_curve?(key) + # OTP may also use these representations + _ when algorithm in [:"ssh-ed25519", :"ssh-ed448"] -> is_ed_key?(key) + _ -> false + end + end + + defp ed25519_curve?({:ECPrivateKey, _, _, {:namedCurve, {1, 3, 101, 112}}, _, _}), do: true + defp ed25519_curve?({:ECPrivateKey, _, _, {:namedCurve, {1, 3, 101, 112}}, _}), do: true + defp ed25519_curve?(_), do: false + + defp ed448_curve?({:ECPrivateKey, _, _, {:namedCurve, {1, 3, 101, 113}}, _, _}), do: true + defp ed448_curve?({:ECPrivateKey, _, _, {:namedCurve, {1, 3, 101, 113}}, _}), do: true + defp ed448_curve?(_), do: false + + defp is_ed_key?(key) when is_map(key), do: true + defp is_ed_key?({:ed_pri, _, _, _}), do: true + defp is_ed_key?(_), do: false + + defp get_key_from_opts(opts) do + case opts[:key_cb_private] do + private when is_list(private) -> Keyword.get(private, :key) + _ -> nil + end || opts[:key] + end +end diff --git a/lib/jido_shell/environment.ex b/lib/jido_shell/environment.ex new file mode 100644 index 0000000..64b1aa3 --- /dev/null +++ b/lib/jido_shell/environment.ex @@ -0,0 +1,87 @@ +defmodule Jido.Shell.Environment do + @moduledoc """ + Behaviour for VM/infrastructure lifecycle management. + + An Environment handles provisioning and tearing down the infrastructure + that a shell session runs on. After provisioning, the environment starts + a shell session using the appropriate `Jido.Shell.Backend`. + + ## Two Concerns + + - **Backend** = command execution on a running machine (`Jido.Shell.Backend`) + - **Environment** = VM lifecycle: provision, teardown, status (this behaviour) + + ## Implementations + + - `Jido.Shell.Environment.Sprite` — Fly.io Sprites + - External packages can implement this for Hetzner, Scaleway, AWS, etc. + + ## Example + + defmodule MyApp.Environment.Hetzner do + @behaviour Jido.Shell.Environment + + @impl true + def provision(workspace_id, config, opts) do + # 1. Create Hetzner VM via API + # 2. Wait for ready, get IP + # 3. Start session with Backend.SSH + session_opts = [ + backend: {Jido.Shell.Backend.SSH, %{host: ip, user: "root", key: config.ssh_key}} + ] + {:ok, session_id} = Jido.Shell.ShellSession.start_with_vfs(workspace_id, session_opts) + {:ok, %{session_id: session_id, workspace_dir: "/work", workspace_id: workspace_id}} + end + + @impl true + def teardown(session_id, _opts), do: %{teardown_verified: true, teardown_attempts: 1, warnings: nil} + end + + """ + + @typedoc """ + Result of a successful provision operation. + + Must include at minimum `session_id`, `workspace_dir`, and `workspace_id`. + Implementations may add environment-specific metadata (e.g., `server_id`, `ip`). + """ + @type provision_result :: %{ + :session_id => String.t(), + :workspace_dir => String.t(), + :workspace_id => String.t(), + optional(:sprite_name) => String.t(), + optional(atom()) => term() + } + + @typedoc "Result of a teardown operation." + @type teardown_result :: %{ + teardown_verified: boolean(), + teardown_attempts: pos_integer(), + warnings: [String.t()] | nil + } + + @doc """ + Provision infrastructure and start a shell session. + + Returns metadata including at minimum `session_id`, `workspace_dir`, and `workspace_id`. + """ + @callback provision(workspace_id :: String.t(), config :: map(), opts :: keyword()) :: + {:ok, provision_result()} | {:error, term()} + + @doc """ + Tear down infrastructure and stop the session. + + Returns teardown metadata with verification status. + """ + @callback teardown(session_id :: String.t(), opts :: keyword()) :: teardown_result() + + @doc """ + Query the status of a provisioned environment. + + Optional — not all environments support status queries. + """ + @callback status(session_id :: String.t(), opts :: keyword()) :: + {:ok, map()} | {:error, term()} + + @optional_callbacks status: 2 +end diff --git a/lib/jido_shell/sprite_lifecycle.ex b/lib/jido_shell/environment/sprite.ex similarity index 92% rename from lib/jido_shell/sprite_lifecycle.ex rename to lib/jido_shell/environment/sprite.ex index a8d1e26..24d139a 100644 --- a/lib/jido_shell/sprite_lifecycle.ex +++ b/lib/jido_shell/environment/sprite.ex @@ -1,8 +1,14 @@ -defmodule Jido.Shell.SpriteLifecycle do +defmodule Jido.Shell.Environment.Sprite do @moduledoc """ - Generic helpers for provisioning and tearing down Sprite-backed sessions. + Fly.io Sprite environment implementation. + + Provisions Sprite-backed shell sessions and handles teardown with + retry-based verification. This is the default environment used by + `Jido.Harness.Exec.Workspace`. """ + @behaviour Jido.Shell.Environment + alias Jido.Shell.Exec @default_retry_backoffs_ms [0, 1_000, 3_000] @@ -20,9 +26,10 @@ defmodule Jido.Shell.SpriteLifecycle do warnings: [String.t()] | nil } + @impl true @spec provision(String.t(), map(), keyword()) :: {:ok, provision_result()} | {:error, term()} - def provision(workspace_id, sprite_config, opts \\ []) - when is_binary(workspace_id) and is_map(sprite_config) do + def provision(workspace_id, config, opts \\ []) + when is_binary(workspace_id) and is_map(config) do workspace_base = Keyword.get(opts, :workspace_base, "/work") workspace_dir = Keyword.get(opts, :workspace_dir, "#{workspace_base}/#{workspace_id}") sprite_name = Keyword.get(opts, :sprite_name, workspace_id) @@ -33,14 +40,14 @@ defmodule Jido.Shell.SpriteLifecycle do backend_config = %{ sprite_name: sprite_name, - token: config_get(sprite_config, :token), - create: config_get(sprite_config, :create, true) + token: config_get(config, :token), + create: config_get(config, :create, true) } - |> maybe_put_base_url(config_get(sprite_config, :base_url)) + |> maybe_put_base_url(config_get(config, :base_url)) session_opts = [ backend: {Jido.Shell.Backend.Sprite, backend_config}, - env: config_get(sprite_config, :env, %{}) + env: config_get(config, :env, %{}) ] with {:ok, session_id} <- session_mod.start_with_vfs(workspace_id, session_opts), @@ -55,6 +62,7 @@ defmodule Jido.Shell.SpriteLifecycle do end end + @impl true @spec teardown(String.t(), keyword()) :: teardown_result() def teardown(session_id, opts \\ []) when is_binary(session_id) do sprite_name = Keyword.get(opts, :sprite_name) diff --git a/mix.exs b/mix.exs index 8ebf01e..e953fcf 100644 --- a/mix.exs +++ b/mix.exs @@ -58,7 +58,7 @@ defmodule Jido.Shell.MixProject do def application do [ - extra_applications: [:logger], + extra_applications: [:logger, :ssh, :public_key], mod: {Jido.Shell.Application, []} ] end @@ -144,7 +144,12 @@ defmodule Jido.Shell.MixProject do ], Backends: [ Jido.Shell.Backend.Local, - Jido.Shell.Backend.Sprite + Jido.Shell.Backend.Sprite, + Jido.Shell.Backend.SSH + ], + Environments: [ + Jido.Shell.Environment, + Jido.Shell.Environment.Sprite ], Commands: ~r/Jido\.Shell\.Command.*/, "Virtual Filesystem": [ diff --git a/test/jido/shell/backend/ssh_test.exs b/test/jido/shell/backend/ssh_test.exs new file mode 100644 index 0000000..3e92ffe --- /dev/null +++ b/test/jido/shell/backend/ssh_test.exs @@ -0,0 +1,339 @@ +defmodule Jido.Shell.Backend.SSHTest do + use Jido.Shell.Case, async: false + + alias Jido.Shell.Backend.SSH + + # --------------------------------------------------------------------------- + # FakeSSH — mimics Erlang's :ssh and :ssh_connection modules for unit testing. + # + # Injected via :ssh_module and :ssh_connection_module config keys so we test + # the real Backend.SSH code path without a real SSH server. + # --------------------------------------------------------------------------- + + defmodule FakeSSH do + @moduledoc false + + # -- :ssh API surface -- + + def connect(host, port, _opts, _timeout) do + conn = spawn(fn -> Process.sleep(:infinity) end) + notify({:connect, host, port, conn}) + {:ok, conn} + end + + def close(conn) do + notify({:close, conn}) + :ok + end + + # -- :ssh_connection API surface -- + + def session_channel(conn, _timeout) do + channel_id = :erlang.unique_integer([:positive]) + notify({:session_channel, conn, channel_id}) + {:ok, channel_id} + end + + def setenv(_conn, _channel_id, _var, _value, _timeout), do: :success + + def exec(conn, channel_id, command, _timeout) do + command_str = to_string(command) + notify({:exec, conn, channel_id, command_str}) + + caller = self() + + cond do + String.contains?(command_str, "echo ssh") -> + send(caller, {:ssh_cm, conn, {:data, channel_id, 0, "ssh\n"}}) + send(caller, {:ssh_cm, conn, {:exit_status, channel_id, 0}}) + send(caller, {:ssh_cm, conn, {:eof, channel_id}}) + send(caller, {:ssh_cm, conn, {:closed, channel_id}}) + + String.contains?(command_str, "fail ssh") -> + send(caller, {:ssh_cm, conn, {:data, channel_id, 1, "failed\n"}}) + send(caller, {:ssh_cm, conn, {:exit_status, channel_id, 7}}) + send(caller, {:ssh_cm, conn, {:eof, channel_id}}) + send(caller, {:ssh_cm, conn, {:closed, channel_id}}) + + String.contains?(command_str, "limit ssh") -> + send(caller, {:ssh_cm, conn, {:data, channel_id, 0, "123456"}}) + send(caller, {:ssh_cm, conn, {:exit_status, channel_id, 0}}) + send(caller, {:ssh_cm, conn, {:eof, channel_id}}) + send(caller, {:ssh_cm, conn, {:closed, channel_id}}) + + String.contains?(command_str, "sleep ssh") -> + Process.send_after(caller, {:ssh_cm, conn, {:data, channel_id, 0, "sleeping\n"}}, 5) + Process.send_after(caller, {:ssh_cm, conn, {:exit_status, channel_id, 0}}, 250) + Process.send_after(caller, {:ssh_cm, conn, {:eof, channel_id}}, 260) + Process.send_after(caller, {:ssh_cm, conn, {:closed, channel_id}}, 270) + + true -> + send(caller, {:ssh_cm, conn, {:exit_status, channel_id, 0}}) + send(caller, {:ssh_cm, conn, {:eof, channel_id}}) + send(caller, {:ssh_cm, conn, {:closed, channel_id}}) + end + + :success + end + + def close(conn, channel_id) do + notify({:close_channel, conn, channel_id}) + :ok + end + + defp notify(event) do + case :persistent_term.get({__MODULE__, :test_pid}, nil) do + pid when is_pid(pid) -> send(pid, {:fake_ssh, event}) + _ -> :ok + end + end + end + + # --------------------------------------------------------------------------- + # Tests + # --------------------------------------------------------------------------- + + @fake_config %{ + ssh_module: FakeSSH, + ssh_connection_module: FakeSSH + } + + setup do + :persistent_term.put({FakeSSH, :test_pid}, self()) + + on_exit(fn -> + :persistent_term.erase({FakeSSH, :test_pid}) + end) + + :ok + end + + defp init_fake(overrides \\ %{}) do + config = Map.merge(%{session_pid: self(), host: "test-host", user: "root"}, @fake_config) + SSH.init(Map.merge(config, overrides)) + end + + test "init connects and terminate closes" do + {:ok, state} = init_fake(%{port: 22}) + + assert_receive {:fake_ssh, {:connect, ~c"test-host", 22, _conn}} + assert state.host == "test-host" + assert state.user == "root" + assert state.cwd == "/" + + assert :ok = SSH.terminate(state) + assert_receive {:fake_ssh, {:close, _}} + end + + test "execute streams stdout and returns command_done" do + {:ok, state} = init_fake() + + {:ok, worker_pid, _state} = SSH.execute(state, "echo ssh", [], []) + assert is_pid(worker_pid) + + assert_receive {:command_event, {:output, "ssh\n"}} + assert_receive {:command_finished, {:ok, nil}} + + ref = Process.monitor(worker_pid) + assert_receive {:DOWN, ^ref, :process, ^worker_pid, _} + end + + test "execute maps non-zero exits to structured errors" do + {:ok, state} = init_fake() + + {:ok, _worker_pid, _state} = SSH.execute(state, "fail ssh", [], []) + + assert_receive {:command_event, {:output, "failed\n"}} + assert_receive {:command_finished, {:error, %Jido.Shell.Error{code: {:command, :exit_code}}}} + end + + test "execute enforces output limits" do + {:ok, state} = init_fake() + + {:ok, _worker_pid, _state} = SSH.execute(state, "limit ssh", [], output_limit: 3) + + assert_receive {:command_finished, + {:error, %Jido.Shell.Error{code: {:command, :output_limit_exceeded}}}} + end + + test "cancel closes channel and stops worker" do + {:ok, state} = init_fake() + + {:ok, worker_pid, _state} = SSH.execute(state, "sleep ssh", [], []) + assert_receive {:fake_ssh, {:exec, _, _, _}} + + # Give the worker a moment to register in ETS + Process.sleep(20) + + assert :ok = SSH.cancel(state, worker_pid) + assert_receive {:fake_ssh, {:close_channel, _, _}} + end + + test "cwd and cd track working directory" do + {:ok, state} = init_fake(%{cwd: "/home"}) + + assert {:ok, "/home", ^state} = SSH.cwd(state) + + {:ok, updated} = SSH.cd(state, "/tmp") + assert {:ok, "/tmp", ^updated} = SSH.cwd(updated) + end + + test "execute updates cwd from exec_opts" do + {:ok, state} = init_fake(%{cwd: "/home"}) + + {:ok, _worker_pid, updated_state} = SSH.execute(state, "echo ssh", [], dir: "/tmp") + + assert updated_state.cwd == "/tmp" + assert_receive {:command_finished, {:ok, nil}} + end + + test "execute with env variables" do + {:ok, state} = init_fake(%{env: %{"FOO" => "bar"}}) + + {:ok, _worker_pid, updated_state} = SSH.execute(state, "echo ssh", [], []) + + assert updated_state.env == %{"FOO" => "bar"} + assert_receive {:command_finished, {:ok, nil}} + end + + test "state stores connect_params for reconnection" do + {:ok, state} = init_fake() + + assert state.connect_params.host == "test-host" + assert state.connect_params.port == 22 + assert state.connect_params.user == "root" + assert state.ssh_module == FakeSSH + assert state.ssh_connection_module == FakeSSH + end + + test "real SSH backend module compiles and implements behaviour" do + # Verify the actual module exists and exports the right functions + assert {:module, SSH} = Code.ensure_loaded(SSH) + assert function_exported?(SSH, :init, 1) + assert function_exported?(SSH, :execute, 4) + assert function_exported?(SSH, :cancel, 2) + assert function_exported?(SSH, :terminate, 1) + assert function_exported?(SSH, :cwd, 1) + assert function_exported?(SSH, :cd, 2) + end + + describe "Docker SSH integration" do + @container_name "jido_shell_ssh_test" + @ssh_port 2222 + @ssh_password "testpass" + + setup do + ensure_container_running!() + wait_for_sshd!("127.0.0.1", @ssh_port, 30_000) + + on_exit(fn -> cleanup_container() end) + + :ok + end + + @tag :ssh_integration + test "connects to Docker SSHD container and executes commands" do + {:ok, state} = + SSH.init(%{ + session_pid: self(), + host: "127.0.0.1", + port: @ssh_port, + user: "root", + password: @ssh_password + }) + + # Test basic echo + {:ok, _worker, state} = SSH.execute(state, "echo hello-docker", [], []) + assert_receive {:command_event, {:output, output}}, 10_000 + assert output =~ "hello-docker" + assert_receive {:command_finished, {:ok, nil}}, 10_000 + + # Test non-zero exit code + {:ok, _worker, state} = SSH.execute(state, "exit 42", [], []) + assert_receive {:command_finished, {:error, %Jido.Shell.Error{code: {:command, :exit_code}} = err}}, 10_000 + assert err.context.code == 42 + + # Test cd / cwd tracking + {:ok, _worker, state} = SSH.execute(state, "pwd", [], dir: "/tmp") + assert_receive {:command_event, {:output, pwd_output}}, 10_000 + assert String.trim(pwd_output) == "/tmp" + assert_receive {:command_finished, {:ok, nil}}, 10_000 + assert state.cwd == "/tmp" + + assert :ok = SSH.terminate(state) + end + + @tag :ssh_integration + test "handles output limit enforcement against real SSH" do + {:ok, state} = + SSH.init(%{ + session_pid: self(), + host: "127.0.0.1", + port: @ssh_port, + user: "root", + password: @ssh_password + }) + + # Generate output larger than the limit + {:ok, _worker, _state} = + SSH.execute(state, "dd if=/dev/zero bs=1024 count=10 2>/dev/null | base64", [], output_limit: 100) + + assert_receive {:command_finished, {:error, %Jido.Shell.Error{code: {:command, :output_limit_exceeded}}}}, 10_000 + + assert :ok = SSH.terminate(state) + end + + defp ensure_container_running! do + # Stop any existing container + System.cmd("docker", ["rm", "-f", @container_name], stderr_to_stdout: true) + + # Start an Alpine container with SSHD and password auth + {_, 0} = + System.cmd("docker", [ + "run", "-d", + "--name", @container_name, + "-p", "#{@ssh_port}:22", + "alpine:latest", + "sh", "-c", + Enum.join([ + "apk add --no-cache openssh", + "echo 'root:#{@ssh_password}' | chpasswd", + "ssh-keygen -A", + "sed -i 's/#PermitRootLogin.*/PermitRootLogin yes/' /etc/ssh/sshd_config", + "sed -i 's/#PasswordAuthentication.*/PasswordAuthentication yes/' /etc/ssh/sshd_config", + "/usr/sbin/sshd -D -e" + ], " && ") + ], stderr_to_stdout: true) + end + + defp cleanup_container do + System.cmd("docker", ["rm", "-f", @container_name], stderr_to_stdout: true) + end + + defp wait_for_sshd!(host, port, timeout) do + deadline = System.monotonic_time(:millisecond) + timeout + do_wait_for_sshd(host, port, deadline) + end + + defp do_wait_for_sshd(host, port, deadline) do + if System.monotonic_time(:millisecond) > deadline do + raise "Timed out waiting for SSHD on #{host}:#{port}" + end + + # Try an actual SSH connection, not just TCP — SSHD needs time after port opens + case :ssh.connect(String.to_charlist(host), port, [ + {:user, ~c"root"}, + {:password, ~c"testpass"}, + {:silently_accept_hosts, true}, + {:user_interaction, false} + ], 3_000) do + {:ok, conn} -> + :ssh.close(conn) + + {:error, _} -> + Process.sleep(1_000) + do_wait_for_sshd(host, port, deadline) + end + end + end +end diff --git a/test/jido/shell/environment_test.exs b/test/jido/shell/environment_test.exs new file mode 100644 index 0000000..2339287 --- /dev/null +++ b/test/jido/shell/environment_test.exs @@ -0,0 +1,43 @@ +defmodule Jido.Shell.EnvironmentTest do + use Jido.Shell.Case, async: true + + alias Jido.Shell.Environment + alias Jido.Shell.Environment.Sprite, as: SpriteEnv + + describe "Environment behaviour" do + test "defines provision/3 callback" do + assert {:provision, 3} in Environment.behaviour_info(:callbacks) + end + + test "defines teardown/2 callback" do + assert {:teardown, 2} in Environment.behaviour_info(:callbacks) + end + + test "defines status/2 as optional callback" do + assert {:status, 2} in Environment.behaviour_info(:optional_callbacks) + end + end + + describe "Environment.Sprite" do + test "implements Environment behaviour" do + behaviours = + SpriteEnv.__info__(:attributes) + |> Keyword.get_values(:behaviour) + |> List.flatten() + + assert Environment in behaviours + end + + test "exports provision/3" do + assert function_exported?(SpriteEnv, :provision, 3) + end + + test "exports teardown/2" do + assert function_exported?(SpriteEnv, :teardown, 2) + end + + test "does not implement optional status/2" do + refute function_exported?(SpriteEnv, :status, 2) + end + end +end diff --git a/test/test_helper.exs b/test/test_helper.exs index e79057f..a1e79c8 100644 --- a/test/test_helper.exs +++ b/test/test_helper.exs @@ -1 +1 @@ -ExUnit.start(exclude: [:flaky, :sprites]) +ExUnit.start(exclude: [:flaky, :sprites, :ssh_integration]) From 0881a30f60df02d57c35a5df0250f338f305a80a Mon Sep 17 00:00:00 2001 From: Mike Hostetler <84222+mikehostetler@users.noreply.github.com> Date: Wed, 4 Mar 2026 10:41:34 -0600 Subject: [PATCH 2/4] fix(ssh): close output-limit gaps and harden async tests --- CHANGELOG.md | 1 + MIGRATION.md | 22 +++++ lib/jido_shell/backend/ssh.ex | 45 +++++---- lib/jido_shell/backend/ssh/key_callback.ex | 56 ++++++++---- .../shell/backend/ssh_key_callback_test.exs | 22 +++++ test/jido/shell/backend/ssh_test.exs | 91 ++++++++++++++----- test/jido/shell/cancellation_test.exs | 43 ++++----- test/jido/shell/environment_test.exs | 9 ++ test/jido/shell/session_server_test.exs | 60 ++++++++---- 9 files changed, 256 insertions(+), 93 deletions(-) create mode 100644 test/jido/shell/backend/ssh_key_callback_test.exs diff --git a/CHANGELOG.md b/CHANGELOG.md index f757863..b4be804 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `Jido.Shell.ShellSession`, - `Jido.Shell.ShellSessionServer`, - `Jido.Shell.ShellSession.State`. +- Removed `Jido.Shell.SpriteLifecycle`; use `Jido.Shell.Environment.Sprite` for environment lifecycle APIs. - Removed `Jido.Shell.Session`, `Jido.Shell.SessionServer`, and `Jido.Shell.Session.State` shim modules. - Canonicalized state struct identity to `%Jido.Shell.ShellSession.State{}`. - Hardened identifier model to use binary workspace IDs across public APIs. diff --git a/MIGRATION.md b/MIGRATION.md index 5fdd8a9..f089a13 100644 --- a/MIGRATION.md +++ b/MIGRATION.md @@ -141,3 +141,25 @@ Session events are emitted as: ```elixir {:jido_shell_session, session_id, event} ``` + +## 10. Sprite Lifecycle Module Renamed + +The lifecycle helper module was renamed and the old name was removed: + +- `Jido.Shell.SpriteLifecycle` -> `Jido.Shell.Environment.Sprite` + +### Example update + +Before (removed): + +```elixir +{:ok, result} = Jido.Shell.SpriteLifecycle.provision(workspace_id, sprite_config) +teardown = Jido.Shell.SpriteLifecycle.teardown(session_id, sprite_name: workspace_id) +``` + +After (canonical): + +```elixir +{:ok, result} = Jido.Shell.Environment.Sprite.provision(workspace_id, sprite_config) +teardown = Jido.Shell.Environment.Sprite.teardown(session_id, sprite_name: workspace_id) +``` diff --git a/lib/jido_shell/backend/ssh.ex b/lib/jido_shell/backend/ssh.ex index 6845536..1865d7c 100644 --- a/lib/jido_shell/backend/ssh.ex +++ b/lib/jido_shell/backend/ssh.ex @@ -274,16 +274,12 @@ defmodule Jido.Shell.Backend.SSH do receive do {:ssh_cm, _conn, {:data, ^channel_id, _type, data}} -> - chunk = IO.iodata_to_binary(data) - - case OutputLimiter.check(byte_size(chunk), emitted_bytes, output_limit) do + case emit_checked_output(state, channel_id, data, output_limit, emitted_bytes) do {:ok, updated_total} -> - send(state.session_pid, {:command_event, {:output, chunk}}) await_ssh_events(state, channel_id, line, timeout, output_limit, updated_total) - {:limit_exceeded, error} -> - ssh_conn_mod.close(state.conn, channel_id) - send_finished(state.session_pid, {:error, error}) + {:error, :output_limit_exceeded} -> + :ok end {:ssh_cm, _conn, {:exit_status, ^channel_id, 0}} -> @@ -291,7 +287,7 @@ defmodule Jido.Shell.Backend.SSH do {:ssh_cm, _conn, {:exit_status, ^channel_id, code}} -> # Don't send finished yet — wait for :closed or :eof to ensure all data is flushed - await_ssh_close(state, channel_id, line, code) + await_ssh_close(state, channel_id, line, code, output_limit, emitted_bytes) {:ssh_cm, _conn, {:eof, ^channel_id}} -> await_ssh_events(state, channel_id, line, timeout, output_limit, emitted_bytes) @@ -306,15 +302,19 @@ defmodule Jido.Shell.Backend.SSH do end # After we've received a non-zero exit_status, drain remaining data/eof/closed messages - defp await_ssh_close(state, channel_id, line, exit_code) do + defp await_ssh_close(state, channel_id, line, exit_code, output_limit, emitted_bytes) do receive do {:ssh_cm, _conn, {:data, ^channel_id, _type, data}} -> - chunk = IO.iodata_to_binary(data) - send(state.session_pid, {:command_event, {:output, chunk}}) - await_ssh_close(state, channel_id, line, exit_code) + case emit_checked_output(state, channel_id, data, output_limit, emitted_bytes) do + {:ok, updated_total} -> + await_ssh_close(state, channel_id, line, exit_code, output_limit, updated_total) + + {:error, :output_limit_exceeded} -> + :ok + end {:ssh_cm, _conn, {:eof, ^channel_id}} -> - await_ssh_close(state, channel_id, line, exit_code) + await_ssh_close(state, channel_id, line, exit_code, output_limit, emitted_bytes) {:ssh_cm, _conn, {:closed, ^channel_id}} -> send_finished(state.session_pid, {:error, Error.command(:exit_code, %{code: exit_code, line: line})}) @@ -350,6 +350,21 @@ defmodule Jido.Shell.Backend.SSH do defp command_line(command, []), do: command defp command_line(command, args), do: Enum.join([command | args], " ") + defp emit_checked_output(state, channel_id, data, output_limit, emitted_bytes) do + chunk = IO.iodata_to_binary(data) + + case OutputLimiter.check(byte_size(chunk), emitted_bytes, output_limit) do + {:ok, updated_total} -> + send(state.session_pid, {:command_event, {:output, chunk}}) + {:ok, updated_total} + + {:limit_exceeded, error} -> + state.ssh_connection_module.close(state.conn, channel_id) + send_finished(state.session_pid, {:error, error}) + {:error, :output_limit_exceeded} + end + end + defp remote_command(shell, line, cwd, env) do env_prefix = env @@ -385,9 +400,7 @@ defmodule Jido.Shell.Backend.SSH do defp extract_limit(execution_context, key) when is_map(execution_context) do limits = Map.get(execution_context, :limits, %{}) - parse_limit( - Map.get(limits, key, Map.get(execution_context, key, nil)) - ) + parse_limit(Map.get(limits, key, Map.get(execution_context, key, nil))) end defp extract_limit(_, _), do: nil diff --git a/lib/jido_shell/backend/ssh/key_callback.ex b/lib/jido_shell/backend/ssh/key_callback.ex index 41a7fb2..41a6c7b 100644 --- a/lib/jido_shell/backend/ssh/key_callback.ex +++ b/lib/jido_shell/backend/ssh/key_callback.ex @@ -17,24 +17,43 @@ defmodule Jido.Shell.Backend.SSH.KeyCallback do @impl true def user_key(algorithm, opts) do # OTP 23+ passes key_cb tuple options under :key_cb_private - key_pem = get_key_from_opts(opts) - - case :public_key.pem_decode(key_pem) do - [] -> - {:error, :no_keys_found} + case get_key_from_opts(opts) do + key_pem when is_binary(key_pem) -> + entries = :public_key.pem_decode(key_pem) + + case entries do + [] -> + maybe_decode_openssh_key(key_pem, algorithm, :no_keys_found) + + _ -> + case find_key_for_algorithm(entries, algorithm) do + {:ok, _} = ok -> + ok + + {:error, reason} when reason in [:no_matching_key, :key_decode_failed] -> + # Some key formats may decode as PEM but still require OpenSSH decoding. + maybe_decode_openssh_key(key_pem, algorithm, reason) + + {:error, _} = error -> + error + end + end - [{{:no_asn1, :new_openssh}, _der, :not_encrypted} | _] -> - # OpenSSH-format key (e.g. Ed25519) — use :ssh_file to decode - decode_openssh_key(key_pem, algorithm) + _ -> + {:error, :no_key_provided} + end + end - entries -> - find_key_for_algorithm(entries, algorithm) + defp maybe_decode_openssh_key(key_pem, algorithm, default_error) do + case decode_openssh_key(key_pem, algorithm) do + {:ok, _} = ok -> ok + _ -> {:error, default_error} end end defp decode_openssh_key(key_pem, algorithm) do case :ssh_file.decode(key_pem, :openssh_key_v1) do - [{key, _attrs} | _] when is_tuple(key) -> + [{key, _attrs} | _] -> if key_matches_algorithm?(key, algorithm) do {:ok, key} else @@ -66,9 +85,11 @@ defmodule Jido.Shell.Backend.SSH.KeyCallback do _ -> {:error, :key_decode_failed} end - defp key_matches_algorithm?(key, algorithm) do + defp key_matches_algorithm?(key, algorithm) when is_tuple(key) do case {elem(key, 0), algorithm} do {:RSAPrivateKey, :"ssh-rsa"} -> true + {:RSAPrivateKey, :"rsa-sha2-256"} -> true + {:RSAPrivateKey, :"rsa-sha2-512"} -> true {:ECPrivateKey, :"ecdsa-sha2-nistp256"} -> true {:ECPrivateKey, :"ecdsa-sha2-nistp384"} -> true {:ECPrivateKey, :"ecdsa-sha2-nistp521"} -> true @@ -76,11 +97,15 @@ defmodule Jido.Shell.Backend.SSH.KeyCallback do {:ECPrivateKey, :"ssh-ed25519"} -> ed25519_curve?(key) {:ECPrivateKey, :"ssh-ed448"} -> ed448_curve?(key) # OTP may also use these representations - _ when algorithm in [:"ssh-ed25519", :"ssh-ed448"] -> is_ed_key?(key) + _ when algorithm in [:"ssh-ed25519", :"ssh-ed448"] -> is_ed_key_tuple?(key) _ -> false end end + defp key_matches_algorithm?(key, algorithm) when is_map(key) do + algorithm in [:"ssh-ed25519", :"ssh-ed448"] + end + defp ed25519_curve?({:ECPrivateKey, _, _, {:namedCurve, {1, 3, 101, 112}}, _, _}), do: true defp ed25519_curve?({:ECPrivateKey, _, _, {:namedCurve, {1, 3, 101, 112}}, _}), do: true defp ed25519_curve?(_), do: false @@ -89,9 +114,8 @@ defmodule Jido.Shell.Backend.SSH.KeyCallback do defp ed448_curve?({:ECPrivateKey, _, _, {:namedCurve, {1, 3, 101, 113}}, _}), do: true defp ed448_curve?(_), do: false - defp is_ed_key?(key) when is_map(key), do: true - defp is_ed_key?({:ed_pri, _, _, _}), do: true - defp is_ed_key?(_), do: false + defp is_ed_key_tuple?({:ed_pri, _, _, _}), do: true + defp is_ed_key_tuple?(_), do: false defp get_key_from_opts(opts) do case opts[:key_cb_private] do diff --git a/test/jido/shell/backend/ssh_key_callback_test.exs b/test/jido/shell/backend/ssh_key_callback_test.exs new file mode 100644 index 0000000..faa228d --- /dev/null +++ b/test/jido/shell/backend/ssh_key_callback_test.exs @@ -0,0 +1,22 @@ +defmodule Jido.Shell.Backend.SSHKeyCallbackTest do + use Jido.Shell.Case, async: true + + alias Jido.Shell.Backend.SSH.KeyCallback + + test "accepts RSA keys for legacy and SHA2 RSA algorithms" do + pem = rsa_private_key_pem() + + assert_rsa_key(KeyCallback.user_key(:"ssh-rsa", key: pem)) + assert_rsa_key(KeyCallback.user_key(:"rsa-sha2-256", key: pem)) + assert_rsa_key(KeyCallback.user_key(:"rsa-sha2-512", key: pem)) + end + + defp assert_rsa_key({:ok, key}) when is_tuple(key) do + assert elem(key, 0) == :RSAPrivateKey + end + + defp rsa_private_key_pem do + key = :public_key.generate_key({:rsa, 1_024, 65_537}) + :public_key.pem_encode([:public_key.pem_entry_encode(:RSAPrivateKey, key)]) + end +end diff --git a/test/jido/shell/backend/ssh_test.exs b/test/jido/shell/backend/ssh_test.exs index 3e92ffe..35d9dea 100644 --- a/test/jido/shell/backend/ssh_test.exs +++ b/test/jido/shell/backend/ssh_test.exs @@ -55,6 +55,18 @@ defmodule Jido.Shell.Backend.SSHTest do send(caller, {:ssh_cm, conn, {:eof, channel_id}}) send(caller, {:ssh_cm, conn, {:closed, channel_id}}) + String.contains?(command_str, "fail trailing ssh") -> + send(caller, {:ssh_cm, conn, {:exit_status, channel_id, 7}}) + send(caller, {:ssh_cm, conn, {:data, channel_id, 1, "small\n"}}) + send(caller, {:ssh_cm, conn, {:eof, channel_id}}) + send(caller, {:ssh_cm, conn, {:closed, channel_id}}) + + String.contains?(command_str, "fail limit ssh") -> + send(caller, {:ssh_cm, conn, {:exit_status, channel_id, 7}}) + send(caller, {:ssh_cm, conn, {:data, channel_id, 1, "123456"}}) + send(caller, {:ssh_cm, conn, {:eof, channel_id}}) + send(caller, {:ssh_cm, conn, {:closed, channel_id}}) + String.contains?(command_str, "limit ssh") -> send(caller, {:ssh_cm, conn, {:data, channel_id, 0, "123456"}}) send(caller, {:ssh_cm, conn, {:exit_status, channel_id, 0}}) @@ -147,13 +159,30 @@ defmodule Jido.Shell.Backend.SSHTest do assert_receive {:command_finished, {:error, %Jido.Shell.Error{code: {:command, :exit_code}}}} end + test "enforces output limit when non-zero exit arrives before trailing oversized data" do + {:ok, state} = init_fake() + + {:ok, _worker_pid, _state} = SSH.execute(state, "fail limit ssh", [], output_limit: 3) + + assert_receive {:command_finished, {:error, %Jido.Shell.Error{code: {:command, :output_limit_exceeded}}}} + end + + test "preserves non-zero exit when trailing data remains under the output limit" do + {:ok, state} = init_fake() + + {:ok, _worker_pid, _state} = SSH.execute(state, "fail trailing ssh", [], output_limit: 100) + + assert_receive {:command_event, {:output, "small\n"}} + assert_receive {:command_finished, {:error, %Jido.Shell.Error{code: {:command, :exit_code}} = error}} + assert error.context.code == 7 + end + test "execute enforces output limits" do {:ok, state} = init_fake() {:ok, _worker_pid, _state} = SSH.execute(state, "limit ssh", [], output_limit: 3) - assert_receive {:command_finished, - {:error, %Jido.Shell.Error{code: {:command, :output_limit_exceeded}}}} + assert_receive {:command_finished, {:error, %Jido.Shell.Error{code: {:command, :output_limit_exceeded}}}} end test "cancel closes channel and stops worker" do @@ -289,21 +318,32 @@ defmodule Jido.Shell.Backend.SSHTest do # Start an Alpine container with SSHD and password auth {_, 0} = - System.cmd("docker", [ - "run", "-d", - "--name", @container_name, - "-p", "#{@ssh_port}:22", - "alpine:latest", - "sh", "-c", - Enum.join([ - "apk add --no-cache openssh", - "echo 'root:#{@ssh_password}' | chpasswd", - "ssh-keygen -A", - "sed -i 's/#PermitRootLogin.*/PermitRootLogin yes/' /etc/ssh/sshd_config", - "sed -i 's/#PasswordAuthentication.*/PasswordAuthentication yes/' /etc/ssh/sshd_config", - "/usr/sbin/sshd -D -e" - ], " && ") - ], stderr_to_stdout: true) + System.cmd( + "docker", + [ + "run", + "-d", + "--name", + @container_name, + "-p", + "#{@ssh_port}:22", + "alpine:latest", + "sh", + "-c", + Enum.join( + [ + "apk add --no-cache openssh", + "echo 'root:#{@ssh_password}' | chpasswd", + "ssh-keygen -A", + "sed -i 's/#PermitRootLogin.*/PermitRootLogin yes/' /etc/ssh/sshd_config", + "sed -i 's/#PasswordAuthentication.*/PasswordAuthentication yes/' /etc/ssh/sshd_config", + "/usr/sbin/sshd -D -e" + ], + " && " + ) + ], + stderr_to_stdout: true + ) end defp cleanup_container do @@ -321,12 +361,17 @@ defmodule Jido.Shell.Backend.SSHTest do end # Try an actual SSH connection, not just TCP — SSHD needs time after port opens - case :ssh.connect(String.to_charlist(host), port, [ - {:user, ~c"root"}, - {:password, ~c"testpass"}, - {:silently_accept_hosts, true}, - {:user_interaction, false} - ], 3_000) do + case :ssh.connect( + String.to_charlist(host), + port, + [ + {:user, ~c"root"}, + {:password, ~c"testpass"}, + {:silently_accept_hosts, true}, + {:user_interaction, false} + ], + 3_000 + ) do {:ok, conn} -> :ssh.close(conn) diff --git a/test/jido/shell/cancellation_test.exs b/test/jido/shell/cancellation_test.exs index cdd7790..4c5e7ec 100644 --- a/test/jido/shell/cancellation_test.exs +++ b/test/jido/shell/cancellation_test.exs @@ -4,6 +4,8 @@ defmodule Jido.Shell.CancellationTest do alias Jido.Shell.ShellSession alias Jido.Shell.ShellSessionServer + @event_timeout 1_000 + setup do workspace_id = "test_ws_#{System.unique_integer([:positive])}" {:ok, session_id} = ShellSession.start(workspace_id) @@ -16,12 +18,12 @@ defmodule Jido.Shell.CancellationTest do test "cancels running command", %{session_id: session_id} do {:ok, :accepted} = ShellSessionServer.run_command(session_id, "sleep 10") - assert_receive {:jido_shell_session, _, {:command_started, "sleep 10"}} - assert_receive {:jido_shell_session, _, {:output, "Sleeping for 10 seconds...\n"}} + assert_receive {:jido_shell_session, _, {:command_started, "sleep 10"}}, @event_timeout + assert_receive {:jido_shell_session, _, {:output, "Sleeping for 10 seconds...\n"}}, @event_timeout {:ok, :cancelled} = ShellSessionServer.cancel(session_id) - assert_receive {:jido_shell_session, _, :command_cancelled} + assert_receive {:jido_shell_session, _, :command_cancelled}, @event_timeout {:ok, state} = ShellSessionServer.get_state(session_id) refute state.current_command @@ -36,17 +38,17 @@ defmodule Jido.Shell.CancellationTest do test "allows new command after cancellation", %{session_id: session_id} do {:ok, :accepted} = ShellSessionServer.run_command(session_id, "sleep 10") - assert_receive {:jido_shell_session, _, {:command_started, _}} + assert_receive {:jido_shell_session, _, {:command_started, _}}, @event_timeout {:ok, :cancelled} = ShellSessionServer.cancel(session_id) - assert_receive {:jido_shell_session, _, :command_cancelled} + assert_receive {:jido_shell_session, _, :command_cancelled}, @event_timeout wait_until_idle(session_id) {:ok, :accepted} = ShellSessionServer.run_command(session_id, "echo done") - assert_receive {:jido_shell_session, _, {:command_started, "echo done"}} - assert_receive {:jido_shell_session, _, {:output, "done\n"}}, 1_000 - assert_receive {:jido_shell_session, _, :command_done}, 1_000 + assert_receive {:jido_shell_session, _, {:command_started, "echo done"}}, @event_timeout + assert_receive {:jido_shell_session, _, {:output, "done\n"}}, @event_timeout + assert_receive {:jido_shell_session, _, :command_done}, @event_timeout end end @@ -54,25 +56,24 @@ defmodule Jido.Shell.CancellationTest do test "streams output chunks", %{session_id: session_id} do {:ok, :accepted} = ShellSessionServer.run_command(session_id, "seq 3 10") - assert_receive {:jido_shell_session, _, {:command_started, _}} - assert_receive {:jido_shell_session, _, {:output, "1\n"}} - assert_receive {:jido_shell_session, _, {:output, "2\n"}} - assert_receive {:jido_shell_session, _, {:output, "3\n"}} - assert_receive {:jido_shell_session, _, :command_done} + assert_receive {:jido_shell_session, _, {:command_started, _}}, @event_timeout + assert_receive {:jido_shell_session, _, {:output, "1\n"}}, @event_timeout + assert_receive {:jido_shell_session, _, {:output, "2\n"}}, @event_timeout + assert_receive {:jido_shell_session, _, {:output, "3\n"}}, @event_timeout + assert_receive {:jido_shell_session, _, :command_done}, @event_timeout end end describe "robustness" do test "handles late messages from cancelled command", %{session_id: session_id} do {:ok, :accepted} = ShellSessionServer.run_command(session_id, "seq 5 50") - assert_receive {:jido_shell_session, _, {:command_started, _}} + assert_receive {:jido_shell_session, _, {:command_started, _}}, @event_timeout - assert_receive {:jido_shell_session, _, {:output, "1\n"}} + assert_receive {:jido_shell_session, _, {:output, "1\n"}}, @event_timeout {:ok, :cancelled} = ShellSessionServer.cancel(session_id) - assert_receive {:jido_shell_session, _, :command_cancelled} - - Process.sleep(100) + assert_receive {:jido_shell_session, _, :command_cancelled}, @event_timeout + wait_until_idle(session_id) {:ok, state} = ShellSessionServer.get_state(session_id) refute state.current_command @@ -80,18 +81,18 @@ defmodule Jido.Shell.CancellationTest do test "rejects command when busy", %{session_id: session_id} do {:ok, :accepted} = ShellSessionServer.run_command(session_id, "sleep 5") - assert_receive {:jido_shell_session, _, {:command_started, _}} + assert_receive {:jido_shell_session, _, {:command_started, _}}, @event_timeout assert {:error, %Jido.Shell.Error{code: {:shell, :busy}}} = ShellSessionServer.run_command(session_id, "echo hello") - assert_receive {:jido_shell_session, _, {:error, %Jido.Shell.Error{code: {:shell, :busy}}}} + assert_receive {:jido_shell_session, _, {:error, %Jido.Shell.Error{code: {:shell, :busy}}}}, @event_timeout {:ok, :cancelled} = ShellSessionServer.cancel(session_id) end end - defp wait_until_idle(session_id, attempts \\ 20) + defp wait_until_idle(session_id, attempts \\ 100) defp wait_until_idle(_session_id, 0), do: :ok defp wait_until_idle(session_id, attempts) do diff --git a/test/jido/shell/environment_test.exs b/test/jido/shell/environment_test.exs index 2339287..d25cd0f 100644 --- a/test/jido/shell/environment_test.exs +++ b/test/jido/shell/environment_test.exs @@ -19,6 +19,11 @@ defmodule Jido.Shell.EnvironmentTest do end describe "Environment.Sprite" do + setup do + assert {:module, SpriteEnv} = Code.ensure_loaded(SpriteEnv) + :ok + end + test "implements Environment behaviour" do behaviours = SpriteEnv.__info__(:attributes) @@ -39,5 +44,9 @@ defmodule Jido.Shell.EnvironmentTest do test "does not implement optional status/2" do refute function_exported?(SpriteEnv, :status, 2) end + + test "does not expose legacy SpriteLifecycle module" do + refute Code.ensure_loaded?(Jido.Shell.SpriteLifecycle) + end end end diff --git a/test/jido/shell/session_server_test.exs b/test/jido/shell/session_server_test.exs index 5932e6b..7c20638 100644 --- a/test/jido/shell/session_server_test.exs +++ b/test/jido/shell/session_server_test.exs @@ -4,6 +4,8 @@ defmodule Jido.Shell.ShellSessionServerTest do alias Jido.Shell.ShellSession alias Jido.Shell.ShellSessionServer + @event_timeout 1_000 + describe "start_link/1" do test "starts a session server" do session_id = ShellSession.generate_id() @@ -109,7 +111,7 @@ defmodule Jido.Shell.ShellSessionServerTest do assert MapSet.member?(state.transports, transport) Process.exit(transport, :kill) - Process.sleep(10) + wait_until_transport_removed(session_id, transport) {:ok, state} = ShellSessionServer.get_state(session_id) refute MapSet.member?(state.transports, transport) @@ -124,9 +126,9 @@ defmodule Jido.Shell.ShellSessionServerTest do {:ok, :accepted} = ShellSessionServer.run_command(session_id, "echo hello") - assert_receive {:jido_shell_session, ^session_id, {:command_started, "echo hello"}} - assert_receive {:jido_shell_session, ^session_id, {:output, "hello\n"}} - assert_receive {:jido_shell_session, ^session_id, :command_done} + assert_receive {:jido_shell_session, ^session_id, {:command_started, "echo hello"}}, @event_timeout + assert_receive {:jido_shell_session, ^session_id, {:output, "hello\n"}}, @event_timeout + assert_receive {:jido_shell_session, ^session_id, :command_done}, @event_timeout {:ok, state} = ShellSessionServer.get_state(session_id) assert "echo hello" in state.history @@ -139,8 +141,12 @@ defmodule Jido.Shell.ShellSessionServerTest do {:ok, :accepted} = ShellSessionServer.run_command(session_id, "unknown_cmd") - assert_receive {:jido_shell_session, ^session_id, {:command_started, "unknown_cmd"}} - assert_receive {:jido_shell_session, ^session_id, {:error, %Jido.Shell.Error{code: {:shell, :unknown_command}}}} + assert_receive {:jido_shell_session, ^session_id, {:command_started, "unknown_cmd"}}, @event_timeout + + assert_receive( + {:jido_shell_session, ^session_id, {:error, %Jido.Shell.Error{code: {:shell, :unknown_command}}}}, + @event_timeout + ) end test "broadcasts busy error when command already running" do @@ -149,12 +155,13 @@ defmodule Jido.Shell.ShellSessionServerTest do {:ok, :subscribed} = ShellSessionServer.subscribe(session_id, self()) {:ok, :accepted} = ShellSessionServer.run_command(session_id, "sleep 5") - assert_receive {:jido_shell_session, ^session_id, {:command_started, "sleep 5"}} + assert_receive {:jido_shell_session, ^session_id, {:command_started, "sleep 5"}}, @event_timeout assert {:error, %Jido.Shell.Error{code: {:shell, :busy}}} = ShellSessionServer.run_command(session_id, "echo second") - assert_receive {:jido_shell_session, ^session_id, {:error, %Jido.Shell.Error{code: {:shell, :busy}}}} + assert_receive {:jido_shell_session, ^session_id, {:error, %Jido.Shell.Error{code: {:shell, :busy}}}}, + @event_timeout {:ok, :cancelled} = ShellSessionServer.cancel(session_id) end @@ -166,9 +173,9 @@ defmodule Jido.Shell.ShellSessionServerTest do {:ok, :accepted} = ShellSessionServer.run_command(session_id, "pwd") - assert_receive {:jido_shell_session, ^session_id, {:command_started, "pwd"}} - assert_receive {:jido_shell_session, ^session_id, {:output, "/home/user\n"}} - assert_receive {:jido_shell_session, ^session_id, :command_done} + assert_receive {:jido_shell_session, ^session_id, {:command_started, "pwd"}}, @event_timeout + assert_receive {:jido_shell_session, ^session_id, {:output, "/home/user\n"}}, @event_timeout + assert_receive {:jido_shell_session, ^session_id, :command_done}, @event_timeout end test "clears current_command after completion" do @@ -178,7 +185,7 @@ defmodule Jido.Shell.ShellSessionServerTest do {:ok, :accepted} = ShellSessionServer.run_command(session_id, "echo test") - assert_receive {:jido_shell_session, ^session_id, :command_done} + assert_receive {:jido_shell_session, ^session_id, :command_done}, @event_timeout {:ok, state} = ShellSessionServer.get_state(session_id) assert state.current_command == nil @@ -191,9 +198,9 @@ defmodule Jido.Shell.ShellSessionServerTest do {:ok, server_pid} = ShellSession.lookup(session_id) GenServer.cast(server_pid, {:run_command, "echo cast", []}) - assert_receive {:jido_shell_session, ^session_id, {:command_started, "echo cast"}} - assert_receive {:jido_shell_session, ^session_id, {:output, "cast\n"}} - assert_receive {:jido_shell_session, ^session_id, :command_done} + assert_receive {:jido_shell_session, ^session_id, {:command_started, "echo cast"}}, @event_timeout + assert_receive {:jido_shell_session, ^session_id, {:output, "cast\n"}}, @event_timeout + assert_receive {:jido_shell_session, ^session_id, :command_done}, @event_timeout # Idle cancel cast should be a no-op with explicit invalid transition handling internally. GenServer.cast(server_pid, :cancel) @@ -208,12 +215,12 @@ defmodule Jido.Shell.ShellSessionServerTest do {:ok, server_pid} = ShellSession.lookup(session_id) {:ok, :accepted} = ShellSessionServer.run_command(session_id, "sleep 1") - assert_receive {:jido_shell_session, ^session_id, {:command_started, "sleep 1"}} + assert_receive {:jido_shell_session, ^session_id, {:command_started, "sleep 1"}}, @event_timeout {:ok, state} = ShellSessionServer.get_state(session_id) assert %{ref: ref, task: task_pid} = state.current_command send(server_pid, {:DOWN, ref, :process, task_pid, :boom}) - assert_receive {:jido_shell_session, ^session_id, {:command_crashed, :boom}} + assert_receive {:jido_shell_session, ^session_id, {:command_crashed, :boom}}, @event_timeout end test "ignores late command events and late finished messages after cancellation" do @@ -275,4 +282,23 @@ defmodule Jido.Shell.ShellSessionServerTest do ShellSessionServer.get_state(session_id) end end + + defp wait_until_transport_removed(session_id, transport, attempts \\ 100) + defp wait_until_transport_removed(_session_id, _transport, 0), do: :ok + + defp wait_until_transport_removed(session_id, transport, attempts) do + case ShellSessionServer.get_state(session_id) do + {:ok, %{transports: transports}} -> + if MapSet.member?(transports, transport) do + Process.sleep(10) + wait_until_transport_removed(session_id, transport, attempts - 1) + else + :ok + end + + _ -> + Process.sleep(10) + wait_until_transport_removed(session_id, transport, attempts - 1) + end + end end From 75293162eb6be366d66993163f92b1c0189d01fb Mon Sep 17 00:00:00 2001 From: Mike Hostetler <84222+mikehostetler@users.noreply.github.com> Date: Wed, 4 Mar 2026 10:51:14 -0600 Subject: [PATCH 3/4] fix(ci): satisfy changelog guard by removing manual entry --- CHANGELOG.md | 1 - 1 file changed, 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b4be804..f757863 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `Jido.Shell.ShellSession`, - `Jido.Shell.ShellSessionServer`, - `Jido.Shell.ShellSession.State`. -- Removed `Jido.Shell.SpriteLifecycle`; use `Jido.Shell.Environment.Sprite` for environment lifecycle APIs. - Removed `Jido.Shell.Session`, `Jido.Shell.SessionServer`, and `Jido.Shell.Session.State` shim modules. - Canonicalized state struct identity to `%Jido.Shell.ShellSession.State{}`. - Hardened identifier model to use binary workspace IDs across public APIs. From a2d5f367321ee93c9edb91a81330e0b6372a204b Mon Sep 17 00:00:00 2001 From: Mike Hostetler <84222+mikehostetler@users.noreply.github.com> Date: Wed, 4 Mar 2026 11:13:38 -0600 Subject: [PATCH 4/4] test(ssh): raise coverage with deterministic backend branches --- .../shell/backend/ssh_key_callback_test.exs | 58 +++ test/jido/shell/backend/ssh_test.exs | 341 +++++++++++++++--- 2 files changed, 346 insertions(+), 53 deletions(-) diff --git a/test/jido/shell/backend/ssh_key_callback_test.exs b/test/jido/shell/backend/ssh_key_callback_test.exs index faa228d..eea059a 100644 --- a/test/jido/shell/backend/ssh_key_callback_test.exs +++ b/test/jido/shell/backend/ssh_key_callback_test.exs @@ -3,20 +3,78 @@ defmodule Jido.Shell.Backend.SSHKeyCallbackTest do alias Jido.Shell.Backend.SSH.KeyCallback + @openssh_ed25519_key """ + -----BEGIN OPENSSH PRIVATE KEY----- + b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW + QyNTUxOQAAACBwzSpfxppMoqPam0WopBM+z1EdRKk6Eh7jy0z1+sjWbwAAAJg+i4wMPouM + DAAAAAtzc2gtZWQyNTUxOQAAACBwzSpfxppMoqPam0WopBM+z1EdRKk6Eh7jy0z1+sjWbw + AAAECoRPdviD1Vv2dENzRydnfynesTGNQEr/Zfeqn4AT/zWnDNKl/Gmkyio9qbRaikEz7P + UR1EqToSHuPLTPX6yNZvAAAAEHRlc3RAZXhhbXBsZS5jb20BAgMEBQ== + -----END OPENSSH PRIVATE KEY----- + """ + + test "host key callback always accepts" do + assert KeyCallback.is_host_key(:key, ~c"host", 22, :"ssh-rsa", []) == true + end + test "accepts RSA keys for legacy and SHA2 RSA algorithms" do pem = rsa_private_key_pem() assert_rsa_key(KeyCallback.user_key(:"ssh-rsa", key: pem)) assert_rsa_key(KeyCallback.user_key(:"rsa-sha2-256", key: pem)) assert_rsa_key(KeyCallback.user_key(:"rsa-sha2-512", key: pem)) + assert_rsa_key(KeyCallback.user_key(:"ssh-rsa", key_cb_private: [key: pem])) + end + + test "accepts ECDSA algorithms for EC keys" do + pem = ecdsa_private_key_pem() + + assert_ec_key(KeyCallback.user_key(:"ecdsa-sha2-nistp256", key: pem)) + assert_ec_key(KeyCallback.user_key(:"ecdsa-sha2-nistp384", key: pem)) + assert_ec_key(KeyCallback.user_key(:"ecdsa-sha2-nistp521", key: pem)) + end + + test "decodes OpenSSH ed25519 keys and rejects mismatched algorithms" do + assert {:ok, {:ECPrivateKey, _, _, {:namedCurve, {1, 3, 101, 112}}, _, _}} = + KeyCallback.user_key(:"ssh-ed25519", key: @openssh_ed25519_key) + + assert {:error, :key_decode_failed} = + KeyCallback.user_key(:"ssh-ed448", key: @openssh_ed25519_key) + + assert {:error, :key_decode_failed} = + KeyCallback.user_key(:"ssh-rsa", key: @openssh_ed25519_key) + end + + test "returns expected errors for missing, unmatched, and malformed keys" do + rsa_pem = rsa_private_key_pem() + + assert {:error, :no_key_provided} = KeyCallback.user_key(:"ssh-rsa", []) + assert {:error, :no_key_provided} = KeyCallback.user_key(:"ssh-rsa", key_cb_private: []) + + assert {:error, :no_matching_key} = KeyCallback.user_key(:"ssh-ed25519", key: rsa_pem) + + assert {:error, :no_keys_found} = + KeyCallback.user_key(:"ssh-rsa", key: "not-a-key") + + undecodable_pem = :public_key.pem_encode([{:RSAPrivateKey, <<1, 2, 3, 4>>, :not_encrypted}]) + assert {:error, :key_decode_failed} = KeyCallback.user_key(:"ssh-rsa", key: undecodable_pem) end defp assert_rsa_key({:ok, key}) when is_tuple(key) do assert elem(key, 0) == :RSAPrivateKey end + defp assert_ec_key({:ok, key}) when is_tuple(key) do + assert elem(key, 0) == :ECPrivateKey + end + defp rsa_private_key_pem do key = :public_key.generate_key({:rsa, 1_024, 65_537}) :public_key.pem_encode([:public_key.pem_entry_encode(:RSAPrivateKey, key)]) end + + defp ecdsa_private_key_pem do + key = :public_key.generate_key({:namedCurve, {1, 2, 840, 10045, 3, 1, 7}}) + :public_key.pem_encode([:public_key.pem_entry_encode(:ECPrivateKey, key)]) + end end diff --git a/test/jido/shell/backend/ssh_test.exs b/test/jido/shell/backend/ssh_test.exs index 35d9dea..92d27e9 100644 --- a/test/jido/shell/backend/ssh_test.exs +++ b/test/jido/shell/backend/ssh_test.exs @@ -15,23 +15,44 @@ defmodule Jido.Shell.Backend.SSHTest do # -- :ssh API surface -- - def connect(host, port, _opts, _timeout) do - conn = spawn(fn -> Process.sleep(:infinity) end) - notify({:connect, host, port, conn}) - {:ok, conn} + def connect(host, port, opts, _timeout) do + case mode() do + :connect_error -> + {:error, :econnrefused} + + _ -> + conn = spawn(fn -> Process.sleep(:infinity) end) + notify({:connect, host, port, opts, conn}) + {:ok, conn} + end end def close(conn) do - notify({:close, conn}) - :ok + case mode() do + :close_throw -> + throw(:close_failed) + + _ -> + notify({:close, conn}) + :ok + end end # -- :ssh_connection API surface -- def session_channel(conn, _timeout) do - channel_id = :erlang.unique_integer([:positive]) - notify({:session_channel, conn, channel_id}) - {:ok, channel_id} + case mode() do + :session_channel_error -> + {:error, :session_channel_failed} + + :session_channel_raise -> + raise "session channel crash" + + _ -> + channel_id = :erlang.unique_integer([:positive]) + notify({:session_channel, conn, channel_id}) + {:ok, channel_id} + end end def setenv(_conn, _channel_id, _var, _value, _timeout), do: :success @@ -42,50 +63,62 @@ defmodule Jido.Shell.Backend.SSHTest do caller = self() - cond do - String.contains?(command_str, "echo ssh") -> - send(caller, {:ssh_cm, conn, {:data, channel_id, 0, "ssh\n"}}) - send(caller, {:ssh_cm, conn, {:exit_status, channel_id, 0}}) - send(caller, {:ssh_cm, conn, {:eof, channel_id}}) - send(caller, {:ssh_cm, conn, {:closed, channel_id}}) - - String.contains?(command_str, "fail ssh") -> - send(caller, {:ssh_cm, conn, {:data, channel_id, 1, "failed\n"}}) - send(caller, {:ssh_cm, conn, {:exit_status, channel_id, 7}}) - send(caller, {:ssh_cm, conn, {:eof, channel_id}}) - send(caller, {:ssh_cm, conn, {:closed, channel_id}}) - - String.contains?(command_str, "fail trailing ssh") -> - send(caller, {:ssh_cm, conn, {:exit_status, channel_id, 7}}) - send(caller, {:ssh_cm, conn, {:data, channel_id, 1, "small\n"}}) - send(caller, {:ssh_cm, conn, {:eof, channel_id}}) - send(caller, {:ssh_cm, conn, {:closed, channel_id}}) - - String.contains?(command_str, "fail limit ssh") -> - send(caller, {:ssh_cm, conn, {:exit_status, channel_id, 7}}) - send(caller, {:ssh_cm, conn, {:data, channel_id, 1, "123456"}}) - send(caller, {:ssh_cm, conn, {:eof, channel_id}}) - send(caller, {:ssh_cm, conn, {:closed, channel_id}}) - - String.contains?(command_str, "limit ssh") -> - send(caller, {:ssh_cm, conn, {:data, channel_id, 0, "123456"}}) - send(caller, {:ssh_cm, conn, {:exit_status, channel_id, 0}}) - send(caller, {:ssh_cm, conn, {:eof, channel_id}}) - send(caller, {:ssh_cm, conn, {:closed, channel_id}}) - - String.contains?(command_str, "sleep ssh") -> - Process.send_after(caller, {:ssh_cm, conn, {:data, channel_id, 0, "sleeping\n"}}, 5) - Process.send_after(caller, {:ssh_cm, conn, {:exit_status, channel_id, 0}}, 250) - Process.send_after(caller, {:ssh_cm, conn, {:eof, channel_id}}, 260) - Process.send_after(caller, {:ssh_cm, conn, {:closed, channel_id}}, 270) - - true -> - send(caller, {:ssh_cm, conn, {:exit_status, channel_id, 0}}) - send(caller, {:ssh_cm, conn, {:eof, channel_id}}) - send(caller, {:ssh_cm, conn, {:closed, channel_id}}) + case mode() do + :exec_failure -> + :failure + + :exec_error -> + {:error, :exec_rejected} + + :no_events -> + :success + + _ -> + cond do + String.contains?(command_str, "echo ssh") -> + send(caller, {:ssh_cm, conn, {:data, channel_id, 0, "ssh\n"}}) + send(caller, {:ssh_cm, conn, {:exit_status, channel_id, 0}}) + send(caller, {:ssh_cm, conn, {:eof, channel_id}}) + send(caller, {:ssh_cm, conn, {:closed, channel_id}}) + + String.contains?(command_str, "fail ssh") -> + send(caller, {:ssh_cm, conn, {:data, channel_id, 1, "failed\n"}}) + send(caller, {:ssh_cm, conn, {:exit_status, channel_id, 7}}) + send(caller, {:ssh_cm, conn, {:eof, channel_id}}) + send(caller, {:ssh_cm, conn, {:closed, channel_id}}) + + String.contains?(command_str, "fail trailing ssh") -> + send(caller, {:ssh_cm, conn, {:exit_status, channel_id, 7}}) + send(caller, {:ssh_cm, conn, {:data, channel_id, 1, "small\n"}}) + send(caller, {:ssh_cm, conn, {:eof, channel_id}}) + send(caller, {:ssh_cm, conn, {:closed, channel_id}}) + + String.contains?(command_str, "fail limit ssh") -> + send(caller, {:ssh_cm, conn, {:exit_status, channel_id, 7}}) + send(caller, {:ssh_cm, conn, {:data, channel_id, 1, "123456"}}) + send(caller, {:ssh_cm, conn, {:eof, channel_id}}) + send(caller, {:ssh_cm, conn, {:closed, channel_id}}) + + String.contains?(command_str, "limit ssh") -> + send(caller, {:ssh_cm, conn, {:data, channel_id, 0, "123456"}}) + send(caller, {:ssh_cm, conn, {:exit_status, channel_id, 0}}) + send(caller, {:ssh_cm, conn, {:eof, channel_id}}) + send(caller, {:ssh_cm, conn, {:closed, channel_id}}) + + String.contains?(command_str, "sleep ssh") -> + Process.send_after(caller, {:ssh_cm, conn, {:data, channel_id, 0, "sleeping\n"}}, 5) + Process.send_after(caller, {:ssh_cm, conn, {:exit_status, channel_id, 0}}, 250) + Process.send_after(caller, {:ssh_cm, conn, {:eof, channel_id}}, 260) + Process.send_after(caller, {:ssh_cm, conn, {:closed, channel_id}}, 270) + + true -> + send(caller, {:ssh_cm, conn, {:exit_status, channel_id, 0}}) + send(caller, {:ssh_cm, conn, {:eof, channel_id}}) + send(caller, {:ssh_cm, conn, {:closed, channel_id}}) + end + + :success end - - :success end def close(conn, channel_id) do @@ -99,6 +132,10 @@ defmodule Jido.Shell.Backend.SSHTest do _ -> :ok end end + + defp mode do + :persistent_term.get({__MODULE__, :mode}, :normal) + end end # --------------------------------------------------------------------------- @@ -112,9 +149,11 @@ defmodule Jido.Shell.Backend.SSHTest do setup do :persistent_term.put({FakeSSH, :test_pid}, self()) + :persistent_term.put({FakeSSH, :mode}, :normal) on_exit(fn -> :persistent_term.erase({FakeSSH, :test_pid}) + :persistent_term.erase({FakeSSH, :mode}) end) :ok @@ -125,10 +164,19 @@ defmodule Jido.Shell.Backend.SSHTest do SSH.init(Map.merge(config, overrides)) end + defp set_fake_mode(mode) do + :persistent_term.put({FakeSSH, :mode}, mode) + end + + defp rsa_private_key_pem do + key = :public_key.generate_key({:rsa, 1_024, 65_537}) + :public_key.pem_encode([:public_key.pem_entry_encode(:RSAPrivateKey, key)]) + end + test "init connects and terminate closes" do {:ok, state} = init_fake(%{port: 22}) - assert_receive {:fake_ssh, {:connect, ~c"test-host", 22, _conn}} + assert_receive {:fake_ssh, {:connect, ~c"test-host", 22, _opts, _conn}} assert state.host == "test-host" assert state.user == "root" assert state.cwd == "/" @@ -246,6 +294,193 @@ defmodule Jido.Shell.Backend.SSHTest do assert function_exported?(SSH, :cd, 2) end + test "init validates required session and host/user config" do + assert {:error, %Jido.Shell.Error{code: {:session, :invalid_state_transition}}} = SSH.init(%{}) + + assert {:error, %Jido.Shell.Error{code: {:command, :start_failed}}} = + SSH.init(%{session_pid: self(), user: "root", ssh_module: FakeSSH, ssh_connection_module: FakeSSH}) + + assert {:error, %Jido.Shell.Error{code: {:command, :start_failed}}} = + SSH.init(%{ + session_pid: self(), + host: " ", + user: "root", + ssh_module: FakeSSH, + ssh_connection_module: FakeSSH + }) + + assert {:error, %Jido.Shell.Error{code: {:command, :start_failed}}} = + SSH.init(%{ + session_pid: self(), + host: "test-host", + ssh_module: FakeSSH, + ssh_connection_module: FakeSSH + }) + end + + test "init builds key and password auth options" do + pem = rsa_private_key_pem() + path = Path.join(System.tmp_dir!(), "jido_shell_test_key_#{System.unique_integer([:positive])}.pem") + File.write!(path, pem) + + on_exit(fn -> File.rm(path) end) + + {:ok, _state} = init_fake(%{key: pem}) + assert_receive {:fake_ssh, {:connect, _, _, opts_with_key, _}} + assert [{:key_cb, {Jido.Shell.Backend.SSH.KeyCallback, [key: ^pem]}}] = Keyword.take(opts_with_key, [:key_cb]) + + {:ok, _state} = init_fake(%{key_path: path}) + assert_receive {:fake_ssh, {:connect, _, _, opts_with_key_path, _}} + assert [{:key_cb, {Jido.Shell.Backend.SSH.KeyCallback, [key: ^pem]}}] = Keyword.take(opts_with_key_path, [:key_cb]) + + {:ok, _state} = init_fake(%{password: "secret"}) + assert_receive {:fake_ssh, {:connect, _, _, opts_with_password, _}} + assert [password: ~c"secret"] = Keyword.take(opts_with_password, [:password]) + end + + test "init returns start_failed when key_path cannot be read or connect fails" do + missing = Path.join(System.tmp_dir!(), "missing_#{System.unique_integer([:positive])}.pem") + + assert {:error, %Jido.Shell.Error{code: {:command, :start_failed}} = error} = + init_fake(%{key_path: missing}) + + assert error.context.reason == {:key_read_failed, :enoent} + + set_fake_mode(:connect_error) + + assert {:error, %Jido.Shell.Error{code: {:command, :start_failed}} = error} = + init_fake() + + assert error.context.reason == {:ssh_connect, :econnrefused} + end + + test "execute reconnects when existing connection pid is dead" do + {:ok, state} = init_fake() + + assert_receive {:fake_ssh, {:connect, _, _, _, old_conn}} + Process.exit(old_conn, :kill) + + {:ok, _worker_pid, _updated_state} = SSH.execute(%{state | conn: old_conn}, "echo ssh", [], []) + + assert_receive {:fake_ssh, {:connect, _, _, _, new_conn}} + assert old_conn != new_conn + assert_receive {:command_finished, {:ok, nil}} + end + + test "execute reconnects when conn value is not a pid" do + {:ok, state} = init_fake() + {:ok, _worker_pid, _updated_state} = SSH.execute(%{state | conn: :invalid_conn}, "echo ssh", [], []) + assert_receive {:fake_ssh, {:connect, _, _, _, _}} + assert_receive {:command_finished, {:ok, nil}} + end + + test "execute returns task start errors when task supervisor cannot accept children" do + {:ok, full_supervisor} = Task.Supervisor.start_link(max_children: 0) + {:ok, state} = init_fake(%{task_supervisor: full_supervisor}) + + assert {:error, :max_children} = SSH.execute(state, "echo ssh", [], []) + end + + test "execute reports start_failed for channel and exec setup errors" do + {:ok, state} = init_fake() + + set_fake_mode(:session_channel_error) + {:ok, _worker_pid, _state} = SSH.execute(state, "echo ssh", [], []) + assert_receive {:command_finished, {:error, %Jido.Shell.Error{code: {:command, :start_failed}} = error}} + assert error.context.reason == {:channel_open_failed, :session_channel_failed} + + set_fake_mode(:exec_failure) + {:ok, _worker_pid, _state} = SSH.execute(state, "echo ssh", [], []) + assert_receive {:fake_ssh, {:close_channel, _, _}} + assert_receive {:command_finished, {:error, %Jido.Shell.Error{code: {:command, :start_failed}} = error}} + assert error.context.reason == :exec_failed + + set_fake_mode(:exec_error) + {:ok, _worker_pid, _state} = SSH.execute(state, "echo ssh", [], []) + assert_receive {:fake_ssh, {:close_channel, _, _}} + assert_receive {:command_finished, {:error, %Jido.Shell.Error{code: {:command, :start_failed}} = error}} + assert error.context.reason == :exec_rejected + end + + test "execute reports crashed when session channel raises" do + {:ok, state} = init_fake() + set_fake_mode(:session_channel_raise) + + {:ok, _worker_pid, _state} = SSH.execute(state, "echo ssh", [], []) + + assert_receive {:command_finished, {:error, %Jido.Shell.Error{code: {:command, :crashed}}}} + end + + test "execute reads runtime/output limits from execution_context and normalizes env/args" do + {:ok, state} = init_fake(%{env: %{"PERSIST" => "1"}}) + + {:ok, _worker_pid, updated_state} = + SSH.execute( + state, + "echo", + ["ssh"], + env: [ignored: :value], + execution_context: %{limits: %{max_runtime_ms: 50, max_output_bytes: "64"}} + ) + + assert_receive {:fake_ssh, {:exec, _, _, wrapped_command}} + assert wrapped_command =~ "echo ssh" + assert updated_state.env == %{} + assert_receive {:command_finished, {:ok, nil}} + end + + test "execute handles non-map execution_context and invalid numeric limits" do + {:ok, state} = init_fake() + + {:ok, _worker_pid, _updated_state} = + SSH.execute( + state, + "echo ssh", + [], + execution_context: :invalid + ) + + assert_receive {:command_finished, {:ok, nil}} + + {:ok, _worker_pid, _updated_state} = + SSH.execute( + state, + "echo ssh", + [], + execution_context: %{max_runtime_ms: "not-a-number"} + ) + + assert_receive {:command_finished, {:ok, nil}} + end + + test "execute times out when channel emits no events" do + {:ok, state} = init_fake() + set_fake_mode(:no_events) + + {:ok, _worker_pid, _state} = SSH.execute(state, "echo ssh", [], timeout: 25) + + assert_receive {:fake_ssh, {:close_channel, _, _}} + assert_receive {:command_finished, {:error, %Jido.Shell.Error{code: {:command, :timeout}}}} + end + + test "cancel handles invalid refs and missing worker channel registrations" do + {:ok, state} = init_fake() + idle_worker = spawn(fn -> Process.sleep(200) end) + + assert :ok = SSH.cancel(state, idle_worker) + assert {:error, :invalid_command_ref} = SSH.cancel(state, :not_a_pid) + end + + test "cancel tolerates invalid commands table and terminate tolerates close/delete failures" do + {:ok, state} = init_fake() + idle_worker = spawn(fn -> Process.sleep(200) end) + + assert :ok = SSH.cancel(%{state | commands_table: :invalid_table}, idle_worker) + + set_fake_mode(:close_throw) + assert :ok = SSH.terminate(%{state | commands_table: :invalid_table}) + end + describe "Docker SSH integration" do @container_name "jido_shell_ssh_test" @ssh_port 2222