Searching for Use-of-Uninitialized-Memory vulnerabilities

This notebook allows to search for potential Use-of-Uninitialized-Memory vulnerabilities in a REVEN trace.

Prerequisites

  • This notebook should be run in a jupyter notebook server equipped with a REVEN 2 python kernel. REVEN comes with a jupyter notebook server accessible with the Open Python button in the Analyze page of any scenario.
  • This notebook depends on capstone being installed in the REVEN 2 python kernel. To install capstone in the current environment, please execute the capstone cell of this notebook.
  • This notebook requires the Memory History resource for your target scenario.

Running the notebook

Fill out the parameters in the Parameters cell below, then run all the cells of this notebook.

Source

# -*- coding: utf-8 -*-
# ---
# jupyter:
#   jupytext:
#     formats: ipynb,py:percent
#     text_representation:
#       extension: .py
#       format_name: percent
#   kernelspec:
#     display_name: reven
#     language: python
#     name: reven-python3
# ---

# %% [markdown]
# # Searching for Use-of-Uninitialized-Memory vulnerabilities
#
# This notebook allows to search for potential Use-of-Uninitialized-Memory vulnerabilities in a REVEN trace.
#
# ## Prerequisites
#
# - This notebook should be run in a jupyter notebook server equipped with a REVEN 2 python kernel.
#   REVEN comes with a jupyter notebook server accessible with the `Open Python` button
#   in the `Analyze` page of any scenario.
# - This notebook depends on capstone being installed in the REVEN 2 python kernel.
#   To install capstone in the current environment, please execute the capstone cell of this notebook.
# - This notebook requires the Memory History resource for your target scenario.
#
#
# ## Running the notebook
#
# Fill out the parameters in the [Parameters cell](#Parameters) below, then run all the cells of this notebook.

# %% [markdown]
# # Capstone Installation
#
# Check for capstone's presence. If missing, attempt to get it from pip

# %%
try:
    import capstone
    print("capstone already installed")
except ImportError:
    print("Could not find capstone, attempting to install it from pip")
    import sys
    import subprocess

    command = [f"{sys.executable}", "-m", "pip", "install", "capstone"]
    p = subprocess.run(command)

    if int(p.returncode) != 0:
        raise RuntimeError("Error installing capstone")

    import capstone  # noqa
    print("Successfully installed capstone")

# %% [markdown]
# # Parameters

# %%
# Server connection

# Host of the REVEN server running the scenario.
# When running this notebook from the Project Manager, '127.0.0.1' should be the correct value.
reven_backend_host = '127.0.0.1'

# Port of the REVEN server running the scenario.
# After starting a REVEN server on your scenario, you can get its port on the Analyze page of that scenario.
reven_backend_port = 13370


# Range control

# First transition considered for the detection of allocation/deallocation pairs
# If set to None, then the first transition of the trace
from_tr = None

# First transition **not** considered for the detection of allocation/deallocation pairs
# If set to None, then the last transition of the trace
to_tr = None


# Filter control

# Beware that, when filtering, if an allocation happens in the specified process and/or binary,
# the script will fail if the deallocation happens in a different process and/or binary.
# This issue should only happen for allocations in the kernel.

# Specify on which PID the allocation/deallocation pairs should be kept.
# Set a value of None to not filter on the PID.
faulty_process_pid = None

# Specify on which process name the allocation/deallocation pairs should be kept.
# Set a value of None to not filter on the process name.
# If both a process PID and a process name are specified, please make sure that they both
# refer to the same process, otherwise all allocations will be filtered and no results
# will be produced.
faulty_process_name = None

# Specify on which binary name the allocation/deallocation pairs should be kept.
# Set a value of None to not filter on the binary name.
# Only allocation/deallocation taking place in the binaries whose filename,
# path or name contain the specified value are kept.
# If filtering on both a process and a binary, please make sure that there are
# allocations taking place in that binary in the selected process, otherwise all
# allocations will be filtered and no result will be produced.
faulty_binary = None


# Address control

# Specify a **physical** address suspected of being faulty here,
# to only test the script for this specific address, instead of all (filtered) allocations.
# The address should still be returned by an allocation/deallocation pair.
# To get a physical address from a virtual address, find a context where the address
# is mapped, then use `virtual_address.translate(ctx)`.
faulty_physical_address = None

# Allocator control

# The script can use two allocators to find allocation/deallocation pairs.
# The following booleans allow to enable the search for allocations by these
# allocators for a scenario.
# Generally it is expected to have only a single allocator enabled for a given
# scenario.

# To add your own allocator, please look at how the two provided allocators were
# added.

# Whether or not to look for windows malloc/free allocation/deallocation pairs.
search_windows_malloc = True

# Whether or not to look for ExAllocatePoolWithTag/ExFreePoolWithTag
# allocation/deallocation pairs.
# This allocator is used by the Windows kernel.
search_pool_allocation = False

# Whether or not to look for linux malloc/free allocation/deallocation pairs.
search_linux_malloc = False


# Analysis control

# Whether or not the display should be restricted to UUMs impacting the control flow
only_display_uum_changing_control_flow = False

# %%
import struct  # noqa: E402
from collections import OrderedDict  # noqa: E402
from typing import Dict, List  # noqa: E402

import reven2  # noqa: E402
import reven2.preview.taint  # noqa: E402

# %%
# Python script to connect to this scenario:
server = reven2.RevenServer(reven_backend_host, reven_backend_port)
print(server.trace.transition_count)


# %%
class MemoryRange:
    page_size = 4096
    page_mask = ~(page_size - 1)

    def __init__(self, logical_address, size):
        self.logical_address = logical_address
        self.size = size

        self.pages = [{
            'logical_address': self.logical_address,
            'size': self.size,
            'physical_address': None,
            'ctx_physical_address_mapped': None,
        }]

        # Compute the pages
        while (((self.pages[-1]['logical_address'] & ~MemoryRange.page_mask) + self.pages[-1]['size'] - 1)
                >= MemoryRange.page_size):
            # Compute the size of the new page
            new_page_size = ((self.pages[-1]['logical_address'] & ~MemoryRange.page_mask) + self.pages[-1]['size']
                             - MemoryRange.page_size)

            # Reduce the size of the previous page and create the new one
            self.pages[-1]['size'] -= new_page_size
            self.pages.append({
                'logical_address': self.pages[-1]['logical_address'] + self.pages[-1]['size'],
                'size': new_page_size,
                'physical_address': None,
                'ctx_physical_address_mapped': None,
            })

    def try_translate_first_page(self, ctx):
        if self.pages[0]['physical_address'] is not None:
            return True

        physical_address = reven2.address.LogicalAddress(self.pages[0]['logical_address']).translate(ctx)

        if physical_address is None:
            return False

        self.pages[0]['physical_address'] = physical_address.offset
        self.pages[0]['ctx_physical_address_mapped'] = ctx

        return True

    def try_translate_all_pages(self, ctx):
        return_value = True

        for page in self.pages:
            if page['physical_address'] is not None:
                continue

            physical_address = reven2.address.LogicalAddress(page['logical_address']).translate(ctx)

            if physical_address is None:
                return_value = False
                continue

            page['physical_address'] = physical_address.offset
            page['ctx_physical_address_mapped'] = ctx

        return return_value

    def is_physical_address_range_in_translated_pages(self, physical_address, size):
        for page in self.pages:
            if page['physical_address'] is None:
                continue

            if (
                physical_address >= page['physical_address']
                and physical_address + size <= page['physical_address'] + page['size']
            ):
                return True

        return False

    def __repr__(self):
        return "MemoryRange(0x%x, %d)" % (self.logical_address, self.size)


# Utils to translate the physical address of an address allocated just now
#     - ctx should be the ctx where the address is located in `rax`
#     - memory_range should be the range of memory of the newly allocated buffer
#
# We are using the translate API to translate it but sometimes just after the allocation
# the address isn't mapped yet. For that we are using the slicing and for all slice access
# we are trying to translate the address.
def translate_first_page_of_allocation(ctx, memory_range):
    if memory_range.try_translate_first_page(ctx):
        return

    tainter = reven2.preview.taint.Tainter(server.trace)
    taint = tainter.simple_taint(
        tag0="rax",
        from_context=ctx,
        to_context=None,
        is_forward=True
    )

    for access in taint.accesses(changes_only=False).all():
        if memory_range.try_translate_first_page(access.transition.context_after()):
            taint.cancel()
            return

    raise RuntimeError("Couldn't find the physical address of the first page")


# %%
class AllocEvent:
    def __init__(self, memory_range, tr_begin, tr_end):
        self.memory_range = memory_range
        self.tr_begin = tr_begin
        self.tr_end = tr_end


class FreeEvent:
    def __init__(self, logical_address, tr_begin, tr_end):
        self.logical_address = logical_address
        self.tr_begin = tr_begin
        self.tr_end = tr_end


def retrieve_events_for_symbol(
    alloc_dict,
    event_class,
    symbol,
    retrieve_event_info,
    event_filter=None,
):
    for ctx in server.trace.search.symbol(
        symbol,
        from_context=None if from_tr is None else server.trace.context_before(from_tr),
        to_context=None if to_tr is None else server.trace.context_before(to_tr)
    ):
        # We don't want hit on exception (code pagefault, hardware interrupts, etc)
        if ctx.transition_after().exception is not None:
            continue

        previous_location = (ctx - 1).ossi.location()
        previous_process = (ctx - 1).ossi.process()

        # Filter by process pid/process name/binary name

        # Filter by process pid
        if faulty_process_pid is not None and previous_process.pid != faulty_process_pid:
            continue

        # Filter by process name
        if faulty_process_name is not None and previous_process.name != faulty_process_name:
            continue

        # Filter by binary name / filename / path
        if faulty_binary is not None and faulty_binary not in [
            previous_location.binary.name,
            previous_location.binary.filename,
            previous_location.binary.path
        ]:
            continue

        # Filter the event with the argument filter
        if event_filter is not None:
            if event_filter(ctx.ossi.location(), previous_location):
                continue

        # Retrieve the call/ret
        # The heuristic is that the ret is the end of our function
        #   - If the call is inlined it should be at the end of the caller function, so the ret is the ret of our
        #     function
        #   - If the call isn't inlined, the ret should be the ret of our function
        ctx_call = next(ctx.stack.frames()).creation_transition.context_after()
        ctx_ret = ctx_call.transition_before().find_inverse().context_before()

        # Build the event by reading the needed registers
        if event_class == AllocEvent:
            current_address, size = retrieve_event_info(ctx, ctx_ret)

            # Filter the alloc failing
            if current_address == 0x0:
                continue

            memory_range = MemoryRange(current_address, size)
            try:
                translate_first_page_of_allocation(ctx_ret, memory_range)
            except RuntimeError:
                # If we can't translate the first page we assume that the buffer isn't used because
                # the heuristic to detect the call/ret failed
                continue

            if memory_range.pages[0]['physical_address'] not in alloc_dict:
                alloc_dict[memory_range.pages[0]['physical_address']] = []

            alloc_dict[memory_range.pages[0]['physical_address']].append(
                AllocEvent(
                    memory_range,
                    ctx.transition_after(), ctx_ret.transition_after()
                )
            )
        elif event_class == FreeEvent:
            current_address = retrieve_event_info(ctx, ctx_ret)

            # Filter the free of NULL
            if current_address == 0x0:
                continue

            current_physical_address = reven2.address.LogicalAddress(current_address).translate(ctx).offset

            if current_physical_address not in alloc_dict:
                alloc_dict[current_physical_address] = []

            alloc_dict[current_physical_address].append(
                FreeEvent(
                    current_address,
                    ctx.transition_after(), ctx_ret.transition_after()
                )
            )
        else:
            raise RuntimeError("Unknown event class: %s" % event_class.__name__)


# %%
# %%time

alloc_dict: Dict = {}


# Basic functions to retrieve the arguments
# They are working for the allocations/frees functions but won't work for all functions
# Particularly because on x86 we don't handle the size of the arguments
# nor if they are pushed left to right or right to left
def retrieve_first_argument(ctx):
    if ctx.is64b():
        return ctx.read(reven2.arch.x64.rcx)
    else:
        esp = ctx.read(reven2.arch.x64.esp)
        return ctx.read(reven2.address.LogicalAddress(esp + 4, reven2.arch.x64.ss), 4)


def retrieve_second_argument(ctx):
    if ctx.is64b():
        return ctx.read(reven2.arch.x64.rdx)
    else:
        esp = ctx.read(reven2.arch.x64.esp)
        return ctx.read(reven2.address.LogicalAddress(esp + 8, reven2.arch.x64.ss), 4)


def retrieve_first_argument_linux(ctx):
    if ctx.is64b():
        return ctx.read(reven2.arch.x64.rdi)
    else:
        raise NotImplementedError("Linux 32bits")


def retrieve_second_argument_linux(ctx):
    if ctx.is64b():
        return ctx.read(reven2.arch.x64.rsi)
    else:
        raise NotImplementedError("Linux 32bits")


def retrieve_return_value(ctx):
    if ctx.is64b():
        return ctx.read(reven2.arch.x64.rax)
    else:
        return ctx.read(reven2.arch.x64.eax)


def retrieve_alloc_info_with_size_as_first_argument(ctx_begin, ctx_end):
    return (
        retrieve_return_value(ctx_end),
        retrieve_first_argument(ctx_begin)
    )


def retrieve_alloc_info_with_size_as_first_argument_linux(ctx_begin, ctx_end):
    return (
        retrieve_return_value(ctx_end),
        retrieve_first_argument_linux(ctx_begin)
    )


def retrieve_alloc_info_with_size_as_second_argument(ctx_begin, ctx_end):
    return (
        retrieve_return_value(ctx_end),
        retrieve_second_argument(ctx_begin)
    )


def retrieve_alloc_info_with_size_as_second_argument_linux(ctx_begin, ctx_end):
    return (
        retrieve_return_value(ctx_end),
        retrieve_second_argument_linux(ctx_begin)
    )


def retrieve_alloc_info_for_calloc(ctx_begin, ctx_end):
    return (
        retrieve_return_value(ctx_end),
        retrieve_first_argument(ctx_begin) * retrieve_second_argument(ctx_begin)
    )


def retrieve_alloc_info_for_calloc_linux(ctx_begin, ctx_end):
    return (
        retrieve_return_value(ctx_end),
        retrieve_first_argument_linux(ctx_begin) * retrieve_second_argument_linux(ctx_begin)
    )


def retrieve_free_info_with_address_as_first_argument(ctx_begin, ctx_end):
    return retrieve_first_argument(ctx_begin)


def retrieve_free_info_with_address_as_first_argument_linux(ctx_begin, ctx_end):
    return retrieve_first_argument_linux(ctx_begin)


if search_windows_malloc:
    def filter_in_realloc(location, caller_location):
        return location.binary == caller_location.binary and caller_location.symbol.name == "realloc"

    # Search for allocations with malloc
    for symbol in server.ossi.symbols(r'^_?malloc$', binary_hint=r'msvcrt.dll'):
        retrieve_events_for_symbol(alloc_dict, AllocEvent, symbol, retrieve_alloc_info_with_size_as_first_argument,
                                   filter_in_realloc)

    # Search for allocations with calloc
    for symbol in server.ossi.symbols(r'^_?calloc(_crt)?$', binary_hint=r'msvcrt.dll'):
        retrieve_events_for_symbol(alloc_dict, AllocEvent, symbol, retrieve_alloc_info_for_calloc)

    # Search for deallocations with free
    for symbol in server.ossi.symbols(r'^_?free$', binary_hint=r'msvcrt.dll'):
        retrieve_events_for_symbol(alloc_dict, FreeEvent, symbol, retrieve_free_info_with_address_as_first_argument,
                                   filter_in_realloc)

    # Search for re-allocations with realloc
    for symbol in server.ossi.symbols(r'^_?realloc$', binary_hint=r'msvcrt.dll'):
        retrieve_events_for_symbol(alloc_dict, AllocEvent, symbol, retrieve_alloc_info_with_size_as_second_argument)
        retrieve_events_for_symbol(alloc_dict, FreeEvent, symbol, retrieve_free_info_with_address_as_first_argument)

if search_pool_allocation:
    # Search for allocations with ExAllocatePool...
    def filter_ex_allocate_pool(location, caller_location):
        return location.binary == caller_location.binary and caller_location.symbol.name.startswith("ExAllocatePool")

    for symbol in server.ossi.symbols(r'^ExAllocatePool', binary_hint=r'ntoskrnl.exe'):
        retrieve_events_for_symbol(alloc_dict, AllocEvent, symbol, retrieve_alloc_info_with_size_as_second_argument,
                                   filter_ex_allocate_pool)

    # Search for deallocations with ExFreePool...
    def filter_ex_free_pool(location, caller_location):
        return location.binary == caller_location.binary and caller_location.symbol.name == "ExFreePool"

    for symbol in server.ossi.symbols(r'^ExFreePool', binary_hint=r'ntoskrnl.exe'):
        retrieve_events_for_symbol(alloc_dict, FreeEvent, symbol, retrieve_free_info_with_address_as_first_argument,
                                   filter_ex_free_pool)

if search_linux_malloc:
    def filter_in_realloc(location, caller_location):
        return (
            location.binary == caller_location.binary
            and (
                caller_location.symbol is not None
                and caller_location.symbol.name in ["realloc", "__GI___libc_realloc"]
            )
        )

    # Search for allocations with malloc
    for symbol in server.ossi.symbols(r'^((__GI___libc_malloc)|(__libc_malloc))$', binary_hint=r'libc-.*.so'):
        retrieve_events_for_symbol(alloc_dict, AllocEvent, symbol,
                                   retrieve_alloc_info_with_size_as_first_argument_linux, filter_in_realloc)

    # Search for allocations with calloc
    for symbol in server.ossi.symbols(r'^((__calloc)|(__libc_calloc))$', binary_hint=r'libc-.*.so'):
        retrieve_events_for_symbol(alloc_dict, AllocEvent, symbol, retrieve_alloc_info_for_calloc_linux)

    # Search for deallocations with free
    for symbol in server.ossi.symbols(r'^((__GI___libc_free)|(cfree))$', binary_hint=r'libc-.*.so'):
        retrieve_events_for_symbol(alloc_dict, FreeEvent, symbol,
                                   retrieve_free_info_with_address_as_first_argument_linux, filter_in_realloc)

    # Search for re-allocations with realloc
    for symbol in server.ossi.symbols(r'^((__GI___libc_realloc)|(realloc))$', binary_hint=r'libc-.*.so'):
        retrieve_events_for_symbol(alloc_dict, AllocEvent, symbol,
                                   retrieve_alloc_info_with_size_as_second_argument_linux)
        retrieve_events_for_symbol(alloc_dict, FreeEvent, symbol,
                                   retrieve_free_info_with_address_as_first_argument_linux)

# Sort the events per address and event type
for physical_address in alloc_dict.keys():
    alloc_dict[physical_address] = list(sorted(
        alloc_dict[physical_address],
        key=lambda event: (event.tr_begin.id, 0 if isinstance(event, FreeEvent) else 1)
    ))

# Sort the dict by address
alloc_dict = OrderedDict(sorted(alloc_dict.items()))


# %%
def get_alloc_free_pairs(events, errors=None):
    previous_event = None
    for event in events:
        if isinstance(event, AllocEvent):
            if previous_event is None:
                pass
            elif isinstance(previous_event, AllocEvent):
                if errors is not None:
                    errors.append("Two consecutives allocs found")
        elif isinstance(event, FreeEvent):
            if previous_event is None:
                continue
            elif isinstance(previous_event, FreeEvent):
                if errors is not None:
                    errors.append("Two consecutives frees found")
            elif isinstance(previous_event, AllocEvent):
                yield (previous_event, event)
        else:
            assert 0, ("Unknown event type: %s" % type(event))

        previous_event = event

    if isinstance(previous_event, AllocEvent):
        yield (previous_event, None)


# %%
# %%time

# Basic checks of the events

for physical_address, events in alloc_dict.items():
    for event in events:
        if not isinstance(event, AllocEvent) and not isinstance(event, FreeEvent):
            raise RuntimeError("Unknown event type: %s" % type(event))

    errors: List[str] = []
    for (alloc_event, free_event) in get_alloc_free_pairs(events, errors):
        # Check the uniformity of the logical address between the alloc and the free
        if free_event is not None and alloc_event.memory_range.logical_address != free_event.logical_address:
            errors.append(
                "Phys:0x%x: Alloc #%d - Free #%d with different logical address: 0x%x != 0x%x" % (
                    physical_address,
                    alloc_event.tr_begin.id,
                    free_event.tr_begin.id,
                    alloc_event.memory_range.logical_address,
                    free_event.logical_address
                )
            )

        # Check size of 0x0
        if alloc_event.memory_range.size == 0x0 or alloc_event.memory_range.size is None:
            if free_event is None:
                errors.append("Phys:0x%x: Alloc #%d - Free N/A with weird size %s" % (
                    physical_address, alloc_event.tr_begin.id, alloc_event.memory_range.size
                ))
            else:
                errors.append("Phys:0x%x: Alloc #%d - Free #%d with weird size %s" % (
                    physical_address, alloc_event.tr_begin.id, free_event.tr_begin.id, alloc_event.memory_range.size
                ))

    if len(errors) > 0:
        print("Phys:0x%x: Error(s) detected:" % (physical_address))
        for error in errors:
            print("    - %s" % error)

# %%
# Print the events

for physical_address, events in alloc_dict.items():
    print("Phys:0x%x" % (physical_address))
    print("    Events:")
    for event in events:
        if isinstance(event, AllocEvent):
            print("        - Alloc at #%d (0x%x of size 0x%x)" % (
                event.tr_begin.id, event.memory_range.logical_address, event.memory_range.size
            ))
        elif isinstance(event, FreeEvent):
            print("        - Free at #%d (0x%x)" % (event.tr_begin.id, event.logical_address))

    print("    Pairs:")
    for (alloc_event, free_event) in get_alloc_free_pairs(events):
        if free_event is None:
            print("    - Allocated at #%d (0x%x of size 0x%x) and freed at N/A" % (
                alloc_event.tr_begin.id, alloc_event.memory_range.logical_address, alloc_event.memory_range.size
            ))
        else:
            print("    - Allocated at #%d (0x%x of size 0x%x) and freed at #%d (0x%x)" % (
                alloc_event.tr_begin.id, alloc_event.memory_range.logical_address, alloc_event.memory_range.size,
                free_event.tr_begin.id, free_event.logical_address
            ))

    print()

# %%
# Setup capstone
md_64 = capstone.Cs(capstone.CS_ARCH_X86, capstone.CS_MODE_64)
md_64.detail = True

md_32 = capstone.Cs(capstone.CS_ARCH_X86, capstone.CS_MODE_32)
md_32.detail = True


# Retrieve a `bytes` array from the capstone operand
def get_mask_from_cs_op(mask_op):
    mask_formats = [
        None,
        'B',  # 1
        'H',  # 2
        None,
        'I',  # 4
        None,
        None,
        None,
        'Q',  # 8
    ]

    return struct.pack(
        mask_formats[mask_op.size],
        mask_op.imm if mask_op.imm >= 0 else ((1 << (mask_op.size * 8)) + mask_op.imm)
    )


# This function will return an array containing either `True` or `False` for each byte
# of the memory access to know which one should be considered for an UUM
def filter_and_bytes(cs_insn, mem_access):
    # `and` instruction could be used to set some bytes to 0 with an immediate mask
    # Bytes in the mask tell us what to do
    #  - with 0x00 we should consider the write and not the read
    #  - with 0xFF we should consider neither of them
    #  - with everything else we should consider the reads and the writes
    filtered_bytes = [False] * mem_access.size

    dest_op = cs_insn.operands[0]
    mask_op = cs_insn.operands[1]

    if dest_op.type != capstone.x86.X86_OP_MEM or mask_op.type != capstone.x86.X86_OP_IMM:
        return filtered_bytes

    mask = get_mask_from_cs_op(mask_op)

    for i in range(0, mask_op.size):
        if mask[i] == 0x00 and mem_access.operation == reven2.memhist.MemoryAccessOperation.Read:
            filtered_bytes[i] = True
        elif mask[i] == 0xFF:
            filtered_bytes[i] = True

    return filtered_bytes


# This function will return an array containing either `True` or `False` for each byte
# of the memory access to know which one should be considered for an UUM
def filter_or_bytes(cs_insn, mem_access):
    # `or` instruction could be used to set some bytes to 0 with an immediate mask
    # Bytes in the mask tell us what to do
    #  - with 0x00 we should consider neither of them
    #  - with 0xFF we should consider the write and not the read
    #  - with everything else we should consider the reads and the writes

    filtered_bytes = [False] * mem_access.size

    dest_op = cs_insn.operands[0]
    mask_op = cs_insn.operands[1]

    if dest_op.type != capstone.x86.X86_OP_MEM or mask_op.type != capstone.x86.X86_OP_IMM:
        return filtered_bytes

    mask = get_mask_from_cs_op(mask_op)

    for i in range(0, mask_op.size):
        if mask[i] == 0x00:
            filtered_bytes[i] = True
        elif mask[i] == 0xFF and mem_access.operation == reven2.memhist.MemoryAccessOperation.Read:
            filtered_bytes[i] = True

    return filtered_bytes


# This function will return an array containing either `True` or `False` for each byte
# of the memory access to know which one should be considered for an UUM.
# Only bytes whose index returns `False` will be considered for potential UUM
def filter_bts_bytes(cs_insn, mem_access):
    # `bts` instruction with an immediate only access one byte in the memory
    # but could be written with a bigger access (e.g `dword`)
    # We only consider the byte accessed by the `bts` instruction in this case
    filtered_bytes = [False] * mem_access.size

    dest_op = cs_insn.operands[0]
    bit_nb_op = cs_insn.operands[1]

    if dest_op.type != capstone.x86.X86_OP_MEM or bit_nb_op.type != capstone.x86.X86_OP_IMM:
        return filtered_bytes

    filtered_bytes = [True] * mem_access.size
    filtered_bytes[bit_nb_op.imm // 8] = False

    return filtered_bytes


# This function will return an array containing either `True` or `False` for each byte
# of the memory access to know which one should be considered for an UUM
def get_filtered_bytes(cs_insn, mem_access):
    if cs_insn.mnemonic in ["and", "lock and"]:
        return filter_and_bytes(cs_insn, mem_access)
    elif cs_insn.mnemonic in ["or", "lock or"]:
        return filter_or_bytes(cs_insn, mem_access)
    elif cs_insn.mnemonic in ["bts", "lock bts"]:
        return filter_bts_bytes(cs_insn, mem_access)

    return [False] * mem_access.size


class UUM:
    # This array contains the relation between the capstone flag and
    # the reven register to check
    test_eflags = {
        capstone.x86.X86_EFLAGS_TEST_OF: reven2.arch.x64.of,
        capstone.x86.X86_EFLAGS_TEST_SF: reven2.arch.x64.sf,
        capstone.x86.X86_EFLAGS_TEST_ZF: reven2.arch.x64.zf,
        capstone.x86.X86_EFLAGS_TEST_PF: reven2.arch.x64.pf,
        capstone.x86.X86_EFLAGS_TEST_CF: reven2.arch.x64.cf,
        capstone.x86.X86_EFLAGS_TEST_NT: reven2.arch.x64.nt,
        capstone.x86.X86_EFLAGS_TEST_DF: reven2.arch.x64.df,
        capstone.x86.X86_EFLAGS_TEST_RF: reven2.arch.x64.rf,
        capstone.x86.X86_EFLAGS_TEST_IF: reven2.arch.x64.if_,
        capstone.x86.X86_EFLAGS_TEST_TF: reven2.arch.x64.tf,
        capstone.x86.X86_EFLAGS_TEST_AF: reven2.arch.x64.af,
    }

    def __init__(self, alloc_event, free_event, memaccess, uum_bytes):
        self.alloc_event = alloc_event
        self.free_event = free_event
        self.memaccess = memaccess
        self.bytes = uum_bytes

        # Store conditionals depending on uninitialized memory
        # - 'transition': the transition
        # - 'reg': the flag which is uninitialized
        self.conditionals = None

    @property
    def nb_uum_bytes(self):
        return len(list(filter(lambda byte: byte, self.bytes)))

    def analyze_usage(self):
        # Initialize an array of what to taint based on the uninitialized bytes
        taint_tags = []
        for i in range(0, self.memaccess.size):
            if not self.bytes[i]:
                continue

            taint_tags.append(reven2.preview.taint.TaintedMemories(self.memaccess.physical_address + i, 1))

        # Start a taint of just the first instruction (the memory access)
        # We don't want to keep the memory tainted as if the memory is accessed later we
        # will have another UUM anyway. So we are using the state of this first taint
        # and we remove the initial tainted memory to start a new taint from after the first
        # instruction to the end of the trace
        tainter = reven2.preview.taint.Tainter(server.trace)
        taint = tainter.simple_taint(
            tag0=taint_tags,
            from_context=self.memaccess.transition.context_before(),
            to_context=self.memaccess.transition.context_after() + 1,
            is_forward=True
        )

        state_after_first_instruction = taint.state_at(self.memaccess.transition.context_after())

        # We assume that we won't have other tainted memories than the uninitialized memories
        # after the first instruction, so we can just keep the registers from
        # `state_after_first_instruction` and not the memories
        # In the future, we should keep the inverse of the intersection of the uninitialized memories
        # and the memories in the `state_after_first_instruction`
        taint = tainter.simple_taint(
            tag0=list(map(
                lambda x: x[0],
                state_after_first_instruction.tainted_registers()
            )),
            from_context=self.memaccess.transition.context_after(),
            to_context=None,
            is_forward=True
        )

        conditionals = []
        for access in taint.accesses(changes_only=False).all():
            ctx = access.transition.context_before()
            md = md_64 if ctx.is64b() else md_32
            cs_insn = next(md.disasm(access.transition.instruction.raw, access.transition.instruction.size))

            # Test conditional jump & move
            for flag, reg in self.test_eflags.items():
                if not cs_insn.eflags & flag:
                    continue

                if not UUM._is_register_tainted_in_taint_state(
                    taint.state_at(access.transition.context_after()),
                    reg
                ):
                    continue

                conditionals.append({
                    'transition': access.transition,
                    'flag': reg,
                })

        self.conditionals = conditionals

    def _is_register_tainted_in_taint_state(taint_state, reg):
        for tainted_reg, _ in taint_state.tainted_registers():
            if tainted_reg.register == reg:
                return True
        return False

    def __str__(self):
        desc = ""

        if self.free_event is None:
            desc += "Phys:0x%x: Allocated at #%d (0x%x of size 0x%x) and freed at N/A\n" % (
                self.alloc_event.memory_range.pages[0]['physical_address'],
                self.alloc_event.tr_begin.id, self.alloc_event.memory_range.logical_address,
                self.alloc_event.memory_range.size,
            )
            desc += "\tAlloc in: %s / %s\n\n" % (
                (self.alloc_event.tr_begin - 1).context_before().ossi.location(),
                (self.alloc_event.tr_begin - 1).context_before().ossi.process(),
            )
        else:
            desc += "Phys:0x%x: Allocated at #%d (0x%x of size 0x%x) ad freed at #%d (0x%x)\n" % (
                self.alloc_event.memory_range.pages[0]['physical_address'],
                self.alloc_event.tr_begin.id, self.alloc_event.memory_range.logical_address,
                self.alloc_event.memory_range.size,
                self.free_event.tr_begin.id, self.free_event.logical_address,
            )
            desc += "\tAlloc in: %s / %s\n" % (
                (self.alloc_event.tr_begin - 1).context_before().ossi.location(),
                (self.alloc_event.tr_begin - 1).context_before().ossi.process(),
            )
            desc += "\tFree in: %s / %s\n\n" % (
                (self.free_event.tr_begin - 1).context_before().ossi.location(),
                (self.free_event.tr_begin - 1).context_before().ossi.process(),
            )
        desc += "\tUUM of %d byte(s) first read at:\n" % self.nb_uum_bytes
        desc += "\t\t%s / %s\n" % (
            self.memaccess.transition.context_before().ossi.location(),
            self.memaccess.transition.context_before().ossi.process(),
        )
        desc += "\t\t%s" % (self.memaccess.transition)

        if self.conditionals is None:
            return desc
        elif len(self.conditionals) == 0:
            desc += "\n\n\tNot impacting the control flow"
            return desc

        desc += "\n\n\tThe control flow depends on uninitialized value(s):"

        conditionals = []
        for conditional in self.conditionals:
            conditional_str = "\n\t\tFlag '%s' depends on uninitialized memory\n" % conditional['flag'].name
            conditional_str += "\t\t%s / %s\n" % (
                conditional['transition'].context_before().ossi.location(),
                conditional['transition'].context_before().ossi.process()
            )
            conditional_str += "\t\t%s" % (conditional['transition'])

            conditionals.append(conditional_str)

        desc += "\n".join(conditionals)

        return desc


def analyze_one_memaccess(alloc_event, free_event, pages, pages_bytes_written, memaccess):
    if (
        memaccess.transition > alloc_event.tr_begin
        and memaccess.transition < alloc_event.tr_end
        and memaccess.operation == reven2.memhist.MemoryAccessOperation.Read
    ):
        # We assume that read accesses during the allocator are okay
        # as the allocator should know what it is doing
        return None

    ctx = memaccess.transition.context_before()
    md = md_64 if ctx.is64b() else md_32
    cs_insn = next(md.disasm(memaccess.transition.instruction.raw, memaccess.transition.instruction.size))

    filtered_bytes = get_filtered_bytes(cs_insn, memaccess)
    uum_bytes = [False] * memaccess.size

    for i in range(0, memaccess.size):
        if filtered_bytes[i]:
            continue

        possible_pages = list(filter(
            lambda page: (
                memaccess.physical_address.offset + i >= page['physical_address']
                and memaccess.physical_address.offset + i < page['physical_address'] + page['size']
            ),
            pages
        ))

        if len(possible_pages) > 1:
            # Should not be possible to have a byte in multiple pages
            raise AssertionError("Single byte access accross multiple pages")
        elif len(possible_pages) == 0:
            # Access partially outside the buffer
            continue

        phys_addr = possible_pages[0]['physical_address']
        byte_offset_in_page = memaccess.physical_address.offset + i - possible_pages[0]['physical_address']

        if memaccess.operation == reven2.memhist.MemoryAccessOperation.Read:
            byte_written = pages_bytes_written[phys_addr][byte_offset_in_page]

            if not byte_written:
                uum_bytes[i] = True

        elif memaccess.operation == reven2.memhist.MemoryAccessOperation.Write:
            pages_bytes_written[phys_addr][byte_offset_in_page] = True

    if any(uum_bytes):
        return UUM(alloc_event, free_event, memaccess, uum_bytes)

    return None


def uum_analyze_function(physical_address, alloc_events):
    uum_count = 0

    for (alloc_event, free_event) in get_alloc_free_pairs(alloc_events, errors):
        # We are trying to translate all the pages and will construct
        # an array of translated pages.
        # We don't check UUM on pages we couldn't translate
        alloc_event.memory_range.try_translate_all_pages(
            free_event.tr_begin.context_before()
            if free_event is not None else
            alloc_event.tr_end.context_before()
        )

        pages = list(filter(
            lambda page: page['physical_address'] is not None,
            alloc_event.memory_range.pages
        ))

        # An iterator of all the memory accesses of all the translated pages
        # The from is the start of the alloc and not the end of the alloc as in some
        # cases we want the accesses in it. For example a `calloc` will write the memory
        # during its execution. That's also why we are ignoring the read memory accesses
        # during the alloc function.
        mem_accesses = reven2.util.collate(map(
            lambda page: server.trace.memory_accesses(
                reven2.address.PhysicalAddress(page['physical_address']),
                page['size'],
                from_transition=alloc_event.tr_begin,
                to_transition=free_event.tr_begin if free_event is not None else None,
                is_forward=True,
                operation=None
            ),
            pages
        ), key=lambda access: access.transition)

        # This will contain for each page an array of booleans representing
        # if the byte have been written before or not
        pages_bytes_written = {}
        for page in pages:
            pages_bytes_written[page['physical_address']] = [False] * page['size']

        for memaccess in mem_accesses:
            if all([all(bytes_written) for bytes_written in pages_bytes_written.values()]):
                # All the bytes have been set in the memory
                # we no longer need to track the memory accesses
                break

            # Do we have a UUM on this memory access?
            uum = analyze_one_memaccess(alloc_event, free_event, pages, pages_bytes_written, memaccess)
            if uum is None:
                continue

            uum.analyze_usage()

            if only_display_uum_changing_control_flow and len(uum.conditionals) == 0:
                continue

            print(str(uum))
            print()

            uum_count += 1

    return uum_count


# %%
# %%time

count = 0

if faulty_physical_address is None:
    for physical_address, alloc_events in alloc_dict.items():
        count += uum_analyze_function(physical_address, alloc_events)
else:
    if faulty_physical_address not in alloc_dict:
        raise KeyError("The passed physical address was not detected during the allocation search")
    count += uum_analyze_function(faulty_physical_address, alloc_dict[faulty_physical_address])

print("---------------------------------------------------------------------------------")
begin_range = "the beginning of the trace" if from_tr is None else "#{}".format(to_tr)
end_range = "the end of the trace" if to_tr is None else "#{}".format(to_tr)
final_range = ("on the whole trace" if from_tr is None and to_tr is None else
               "between {} and {}".format(begin_range, end_range))

range_size = server.trace.transition_count
if from_tr is not None:
    range_size -= from_tr
if to_tr is not None:
    range_size -= server.trace.transition_count - to_tr

if faulty_physical_address is None:
    searched_memory_addresses = "with {} searched memory addresses".format(len(alloc_dict))
else:
    searched_memory_addresses = "on {:#x}".format(faulty_physical_address)

print("{} UUM(s) found {} ({} transitions) {}".format(
    count, final_range, range_size, searched_memory_addresses
))
print("---------------------------------------------------------------------------------")