diff --git a/arm_mte.h b/arm_mte.h new file mode 100644 index 0000000..8deefb2 --- /dev/null +++ b/arm_mte.h @@ -0,0 +1,86 @@ +#ifndef ARM_MTE_H +#define ARM_MTE_H + +#include +#include + +// Returns a tagged pointer. +// See https://developer.arm.com/documentation/ddi0602/2023-09/Base-Instructions/IRG--Insert-Random-Tag- +static inline void *arm_mte_create_random_tag(void *p, u64 exclusion_mask) { + return __arm_mte_create_random_tag(p, exclusion_mask); +} + +// Tag the memory region with the tag specified in tag bits of tagged_ptr. Memory region itself is +// zeroed. +// Arm's software optimization guide says: +// "it is recommended to use STZGM (or DCZGVA) to set tag if data is not a concern." (STZGM and +// DCGZVA are zeroing variants of tagging instructions). +// +// Contents of this function were copied from scudo: +// https://android.googlesource.com/platform/external/scudo/+/refs/tags/android-14.0.0_r1/standalone/memtag.h#167 +static inline void arm_mte_store_tags_and_clear(void *tagged_ptr, size_t len) { + uintptr_t Begin = (uintptr_t) tagged_ptr; + uintptr_t End = Begin + len; + uintptr_t LineSize, Next, Tmp; + __asm__ __volatile__( + ".arch_extension memtag \n\t" + + // Compute the cache line size in bytes (DCZID_EL0 stores it as the log2 + // of the number of 4-byte words) and bail out to the slow path if DCZID_EL0 + // indicates that the DC instructions are unavailable. + "DCZID .req %[Tmp] \n\t" + "mrs DCZID, dczid_el0 \n\t" + "tbnz DCZID, #4, 3f \n\t" + "and DCZID, DCZID, #15 \n\t" + "mov %[LineSize], #4 \n\t" + "lsl %[LineSize], %[LineSize], DCZID \n\t" + ".unreq DCZID \n\t" + + // Our main loop doesn't handle the case where we don't need to perform any + // DC GZVA operations. If the size of our tagged region is less than + // twice the cache line size, bail out to the slow path since it's not + // guaranteed that we'll be able to do a DC GZVA. + "Size .req %[Tmp] \n\t" + "sub Size, %[End], %[Cur] \n\t" + "cmp Size, %[LineSize], lsl #1 \n\t" + "b.lt 3f \n\t" + ".unreq Size \n\t" + + "LineMask .req %[Tmp] \n\t" + "sub LineMask, %[LineSize], #1 \n\t" + + // STZG until the start of the next cache line. + "orr %[Next], %[Cur], LineMask \n\t" + + "1:\n\t" + "stzg %[Cur], [%[Cur]], #16 \n\t" + "cmp %[Cur], %[Next] \n\t" + "b.lt 1b \n\t" + + // DC GZVA cache lines until we have no more full cache lines. + "bic %[Next], %[End], LineMask \n\t" + ".unreq LineMask \n\t" + + "2: \n\t" + "dc gzva, %[Cur] \n\t" + "add %[Cur], %[Cur], %[LineSize] \n\t" + "cmp %[Cur], %[Next] \n\t" + "b.lt 2b \n\t" + + // STZG until the end of the tagged region. This loop is also used to handle + // slow path cases. + + "3: \n\t" + "cmp %[Cur], %[End] \n\t" + "b.ge 4f \n\t" + "stzg %[Cur], [%[Cur]], #16 \n\t" + "b 3b \n\t" + + "4: \n\t" + + : [Cur] "+&r"(Begin), [LineSize] "=&r"(LineSize), [Next] "=&r"(Next), [Tmp] "=&r"(Tmp) + : [End] "r"(End) + : "memory" + ); +} +#endif diff --git a/h_malloc.c b/h_malloc.c index 2dc0bde..098eb37 100644 --- a/h_malloc.c +++ b/h_malloc.c @@ -14,6 +14,7 @@ #include "h_malloc.h" #include "memory.h" +#include "memtag.h" #include "mutex.h" #include "pages.h" #include "random.h" @@ -66,6 +67,10 @@ static atomic_uint thread_arena_counter = 0; static const unsigned thread_arena = 0; #endif +#ifdef MEMTAG +bool __is_memtag_enabled = true; +#endif + static union { struct { void *slab_region_start; @@ -99,6 +104,18 @@ struct slab_metadata { #if SLAB_QUARANTINE u64 quarantine_bitmap[4]; #endif +#ifdef HAS_ARM_MTE + // arm_mte_tags is used as a u4 array (MTE tags are 4-bit wide) + // + // Its size is calculated by the following formula: + // (MAX_SLAB_SLOT_COUNT + 2) / 2 + // MAX_SLAB_SLOT_COUNT is currently 256, 2 extra slots are needed for branchless handling of + // edge slots in tag_and_clear_slab_slot() + // + // It's intentionally placed at the end of struct to improve locality: for most size classes, + // slot count is far lower than MAX_SLAB_SLOT_COUNT. + u8 arm_mte_tags[129]; +#endif }; static const size_t min_align = 16; @@ -506,6 +523,47 @@ static inline void stats_slab_deallocate(UNUSED struct size_class *c, UNUSED siz #endif } +static void *tag_and_clear_slab_slot(struct slab_metadata *metadata, void *slot_ptr, size_t slot_idx, size_t slot_size) { +#ifdef HAS_ARM_MTE + if (unlikely(!is_memtag_enabled())) { + return slot_ptr; + } + + // arm_mte_tags is an array of 4-bit unsigned integers stored as u8 array (MTE tags are 4-bit wide) + // + // It stores the most recent tag for each slab slot, or 0 if the slot was never used. + // Slab indices in arm_mte_tags array are shifted to the right by 1, and size of this array + // is (MAX_SLAB_SLOT_COUNT + 2). This means that first and last values of arm_mte_tags array + // are always 0, which allows to handle edge slots in a branchless way when tag exclusion mask + // is constructed. + u8 *slot_tags = metadata->arm_mte_tags; + + // Tag exclusion mask + u64 tem = (1 << 0) | (1 << RESERVED_TAG); + + // current or previous tag of left neighbor or 0 if there's no left neighbor or if it was never used + tem |= (1 << u4_arr_get(slot_tags, slot_idx)); + // previous tag of this slot or 0 if it was never used + tem |= (1 << u4_arr_get(slot_tags, slot_idx + 1)); + // current or previous tag of right neighbor or 0 if there's no right neighbor or if it was never used + tem |= (1 << u4_arr_get(slot_tags, slot_idx + 2)); + + void *tagged_ptr = arm_mte_create_random_tag(slot_ptr, tem); + // slot addresses and sizes are always aligned by 16 + arm_mte_store_tags_and_clear(tagged_ptr, slot_size); + + // store new tag of this slot + u4_arr_set(slot_tags, slot_idx + 1, get_pointer_tag(tagged_ptr)); + + return tagged_ptr; +#else + (void) metadata; + (void) slot_idx; + (void) slot_size; + return slot_ptr; +#endif +} + static inline void *allocate_small(unsigned arena, size_t requested_size) { struct size_info info = get_size_info(requested_size); size_t size = likely(info.size) ? info.size : 16; @@ -534,6 +592,7 @@ static inline void *allocate_small(unsigned arena, size_t requested_size) { if (requested_size) { write_after_free_check(p, size - canary_size); set_canary(metadata, p, size); + p = tag_and_clear_slab_slot(metadata, p, slot, size); } stats_small_allocate(c, size); @@ -566,6 +625,7 @@ static inline void *allocate_small(unsigned arena, size_t requested_size) { void *p = slot_pointer(size, slab, slot); if (requested_size) { set_canary(metadata, p, size); + p = tag_and_clear_slab_slot(metadata, p, slot, size); } stats_slab_allocate(c, slab_size); stats_small_allocate(c, size); @@ -588,6 +648,7 @@ static inline void *allocate_small(unsigned arena, size_t requested_size) { void *p = slot_pointer(size, slab, slot); if (requested_size) { set_canary(metadata, p, size); + p = tag_and_clear_slab_slot(metadata, p, slot, size); } stats_slab_allocate(c, slab_size); stats_small_allocate(c, size); @@ -612,6 +673,7 @@ static inline void *allocate_small(unsigned arena, size_t requested_size) { if (requested_size) { write_after_free_check(p, size - canary_size); set_canary(metadata, p, size); + p = tag_and_clear_slab_slot(metadata, p, slot, size); } stats_small_allocate(c, size); @@ -694,7 +756,17 @@ static inline void deallocate_small(void *p, const size_t *expected_size) { if (likely(!is_zero_size)) { check_canary(metadata, p, size); - if (ZERO_ON_FREE) { + bool skip_zero = false; +#ifdef HAS_ARM_MTE + if (likely(is_memtag_enabled())) { + arm_mte_store_tags_and_clear(set_pointer_tag(p, RESERVED_TAG), size); + // metadata->arm_mte_tags is intentionally not updated, it should keep the previous slot + // tag after slot is freed + skip_zero = true; + } +#endif + + if (ZERO_ON_FREE && !skip_zero) { memset(p, 0, size - canary_size); } } @@ -1123,8 +1195,15 @@ COLD static void init_slow_path(void) { if (unlikely(memory_protect_rw_metadata(ra->regions, ra->total * sizeof(struct region_metadata)))) { fatal_error("failed to unprotect memory for regions table"); } - +#ifdef HAS_ARM_MTE + if (likely(is_memtag_enabled())) { + ro.slab_region_start = memory_map_mte(slab_region_size); + } else { + ro.slab_region_start = memory_map(slab_region_size); + } +#else ro.slab_region_start = memory_map(slab_region_size); +#endif if (unlikely(ro.slab_region_start == NULL)) { fatal_error("failed to allocate slab region"); } @@ -1368,6 +1447,11 @@ EXPORT void *h_calloc(size_t nmemb, size_t size) { if (!ZERO_ON_FREE && likely(p != NULL) && total_size && total_size <= max_slab_size_class) { memset(p, 0, total_size - canary_size); } +#ifdef HAS_ARM_MTE + // use an assert instead of adding a conditional to memset() above (freed memory is always + // zeroed when MTE is enabled) + static_assert(ZERO_ON_FREE, "disabling ZERO_ON_FREE reduces performance when ARM MTE is enabled"); +#endif return p; } @@ -1385,11 +1469,14 @@ EXPORT void *h_realloc(void *old, size_t size) { } } + void *old_orig = old; + old = untag_pointer(old); + size_t old_size; if (old < get_slab_region_end() && old >= ro.slab_region_start) { old_size = slab_usable_size(old); if (size <= max_slab_size_class && get_size_info(size).size == old_size) { - return old; + return old_orig; } thread_unseal_metadata(); } else { @@ -1502,7 +1589,7 @@ EXPORT void *h_realloc(void *old, size_t size) { if (copy_size > 0 && copy_size <= max_slab_size_class) { copy_size -= canary_size; } - memcpy(new, old, copy_size); + memcpy(new, old_orig, copy_size); if (old_size <= max_slab_size_class) { deallocate_small(old, NULL); } else { @@ -1543,6 +1630,8 @@ EXPORT void h_free(void *p) { return; } + p = untag_pointer(p); + if (p < get_slab_region_end() && p >= ro.slab_region_start) { thread_unseal_metadata(); deallocate_small(p, NULL); @@ -1566,6 +1655,8 @@ EXPORT void h_free_sized(void *p, size_t expected_size) { return; } + p = untag_pointer(p); + expected_size = adjust_size_for_canary(expected_size); if (p < get_slab_region_end() && p >= ro.slab_region_start) { @@ -1619,11 +1710,13 @@ static inline void memory_corruption_check_small(const void *p) { mutex_unlock(&c->lock); } -EXPORT size_t h_malloc_usable_size(H_MALLOC_USABLE_SIZE_CONST void *p) { - if (p == NULL) { +EXPORT size_t h_malloc_usable_size(H_MALLOC_USABLE_SIZE_CONST void *arg) { + if (arg == NULL) { return 0; } + void *p = untag_pointer((void *) (uintptr_t) arg); + if (p < get_slab_region_end() && p >= ro.slab_region_start) { thread_unseal_metadata(); memory_corruption_check_small(p); diff --git a/memory.c b/memory.c index 04afc23..5434060 100644 --- a/memory.c +++ b/memory.c @@ -28,6 +28,20 @@ void *memory_map(size_t size) { return p; } +#ifdef HAS_ARM_MTE +// Note that PROT_MTE can't be cleared via mprotect +void *memory_map_mte(size_t size) { + void *p = mmap(NULL, size, PROT_MTE, MAP_ANONYMOUS|MAP_PRIVATE, -1, 0); + if (unlikely(p == MAP_FAILED)) { + if (errno != ENOMEM) { + fatal_error("non-ENOMEM MTE mmap failure"); + } + return NULL; + } + return p; +} +#endif + bool memory_map_fixed(void *ptr, size_t size) { void *p = mmap(ptr, size, PROT_NONE, MAP_ANONYMOUS|MAP_PRIVATE|MAP_FIXED, -1, 0); bool ret = p == MAP_FAILED; diff --git a/memory.h b/memory.h index c04bfd9..6e4cd4d 100644 --- a/memory.h +++ b/memory.h @@ -11,6 +11,9 @@ int get_metadata_key(void); void *memory_map(size_t size); +#ifdef HAS_ARM_MTE +void *memory_map_mte(size_t size); +#endif bool memory_map_fixed(void *ptr, size_t size); bool memory_unmap(void *ptr, size_t size); bool memory_protect_ro(void *ptr, size_t size); diff --git a/memtag.h b/memtag.h new file mode 100644 index 0000000..a768d41 --- /dev/null +++ b/memtag.h @@ -0,0 +1,52 @@ +#ifndef MEMTAG_H +#define MEMTAG_H + +#include "util.h" + +#ifdef HAS_ARM_MTE +#include "arm_mte.h" +#define MEMTAG 1 +#define RESERVED_TAG 15 +#define TAG_WIDTH 4 +#endif + +#ifdef MEMTAG +extern bool __is_memtag_enabled; +#endif + +static inline bool is_memtag_enabled(void) { +#ifdef MEMTAG + return __is_memtag_enabled; +#else + return false; +#endif +} + +static inline void *untag_pointer(void *ptr) { +#ifdef HAS_ARM_MTE + const uintptr_t mask = UINTPTR_MAX >> 8; + return (void *) ((uintptr_t) ptr & mask); +#else + return ptr; +#endif +} + +static inline void *set_pointer_tag(void *ptr, u8 tag) { +#ifdef HAS_ARM_MTE + return (void *) (((uintptr_t) tag << 56) | (uintptr_t) untag_pointer(ptr)); +#else + (void) tag; + return ptr; +#endif +} + +static inline u8 get_pointer_tag(void *ptr) { +#ifdef HAS_ARM_MTE + return (((uintptr_t) ptr) >> 56) & 0xf; +#else + (void) ptr; + return 0; +#endif +} + +#endif