...@@ -12,12 +12,17 @@ use std::{ ...@@ -12,12 +12,17 @@ use std::{
}; };
use libc::{ use libc::{
malloc,
c_char, c_char,
}; };
use crate::Error;
use crate::Result;
use crate::ffi::MM;
/// Copies a Rust string to a buffer, adding a terminating zero. /// Copies a Rust string to a buffer, adding a terminating zero.
pub fn rust_str_to_c_str<S: AsRef<str>>(s: S) -> *mut c_char { pub fn rust_str_to_c_str<S: AsRef<str>>(mm: MM, s: S) -> *mut c_char {
let malloc = mm.malloc;
let s = s.as_ref(); let s = s.as_ref();
let bytes = s.as_bytes(); let bytes = s.as_bytes();
unsafe { unsafe {
...@@ -31,7 +36,9 @@ pub fn rust_str_to_c_str<S: AsRef<str>>(s: S) -> *mut c_char { ...@@ -31,7 +36,9 @@ pub fn rust_str_to_c_str<S: AsRef<str>>(s: S) -> *mut c_char {
/// Copies a C string to a buffer, adding a terminating zero. /// Copies a C string to a buffer, adding a terminating zero.
/// ///
/// Replaces embedded zeros with '_'. /// Replaces embedded zeros with '_'.
pub fn rust_bytes_to_c_str_lossy<S: AsRef<[u8]>>(s: S) -> *mut c_char { pub fn rust_bytes_to_c_str_lossy<S: AsRef<[u8]>>(mm: MM, s: S) -> *mut c_char {
let malloc = mm.malloc;
let bytes = s.as_ref(); let bytes = s.as_ref();
unsafe { unsafe {
let buf = malloc(bytes.len() + 1); let buf = malloc(bytes.len() + 1);
...@@ -46,3 +53,17 @@ pub fn rust_bytes_to_c_str_lossy<S: AsRef<[u8]>>(s: S) -> *mut c_char { ...@@ -46,3 +53,17 @@ pub fn rust_bytes_to_c_str_lossy<S: AsRef<[u8]>>(s: S) -> *mut c_char {
buf as *mut c_char buf as *mut c_char
} }
} }
pub fn malloc_cleared<T>(mm: MM) -> Result<*mut T>
{
let malloc = mm.malloc;
let size = std::mem::size_of::<T>();
let buffer = unsafe { malloc(size) };
if buffer.is_null() {
return Err(Error::OutOfMemory("malloc".into(), size));
};
unsafe { libc::memset(buffer, 0, size) };
Ok(buffer as *mut T)
}
...@@ -9,6 +9,12 @@ pub type Free = unsafe extern "C" fn(*mut c_void); ...@@ -9,6 +9,12 @@ pub type Free = unsafe extern "C" fn(*mut c_void);
/// How to free the memory allocated by the callback. /// How to free the memory allocated by the callback.
pub type Malloc = unsafe extern "C" fn(size_t) -> *mut c_void; pub type Malloc = unsafe extern "C" fn(size_t) -> *mut c_void;
#[derive(Copy, Clone)]
pub struct MM {
pub malloc: Malloc,
pub free: Free,
}
// Wraps an ffi function. // Wraps an ffi function.
// //
// This wrapper allows the function to return a Result. The Ok // This wrapper allows the function to return a Result. The Ok
......
...@@ -105,7 +105,7 @@ use pep::{ ...@@ -105,7 +105,7 @@ use pep::{
Timestamp, Timestamp,
}; };
#[macro_use] mod ffi; #[macro_use] mod ffi;
use ffi::{Malloc, Free}; use ffi::MM;
mod keystore; mod keystore;
use keystore::Keystore; use keystore::Keystore;
...@@ -204,8 +204,8 @@ fn _pgp_get_decrypted_key_iter<'a, I>(iter: I, pass: Option<&Password>) ...@@ -204,8 +204,8 @@ fn _pgp_get_decrypted_key_iter<'a, I>(iter: I, pass: Option<&Password>)
// PEP_STATUS pgp_init(PEP_SESSION session, bool in_first) // PEP_STATUS pgp_init(PEP_SESSION session, bool in_first)
ffi!(fn pgp_init_(session: *mut Session, _in_first: bool, ffi!(fn pgp_init_(session: *mut Session, _in_first: bool,
per_user_directory: *const c_char, per_user_directory: *const c_char,
malloc: Malloc, malloc: ffi::Malloc,
free: Free, free: ffi::Free,
session_size: c_uint, session_size: c_uint,
session_cookie_offset: c_uint, session_cookie_offset: c_uint,
session_curr_passphrase_offset: c_uint, session_curr_passphrase_offset: c_uint,
...@@ -293,7 +293,7 @@ ffi!(fn pgp_init_(session: *mut Session, _in_first: bool, ...@@ -293,7 +293,7 @@ ffi!(fn pgp_init_(session: *mut Session, _in_first: bool,
}; };
let ks = keystore::Keystore::init(Path::new(per_user_directory))?; let ks = keystore::Keystore::init(Path::new(per_user_directory))?;
session.init(ks, malloc, free); session.init(MM { malloc, free }, ks);
Ok(()) Ok(())
}); });
...@@ -331,11 +331,13 @@ struct Helper<'a> { ...@@ -331,11 +331,13 @@ struct Helper<'a> {
impl<'a> Helper<'a> { impl<'a> Helper<'a> {
fn new(session: &'a mut Session) -> Self { fn new(session: &'a mut Session) -> Self {
let mm = session.mm();
Helper { Helper {
session: session, session: session,
secret_keys_called: false, secret_keys_called: false,
recipient_keylist: StringList::empty(), recipient_keylist: StringList::empty(mm),
signer_keylist: StringList::empty(), signer_keylist: StringList::empty(mm),
good_checksums: 0, good_checksums: 0,
malformed_signature: 0, malformed_signature: 0,
missing_keys: 0, missing_keys: 0,
...@@ -720,7 +722,8 @@ ffi!(fn pgp_decrypt_and_verify(session: *mut Session, ...@@ -720,7 +722,8 @@ ffi!(fn pgp_decrypt_and_verify(session: *mut Session,
-> Result<()> -> Result<()>
{ {
let session = Session::as_mut(session); let session = Session::as_mut(session);
let malloc = session.malloc(); let mm = session.mm();
let malloc = mm.malloc;
if ctext.is_null() { if ctext.is_null() {
return Err(Error::IllegalValue( return Err(Error::IllegalValue(
...@@ -796,13 +799,13 @@ ffi!(fn pgp_decrypt_and_verify(session: *mut Session, ...@@ -796,13 +799,13 @@ ffi!(fn pgp_decrypt_and_verify(session: *mut Session,
h.signer_keylist.append(&mut h.recipient_keylist); h.signer_keylist.append(&mut h.recipient_keylist);
unsafe { keylistp.as_mut() }.map(|p| { unsafe { keylistp.as_mut() }.map(|p| {
*p = mem::replace(&mut h.signer_keylist, StringList::empty()).to_c(); *p = mem::replace(&mut h.signer_keylist, StringList::empty(mm)).to_c();
}); });
if ! filename_ptr.is_null() { if ! filename_ptr.is_null() {
unsafe { filename_ptr.as_mut() }.map(|p| { unsafe { filename_ptr.as_mut() }.map(|p| {
if let Some(filename) = h.filename.as_ref() { if let Some(filename) = h.filename.as_ref() {
*p = rust_bytes_to_c_str_lossy(filename); *p = rust_bytes_to_c_str_lossy(mm, filename);
} else { } else {
*p = ptr::null_mut(); *p = ptr::null_mut();
} }
...@@ -846,6 +849,7 @@ ffi!(fn pgp_verify_text(session: *mut Session, ...@@ -846,6 +849,7 @@ ffi!(fn pgp_verify_text(session: *mut Session,
-> Result<()> -> Result<()>
{ {
let session = Session::as_mut(session); let session = Session::as_mut(session);
let mm = session.mm();
if size == 0 || sig_size == 0 { if size == 0 || sig_size == 0 {
return Err(Error::DecryptWrongFormat); return Err(Error::DecryptWrongFormat);
...@@ -921,7 +925,7 @@ ffi!(fn pgp_verify_text(session: *mut Session, ...@@ -921,7 +925,7 @@ ffi!(fn pgp_verify_text(session: *mut Session,
} }
h.signer_keylist.append(&mut h.recipient_keylist); h.signer_keylist.append(&mut h.recipient_keylist);
unsafe { keylistp.as_mut() }.map(|p| { unsafe { keylistp.as_mut() }.map(|p| {
*p = mem::replace(&mut h.signer_keylist, StringList::empty()).to_c(); *p = mem::replace(&mut h.signer_keylist, StringList::empty(mm)).to_c();
}); });
...@@ -963,7 +967,8 @@ ffi!(fn pgp_sign_only( ...@@ -963,7 +967,8 @@ ffi!(fn pgp_sign_only(
-> Result<()> -> Result<()>
{ {
let session = Session::as_mut(session); let session = Session::as_mut(session);
let malloc = session.malloc(); let mm = session.mm();
let malloc = mm.malloc;
if fpr.is_null() { if fpr.is_null() {
return Err(Error::IllegalValue( return Err(Error::IllegalValue(
...@@ -1069,7 +1074,8 @@ fn pgp_encrypt_sign_optional( ...@@ -1069,7 +1074,8 @@ fn pgp_encrypt_sign_optional(
tracer!(*crate::TRACE, "pgp_encrypt_sign_optional"); tracer!(*crate::TRACE, "pgp_encrypt_sign_optional");
let session = Session::as_mut(session); let session = Session::as_mut(session);
let malloc = session.malloc(); let mm = session.mm();
let malloc = mm.malloc;
if ptext.is_null() { if ptext.is_null() {
return Err(Error::IllegalValue( return Err(Error::IllegalValue(
...@@ -1088,7 +1094,7 @@ fn pgp_encrypt_sign_optional( ...@@ -1088,7 +1094,7 @@ fn pgp_encrypt_sign_optional(
let keystore = session.keystore(); let keystore = session.keystore();
let keylist = StringList::to_rust(keylist, false); let keylist = StringList::to_rust(mm, keylist, false);
t!("{} recipients.", keylist.len()); t!("{} recipients.", keylist.len());
for (i, v) in keylist.iter().enumerate() { for (i, v) in keylist.iter().enumerate() {
t!(" {}. {}", i, String::from_utf8_lossy(v.to_bytes())); t!(" {}. {}", i, String::from_utf8_lossy(v.to_bytes()));
...@@ -1246,6 +1252,8 @@ ffi!(fn _pgp_generate_keypair(session: *mut Session, ...@@ -1246,6 +1252,8 @@ ffi!(fn _pgp_generate_keypair(session: *mut Session,
-> Result<()> -> Result<()>
{ {
let session = Session::as_mut(session); let session = Session::as_mut(session);
let mm = session.mm();
let identity = if let Some(i) = PepIdentity::as_mut(identity) { let identity = if let Some(i) = PepIdentity::as_mut(identity) {
i i
} else { } else {
...@@ -1373,7 +1381,7 @@ ffi!(fn _pgp_generate_keypair(session: *mut Session, ...@@ -1373,7 +1381,7 @@ ffi!(fn _pgp_generate_keypair(session: *mut Session,
CannotCreateKey, CannotCreateKey,
"Saving new key")?; "Saving new key")?;
identity.set_fingerprint(fpr); identity.set_fingerprint(mm, fpr);
Ok(()) Ok(())
}); });
...@@ -1584,6 +1592,7 @@ ffi!(fn pgp_import_keydata(session: *mut Session, ...@@ -1584,6 +1592,7 @@ ffi!(fn pgp_import_keydata(session: *mut Session,
-> Result<()> -> Result<()>
{ {
let session = Session::as_mut(session); let session = Session::as_mut(session);
let mm = session.mm();
if imported_keysp.is_null() && ! changed_key_indexp.is_null() { if imported_keysp.is_null() && ! changed_key_indexp.is_null() {
return Err(Error::IllegalValue( return Err(Error::IllegalValue(
...@@ -1601,11 +1610,11 @@ ffi!(fn pgp_import_keydata(session: *mut Session, ...@@ -1601,11 +1610,11 @@ ffi!(fn pgp_import_keydata(session: *mut Session,
}; };
// We add(!) to the existing lists. // We add(!) to the existing lists.
let mut identity_list = unsafe { identity_listp.as_mut() } let mut identity_list = unsafe { identity_listp.as_mut() }
.map(|p| PepIdentityList::to_rust(*p, false)) .map(|p| PepIdentityList::to_rust(mm, *p, false))
.unwrap_or_else(|| PepIdentityList::empty()); .unwrap_or_else(|| PepIdentityList::empty(mm));
let mut imported_keys = unsafe { imported_keysp.as_mut() } let mut imported_keys = unsafe { imported_keysp.as_mut() }
.map(|p| StringList::to_rust(*p, false)) .map(|p| StringList::to_rust(mm, *p, false))
.unwrap_or_else(|| StringList::empty()); .unwrap_or_else(|| StringList::empty(mm));
let mut changed_key_index: u64 = unsafe { changed_key_indexp.as_mut() } let mut changed_key_index: u64 = unsafe { changed_key_indexp.as_mut() }
.map(|p| *p) .map(|p| *p)
.unwrap_or(0); .unwrap_or(0);
...@@ -1701,7 +1710,8 @@ ffi!(fn pgp_export_keydata(session: *mut Session, ...@@ -1701,7 +1710,8 @@ ffi!(fn pgp_export_keydata(session: *mut Session,
-> Result<()> -> Result<()>
{ {
let session = Session::as_mut(session); let session = Session::as_mut(session);
let malloc = session.malloc(); let mm = session.mm();
let malloc = mm.malloc;
if fpr.is_null() { if fpr.is_null() {
return Err(Error::IllegalValue("fpr must not be NULL".into())); return Err(Error::IllegalValue("fpr must not be NULL".into()));
...@@ -1778,6 +1788,7 @@ fn list_keys(session: *mut Session, ...@@ -1778,6 +1788,7 @@ fn list_keys(session: *mut Session,
tracer!(*crate::TRACE, "list_keys"); tracer!(*crate::TRACE, "list_keys");
let session = Session::as_mut(session); let session = Session::as_mut(session);
let mm = session.mm();
if pattern.is_null() { if pattern.is_null() {
return Err(Error::IllegalValue( return Err(Error::IllegalValue(
...@@ -1787,7 +1798,7 @@ fn list_keys(session: *mut Session, ...@@ -1787,7 +1798,7 @@ fn list_keys(session: *mut Session,
// XXX: What should we do if pattern is not valid UTF-8? // XXX: What should we do if pattern is not valid UTF-8?
let pattern = pattern.to_string_lossy(); let pattern = pattern.to_string_lossy();
let mut keylist = StringList::empty(); let mut keylist = StringList::empty(mm);
match session.keystore().list_keys(&pattern, private_only) { match session.keystore().list_keys(&pattern, private_only) {
Err(Error::KeyNotFound(_)) => { Err(Error::KeyNotFound(_)) => {
......
...@@ -31,13 +31,14 @@ use std::os::raw::{ ...@@ -31,13 +31,14 @@ use std::os::raw::{
}; };
use std::ptr; use std::ptr;
use libc::calloc;
use libc::free;
use sequoia_openpgp as openpgp; use sequoia_openpgp as openpgp;
use openpgp::Fingerprint; use openpgp::Fingerprint;
use crate::buffer::rust_str_to_c_str; use crate::buffer::{
malloc_cleared,
rust_str_to_c_str,
};
use crate::ffi::MM;
use crate::pep::{ use crate::pep::{
PepCommType, PepCommType,
PepEncFormat, PepEncFormat,
...@@ -94,19 +95,19 @@ impl PepIdentity { ...@@ -94,19 +95,19 @@ impl PepIdentity {
/// ///
/// The memory is allocated using the libc allocator. The caller /// The memory is allocated using the libc allocator. The caller
/// is responsible for freeing it explicitly. /// is responsible for freeing it explicitly.
pub fn new(template: &PepIdentityTemplate) pub fn new(mm: MM, template: &PepIdentityTemplate)
-> &'static mut Self -> &'static mut Self
{ {
let buffer = unsafe { calloc(1, std::mem::size_of::<Self>()) }; let buffer = if let Ok(buffer) = malloc_cleared::<Self>(mm) {
if buffer.is_null() { buffer
} else {
panic!("Out of memory allocating a PepIdentity"); panic!("Out of memory allocating a PepIdentity");
} };
let ident = unsafe { &mut *(buffer as *mut Self) }; let ident = unsafe { &mut *(buffer as *mut Self) };
ident.address = rust_str_to_c_str(&template.address); ident.address = rust_str_to_c_str(mm, &template.address);
ident.fpr = rust_str_to_c_str(&template.fpr.to_hex()); ident.fpr = rust_str_to_c_str(mm, &template.fpr.to_hex());
if let Some(username) = template.username.as_ref() { if let Some(username) = template.username.as_ref() {
ident.username = rust_str_to_c_str(username); ident.username = rust_str_to_c_str(mm, username);
} }
ident ident
} }
...@@ -137,9 +138,9 @@ impl PepIdentity { ...@@ -137,9 +138,9 @@ impl PepIdentity {
} }
/// Replaces the fingerprint. /// Replaces the fingerprint.
pub fn set_fingerprint(&mut self, fpr: Fingerprint) { pub fn set_fingerprint(&mut self, mm: MM, fpr: Fingerprint) {
unsafe { libc::free(self.fpr as *mut _) }; unsafe { libc::free(self.fpr as *mut _) };
self.fpr = rust_str_to_c_str(fpr.to_hex()); self.fpr = rust_str_to_c_str(mm, fpr.to_hex());
} }
/// Returns the username (in RFC 2822 speak: the display name). /// Returns the username (in RFC 2822 speak: the display name).
...@@ -188,12 +189,13 @@ impl PepIdentityListItem { ...@@ -188,12 +189,13 @@ impl PepIdentityListItem {
/// ///
/// The memory is allocated using the libc allocator. The caller /// The memory is allocated using the libc allocator. The caller
/// is responsible for freeing it explicitly. /// is responsible for freeing it explicitly.
fn new(ident: &'static mut PepIdentity) -> &'static mut Self fn new(mm: MM, ident: &'static mut PepIdentity) -> &'static mut Self
{ {
let buffer = unsafe { calloc(1, std::mem::size_of::<Self>()) }; let buffer = if let Ok(buffer) = malloc_cleared::<Self>(mm) {
if buffer.is_null() { buffer
} else {
panic!("Out of memory allocating a PepIdentityListItem"); panic!("Out of memory allocating a PepIdentityListItem");
} };
let item = unsafe { &mut *(buffer as *mut Self) }; let item = unsafe { &mut *(buffer as *mut Self) };
item.ident = ident as *mut _; item.ident = ident as *mut _;
item item
...@@ -211,6 +213,7 @@ impl PepIdentityListItem { ...@@ -211,6 +213,7 @@ impl PepIdentityListItem {
pub struct PepIdentityList { pub struct PepIdentityList {
head: *mut PepIdentityListItem, head: *mut PepIdentityListItem,
owned: bool, owned: bool,
mm: MM,
} }
impl PepIdentityList { impl PepIdentityList {
...@@ -219,10 +222,12 @@ impl PepIdentityList { ...@@ -219,10 +222,12 @@ impl PepIdentityList {
/// `owned` indicates whether the rust code should own the items. /// `owned` indicates whether the rust code should own the items.
/// If so, when the `PepIdentityList` is dropped, the items will /// If so, when the `PepIdentityList` is dropped, the items will
/// also be freed. /// also be freed.
pub fn to_rust(l: *mut PepIdentityListItem, owned: bool) -> Self { pub fn to_rust(mm: MM, l: *mut PepIdentityListItem, owned: bool) -> Self
{
Self { Self {
head: l, head: l,
owned, owned,
mm,
} }
} }
...@@ -239,10 +244,11 @@ impl PepIdentityList { ...@@ -239,10 +244,11 @@ impl PepIdentityList {
/// Any added items are owned by the `PepIdentityList`, and when /// Any added items are owned by the `PepIdentityList`, and when
/// it is dropped, they are freed. To take ownership of the /// it is dropped, they are freed. To take ownership of the
/// items, call `PepIdentityList::to_c`. /// items, call `PepIdentityList::to_c`.
pub fn empty() -> Self { pub fn empty(mm: MM) -> Self {
Self { Self {
head: ptr::null_mut(), head: ptr::null_mut(),
owned: true, owned: true,
mm,
} }
} }
...@@ -251,7 +257,8 @@ impl PepIdentityList { ...@@ -251,7 +257,8 @@ impl PepIdentityList {
/// The item's ownership is determined by the list's ownership /// The item's ownership is determined by the list's ownership
/// property. /// property.
pub fn add(&mut self, ident: &PepIdentityTemplate) { pub fn add(&mut self, ident: &PepIdentityTemplate) {
let ident = PepIdentityListItem::new(PepIdentity::new(ident)); let ident = PepIdentityListItem::new(
self.mm, PepIdentity::new(self.mm, ident));
ident.next = self.head; ident.next = self.head;
self.head = ident; self.head = ident;
} }
...@@ -259,6 +266,8 @@ impl PepIdentityList { ...@@ -259,6 +266,8 @@ impl PepIdentityList {
impl Drop for PepIdentityList { impl Drop for PepIdentityList {
fn drop(&mut self) { fn drop(&mut self) {
let free = self.mm.free;
let mut curr: *mut PepIdentityListItem = self.head; let mut curr: *mut PepIdentityListItem = self.head;
self.head = ptr::null_mut(); self.head = ptr::null_mut();
...@@ -288,13 +297,15 @@ mod tests { ...@@ -288,13 +297,15 @@ mod tests {
#[test] #[test]
fn identity() { fn identity() {
let mm = MM { malloc: libc::malloc, free: libc::free };
let address = "addr@ess"; let address = "addr@ess";
let fpr = Fingerprint::from_str( let fpr = Fingerprint::from_str(
"0123 4567 89AB CDEF 0000 0123 4567 89ab cdef 0000").unwrap(); "0123 4567 89AB CDEF 0000 0123 4567 89ab cdef 0000").unwrap();
let username = "User Name"; let username = "User Name";
let template = PepIdentityTemplate::new(address, fpr, Some(username)); let template = PepIdentityTemplate::new(address, fpr, Some(username));
let ident = PepIdentity::new(&template); let ident = PepIdentity::new(mm, &template);
assert_eq!(ident.address().map(|s| s.to_bytes()), assert_eq!(ident.address().map(|s| s.to_bytes()),
Some(address.as_bytes())); Some(address.as_bytes()));
...@@ -307,7 +318,9 @@ mod tests { ...@@ -307,7 +318,9 @@ mod tests {
#[test] #[test]
fn list() { fn list() {
let mut list = PepIdentityList::empty(); let mm = MM { malloc: libc::malloc, free: libc::free };
let mut list = PepIdentityList::empty(mm);
assert!(list.head.is_null()); assert!(list.head.is_null());
let address = "addr@ess"; let address = "addr@ess";
......
...@@ -11,14 +11,13 @@ use crate::Error; ...@@ -11,14 +11,13 @@ use crate::Error;
use crate::Keystore; use crate::Keystore;
use crate::PepCipherSuite; use crate::PepCipherSuite;
use crate::Result; use crate::Result;
use crate::{Malloc, Free}; use crate::ffi::MM;
const MAGIC: u64 = 0xE3F3_05AD_48EE_0DF5; const MAGIC: u64 = 0xE3F3_05AD_48EE_0DF5;
pub struct State { pub struct State {
ks: Keystore, ks: Keystore,
malloc: Malloc, mm: MM,
free: Free,
magic: u64, magic: u64,
} }
...@@ -75,8 +74,10 @@ impl Session { ...@@ -75,8 +74,10 @@ impl Session {
version: ptr::null(), version: ptr::null(),
state: Box::into_raw(Box::new(State { state: Box::into_raw(Box::new(State {
ks: Keystore::init_in_memory().unwrap(), ks: Keystore::init_in_memory().unwrap(),
malloc: libc::malloc, mm: MM {
free: libc::free, malloc: libc::malloc,
free: libc::free,
},
magic: MAGIC, magic: MAGIC,
})), })),
curr_passphrase: ptr::null(), curr_passphrase: ptr::null(),
...@@ -87,16 +88,14 @@ impl Session { ...@@ -87,16 +88,14 @@ impl Session {
} }
pub fn init(&mut self, pub fn init(&mut self,
ks: Keystore, mm: MM,
malloc: Malloc, ks: Keystore)
free: Free)
{ {
assert!(self.state.is_null()); assert!(self.state.is_null());
self.state = Box::into_raw(Box::new(State { self.state = Box::into_raw(Box::new(State {
ks: ks, ks: ks,
malloc: malloc, mm,
free: free,
magic: MAGIC, magic: MAGIC,
})); }));
} }
...@@ -121,14 +120,9 @@ impl Session { ...@@ -121,14 +120,9 @@ impl Session {
&mut State::as_mut(self.state).ks &mut State::as_mut(self.state).ks
} }
/// Returns the application's malloc routine. /// Returns the application's memory management routines.
pub fn malloc(&self) -> Malloc { pub fn mm(&self) -> MM {
State::as_mut(self.state).malloc State::as_mut(self.state).mm
}
/// Returns the application's free routine.
pub fn free(&self) -> Free {
State::as_mut(self.state).free
} }
/// Returns the value of curr_passphrase. /// Returns the value of curr_passphrase.
...@@ -209,8 +203,8 @@ mod tests { ...@@ -209,8 +203,8 @@ mod tests {
session.deinit(); session.deinit();
// If the state pointer is non-NULL, this will panic. // If the state pointer is non-NULL, this will panic.
session.init(Keystore::init_in_memory().unwrap(), session.init(MM { malloc: libc::malloc, free: libc::free },
libc::malloc, libc::free); Keystore::init_in_memory().unwrap());
let ks = session.keystore() as *mut _; let ks = session.keystore() as *mut _;
let ks2 = session.keystore() as *mut _; let ks2 = session.keystore() as *mut _;
assert!(ptr::eq(ks, ks2)); assert!(ptr::eq(ks, ks2));
......
...@@ -17,10 +17,12 @@ use std::ptr; ...@@ -17,10 +17,12 @@ use std::ptr;
use std::ffi::CStr; use std::ffi::CStr;
use libc::c_char; use libc::c_char;
use libc::calloc;
use libc::free;
use crate::buffer::rust_str_to_c_str; use crate::ffi::MM;
use crate::buffer::{
malloc_cleared,
rust_str_to_c_str,
};
#[repr(C)] #[repr(C)]
pub struct StringListItem { pub struct StringListItem {
...@@ -33,11 +35,12 @@ impl StringListItem { ...@@ -33,11 +35,12 @@ impl StringListItem {
/// ///
/// The memory is allocated using the libc allocator. The caller /// The memory is allocated using the libc allocator. The caller
/// is responsible for freeing it explicitly. /// is responsible for freeing it explicitly.
fn empty() -> &'static mut Self { fn empty(mm: MM) -> &'static mut Self {
let buffer = unsafe { calloc(1, std::mem::size_of::<Self>()) }; let buffer = if let Ok(buffer) = malloc_cleared::<Self>(mm) {
if buffer.is_null() { buffer
} else {
panic!("Out of memory allocating a StringListItem"); panic!("Out of memory allocating a StringListItem");
} };
unsafe { &mut *(buffer as *mut Self) } unsafe { &mut *(buffer as *mut Self) }
} }
...@@ -46,10 +49,10 @@ impl StringListItem { ...@@ -46,10 +49,10 @@ impl StringListItem {
/// ///
/// The memory is allocated using the libc allocator. The caller /// The memory is allocated using the libc allocator. The caller
/// is responsible for freeing it explicitly. /// is responsible for freeing it explicitly.
fn new<S: AsRef<str>>(value: S, next: *mut Self) -> &'static mut Self { fn new<S: AsRef<str>>(mm: MM, value: S, next: *mut Self) -> &'static mut Self {
let item = Self::empty(); let item = Self::empty(mm);
item.value = rust_str_to_c_str(value); item.value = rust_str_to_c_str(mm, value);
item.next = next; item.next = next;
item item
...@@ -69,6 +72,7 @@ pub struct StringList { ...@@ -69,6 +72,7 @@ pub struct StringList {
head: *mut StringListItem, head: *mut StringListItem,
// If set, when the StringList is dropped, the items are freed. // If set, when the StringList is dropped, the items are freed.
owned: bool, owned: bool,
mm: MM,
} }
impl StringList { impl StringList {
...@@ -77,10 +81,12 @@ impl StringList { ...@@ -77,10 +81,12 @@ impl StringList {
/// `owned` indicates whether the rust code should own the items. /// `owned` indicates whether the rust code should own the items.
/// If so, when the `StringList` is dropped, the items will also /// If so, when the `StringList` is dropped, the items will also
/// be freed. /// be freed.
pub fn to_rust(sl: *mut StringListItem, owned: bool) -> Self { pub fn to_rust(mm: MM, sl: *mut StringListItem, owned: bool) -> Self
{
StringList { StringList {
head: sl, head: sl,
owned, owned,
mm,
} }
} }
...@@ -97,10 +103,11 @@ impl StringList { ...@@ -97,10 +103,11 @@ impl StringList {
/// The items are owned by the `StringList`, and when it is /// The items are owned by the `StringList`, and when it is
/// dropped, they are freed. To take ownership of the items, call /// dropped, they are freed. To take ownership of the items, call
/// `StringList::to_c`. /// `StringList::to_c`.
pub fn new<S: AsRef<str>>(value: S) -> Self { pub fn new<S: AsRef<str>>(mm: MM, value: S) -> Self {
StringList { StringList {
head: StringListItem::new(value, ptr::null_mut()), head: StringListItem::new(mm, value, ptr::null_mut()),
owned: true, owned: true,
mm,
} }
} }
...@@ -109,10 +116,12 @@ impl StringList { ...@@ -109,10 +116,12 @@ impl StringList {
/// Any added items are owned by the `StringList`, and when it is /// Any added items are owned by the `StringList`, and when it is
/// dropped, they are freed. To take ownership of the items, call /// dropped, they are freed. To take ownership of the items, call
/// `StringList::to_c`. /// `StringList::to_c`.
pub fn empty() -> Self { pub fn empty(mm: MM) -> Self
{
StringList { StringList {
head: ptr::null_mut(), head: ptr::null_mut(),
owned: true, owned: true,
mm,
} }
} }
...@@ -121,9 +130,12 @@ impl StringList { ...@@ -121,9 +130,12 @@ impl StringList {
/// variant, which we use for testing. /// variant, which we use for testing.
#[cfg(test)] #[cfg(test)]
fn empty_alt() -> Self { fn empty_alt() -> Self {
let mm = MM { malloc: libc::malloc, free: libc::free };
StringList { StringList {
head: StringListItem::empty(), head: StringListItem::empty(mm),
owned: true, owned: true,
mm,
} }
} }
...@@ -147,6 +159,8 @@ impl StringList { ...@@ -147,6 +159,8 @@ impl StringList {
} }
fn add_<S: AsRef<str>>(&mut self, value: S, dedup: bool) { fn add_<S: AsRef<str>>(&mut self, value: S, dedup: bool) {
let mm = self.mm;
let value = value.as_ref(); let value = value.as_ref();
// See if the value already exists in the string list. // See if the value already exists in the string list.
...@@ -163,18 +177,18 @@ impl StringList { ...@@ -163,18 +177,18 @@ impl StringList {
let itemp = iter.item(); let itemp = iter.item();
if (*itemp).is_null() { if (*itemp).is_null() {
// 1. head is NULL (this is the case if item is NULL). // 1. head is NULL (this is the case if item is NULL).
*itemp = StringListItem::new(value, ptr::null_mut()); *itemp = StringListItem::new(mm, value, ptr::null_mut());
} else { } else {
let item: &mut StringListItem let item: &mut StringListItem
= StringListItem::as_mut(*itemp).expect("just checked"); = StringListItem::as_mut(*itemp).expect("just checked");
if item.value.is_null() { if item.value.is_null() {
// 2. head is not NULL, but head.value is NULL. // 2. head is not NULL, but head.value is NULL.
item.value = rust_str_to_c_str(value); item.value = rust_str_to_c_str(mm, value);
} else { } else {
// 3. neither head nor head.value are NULL. // 3. neither head nor head.value are NULL.
assert!(item.next.is_null()); assert!(item.next.is_null());
item.next = StringListItem::new(value, ptr::null_mut()); item.next = StringListItem::new(mm, value, ptr::null_mut());
} }
} }
} }
...@@ -200,6 +214,8 @@ impl StringList { ...@@ -200,6 +214,8 @@ impl StringList {
/// The items in other have the same ownership as items in `self`. /// The items in other have the same ownership as items in `self`.
/// `other` is reset to an empty list. /// `other` is reset to an empty list.
pub fn append(&mut self, other: &mut StringList) { pub fn append(&mut self, other: &mut StringList) {
let free = self.mm.free;
let mut iter = self.iter_mut(); let mut iter = self.iter_mut();
(&mut iter).last(); (&mut iter).last();
...@@ -229,6 +245,8 @@ impl StringList { ...@@ -229,6 +245,8 @@ impl StringList {
impl Drop for StringList { impl Drop for StringList {
fn drop(&mut self) { fn drop(&mut self) {
let free = self.mm.free;
let mut curr: *mut StringListItem = self.head; let mut curr: *mut StringListItem = self.head;
self.head = ptr::null_mut(); self.head = ptr::null_mut();
...@@ -320,27 +338,33 @@ mod tests { ...@@ -320,27 +338,33 @@ mod tests {
#[test] #[test]
fn empty() { fn empty() {
let mm = MM { malloc: libc::malloc, free: libc::free };
// There are two ways to make an empty list. Either head is // There are two ways to make an empty list. Either head is
// NULL or the string list item's value and next are NULL. // NULL or the string list item's value and next are NULL.
let empty = StringList { let empty = StringList {
head: ptr::null_mut(), head: ptr::null_mut(),
owned: true, owned: true,
mm: mm,
}; };
assert_eq!(empty.len(), 0); assert_eq!(empty.len(), 0);
let empty = StringList { let empty = StringList {
head: StringListItem::empty(), head: StringListItem::empty(mm),
owned: true, owned: true,
mm: mm,
}; };
assert_eq!(empty.len(), 0); assert_eq!(empty.len(), 0);
} }
#[test] #[test]
fn add() { fn add() {
let mm = MM { malloc: libc::malloc, free: libc::free };
for variant in 0..3 { for variant in 0..3 {
let (mut list, mut v) = match variant { let (mut list, mut v) = match variant {
0 => { 0 => {
let list = StringList::new("abc"); let list = StringList::new(mm, "abc");
assert_eq!(list.len(), 1); assert_eq!(list.len(), 1);
let mut v: Vec<String> = Vec::new(); let mut v: Vec<String> = Vec::new();
...@@ -348,7 +372,7 @@ mod tests { ...@@ -348,7 +372,7 @@ mod tests {
(list, v) (list, v)
}, },
1 => (StringList::empty(), Vec::new()), 1 => (StringList::empty(mm), Vec::new()),
2 => (StringList::empty_alt(), Vec::new()), 2 => (StringList::empty_alt(), Vec::new()),
_ => unreachable!(), _ => unreachable!(),
}; };
...@@ -374,10 +398,12 @@ mod tests { ...@@ -374,10 +398,12 @@ mod tests {
#[test] #[test]
fn add_unique() { fn add_unique() {
let mm = MM { malloc: libc::malloc, free: libc::free };
for variant in 0..3 { for variant in 0..3 {
let (mut list, mut v) = match variant { let (mut list, mut v) = match variant {
0 => { 0 => {
let list = StringList::new("abc"); let list = StringList::new(mm, "abc");
assert_eq!(list.len(), 1); assert_eq!(list.len(), 1);
let mut v: Vec<String> = Vec::new(); let mut v: Vec<String> = Vec::new();
...@@ -385,7 +411,7 @@ mod tests { ...@@ -385,7 +411,7 @@ mod tests {
(list, v) (list, v)
}, },
1 => (StringList::empty(), Vec::new()), 1 => (StringList::empty(mm), Vec::new()),
2 => (StringList::empty_alt(), Vec::new()), 2 => (StringList::empty_alt(), Vec::new()),
_ => unreachable!(), _ => unreachable!(),
}; };
...@@ -420,12 +446,14 @@ mod tests { ...@@ -420,12 +446,14 @@ mod tests {
#[test] #[test]
fn append() { fn append() {
let mm = MM { malloc: libc::malloc, free: libc::free };
for variant in 0..2 { for variant in 0..2 {
// Returns a list and a vector with `count` items whose // Returns a list and a vector with `count` items whose
// values are `prefix_0`, `prefix_1`, etc. // values are `prefix_0`, `prefix_1`, etc.
let list = |count: usize, prefix: &str| -> (StringList, Vec<String>) { let list = |count: usize, prefix: &str| -> (StringList, Vec<String>) {
let mut l = match variant { let mut l = match variant {
0 => StringList::empty(), 0 => StringList::empty(mm),
1 => StringList::empty_alt(), 1 => StringList::empty_alt(),
_ => unreachable!(), _ => unreachable!(),
}; };
......