Skip to content

Commit

Permalink
Add SSH known host key check
Browse files Browse the repository at this point in the history
  • Loading branch information
hahnavi committed Dec 17, 2024
1 parent 3f47ba8 commit 2c0469e
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 32 deletions.
5 changes: 5 additions & 0 deletions src/defaults.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ pub fn defaults(lua: &Lua) -> mlua::Result<Table> {
defaults.set("ignore_exit_code", false)?;
defaults.set("elevate", false)?;
defaults.set("elevation_method", "sudo")?;
defaults.set(
"known_hosts_file",
format!("{}/.ssh/known_hosts", env!("HOME")),
)?;
defaults.set("host_key_check", true)?;

let env = lua.create_table()?;
env.set("DEBIAN_FRONTEND", "noninteractive")?;
Expand Down
38 changes: 35 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,43 @@ async fn komando(lua: Lua, (host, task): (Value, Value)) -> mlua::Result<Table>
as_user,
};

let mut ssh = SSHSession::connect(
(host.get::<String>("address")?.as_str(), port),
let known_hosts_file = match host.get::<String>("known_hosts_file") {
Ok(known_hosts_file) => Some(known_hosts_file),
Err(_) => match defaults.get::<String>("known_hosts_file") {
Ok(known_hosts_file) => Some(known_hosts_file),
Err(_) => None,
},
};

let host_key_check = match host.get::<Value>("host_key_check") {
Ok(host_key_check) => match host_key_check {
Value::Nil => match defaults.get::<Value>("host_key_check") {
Ok(host_key_check) => match host_key_check {
Value::Nil => true,
Value::Boolean(false) => false,
_ => true,
},
Err(_) => true,
},
Value::Boolean(false) => false,
_ => true,
},
Err(_) => true,
};

let mut ssh = SSHSession::new()?;

if host_key_check {
ssh.known_hosts_file = known_hosts_file;
}

ssh.elevation = elevation;

ssh.connect(
host.get::<String>("address")?.as_str(),
port,
&user,
ssh_auth_method,
elevation,
)?;

let env_defaults = defaults.get::<Table>("env").unwrap_or(lua.create_table()?);
Expand Down
14 changes: 10 additions & 4 deletions src/modules/lineinfile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use rand::{distributions::Alphanumeric, Rng};

pub fn lineinfile(lua: &Lua, params: Table) -> mlua::Result<Table> {
let path = match params.get::<String>("path") {
Ok(path) => path,
Ok(path) => path,
Err(_) => return Err(RuntimeError(String::from("'path' parameter is required"))),
};

Expand All @@ -14,13 +14,19 @@ Ok(path) => path,
Ok(state) => match state.as_str() {
"present" => state,
"absent" => state,
_ => return Err(RuntimeError(String::from("'state' parameter must be 'present' or 'absent'"))),
_ => {
return Err(RuntimeError(String::from(
"'state' parameter must be 'present' or 'absent'",
)))
}
},
Err(_) => String::from("present"),
};

if line.is_nil() && pattern.is_nil() {
return Err(RuntimeError(String::from("'line' or 'pattern' parameter is required")));
return Err(RuntimeError(String::from(
"'line' or 'pattern' parameter is required",
)));
}

let insert_after = params.get::<Value>("insert_after")?;
Expand Down Expand Up @@ -50,7 +56,7 @@ Ok(path) => path,
if $line ~= nil then
cmd = cmd .. " --line \"" .. $line .. "\""
end

if $pattern ~= nil then
cmd = cmd .. " --pattern \"" .. $pattern .. "\""
end
Expand Down
6 changes: 4 additions & 2 deletions src/modules/template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@ pub fn template(lua: &Lua, params: Table) -> mlua::Result<Table> {
Ok(s) => s,
Err(_) => return Err(RuntimeError(String::from("'src' parameter is required"))),
};

let dst = match params.get::<String>("dst") {
Ok(s) => s,
Err(_) => return Err(RuntimeError(String::from("'dst' parameter is required"))),
};

let vars = params.get::<Value>("vars")?;
if !vars.is_nil() && !vars.is_table() {
return Err(RuntimeError(String::from("'vars' parameter must be a table")));
return Err(RuntimeError(String::from(
"'vars' parameter must be a table",
)));
};

if !std::path::Path::new(&src).exists() {
Expand Down
80 changes: 57 additions & 23 deletions src/ssh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ use std::{
collections::HashMap,
fs,
io::{self, Read, Write},
net::{TcpStream, ToSocketAddrs},
net::TcpStream,
path::Path,
};

use anyhow::Result;
use mlua::UserData;
use ssh2::{Session, Sftp};
use ssh2::{CheckResult, KnownHostFileKind, Session, Sftp};

pub enum SSHAuthMethod {
Password(String),
Expand All @@ -31,35 +31,76 @@ pub enum ElevateMethod {

pub struct SSHSession {
session: Session,
pub known_hosts_file: Option<String>,
env: HashMap<String, String>,
elevation: Elevation,
pub elevation: Elevation,
stdout: Option<String>,
stderr: Option<String>,
exit_code: Option<i32>,
}

impl SSHSession {
pub fn connect<A: ToSocketAddrs>(
addr: A,
pub fn new() -> Result<Self> {
Ok(Self {
session: Session::new()?,
known_hosts_file: None,
env: HashMap::new(),
elevation: Elevation {
method: ElevateMethod::None,
as_user: None,
},
stdout: Some(String::new()),
stderr: Some(String::new()),
exit_code: Some(0),
})
}

pub fn connect(
&mut self,
address: &str,
port: u16,
username: &str,
auth_method: SSHAuthMethod,
elevation: Elevation,
) -> Result<Self> {
let tcp = TcpStream::connect(addr)?;
let mut session = Session::new()?;

session.set_tcp_stream(tcp);
session.handshake()?;
) -> Result<()> {
let tcp = TcpStream::connect((address, port))?;

self.session.set_tcp_stream(tcp);
self.session.handshake()?;

match &self.known_hosts_file {
Some(file) => {
let host_key = self.session.host_key().unwrap();
let mut known_hosts = self.session.known_hosts()?;
match known_hosts.read_file(Path::new(file.as_str()), KnownHostFileKind::OpenSSH) {
Ok(_) => {}
Err(_) => {
return Err(anyhow::Error::msg(
format!("SSH host key verification failed. Please add the host key to the known_hosts file: {}", file),
));
}
};

match known_hosts.check(&address, &host_key.0) {
CheckResult::Match => {}
_ => {
return Err(anyhow::Error::msg(
format!("SSH host key verification failed. Please add the host key to the known_hosts file: {}", file),
));
}
};
}
None => {}
}

match auth_method {
SSHAuthMethod::Password(password) => {
session.userauth_password(&username, &password)?;
self.session.userauth_password(&username, &password)?;
}
SSHAuthMethod::PublicKey {
private_key,
passphrase,
} => {
session.userauth_pubkey_file(
self.session.userauth_pubkey_file(
&username,
None,
Path::new(&private_key),
Expand All @@ -68,18 +109,11 @@ impl SSHSession {
}
}

if !session.authenticated() {
if !self.session.authenticated() {
return Err(anyhow::Error::msg("SSH authentication failed."));
}

Ok(Self {
session,
env: HashMap::new(),
elevation,
stdout: Some(String::new()),
stderr: Some(String::new()),
exit_code: Some(0),
})
Ok(())
}

pub fn set_env(&mut self, key: &str, value: &str) {
Expand Down

0 comments on commit 2c0469e

Please sign in to comment.