Use-after-Free vulnerabilities detection

This notebook allows to search for potential Use-after-Free vulnerabilities in a REVEN trace.

Prerequisites

  • This notebook should be run in a jupyter notebook server equipped with a REVEN 2 python kernel. REVEN comes with a jupyter notebook server accessible with the Open Python button in the Analyze page of any scenario.
  • This notebook depends on capstone being installed in the REVEN 2 python kernel. To install capstone in the current environment, please execute the capstone cell of this notebook.
  • This notebook requires the Memory History resource for your target scenario.

Running the notebook

Fill out the parameters in the Parameters cell below, then run all the cells of this notebook.

Note

Although the script is designed to limit false positive results, you may still encounter a few ones. Please check the results and feel free to report any issue, be it a false positive or a false negative ;-).

Source

# -*- coding: utf-8 -*-
# ---
# jupyter:
#   jupytext:
#     formats: ipynb,py:percent
#     text_representation:
#       extension: .py
#       format_name: percent
#   kernelspec:
#     display_name: reven
#     language: python
#     name: reven-python3
# ---

# %% [markdown]
# # Use-after-Free vulnerabilities detection
#
# This notebook allows to search for potential Use-after-Free vulnerabilities in a REVEN trace.
#
# ## Prerequisites
#
# - This notebook should be run in a jupyter notebook server equipped with a REVEN 2 python kernel.
#   REVEN comes with a jupyter notebook server accessible with the `Open Python` button in the `Analyze` page of any
#   scenario.
#
# - This notebook depends on `capstone` being installed in the REVEN 2 python kernel.
#   To install capstone in the current environment, please execute the [capstone cell](#Capstone-Installation) of this
#   notebook.
#
# - This notebook requires the Memory History resource for your target scenario.
#
#
# ## Running the notebook
#
# Fill out the parameters in the [Parameters cell](#Parameters) below, then run all the cells of this notebook.
#
#
# ## Note
#
# Although the script is designed to limit false positive results, you may still encounter a few ones. Please check
# the results and feel free to report any issue, be it a false positive or a false negative ;-).

# %% [markdown]
# # Capstone Installation
#
# Check for capstone's presence. If missing, attempt to get it from pip

# %%
try:
    import capstone
    print("capstone already installed")
except ImportError:
    print("Could not find capstone, attempting to install it from pip")
    import sys
    import subprocess

    command = [f"{sys.executable}", "-m", "pip", "install", "capstone"]
    p = subprocess.run(command)

    if int(p.returncode) != 0:
        raise RuntimeError("Error installing capstone")

    import capstone  # noqa
    print("Successfully installed capstone")

# %% [markdown]
# # Parameters

# %%
# Server connection

# Host of the REVEN server running the scenario.
# When running this notebook from the Project Manager, '127.0.0.1' should be the correct value.
reven_backend_host = '127.0.0.1'

# Port of the REVEN server running the scenario.
# After starting a REVEN server on your scenario, you can get its port on the Analyze page of that scenario.
reven_backend_port = 13370


# Range control

# First transition considered for the detection of allocation/deallocation pairs
# If set to None, then the first transition of the trace
uaf_from_tr = None
# uaf_from_tr = 3984021008  # ex: vlc scenario
# uaf_from_tr = 8655210  # ex: bksod scenario

# First transition **not** considered for the detection of allocation/deallocation pairs
# If set to None, then the last transition of the trace
uaf_to_tr = None
# uaf_to_tr = 3986760400  # ex: vlc scenario
# uaf_to_tr = 169673257  # ex: bksod scenario


# Filter control

# Beware that, when filtering, if an allocation happens in the specified process and/or binary,
# the script will fail if the deallocation happens in a different process and/or binary.
# This issue should only happen for allocations in the kernel.

# Specify on which PID the allocation/deallocation pairs should be kept.
# Set a value of None to not filter on the PID.
faulty_process_pid = None
# faulty_process_pid = 2620  # ex: vlc scenario
# faulty_process_pid = None  # ex: bksod
# We can't filter on process for BKSOD as we have alloc in a process and free in another process
# As termdd.sys and the process svchost.exe (pid: 1004) is using kernel workers to free some resources
# So the allocs are done in the process svchost.exe and some of the frees in the "System" process (pid: 4)

# Specify on which process name the allocation/deallocation pairs should be kept.
# Set a value of None to not filter on the process name.
# If both a process PID and a process name are specified, please make sure that they both
# refer to the same process, otherwise all allocations will be filtered and no results
# will be produced.
faulty_process_name = None
# faulty_process_name = "vlc.exe"  # ex: vlc scenario
# faulty_process_name = None  # ex: bksod scenario

# Specify on which binary name the allocation/deallocation pairs should be kept.
# Set a value of None to not filter on the binary name.
# Only allocation/deallocation taking place in the binaries whose filename,
# path or name contain the specified value are kept.
# If filtering on both a process and a binary, please make sure that there are
# allocations taking place in that binary in the selected process, otherwise all
# allocations will be filtered and no result will be produced.
faulty_binary = None
# faulty_binary = "libmkv_plugin.dll"  # ex: vlc scenario
# faulty_binary = "termdd.sys"  # ex: bksod scenario


# Address control

# Specify a **physical** address suspected of being faulty here,
# to only test UaF for this specific address, instead of all (filtered) allocations.
# The address should still be returned by an allocation/deallocation pair.
# To get a physical address from a virtual address, find a context where the address
# is mapped, then use `virtual_address.translate(ctx)`.
uaf_faulty_physical_address = None
# uaf_faulty_physical_address = 0x5a36cd20  # ex: vlc scenario
# uaf_faulty_physical_address = 0x7fb65010  # ex: bksod scenario


# Allocator control

# The script can use two allocators to find allocation/deallocation pairs.
# The following booleans allow to enable the search for allocations by these
# allocators for a scenario.
# Generally it is expected to have only a single allocator enabled for a given
# scenario.

# To add your own allocator, please look at how the two provided allocators were
# added.

# Whether or not to look for malloc/free allocation/deallocation pairs.
search_malloc = True
# search_malloc = True  # ex: vlc scenario, user space scenario
# search_malloc = False  # ex: bksod scenario

# Whether or not to look for ExAllocatePoolWithTag/ExFreePoolWithTag
# allocation/deallocation pairs.
# This allocator is used by the Windows kernel.
search_pool_allocation = False
# search_pool_allocation = False  # ex: vlc scenario
# search_pool_allocation = True  # ex: bksod scenario, kernel scenario


# Taint control

# Technical parameter: number of accesses in a taint after which the script gives up on
# that particular taint.
# As long taints degrade the performance of the script significantly, it is recommended to
# give up on a taint after it exceeds a certain number of operations.
# If you experience missing results, try increasing the value.
# If you experience a very long runtime for the script, try decreasing the value.
# The default value should be appropriate for most cases.
taint_max_length = 100000

# %%
import itertools  # noqa: E402
from collections import OrderedDict  # noqa: E402
from typing import Dict, List  # noqa: E402

import reven2  # noqa: E402
import reven2.preview.taint  # noqa: E402

# %%
# Python script to connect to this scenario:
server = reven2.RevenServer(reven_backend_host, reven_backend_port)
print(server.trace.transition_count)


# %%
class MemoryRange:
    page_size = 4096
    page_mask = ~(page_size - 1)

    def __init__(self, logical_address, size):
        self.logical_address = logical_address
        self.size = size

        self.pages = [{
            'logical_address': self.logical_address,
            'size': self.size,
            'physical_address': None,
            'ctx_physical_address_mapped': None,
        }]

        # Compute the pages
        while (((self.pages[-1]['logical_address'] & ~MemoryRange.page_mask) + self.pages[-1]['size'] - 1)
                >= MemoryRange.page_size):
            # Compute the size of the new page
            new_page_size = ((self.pages[-1]['logical_address'] & ~MemoryRange.page_mask) + self.pages[-1]['size']
                             - MemoryRange.page_size)

            # Reduce the size of the previous page and create the new one
            self.pages[-1]['size'] -= new_page_size
            self.pages.append({
                'logical_address': self.pages[-1]['logical_address'] + self.pages[-1]['size'],
                'size': new_page_size,
                'physical_address': None,
                'ctx_physical_address_mapped': None,
            })

    def try_translate_first_page(self, ctx):
        if self.pages[0]['physical_address'] is not None:
            return True

        physical_address = reven2.address.LogicalAddress(self.pages[0]['logical_address']).translate(ctx)

        if physical_address is None:
            return False

        self.pages[0]['physical_address'] = physical_address.offset
        self.pages[0]['ctx_physical_address_mapped'] = ctx

        return True

    def try_translate_all_pages(self, ctx):
        return_value = True

        for page in self.pages:
            if page['physical_address'] is not None:
                continue

            physical_address = reven2.address.LogicalAddress(page['logical_address']).translate(ctx)

            if physical_address is None:
                return_value = False
                continue

            page['physical_address'] = physical_address.offset
            page['ctx_physical_address_mapped'] = ctx

        return return_value

    def is_physical_address_range_in_translated_pages(self, physical_address, size):
        for page in self.pages:
            if page['physical_address'] is None:
                continue

            if (
                physical_address >= page['physical_address']
                and physical_address + size <= page['physical_address'] + page['size']
            ):
                return True

        return False

    def __repr__(self):
        return "MemoryRange(0x%x, %d)" % (self.logical_address, self.size)

# Utils to translate the physical address of an address allocated just now
#     - ctx should be the ctx where the address is located in `rax`
#     - memory_range should be the range of memory of the newly allocated buffer
#
# We are using the translate API to translate it but sometimes just after the allocation
# the address isn't mapped yet. For that we are using the slicing and for all slice access
# we are trying to translate the address.


def translate_first_page_of_allocation(ctx, memory_range):
    if memory_range.try_translate_first_page(ctx):
        return

    tainter = reven2.preview.taint.Tainter(server.trace)
    taint = tainter.simple_taint(
        tag0="rax",
        from_context=ctx,
        to_context=None,
        is_forward=True
    )

    for access in taint.accesses(changes_only=False).all():
        if memory_range.try_translate_first_page(access.transition.context_after()):
            taint.cancel()
            return

    raise RuntimeError("Couldn't find the physical address of the first page")


# %%
class AllocEvent:
    def __init__(self, memory_range, tr_begin, tr_end):
        self.memory_range = memory_range
        self.tr_begin = tr_begin
        self.tr_end = tr_end


class FreeEvent:
    def __init__(self, logical_address, tr_begin, tr_end):
        self.logical_address = logical_address
        self.tr_begin = tr_begin
        self.tr_end = tr_end


def retrieve_events_for_symbol(
    alloc_dict,
    event_class,
    symbol,
    retrieve_event_info,
    event_filter=None,
):
    for ctx in server.trace.search.symbol(
        symbol,
        from_context=None if uaf_from_tr is None else server.trace.context_before(uaf_from_tr),
        to_context=None if uaf_to_tr is None else server.trace.context_before(uaf_to_tr)
    ):
        # We don't want hit on exception (code pagefault, hardware interrupts, etc)
        if ctx.transition_after().exception is not None:
            continue

        previous_location = (ctx - 1).ossi.location()
        previous_process = (ctx - 1).ossi.process()

        # Filter by process pid/process name/binary name

        # Filter by process pid
        if faulty_process_pid is not None and previous_process.pid != faulty_process_pid:
            continue

        # Filter by process name
        if faulty_process_name is not None and previous_process.name != faulty_process_name:
            continue

        # Filter by binary name / filename / path
        if faulty_binary is not None and faulty_binary not in [
            previous_location.binary.name,
            previous_location.binary.filename,
            previous_location.binary.path
        ]:
            continue

        # Filter the event with the argument filter
        if event_filter is not None:
            if event_filter(ctx.ossi.location(), previous_location):
                continue

        # Retrieve the call/ret
        # The heuristic is that the ret is the end of our function
        #   - If the call is inlined it should be at the end of the caller function, so the ret is the ret of our
        #     function
        #   - If the call isn't inlined, the ret should be the ret of our function
        ctx_call = next(ctx.stack.frames()).creation_transition.context_after()

        tr_ret = ctx_call.transition_before().find_inverse()
        # Finding the inverse operation can fail
        # for instance if the end of the allocation function was not recorded in the trace
        if tr_ret is None:
            continue
        ctx_ret = tr_ret.context_before()

        # Build the event by reading the needed registers
        if event_class == AllocEvent:
            current_address, size = retrieve_event_info(ctx, ctx_ret)

            # Filter the alloc failing
            if current_address == 0x0:
                continue

            memory_range = MemoryRange(current_address, size)
            translate_first_page_of_allocation(ctx_ret, memory_range)

            if memory_range.pages[0]['physical_address'] not in alloc_dict:
                alloc_dict[memory_range.pages[0]['physical_address']] = []

            alloc_dict[memory_range.pages[0]['physical_address']].append(
                AllocEvent(
                    memory_range,
                    ctx.transition_after(), ctx_ret.transition_after()
                )
            )
        elif event_class == FreeEvent:
            current_address = retrieve_event_info(ctx, ctx_ret)

            # Filter the free of NULL
            if current_address == 0x0:
                continue

            current_physical_address = reven2.address.LogicalAddress(current_address).translate(ctx).offset

            if current_physical_address not in alloc_dict:
                alloc_dict[current_physical_address] = []

            alloc_dict[current_physical_address].append(
                FreeEvent(
                    current_address,
                    ctx.transition_after(), ctx_ret.transition_after()
                )
            )
        else:
            raise RuntimeError("Unknown event class: %s" % event_class.__name__)

# %%
# %%time


alloc_dict: Dict = {}


# Basic functions to retrieve the arguments
# They are working for the allocations/frees functions but won't work for all functions
# Particularly because on x86 we don't handle the size of the arguments
# nor if they are pushed left to right or right to left
def retrieve_first_argument(ctx):
    if ctx.is64b():
        return ctx.read(reven2.arch.x64.rcx)
    else:
        esp = ctx.read(reven2.arch.x64.esp)
        return ctx.read(reven2.address.LogicalAddress(esp + 4, reven2.arch.x64.ss), 4)


def retrieve_second_argument(ctx):
    if ctx.is64b():
        return ctx.read(reven2.arch.x64.rdx)
    else:
        esp = ctx.read(reven2.arch.x64.esp)
        return ctx.read(reven2.address.LogicalAddress(esp + 8, reven2.arch.x64.ss), 4)


def retrieve_return_value(ctx):
    if ctx.is64b():
        return ctx.read(reven2.arch.x64.rax)
    else:
        return ctx.read(reven2.arch.x64.eax)


def retrieve_alloc_info_with_size_as_first_argument(ctx_begin, ctx_end):
    return (
        retrieve_return_value(ctx_end),
        retrieve_first_argument(ctx_begin)
    )


def retrieve_alloc_info_with_size_as_second_argument(ctx_begin, ctx_end):
    return (
        retrieve_return_value(ctx_end),
        retrieve_second_argument(ctx_begin)
    )


def retrieve_alloc_info_for_calloc(ctx_begin, ctx_end):
    return (
        retrieve_return_value(ctx_end),
        retrieve_first_argument(ctx_begin) * retrieve_second_argument(ctx_begin)
    )


def retrieve_free_info_with_address_as_first_argument(ctx_begin, ctx_end):
    return retrieve_first_argument(ctx_begin)


if search_malloc:
    def filter_in_realloc(location, caller_location):
        return location.binary == caller_location.binary and caller_location.symbol.name == "realloc"

    # Search for allocations with malloc
    for symbol in server.ossi.symbols(r'^_?malloc$', binary_hint=r'msvcrt.dll'):
        retrieve_events_for_symbol(alloc_dict, AllocEvent, symbol, retrieve_alloc_info_with_size_as_first_argument,
                                   filter_in_realloc)

    # Search for allocations with calloc
    for symbol in server.ossi.symbols(r'^_?calloc(_crt)?$', binary_hint=r'msvcrt.dll'):
        retrieve_events_for_symbol(alloc_dict, AllocEvent, symbol, retrieve_alloc_info_for_calloc)

    # Search for deallocations with free
    for symbol in server.ossi.symbols(r'^_?free$', binary_hint=r'msvcrt.dll'):
        retrieve_events_for_symbol(alloc_dict, FreeEvent, symbol, retrieve_free_info_with_address_as_first_argument,
                                   filter_in_realloc)

    # Search for re-allocations with realloc
    for symbol in server.ossi.symbols(r'^_?realloc$', binary_hint=r'msvcrt.dll'):
        retrieve_events_for_symbol(alloc_dict, AllocEvent, symbol, retrieve_alloc_info_with_size_as_second_argument)
        retrieve_events_for_symbol(alloc_dict, FreeEvent, symbol, retrieve_free_info_with_address_as_first_argument)

if search_pool_allocation:
    # Search for allocations with ExAllocatePool...
    def filter_ex_allocate_pool(location, caller_location):
        return location.binary == caller_location.binary and caller_location.symbol.name.startswith("ExAllocatePool")

    for symbol in server.ossi.symbols(r'^ExAllocatePool', binary_hint=r'ntoskrnl.exe'):
        retrieve_events_for_symbol(alloc_dict, AllocEvent, symbol, retrieve_alloc_info_with_size_as_second_argument,
                                   filter_ex_allocate_pool)

    # Search for deallocations with ExFreePool...
    def filter_ex_free_pool(location, caller_location):
        return location.binary == caller_location.binary and caller_location.symbol.name == "ExFreePool"

    for symbol in server.ossi.symbols(r'^ExFreePool', binary_hint=r'ntoskrnl.exe'):
        retrieve_events_for_symbol(alloc_dict, FreeEvent, symbol, retrieve_free_info_with_address_as_first_argument,
                                   filter_ex_free_pool)

# Sort the events per address and event type
for physical_address in alloc_dict.keys():
    alloc_dict[physical_address] = list(sorted(
        alloc_dict[physical_address],
        key=lambda event: (event.tr_begin.id, 0 if isinstance(event, FreeEvent) else 1)
    ))

# Sort the dict by address
alloc_dict = OrderedDict(sorted(alloc_dict.items()))


# %%
def get_alloc_free_pairs(events, errors=None):
    previous_event = None
    for event in events:
        if isinstance(event, AllocEvent):
            if previous_event is None:
                pass
            elif isinstance(previous_event, AllocEvent):
                if errors is not None:
                    errors.append("Two consecutives allocs found")
        elif isinstance(event, FreeEvent):
            if previous_event is None:
                continue
            elif isinstance(previous_event, FreeEvent):
                if errors is not None:
                    errors.append("Two consecutives frees found")
            elif isinstance(previous_event, AllocEvent):
                yield (previous_event, event)
        else:
            assert 0, ("Unknown event type: %s" % type(event))

        previous_event = event


# %%
# %%time

# Basic checks of the events

for physical_address, events in alloc_dict.items():
    for event in events:
        if not isinstance(event, AllocEvent) and not isinstance(event, FreeEvent):
            raise RuntimeError("Unknown event type: %s" % type(event))

    errors: List[str] = []
    for (alloc_event, free_event) in get_alloc_free_pairs(events, errors):
        # Check the uniformity of the logical address between the alloc and the free
        if alloc_event.memory_range.logical_address != free_event.logical_address:
            errors.append("Phys:0x%x: Alloc #%d - Free #%d with different logical address: 0x%x != 0x%x" % (
                physical_address,
                alloc_event.tr_begin.id, free_event.tr_begin.id,
                alloc_event.memory_range.logical_address, free_event.logical_address))

        # Check size of 0x0
        if alloc_event.memory_range.size == 0x0 or alloc_event.memory_range.size is None:
            errors.append("Phys:0x%x: Alloc #%d - Free #%d with weird size %s" % (
                physical_address,
                alloc_event.tr_begin.id, free_event.tr_begin.id,
                alloc_event.memory_range.size))

    if len(errors) > 0:
        print("Phys:0x%x: Error(s) detected:" % (physical_address))
        for error in errors:
            print("    - %s" % error)

# %%
# Print the events

for physical_address, events in alloc_dict.items():
    print("Phys:0x%x" % (physical_address))
    print("    Events:")
    for event in events:
        if isinstance(event, AllocEvent):
            print("        - Alloc at #%d (0x%x of size 0x%x)" % (event.tr_begin.id,
                  event.memory_range.logical_address, event.memory_range.size))
        elif isinstance(event, FreeEvent):
            print("        - Free at #%d (0x%x)" % (event.tr_begin.id, event.logical_address))

    print("    Pairs:")
    for (alloc_event, free_event) in get_alloc_free_pairs(events):
        print("    - Allocated at #%d (0x%x of size 0x%x) and freed at #%d (0x%x)" % (alloc_event.tr_begin.id,
              alloc_event.memory_range.logical_address, alloc_event.memory_range.size, free_event.tr_begin.id,
              free_event.logical_address))

    print()


# %%
# This function is used to ignore the changes from the `free` in the taint of the address
# as the `free` will store the address on some internal structures re-used to malloc
# others addresses leading to false positives

def start_alloc_address_taint(server, alloc_event, free_event):
    # Setup the first taint [alloc; free]
    tainter = reven2.preview.taint.Tainter(server.trace)
    taint = tainter.simple_taint(
        tag0="rax" if alloc_event.tr_end.mode == reven2.trace.Mode.X86_64 else "eax",
        from_context=alloc_event.tr_end.context_before(),
        to_context=free_event.tr_begin.context_before() + 1,
        is_forward=True
    )

    # Setup the second taint [free; [
    state_before_free = taint.state_at(free_event.tr_begin.context_before())

    # `rcx` is lost during the execution of the free in theory
    # `rsp` is non-useful if we have it
    tag0_regs = filter(
        lambda x: x[0].register.name not in ['rcx', 'rsp'],
        state_before_free.tainted_registers()
    )

    # We don't want to keep memory inside the allocated object as accessing them will trigger a UAF anyway
    # It is also used to remove some false-positive because the memory will be used by the alloc/free functions
    # TODO: Handle addresses different than PhysicalAddress (is that even possible?)
    tag0_mems = filter(
        lambda x: not alloc_event.memory_range.is_physical_address_range_in_translated_pages(
            x[0].address.offset, x[0].size),
        state_before_free.tainted_memories()
    )

    # Only keep the slices
    tag0 = map(
        lambda x: x[0],
        itertools.chain(tag0_regs, tag0_mems)
    )

    tainter = reven2.preview.taint.Tainter(server.trace)
    return tainter.simple_taint(
        tag0=list(tag0),
        from_context=free_event.tr_end.context_before(),
        to_context=None,
        is_forward=True
    )


# %%
# Capstone utilities
def get_reven_register_from_name(name):
    for reg in reven2.arch.helpers.x64_registers():
        if reg.name == name:
            return reg

    raise RuntimeError("Unknown register: %s" % name)


def read_reg(tr, reg):
    if reg in [reven2.arch.x64.rip, reven2.arch.x64.eip]:
        return tr.pc
    else:
        return tr.context_before().read(reg)


def compute_dereferenced_address(tr, cs_insn, cs_op):
    dereferenced_address = 0

    if cs_op.value.mem.base != 0:
        base_reg = get_reven_register_from_name(cs_insn.reg_name(cs_op.value.mem.base))
        dereferenced_address += read_reg(tr, base_reg)

    if cs_op.value.mem.index != 0:
        index_reg = get_reven_register_from_name(cs_insn.reg_name(cs_op.value.mem.index))
        dereferenced_address += (cs_op.value.mem.scale * read_reg(tr, index_reg))

    dereferenced_address += cs_op.value.mem.disp

    return dereferenced_address & 0xFFFFFFFFFFFFFFFF


# %%
def uaf_analyze_function(physical_address, alloc_events):
    uaf_count = 0

    # Setup capstone
    md_64 = capstone.Cs(capstone.CS_ARCH_X86, capstone.CS_MODE_64)
    md_64.detail = True

    md_32 = capstone.Cs(capstone.CS_ARCH_X86, capstone.CS_MODE_32)
    md_32.detail = True

    errors = []
    for (alloc_event, free_event) in get_alloc_free_pairs(alloc_events, errors):
        # Get the memory accesses access of the allocated block after the free
        # The optimization is disabled if we can't translate all the pages of the memory range
        mem_access = None
        mem_accesses = None

        if alloc_event.memory_range.try_translate_all_pages(free_event.tr_begin.context_before()):
            mem_accesses = reven2.util.collate(map(
                lambda page: server.trace.memory_accesses(
                    reven2.address.PhysicalAddress(page['physical_address']),
                    page['size'],
                    from_transition=free_event.tr_end,
                    to_transition=None,
                    is_forward=True,
                    operation=None
                ),
                alloc_event.memory_range.pages
            ), key=lambda access: access.transition)

            try:
                mem_access = next(mem_accesses)
            except StopIteration:
                continue

        else:
            print("Phys:0x%x: Allocated at #%d (0x%x of size 0x%x) and freed at #%d (0x%x)" % (physical_address,
                  alloc_event.tr_begin.id, alloc_event.memory_range.logical_address, alloc_event.memory_range.size,
                  free_event.tr_begin.id, free_event.logical_address))
            print("    Warning: Memory history optimization disabled because we couldn't "
                  "translate all the memory pages")
            print()

        # Setup the slicing
        taint = start_alloc_address_taint(server, alloc_event, free_event)

        # Iterate on the slice
        access_count = 0
        for access in taint.accesses(changes_only=False).all():
            access_count += 1
            access_transition = access.transition

            if access_count > taint_max_length:
                print("Phys:0x%x: Allocated at #%d (0x%x of size 0x%x) and freed at #%d (0x%x)" % (physical_address,
                      alloc_event.tr_begin.id, alloc_event.memory_range.logical_address, alloc_event.memory_range.size,
                      free_event.tr_begin.id, free_event.logical_address))
                print("    Warning: Allocation skipped: post-free taint stopped after %d accesses" % access_count)
                print()
                break

            if mem_accesses is not None:
                # Check that we have an access on the same transition as the taint access
                # if not the memory operand won't be an UAF anyway so we can skip it

                while mem_access.transition < access_transition:
                    try:
                        mem_access = next(mem_accesses)
                    except StopIteration:
                        break

                if mem_access.transition != access_transition:
                    continue

            md = md_64 if access_transition.mode == reven2.trace.Mode.X86_64 else md_32
            cs_insn = next(md.disasm(access_transition.instruction.raw, access_transition.instruction.size))

            # Skip `lea` instructions are they are not really memory read/write and the taint
            # will propagate the taint anyway so that we will see the dereference of the computed value
            if cs_insn.mnemonic == "lea":
                continue

            registers_in_state = {}
            for reg_slice, _ in access.state_before().tainted_registers():
                registers_in_state[reg_slice.register.name] = reg_slice

            for cs_op in cs_insn.operands:
                if cs_op.type != capstone.x86.X86_OP_MEM:
                    continue

                uaf_reg = None

                if cs_op.value.mem.base != 0:
                    base_reg_name = cs_insn.reg_name(cs_op.value.mem.base)
                    if base_reg_name in registers_in_state:
                        uaf_reg = registers_in_state[base_reg_name]

                if uaf_reg is None and cs_op.value.mem.index != 0:
                    index_reg_name = cs_insn.reg_name(cs_op.value.mem.index)
                    if index_reg_name in registers_in_state:
                        uaf_reg = registers_in_state[index_reg_name]

                if uaf_reg is None:
                    continue

                dereferenced_address = compute_dereferenced_address(access_transition, cs_insn, cs_op)

                if mem_accesses is None:
                    # As we don't have the memory access optimization we need to check if the dereferenced address
                    # is in the allocated buffer
                    # We only check on the translated pages as the taint won't return an access with a pagefault
                    # so the dereferenced address should be translated

                    dereferenced_physical_address = reven2.address.LogicalAddress(dereferenced_address).translate(
                        access_transition.context_before())

                    if dereferenced_physical_address is None:
                        continue

                    if not alloc_event.memory_range.is_physical_address_range_in_translated_pages(
                            dereferenced_physical_address.offset, 1):
                        continue

                print("Phys:0x%x: Allocated at #%d (0x%x of size 0x%x) and freed at #%d (0x%x)" % (physical_address,
                      alloc_event.tr_begin.id, alloc_event.memory_range.logical_address, alloc_event.memory_range.size,
                      free_event.tr_begin.id, free_event.logical_address))
                print("    UAF coming from reg %s[%d-%d] leading to dereferenced address = 0x%x" % (
                      uaf_reg.register.name, uaf_reg.begin, uaf_reg.end, dereferenced_address))
                print("    ", end="")
                print(access_transition, end=" ")
                print(access_transition.context_before().ossi.location())
                print("    Accessed %d transitions after the free" % (access_transition.id - free_event.tr_end.id))
                print()

                uaf_count += 1

    if len(errors) > 0:
        print("Phys:0x%x: Error(s) detected:" % (physical_address))
        for error in errors:
            print("    - %s" % error)

    return uaf_count


# %%
# %%time

uaf_count = 0

if uaf_faulty_physical_address is None:
    for physical_address, alloc_events in alloc_dict.items():
        uaf_count += uaf_analyze_function(physical_address, alloc_events)

else:
    if uaf_faulty_physical_address not in alloc_dict:
        raise KeyError("The passed physical address was not detected during the allocation search")
    uaf_count += uaf_analyze_function(uaf_faulty_physical_address, alloc_dict[uaf_faulty_physical_address])


print("---------------------------------------------------------------------------------")
uaf_begin_range = "the beginning of the trace" if uaf_from_tr is None else "#{}".format(uaf_to_tr)
uaf_end_range = "the end of the trace" if uaf_to_tr is None else "#{}".format(uaf_to_tr)
uaf_range = ("on the whole trace" if uaf_from_tr is None and uaf_to_tr is None else
             "between {} and {}".format(uaf_begin_range, uaf_end_range))

uaf_range_size = server.trace.transition_count
if uaf_from_tr is not None:
    uaf_range_size -= uaf_from_tr
if uaf_to_tr is not None:
    uaf_range_size -= server.trace.transition_count - uaf_to_tr

if uaf_faulty_physical_address is None:
    searched_memory_addresses = "with {} searched memory addresses".format(len(alloc_dict))
else:
    searched_memory_addresses = "on {:#x}".format(uaf_faulty_physical_address)

print("{} UAF(s) found {} ({} transitions) {}".format(
    uaf_count, uaf_range, uaf_range_size, searched_memory_addresses
))
print("---------------------------------------------------------------------------------")

# %%