diff --git a/src/platform/aws.rs b/src/platform/aws.rs index f67b6ce..30ca698 100644 --- a/src/platform/aws.rs +++ b/src/platform/aws.rs @@ -1,9 +1,54 @@ -use crate::result::{Result, Context}; +use crate::result::{Context, Result}; pub struct Aws; +#[derive(Debug)] +pub struct InvalidHeartbeatResponse(u8); + +impl std::fmt::Display for InvalidHeartbeatResponse { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Invalid heartbeat response: expected 0x87, got 0x{}", + self.0 + ) + } +} + +impl std::error::Error for InvalidHeartbeatResponse {} + impl Aws { fn init_heartbeat() -> Result<()> { + use crate::system::syscall::{ + close, connect, read, sockaddr_vm, socket, write, SocketFamily, SocketType, + }; + + let mut sockaddr = sockaddr_vm { + svm_family: SocketFamily::Vsock as u16, + svm_reserved1: 0, + svm_port: 9000, + svm_cid: 3, + svm_zero: [0; 4], + }; + let mut buf = [0x87]; + + let fd = socket(SocketFamily::Vsock, SocketType::Stream)?; + unsafe { + connect( + fd, + std::ptr::from_mut(&mut sockaddr).cast(), + std::mem::size_of_val(&sockaddr), + )?; + }; + write(fd, &buf)?; + read(fd, &mut buf)?; + close(fd)?; + + if buf[0] != 0x87 { + return Err(InvalidHeartbeatResponse(buf[0])) + .context(format_args!("Bad value from heartbeat")); + } + Ok(()) } } @@ -14,9 +59,7 @@ impl super::Platform for Aws { } fn get_modules(&self) -> Result> { - Ok(vec![ - ("/nsm.ko".into(), String::new()) - ]) + Ok(vec![("/nsm.ko".into(), String::new())]) } fn init(&self) -> Result<()> { diff --git a/src/system/syscall.rs b/src/system/syscall.rs index 61dbe8b..496aa45 100644 --- a/src/system/syscall.rs +++ b/src/system/syscall.rs @@ -11,7 +11,7 @@ use crate::result::{ctx_os_error, Context, Result}; use libc::{self, c_ulong, c_void}; use std::{ ffi::CString, - os::fd::AsRawFd, + os::fd::{AsRawFd, RawFd}, path::Path, }; @@ -107,3 +107,61 @@ pub fn finit_module(fd: impl AsRawFd, params: impl AsRef) -> Result<()> { n => unreachable!("syscall(SYS_finit_module, ...) returned bad value: {n}"), } } + +#[repr(i32)] +#[derive(Debug, Clone, Copy)] +pub enum SocketFamily { + Vsock = libc::AF_VSOCK, +} + +#[repr(i32)] +#[derive(Debug, Clone, Copy)] +pub enum SocketType { + Stream = libc::SOCK_STREAM, +} + +// TODO: allow specifying protocol? +pub fn socket(family: SocketFamily, typ: SocketType) -> Result { + match unsafe { libc::socket(family as i32, typ as i32, 0) } { + -1 => ctx_os_error(format_args!("error calling socket({family:?}, {typ:?})")).map(|()| 0), + fd => Ok(RawFd::from(fd)), + } +} + +pub use libc::sockaddr_vm; + +// This function is unsafe since we have to pass it a C-style union. +pub unsafe fn connect(fd: RawFd, sockaddr: *mut libc::sockaddr, size: usize) -> Result<()> { + let size = u32::try_from(size).context(format_args!( + "connect(..., size = {size}) has size > {}", + u32::MAX + ))?; + + match unsafe { libc::connect(fd, sockaddr, size) } { + 0 => Ok(()), + -1 => ctx_os_error(format_args!("error calling connect({fd}, ...)")), + n => unreachable!("connect({fd}, ...) returned bad value: {n}"), + } +} + +pub fn write(fd: RawFd, bytes: &[u8]) -> Result { + match unsafe { libc::write(fd, bytes.as_ptr().cast(), bytes.len()) } { + ..0 => ctx_os_error(format_args!("error calling write({fd}, ...)")).map(|()| 0), + n @ 0.. => Ok(usize::try_from(n).unwrap()), + } +} + +pub fn read(fd: RawFd, buffer: &mut [u8]) -> Result { + match unsafe { libc::read(fd, buffer.as_mut_ptr().cast(), buffer.len()) } { + ..0 => ctx_os_error(format_args!("error calling read({fd}, ...)")).map(|()| 0), + n @ 0.. => Ok(usize::try_from(n).unwrap()), + } +} + +pub fn close(fd: RawFd) -> Result<()> { + match unsafe { libc::close(fd) } { + 0 => Ok(()), + -1 => ctx_os_error(format_args!("error calling close({fd})")), + n => unreachable!("close({fd}) returned bad value: {n}"), + } +}