REVEN API Examples

This book gathers the examples of the REVEN Python API.

These examples are distributed with REVEN and can be found in the package.

Other scripts for the REVEN Python API include ltrace and File activity that are available on GitHub.

Analyze examples

The examples in this section use the reven2 Python API to analyze a REVEN scenario.

Bookmarks to WinDbg breakpoints

Purpose

This notebook and script are designed to convert the bookmarks of a scenario to WinDbg breakpoints. The meat of the script uses the ability of the API to iterate on the bookmarks of a REVEN scenario, as well as the OSSI location, to generate a list of breakpoint commands for WinDbg where the addresses are independent of the REVEN scenario itself:

for bookmark in self._server.bookmarks.all():
location = bookmark.transition.context_before().ossi.location()
print(f"bp {location.binary.name}+{location.rva:#x}\r\n")

The output of the script is a list of WinDbg breakpoint commands corresponding to the relative virtual address of the location of each of the bookmarks. This list of command can either be copy-pasted in WinDbg or output to a file, which can then be executed in WinDbg using the following syntax:

$<breakpoints.txt

How to use

Bookmark can be converted from this notebook or from the command line. The script can also be imported as a module for use from your own script or notebook.

From the notebook

  1. Upload the bk2bp.ipynb file in Jupyter.
  2. Fill out the parameters cell of this notebook according to your scenario and desired output.
  3. Run the full notebook.

From the command line

  1. Make sure that you are in an environment that can run REVEN scripts.
  2. Run python bk2bp.py --help to get a tour of available arguments.
  3. Run python bk2bp.py --host <your_host> --port <your_port> [<other_option>] with your arguments of choice.

Imported in your own script or notebook

  1. Make sure that you are in an environment that can run REVEN scripts.
  2. Make sure that bk2bp.py is in the same directory as your script or notebook.
  3. Add import bk2bp to your script or notebook. You can access the various functions and classes exposed by the module from the bk2bp namespace.
  4. Refer to the Argument parsing cell for an example of use in a script, and to the Parameters cell and below for an example of use in a notebook (you just need to preprend bk2bp in front of the functions and classes from the script).

Known limitations

  • For the breakpoints to be resolved by WinDbg, the debugged program/machine/REVEN scenario needs to be in a state where the corresponding modules have been loaded. Otherwise, WinDbg will add the breakpoints in an unresolved state, and may mixup module and symbols.
  • When importing breakpoints generated from the bookmarks of a scenario using this script in WinDbg, make sure that the debugged system is "similar enough" to the VM that was used to record the scenario. In particular, if a binary changed and has symbols at different offsets in the debugged system, importing the breakpoints will not lead to the correct location in the binary, and may render the debugged system unstable.

Supported versions

REVEN 2.8+

Supported perimeter

Any Windows REVEN scenario.

Dependencies

The script requires that the target REVEN scenario have:

  • The Fast Search feature replayed.
  • The OSSI feature replayed.

Source

# ---
# jupyter:
#   jupytext:
#     formats: ipynb,py:percent
#     text_representation:
#       extension: .py
#       format_name: percent
#   kernelspec:
#     display_name: reven
#     language: python
#     name: reven-python3
# ---

# %% [markdown]
# # Bookmarks to WinDbg breakpoints
#
# ## Purpose
#
# This notebook and script are designed to convert the bookmarks of a scenario to WinDbg breakpoints.
#
# The meat of the script uses the ability of the API to iterate on the bookmarks of a REVEN scenario, as well as the
# OSSI location, to generate a list of breakpoint commands for WinDbg where the addresses are independent of the REVEN
# scenario itself:
#
# ```py
# for bookmark in self._server.bookmarks.all():
#     location = bookmark.transition.context_before().ossi.location()
#     print(f"bp {location.binary.name}+{location.rva:#x}\r\n")
# ```
#
# The output of the script is a list of WinDbg breakpoint commands corresponding to the relative virtual address
# of the location of each of the bookmarks.
#
# This list of command can either be copy-pasted in WinDbg or output to a file, which can then be executed in WinDbg
# using the following syntax:
#
# ```kd
# $<breakpoints.txt
# ```
#
# ## How to use
#
# Bookmark can be converted from this notebook or from the command line.
# The script can also be imported as a module for use from your own script or notebook.
#
#
# ### From the notebook
#
# 1. Upload the `bk2bp.ipynb` file in Jupyter.
# 2. Fill out the [parameters](#Parameters) cell of this notebook according to your scenario and desired output.
# 3. Run the full notebook.
#
#
# ### From the command line
#
# 1. Make sure that you are in an
#    [environment](http://doc.tetrane.com/professional/latest/Python-API/Installation.html#on-the-reven-server)
#    that can run REVEN scripts.
# 2. Run `python bk2bp.py --help` to get a tour of available arguments.
# 3. Run `python bk2bp.py --host <your_host> --port <your_port> [<other_option>]` with your arguments of
#    choice.
#
# ### Imported in your own script or notebook
#
# 1. Make sure that you are in an
#    [environment](http://doc.tetrane.com/professional/latest/Python-API/Installation.html#on-the-reven-server)
#    that can run REVEN scripts.
# 2. Make sure that `bk2bp.py` is in the same directory as your script or notebook.
# 3. Add `import bk2bp` to your script or notebook. You can access the various functions and classes
#    exposed by the module from the `bk2bp` namespace.
# 4. Refer to the [Argument parsing](#Argument-parsing) cell for an example of use in a script, and to the
#    [Parameters](#Parameters) cell and below for an example of use in a notebook (you just need to preprend
#    `bk2bp` in front of the functions and classes from the script).
#
# ## Known limitations
#
# - For the breakpoints to be resolved by WinDbg, the debugged program/machine/REVEN scenario needs to be in a state
# where the corresponding modules have been loaded. Otherwise, WinDbg will add the breakpoints in an unresolved state,
#   and may mixup module and symbols.
#
# - When importing breakpoints generated from the bookmarks of a scenario using this script in WinDbg,
#   make sure that the debugged system is "similar enough" to the VM that was used to record the scenario.
#   In particular, if a binary changed and has symbols at different offsets in the debugged system, importing
#   the breakpoints will not lead to the correct location in the binary, and may render the debugged system unstable.
#
# ## Supported versions
#
# REVEN 2.8+
#
# ## Supported perimeter
#
# Any Windows REVEN scenario.
#
# ## Dependencies
#
# The script requires that the target REVEN scenario have:
#
# * The Fast Search feature replayed.
# * The OSSI feature replayed.

# %% [markdown]
# ### Package imports

# %%
import argparse
from typing import Optional

import reven2  # type: ignore


# %% [markdown]
# ### Utility functions

# %%
# Detect if we are currently running a Jupyter notebook.
#
# This is used e.g. to display rendered results inline in Jupyter when we are executing in the context of a Jupyter
# notebook, or to display raw results on the standard output when we are executing in the context of a script.
def in_notebook():
    try:
        from IPython import get_ipython  # type: ignore

        if get_ipython() is None or ("IPKernelApp" not in get_ipython().config):
            return False
    except ImportError:
        return False
    return True


# %% [markdown]
# ### Main function

# %%
def bk2bp(server: reven2.RevenServer, output: Optional[str]):
    text = ""
    for bookmark in server.bookmarks.all():
        ossi = bookmark.transition.context_before().ossi
        if ossi is None:
            continue
        location = ossi.location()
        if location is None:
            continue
        if location.binary is None:
            continue
        if location.rva is None:
            continue
        name = location.binary.name
        # WinDbg requires the precise name of the kernel, which is difficult to get.
        # WinDbg seems to always accept "nt" as name for the kernel, so replace that.
        if name == "ntoskrnl":
            name = "nt"
        text += f"bp {name}+{location.rva:#x}\r\n"  # for windows it is safest to have the \r
    if output is None:
        print(text)
    else:
        try:
            with open(output, "w") as f:
                f.write(text)
        except OSError as ose:
            raise ValueError(f"Could not open file {output}: {ose}")


# %% [markdown]
# ### Argument parsing
#
# Argument parsing function for use in the script context.

# %%
def script_main():
    parser = argparse.ArgumentParser(
        description="Convert the bookmarks of a scenario to a WinDbg breakpoints commands.",
        epilog="Requires the Fast Search and the OSSI features replayed.",
    )
    parser.add_argument(
        "--host",
        type=str,
        default="localhost",
        required=False,
        help='REVEN host, as a string (default: "localhost")',
    )
    parser.add_argument(
        "-p",
        "--port",
        type=int,
        default="13370",
        required=False,
        help="REVEN port, as an int (default: 13370)",
    )
    parser.add_argument(
        "-o",
        "--output-file",
        type=str,
        required=False,
        help="The target file of the script. If absent, the results will be printed on the standard output.",
    )

    args = parser.parse_args()

    try:
        server = reven2.RevenServer(args.host, args.port)
    except RuntimeError:
        raise RuntimeError(f"Could not connect to the server on {args.host}:{args.port}.")

    bk2bp(server, args.output_file)


# %% [markdown]
# ### Parameters
#
# These parameters have to be filled out to use in the notebook context.

# %%
# Server connection
#
host = "localhost"
port = 13370


# Output target
#
# If set to a path, writes the breakpoint commands file there
output_file = None  # display bp commands inline in the Jupyter Notebook
# output_file = "breakpoints.txt"  # write bp commands to a file named "breakpoints.txt" in the current directory

# %% [markdown]
# ### Execution cell
#
# This cell executes according to the [parameters](#Parameters) when in notebook context, or according to the
# [parsed arguments](#Argument-parsing) when in script context.
#
# When in notebook context, if the `output` parameter is `None`, then the output will be displayed in the last cell of
# the notebook.

# %%
if __name__ == "__main__":
    if in_notebook():
        try:
            server = reven2.RevenServer(host, port)
        except RuntimeError:
            raise RuntimeError(f"Could not connect to the server on {host}:{port}.")
        bk2bp(server, output_file)
    else:
        script_main()

# %%

Percent

Purpose

Get the transition that performs the opposite operation to the given transition.

The opposite operations are the following:

  • The transition switches between user and kernel land. Examples:
    • a syscall transition => the related sysret transition
    • a sysret transition => the related syscall transition
    • a exception transition => the related iretq transition
    • a iretq transition => the related exception transition
  • The transition does memory accesses:
    • case 1: a unique access. The access is selected.
    • case 1: multiple write accesses. The first one is selected.
    • case 2: multiple read accesses. The first one is selected.
    • case 3: multiple read and write accesses. The first write access is selected. This enable to get the matching ret transition on an indirect call transition e.g. call [rax + 10]. If the selected access is a write then the next read access on the same memory is search for. If the selected access is a read then the previous write access on the same memory search for.

Examples, percent on:

  • a call transition => the related ret transition.
  • a ret transition => the related call transition.
  • a push transition => the related pop transition.
  • a pop transition => the related push transition.

If no related transition is found, None is returned.

How to use

usage: percent.py [-h] [--host HOST] [-p PORT] transition

positional arguments:
  transition            Transition id, as an int

optional arguments:
  -h, --help            show this help message and exit
  --host HOST           Reven host, as a string (default: "localhost")
  -p PORT, --port PORT  Reven port, as an int (default: 13370)

Known limitations

percent is an heuristic that sometimes doesn't produce the expected result.

Supported versions

REVEN 2.2+. For REVEN 2.5+, prefer to use the Transition.find_inverse method.

Supported perimeter

Any REVEN scenario.

Dependencies

The script requires that the target REVEN scenario have the Memory History feature replayed.

Source

import argparse

import reven2

import reven_api

"""
# Percent

## Purpose

Get the transition that performs the `opposite` operation to the given transition.

The opposite operations are the following:

  * The transition switches between user and kernel land.
    Examples:
    * a `syscall` transition => the related `sysret` transition
    * a `sysret` transition => the related `syscall` transition
    * a exception transition => the related `iretq` transition
    * a `iretq` transition => the related exception transition
  * The transition does memory accesses:
    * case 1: a unique access.
              The access is selected.
    * case 1: multiple write accesses.
              The first one is selected.
    * case 2: multiple read accesses.
              The first one is selected.
    * case 3: multiple read and write accesses.
              The first write access is selected.
              This enable to get the matching `ret` transition
              on an indirect call transition e.g. `call [rax + 10]`.
      If the selected access is a write then the next read access
      on the same memory is search for.
      If the selected access is a read then the previous write access
      on the same memory search for.

Examples, percent on:
  * a `call` transition => the related `ret` transition.
  * a `ret` transition => the related `call` transition.
  * a `push` transition => the related `pop` transition.
  * a `pop` transition => the related `push` transition.

If no related transition is found, `None` is returned.

## How to use

```bash
usage: percent.py [-h] [--host HOST] [-p PORT] transition

positional arguments:
  transition            Transition id, as an int

optional arguments:
  -h, --help            show this help message and exit
  --host HOST           Reven host, as a string (default: "localhost")
  -p PORT, --port PORT  Reven port, as an int (default: 13370)
```

## Known limitations

`percent` is an heuristic that sometimes doesn't produce the expected result.

## Supported versions

REVEN 2.2+. For REVEN 2.5+, prefer to use the `Transition.find_inverse` method.

## Supported perimeter

Any REVEN scenario.

## Dependencies

The script requires that the target REVEN scenario have the Memory History feature replayed.
"""


def previous_register_change(reven, register, from_transition):
    """
    Get the previous transition where the register's value changed.
    """
    range_size = 5000000
    start = reven_api.execution_point(from_transition.id)
    stop = reven_api.execution_point(max(from_transition.id - range_size, 0))
    result = reven._rvn.run_search_next_register_use(
        start, forward=False, read=False, write=True, register_name=register.name, stop=stop
    )
    while result == stop:
        start = stop
        stop = reven_api.execution_point(max(start.sequence_identifier - range_size, 0))
        result = reven._rvn.run_search_next_register_use(
            start, forward=False, read=False, write=True, register_name=register.name, stop=stop
        )
    if result.valid():
        return reven.trace.transition(result.sequence_identifier)
    return None


def next_register_change(reven, register, from_transition):
    """
    Get the next transition where the register's value changed.
    """
    range_size = 5000000
    start = reven_api.execution_point(from_transition.id)
    stop = reven_api.execution_point(from_transition.id + range_size)
    result = reven._rvn.run_search_next_register_use(
        start, forward=True, read=False, write=True, register_name=register.name, stop=stop
    )
    while result == stop:
        start = stop
        stop = reven_api.execution_point(start.sequence_identifier + range_size)
        result = reven._rvn.run_search_next_register_use(
            start, forward=True, read=False, write=True, register_name=register.name, stop=stop
        )
    if result.valid():
        return reven.trace.transition(result.sequence_identifier)
    return None


def previous_memory_use(reven, address, size, from_transition, operation=None):
    """
    Get the previous transition where the memory range [address ; size] is used (read/write).
    """
    try:
        access = next(
            reven.trace.memory_accesses(address, size, from_transition, is_forward=False, operation=operation)
        )
        return access.transition
    except StopIteration:
        return None


def next_memory_use(reven, address, size, from_transition, operation=None):
    """
    Get the next transition where the memory range [address ; size] is used (read/write).
    """
    try:
        access = next(
            reven.trace.memory_accesses(address, size, from_transition, is_forward=True, operation=operation)
        )
        return access.transition
    except StopIteration:
        return None


def percent(reven, transition):
    """
    This function is a helper to get the transition that performs
    the `opposite` operation to the given transition.

    If no opposite transition is found, `None` is returned.

    Opposite operations
    ===================

    * The transition switches between user and kernel land.
      Examples:
          * a `syscall` transition => the related `sysret` transition
          * a `sysret` transition => the related `syscall` transition
          * a exception transition => the related `iretq` transition
          * a `iretq` transition => the related exception transition

    * The transition does memory accesses:
        * case 1: a unique access.
                  The access is selected.
        * case 1: multiple write accesses.
                  The first one is selected.
        * case 2: multiple read accesses.
                  The first one is selected.
        * case 3: multiple read and write accesses.
                  The first write access is selected.
                  This enable to get the matching `ret` transition
                  on an indirect call transition e.g. `call [rax + 10]`.
      If the selected access is a write then the next read access
      on the same memory is search for.
      If the selected access is a read then the previous write access
      on the same memory search for.

      Examples, percent on:
          * a `call` transition => the related `ret` transition.
          * a `ret` transition => the related `call` transition.
          * a `push` transition => the related `pop` transition.
          * a `pop` transition => the related `push` transition.

    Dependencies
    ============

    The script requires that the target REVEN scenario have the Memory History feature replayed.

    Usage
    =====

    It can be combined with other features like backtrace to obtain interesting results.

    For example, to jump to the end of the current function:
        >>> import reven2
        >>> from percent import percent
        >>> reven_server = reven2.RevenServer('localhost', 13370)
        >>> current_transition = reven_server.trace.transition(10000000)
        >>> ret_transition = percent(reven_server,
        ...                          current_transition.context_before().stack.frames[0].creation_transition)
    """
    ctx_b = transition.context_before()
    ctx_a = transition.context_after()

    # cs basic heuristic to handle sysenter/sysexit

    cs_b = ctx_b.read(reven2.arch.x64.cs)
    cs_a = ctx_a.read(reven2.arch.x64.cs)

    if cs_b > cs_a:
        # ss is modified by transition
        return next_register_change(reven, reven2.arch.x64.cs, transition)
    if cs_b < cs_a:
        # ss is modified by transition
        return previous_register_change(reven, reven2.arch.x64.cs, transition)

    # memory heuristic

    # first: check write accesses (get the first one)
    # this is to avoid failure on indirect call (1 read access then 1 write access)
    for access in transition.memory_accesses(operation=reven2.memhist.MemoryAccessOperation.Write):
        if access.virtual_address is None:
            # ignoring physical access
            continue
        return next_memory_use(
            reven, access.virtual_address, access.size, transition, reven2.memhist.MemoryAccessOperation.Read
        )

    # second: check read accesses (get the first one)
    for access in transition.memory_accesses(operation=reven2.memhist.MemoryAccessOperation.Read):
        if access.virtual_address is None:
            # ignoring physical access
            continue
        return previous_memory_use(
            reven, access.virtual_address, access.size, transition, reven2.memhist.MemoryAccessOperation.Write
        )

    return None


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="localhost", help='Reven host, as a string (default: "localhost")')
    parser.add_argument("-p", "--port", type=int, default="13370", help="Reven port, as an int (default: 13370)")
    parser.add_argument("transition", type=int, help="Transition id, as an int")
    args = parser.parse_args()

    rvn = reven2.RevenServer(args.host, args.port)
    transition = rvn.trace.transition(args.transition)

    result = percent(rvn, transition)
    if result is not None:
        if result >= transition:
            print("=> {}".format(transition))
            print("<= {}".format(result))
        else:
            print("<= {}".format(transition))
            print("=> {}".format(result))
    else:
        print("No result found for {}".format(transition))

Search in memory

Purpose

Search the memory at a specific context for a string or for an array of bytes.

The memory range to search in is defined by a starting address and a search_size.

All unmapped addresses are ignored during the search.

How to use

usage: search_in_memory.py [-h] --host HOST -p PORT --transition TRANSITION
                           --address ADDRESS --pattern PATTERN
                           [--search-size SEARCH_SIZE] [--backward]

optional arguments:
  -h, --help            show this help message and exit
  --host HOST           Reven host, as a string (default: "localhost")
  -p PORT, --port PORT  Reven port, as an int (default: 13370)
  --transition TRANSITION
                        transition id. the context before this id will be
                        searched
  --address ADDRESS     The start address of the memory area to search in. It
                        can be a hex offset as 0xfff123 (same as ds:0xfff123),
                        hex offset prefixed by segment register as
                        gs:0xfff123, hex offset prefixed by hex segment index
                        as 0x20:0xfff123, hex offset prefixed by 'lin' for
                        linear address, or offset prefixed by 'phy' for
                        physical address.
  --pattern PATTERN     pattern that will be searched. It can be a normal
                        string as 'test', or a string of bytes as
                        '\x01\x02\x03\x04'.Maximum accepted length is 4096
  --search-size SEARCH_SIZE
                        The size of memory area to search in. accepted value
                        can take a suffix, like 1000, 10kb or 10mb.Default
                        value is 1000mb
  --backward            If present the search will go in backward direction.

Known limitations

  • Currently, this script cannot handle logical addresses that are not aligned on a memory page (4K) with their corresponding physical address. In 64 bits, this can happen mainly for the gs and fs segment registers. If you encounter this limitation, you can manually translate your virtual address using its translate method, and then restart the search on the resulting physical address (limiting the search range to 4K, so as to remain in the boundaries of the virtual page).

  • Pattern length must be less than or equal to the page size (4k).

Supported versions

REVEN 2.6+.

Supported perimeter

Any REVEN scenario.

Dependencies

None.

Source

import argparse
import sys
from copy import copy

import reven2
import reven2.address as _address
import reven2.arch.x64 as x64_regs


"""
# Search in memory

## Purpose

Search the memory at a specific context for a string or for an array of bytes.

The memory range to search in is defined by a starting address and a search_size.

All unmapped addresses are ignored during the search.

## How to use

```bash
usage: search_in_memory.py [-h] --host HOST -p PORT --transition TRANSITION
                           --address ADDRESS --pattern PATTERN
                           [--search-size SEARCH_SIZE] [--backward]

optional arguments:
  -h, --help            show this help message and exit
  --host HOST           Reven host, as a string (default: "localhost")
  -p PORT, --port PORT  Reven port, as an int (default: 13370)
  --transition TRANSITION
                        transition id. the context before this id will be
                        searched
  --address ADDRESS     The start address of the memory area to search in. It
                        can be a hex offset as 0xfff123 (same as ds:0xfff123),
                        hex offset prefixed by segment register as
                        gs:0xfff123, hex offset prefixed by hex segment index
                        as 0x20:0xfff123, hex offset prefixed by 'lin' for
                        linear address, or offset prefixed by 'phy' for
                        physical address.
  --pattern PATTERN     pattern that will be searched. It can be a normal
                        string as 'test', or a string of bytes as
                        '\x01\x02\x03\x04'.Maximum accepted length is 4096
  --search-size SEARCH_SIZE
                        The size of memory area to search in. accepted value
                        can take a suffix, like 1000, 10kb or 10mb.Default
                        value is 1000mb
  --backward            If present the search will go in backward direction.
```

## Known limitations

- Currently, this script cannot handle logical addresses that are not aligned on a memory page (4K)
with  their corresponding physical address. In 64 bits, this can happen mainly for
the `gs` and `fs` segment registers.
  If you encounter this limitation, you can manually translate your virtual address
  using its `translate` method, and then restart the search on the resulting physical address
  (limiting the search range to 4K, so as to remain in the boundaries of the virtual page).

- Pattern length must be less than or equal to the page size (4k).

## Supported versions

REVEN 2.6+.

## Supported perimeter

Any REVEN scenario.

## Dependencies

None.
"""


class MemoryFinder(object):
    r"""
    This class is a helper class to search the memory at a specific context for a string or for an array of bytes.
    The memory range to search in is defined by a starting address and a search_size.

    The matching addresses are returned.

    Known limitation
    ================

    Currently, this class cannot handle logical addresses that are not aligned on a memory page (4K)
    with  their corresponding physical address. In 64 bits, this can happen mainly for
    the `gs` and `fs` segment registers.

    If you encounter this limitation, you can manually translate your virtual address
    using its `translate` method, and then restart the search on the resulting physical address
    (limiting the search range to 4K, so as to remain in the boundaries of the virtual page).

    Pattern length must be less than or equal to page size (4k).

    Examples
    ========

    >>> # Search the first context starting from the address ds:0xfffff123123 for the string 'string'
    >>> # Search_size default value is 1000MB.
    >>> # Memory range to search in is: [ds:0xfffff123123, ds:0xfffff123123 + 1000MB]
    >>> for address, progress in MemoryFinder(context, 0xfffff123123).query('string'):
    ...     sys.stderr.write("progress: %d%s\r" % (int(progress / finder.search_size * 100), '%'))
    ...     if address:
    ...         print("found match at {}".format(address))
    found match at ds:0xfffff123444
    ...

    >>> # Search the first context starting from the address lin:0xfffff123123 for the
    >>> # array of bytes '\\x35\\xfe\\x0e\\x4a'
    >>> # Search size default value is 1000MB
    >>> # Memory range to search in is: [lin:0xfffff123123, lin:0xfffff123123 + 1000MB]
    >>> address = reven2.address.LinearAddress(0xfffff123123)
    >>> for address, progress in MemoryFinder(context, address).query('\\x35\\xfe\\x0e\\x4a'):
    ...     sys.stderr.write("progress: %d%s\r" % (int(progress / finder.search_size * 100), '%'))
    ...     if address:
    ...         print("found match at {}".format(address))
    found match at ds:0xfffff125229
    ...

    >>> # Search the first context starting from the address gs:0x180 for the string 'string'
    >>> # Search size value is 100MB
    >>> # Memory range to search in is: [gs:0x180, ds:0x180 + 100MB]
    >>> address = reven2.address.LogicalAddress(0x180, reven2.arch.x64.gs)
    >>> for address, progress in MemoryFinder(context, address, 100*1024*1024).query('string'):
    ...     sys.stderr.write("progress: %d%s\r" % (int(progress / finder.search_size * 100), '%'))
    ...     if address:
    ...         print("found match at {}".format(address))
    found match at ds:0xfffff123444
    ...

    >>> # Search the first context starting from the address ds:0xfffff123123 for the string 'string'
    >>> # in backward direction.
    >>> # Search_size default value is 1000MB.
    >>> # Memory range to search in is: [ds:0xfffff123123, ds:0xfffff123123 + 1000MB]
    >>> for address, progress in MemoryFinder(context, 0xfffff123123).query('string', False):
    ...     sys.stderr.write("progress: %d%s\r" % (int(progress / finder.search_size * 100), '%'))
    ...     if address:
    ...         print("found match at {}".format(address))
    found match at ds:0xfffff123004
    ...
    """
    page_size = 0x1000
    progress_step = 0x10000

    def __init__(self, context, address, search_size=1000 * 1024**2):
        r"""
        Initialize a C{MemoryFinder} from context and address

        Information
        ===========

        @param context: C{reven2.trace.Context} where searching will be done.
        @param address: a class from C{reven2.address} the address where the search will be started.
        @param search_size: an C{Integer} representing the size, in bytes, of the search range.


        @raises TypeError: if context is not a C{reven2.trace.Context} or address is not a C{Integer} or
                           one of the address classes on C{reven2.address}.
        @raises RunTimeError: If the address is a virtual address that is not aligned to its
                              corresponding physical address.
        """
        if not isinstance(context, reven2.trace.Context):
            raise TypeError("context must be an instance of reven2.trace.Context class")
        self._context = context

        search_addr = copy(address)
        if not isinstance(search_addr, _address._AbstractAddress):
            try:
                # if address is of type int make it a logical address with ds as segment register
                search_addr = _address.LogicalAddress(address)
            except TypeError:
                raise TypeError(
                    "address must be an instance of a class from reven2.address " "module or an integer value."
                )

        self._search_size = search_size
        self._start = search_addr

    @property
    def search_size(self):
        return self._search_size

    def query(self, pattern, is_forward=True):
        r"""
        Iterate the search range looking for the specified pattern.

        This method returns a generator of tuples, of the form C{(A, processed_bytes)}, such that:

        - C{processed_bytes} indicates the number of bytes already processed in the search range.
        - C{A} is either an address of the same type as the input address, or C{None}.
          If an address is returned, it corresponds to an address matching the searched pattern.
          C{None} is returned every 40KB of the search range, as a means of indicating progress.

        Information
        ===========

        @param pattern: A C{str} or C{bytearray}. The pattern to look for in memory.
                        Note: C{str} pattern is converted to bytearray using ascci encoding.

        @param is_forward: C{bool}, C{True} to search in forward direction and C{False}
                           to search in backward direction

        @returns: a generator of tuples, where the tuples are either:
                - C{(None, processed_bytes)} every  40KB of the search range,
                - C{(matching_address, processed_bytes)} each time a matching_address is found.
        """

        # pattern is a byte array or a string
        search_pattern = copy(pattern)
        if not isinstance(search_pattern, bytearray):
            if isinstance(search_pattern, str):
                search_pattern = bytearray(str.encode(pattern))
            else:
                raise RuntimeError("Cannot parse pattern, bad format.")

        if len(search_pattern) > self.page_size:
            raise RuntimeError("Maximum length of pattern must be less than or equal to %d." % self.page_size)

        return self._search(search_pattern, is_forward)

    def _search(self, pattern, is_forward):
        def loop_condition(curr, end):
            return curr < end if is_forward else curr > end

        cross_page_addition = len(pattern) - 1

        iteration_step = self.page_size if is_forward else -self.page_size
        curr = self._start
        end = curr + self._search_size if is_forward else curr - self._search_size
        prev = None
        progress = 0

        # first loop detects the first mapped address, then test if it aligned
        # this step is only applied for logical address
        if not isinstance(curr, reven2.address.PhysicalAddress):
            while loop_condition(curr, end):
                phy = curr.translate(self._context)
                if phy is None:
                    curr += iteration_step
                    progress += self.page_size
                    if progress % self.progress_step == 0:
                        yield None, progress
                    continue
                # linear -> physical alignment is guaranteed on 4k boundary:
                # If linear is 0xxxxx123, physical will be 0xyyyy123
                # logical -> linear alignment is not guaranteed because segment offset goes down to the byte
                # (or at least down to less than 4k): logical gs:0x123 could be linear 0xzzzzz456
                # Problem is: gs:0x0 might not be at start of page, 0x0:0x1000 might span on two pages
                # instead of one. To solve: we need to translate logical -> physical for start address,
                # and take note of offset to use that to compute actual start of page
                # currently, we don't treat the case where logical -> linear alignment isn't valid.
                if curr.offset % self.page_size != phy.offset % self.page_size:
                    raise RuntimeError(
                        "The provided address is not aligned on a memory page (4K)"
                        "with  their corresponding physical address. Only aligned "
                        "addresses can be handled."
                    )
                break

        # second loop starts the search
        while loop_condition(curr, end):
            # get offset between current address and the start of the page
            # This offset is zero except in the first iteration may be different to zero
            offset = curr.offset % self.page_size
            # compute the length of the buffer to read.
            # This buffer length equals the page size except in the first iteration may be different
            buffer_length = self.page_size if offset == 0 else (self.page_size - offset if is_forward else offset)
            # the iteration step to go forward or backward
            iteration_step = buffer_length if is_forward else -buffer_length
            # compute the address to read it.
            # in forward this address is the current address,
            # in backward we have to read until the current address so it is current - buffer length
            read_address = curr if is_forward else curr - buffer_length
            # if the read buffer will exceed the search range adjust it
            if is_forward and read_address + buffer_length > end:
                buffer_length = end.offset - read_address.offset
            elif not is_forward and read_address < end:
                read_address = end
                buffer_length = curr.offset - read_address.offset

            try:
                buffer = self._context.read(read_address, buffer_length, raw=True)
            except Exception:
                curr += iteration_step
                progress += self.page_size
                prev = None
                if progress % self.progress_step == 0:
                    yield None, progress
                continue

            # Add necessary bytes from previous page to allow cross-page matches
            addr_offset = 0
            if prev is not None:
                if is_forward:
                    prev_buf_len = -len(prev) if cross_page_addition > len(prev) else -cross_page_addition
                    buffer = prev[prev_buf_len:] + buffer if prev_buf_len < 0 else buffer
                    addr_offset = prev_buf_len
                else:
                    prev_buf_len = len(prev) if cross_page_addition > len(prev) else cross_page_addition
                    buffer = buffer + prev[:prev_buf_len]

            index = 0
            addr_res = []
            while True:
                index = buffer.find(pattern, index)
                if index == -1:
                    break
                addr_res.append(read_address + index + addr_offset)
                index += 1

            for addr in addr_res if is_forward else reversed(addr_res):
                yield addr, progress

            progress += self.page_size
            prev = buffer
            curr += iteration_step


def parse_address(string_address):
    segments = [x64_regs.ds, x64_regs.cs, x64_regs.es, x64_regs.ss, x64_regs.gs, x64_regs.fs]

    def _str_to_seg(str_reg):
        for segment in segments:
            if str_reg == segment.name:
                return segment
        return None

    try:
        # Try to parse address as offset only as 0xfff123.
        return _address.LogicalAddress(int(string_address, base=16))
    except ValueError:
        pass
    # Try to parse address as prefex:offset as 0x32:0xfff123, gs:0xfff123, lin:0xfff123 or phy:0xff123.
    res = string_address.split(":")
    if len(res) != 2:
        raise RuntimeError("Cannot parse address, bad format")

    try:
        offset = int(res[1].strip(), base=16)
    except ValueError:
        raise RuntimeError("Cannot parse address, bad format")

    try:
        # Try to parse it as 0x32:0xfff123.
        segment_index = int(res[0].strip(), base=16)
        return _address.LogicalAddressSegmentIndex(segment_index, offset)
    except ValueError:
        pass

    lower_res0 = res[0].lower().strip()
    # Try parse it as ds:0xfff123, cs::0xfff123, es::0xfff123, ss::0xfff123, gs::0xfff123 or fs::0xfff123.
    sreg = _str_to_seg(lower_res0)
    if sreg:
        return _address.LogicalAddress(offset, sreg)
    elif lower_res0 == "lin":
        # Try parse it as lin:0xfff123.
        return _address.LinearAddress(offset)
    elif lower_res0 == "phy":
        # Try parse it as phy:0xfff123.
        return _address.PhysicalAddress(offset)
    else:
        raise RuntimeError("Cannot parse address, bad format")


def parse_search_size(string_size):
    try:
        # try to convert it to int
        return int(string_size)
    except ValueError:
        pass

    # try to convert it to int without the two last char
    lower_string = string_size.lower()
    ssize = lower_string[:-2]
    try:
        size = int(ssize)
    except ValueError:
        raise RuntimeError("Cannot parse search size, bad format")
    # convert it according to its suffix
    if lower_string.endswith("kb"):
        return size * 1024
    elif lower_string.endswith("mb"):
        return size * 1024 * 1024
    else:
        raise RuntimeError("Cannot parse search size, bad format")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--host", type=str, default="localhost", required=True, help='Reven host, as a string (default: "localhost")'
    )
    parser.add_argument(
        "-p", "--port", type=int, default="13370", required=True, help="Reven port, as an int (default: 13370)"
    )
    parser.add_argument(
        "--transition", type=int, required=True, help="transition id. the context before this id will be searched"
    )
    parser.add_argument(
        "--address",
        type=str,
        required=True,
        help="The start address of the memory area to search in. "
        "It can be a hex offset as 0xfff123 (same as ds:0xfff123), "
        "hex offset prefixed by segment register as gs:0xfff123, "
        "hex offset prefixed by hex segment index as 0x20:0xfff123, "
        "hex offset prefixed by 'lin' for linear address, "
        "or offset prefixed by 'phy' for physical address.",
    )
    parser.add_argument(
        "--pattern",
        type=str,
        required=True,
        help="pattern that will be searched. "
        "It can be a normal string as 'test', "
        "or a string of bytes as '\\x01\\x02\\x03\\x04'."
        "Maximum accepted length is 4096",
    )
    parser.add_argument(
        "--search-size",
        type=str,
        default="1000mb",
        help="The size of memory area to search in. "
        "accepted value can take a suffix, like 1000, 10kb or 10mb."
        "Default value is 1000mb",
    )
    parser.add_argument(
        "--backward", default=False, action="store_true", help="If present the search will go in backward direction."
    )

    args = parser.parse_args()

    try:
        pattern = bytearray(map(ord, bytearray(map(ord, args.pattern.strip())).decode("unicode_escape")))
    except Exception as e:
        raise RuntimeError("Cannot parse pattern, bad format(%s)" % str(e))

    address = parse_address(args.address.strip())

    reven_server = reven2.RevenServer(args.host, args.port)
    context = reven_server.trace.context_before(args.transition)

    finder = MemoryFinder(context, address, parse_search_size(args.search_size.strip()))

    for address, progress in finder.query(pattern, not args.backward):
        sys.stderr.write("progress: %d%s\r" % (int(progress / finder.search_size * 100), "%"))
        if address:
            print("found match at {}".format(address))

OSSI

Examples in this section demonstrate OS-Specific Information capabilities, such as browsing processes, binaries and symbols.

Thread id

Purpose

Detect the current thread at a point in the trace and find when the thread is created.

How to use

usage: thread_id.py [-h] [--host HOST] [-p PORT] TRANSITION_ID

positional arguments:
  TRANSITION_ID         Get thread id at transition (before)

optional arguments:
  -h, --help            show this help message and exit
  --host HOST           Reven host, as a string (default: "localhost")
  -p PORT, --port PORT  Reven port, as an int (default: 13370)

Known limitations

  • Current thread is not detected if the given point is in ring0.

Supported versions

REVEN 2.8+

Supported perimeter

Any Windows 10 on x86-64 scenario. Given point must be in a 64 bit process.

Dependencies

The script requires that the target REVEN scenario have: * The Fast Search feature replayed. * The OSSI feature replayed. * An access to the binary 'ntdll.dll' and its PDB file.

Source

import argparse

import reven2
from reven2.address import LogicalAddress
from reven2.arch import x64

"""
# Thread id

## Purpose

Detect the current thread at a point in the trace
and find when the thread is created.

## How to use

```bash
usage: thread_id.py [-h] [--host HOST] [-p PORT] TRANSITION_ID

positional arguments:
  TRANSITION_ID         Get thread id at transition (before)

optional arguments:
  -h, --help            show this help message and exit
  --host HOST           Reven host, as a string (default: "localhost")
  -p PORT, --port PORT  Reven port, as an int (default: 13370)
```

## Known limitations

- Current thread is not detected if the given point is in ring0.

## Supported versions

REVEN 2.8+

## Supported perimeter

Any Windows 10 on x86-64 scenario. Given point must be in a 64 bit process.

## Dependencies

The script requires that the target REVEN scenario have:
    * The Fast Search feature replayed.
    * The OSSI feature replayed.
    * An access to the binary 'ntdll.dll' and its PDB file.
"""


class ThreadInfo(object):
    def __init__(self, ctxt):
        self.cr3 = ctxt.read(x64.cr3)
        self.pid = ctxt.read(LogicalAddress(0x40, x64.gs), 4)
        self.tid = ctxt.read(LogicalAddress(0x48, x64.gs), 4)

    def __eq__(self, other):
        return (self.cr3, self.pid, self.tid) == (
            other.cr3,
            other.pid,
            other.tid,
        )

    def __ne__(self, other):
        return not self == other


def context_ring(ctxt):
    return ctxt.read(x64.cs) & 0x3


def all_start_thread_calls(ossi, trace):
    # look for RtlUserThreadStart
    ntdll_dll = next(ossi.executed_binaries("c:/windows/system32/ntdll.dll"))
    rtl_user_thread_start = next(ntdll_dll.symbols("RtlUserThreadStart"))
    return trace.search.symbol(rtl_user_thread_start)


def thread_search_pc(trace, thread_info, pc, from_context=None, to_context=None):
    matches = trace.search.pc(pc, from_context=from_context, to_context=to_context)
    for ctxt in matches:
        # ensure current match is in requested thread
        if ThreadInfo(ctxt) == thread_info:
            yield ctxt


def find_thread_starting_transition(rvn, thread_info):
    for start_thread_ctxt in all_start_thread_calls(rvn.ossi, rvn.trace):
        if ThreadInfo(start_thread_ctxt) == thread_info:
            # the first argument is the start address of the thread
            thread_start_address = start_thread_ctxt.read(x64.rcx)

            matches = thread_search_pc(
                rvn.trace,
                thread_info,
                pc=thread_start_address,
                from_context=start_thread_ctxt,
            )
            for match in matches:
                return match.transition_after()
    return None


def print_thread_info(rvn, tr_id):
    ctxt = rvn.trace.transition(tr_id).context_before()

    if context_ring(ctxt) == 0:
        print("(User) thread may not count in ring 0")
        return

    # pid, tid at the transition
    thread = ThreadInfo(ctxt)

    start_transition = find_thread_starting_transition(rvn, thread)
    if start_transition is None:
        print("TID: {thread.tid} (PID: {thread.pid}) starting transition not found".format(thread=thread))
        return

    print(
        "TID: {thread.tid} (PID: {thread.pid}), starts at: {transition}".format(
            thread=thread, transition=start_transition
        )
    )


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--host",
        type=str,
        default="localhost",
        help='Reven host, as a string (default: "localhost")',
    )
    parser.add_argument(
        "-p",
        "--port",
        type=int,
        default="13370",
        help="Reven port, as an int (default: 13370)",
    )
    parser.add_argument(
        "transition_id",
        metavar="TRANSITION_ID",
        type=int,
        help="Get thread id at transition (before)",
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()

    rvn = reven2.RevenServer(args.host, args.port)
    tr_id = args.transition_id
    print_thread_info(rvn, tr_id)

Search

Purpose

Search in a whole trace one of the following points of interest:

  • An executed symbol.
  • An executed binary.
  • An executed virtual address.

How to use

usage: search.py [-h] [--host HOST] [-p PORT] [-s SYMBOL] [-b BINARY] [-a PC]
                 [--case-sensitive]

optional arguments:
  -h, --help            show this help message and exit
  --host HOST           Reven host, as a string (default: "localhost")
  -p PORT, --port PORT  Reven port, as an int (default: 13370)
  -s SYMBOL, --symbol SYMBOL
                        symbol pattern
  -b BINARY, --binary BINARY
                        binary pattern
  -a PC, --pc PC        pc address
  --case-sensitive      case sensitive symbol search

Known limitations

N/A

Supported versions

REVEN 2.2+

Supported perimeter

Any REVEN scenario.

Dependencies

The script requires that the target REVEN scenario have:

  • The Fast Search feature replayed.
  • The OSSI feature replayed.

Source

import argparse

import reven2


"""
# Search

## Purpose

Search in a whole trace one of the following points of interest:
  * An executed symbol.
  * An executed binary.
  * An executed virtual address.

## How to use

```bash
usage: search.py [-h] [--host HOST] [-p PORT] [-s SYMBOL] [-b BINARY] [-a PC]
                 [--case-sensitive]

optional arguments:
  -h, --help            show this help message and exit
  --host HOST           Reven host, as a string (default: "localhost")
  -p PORT, --port PORT  Reven port, as an int (default: 13370)
  -s SYMBOL, --symbol SYMBOL
                        symbol pattern
  -b BINARY, --binary BINARY
                        binary pattern
  -a PC, --pc PC        pc address
  --case-sensitive      case sensitive symbol search
```

## Known limitations

N/A

## Supported versions

REVEN 2.2+

## Supported perimeter

Any REVEN scenario.

## Dependencies

The script requires that the target REVEN scenario have:
  * The Fast Search feature replayed.
  * The OSSI feature replayed.
"""


def search(reven_server, symbol=None, binary=None, pc=None, case_sensitive=False):
    r"""
    This function is a helper to search easily one of the following points of interest:
    * executed symbols
    * executed binaries
    * an executed virtual address

    The matching contexts are returned in ascending order.

    Examples
    ========

    >>> # Search for RIP = 0x7fff57263b2f
    >>> for ctx in search(reven_server, pc=0x7fff57263b2f):
    ...     print(ctx)
    Context before #240135
    Context before #281211
    Context before #14608067
    Context before #14690369
    Context before #15756067
    Context before #15787089
    ...

    >>> # Search for binary "kernelbase.dll"
    >>> for ctx in search(reven_server, binary=r'kernelbase\.dll'):
    ...     print(ctx)
    Context before #240135
    Context before #240136
    Context before #240137
    Context before #240138
    Context before #240139
    Context before #240140
    Context before #240141
    ...

    >>> # Search for binaries that contains ".exe"
    >>> for ctx in search(reven_server, binary=r'\.exe'):
    ...     print(ctx)
    Context before #1537879110
    Context before #1537879111
    Context before #1537879112
    Context before #1537879113
    Context before #1537879372
    Context before #1537879373
    Context before #1537879374
    ...

    >>> # Search for all symbol symbols that contains "acpi"
    >>> for ctx in search(reven_server, symbol='acpi'):
    ...     print(ctx)
    Context before #1471900961
    Context before #1471903808
    Context before #1471908093
    Context before #1471914935
    Context before #1472413834
    Context before #1472416173
    Context before #1472419063
    ...

    >>> # Search for symbol "CreateProcessW" in binary "kernelbase.dll"
    >>> for ctx in search(reven_server, symbol='^CreateProcessW$', binary=r'kernelbase\.dll'):
    ...     print(ctx)
    Context before #23886919
    Context before #1370448535
    Context before #2590849986

    Information
    ===========

    @param reven_server: A C{reven2.RevenServer} instance.
    @param symbol: A symbol regex pattern.
                   Can be complete with the `binary` argument.
    @param binary: A binary regex pattern.
    @param pc: A virtual address integer.
    @param case_sensitive: Whether the symbol pattern comparison is case sensitive or not.

    @return: A generator of C{reven2.trace.Context} instances.
    """
    search = reven_server.trace.search
    if pc is not None:
        return search.pc(pc)

    if binary is not None:
        if symbol is not None:
            queries = [
                search.symbol(rsymbol)
                for rsymbol in reven_server.ossi.symbols(
                    pattern=symbol, binary_hint=binary, case_sensitive=case_sensitive
                )
            ]
        else:
            queries = [search.binary(rbinary) for rbinary in reven_server.ossi.executed_binaries(pattern=binary)]
        return reven2.util.collate(queries)

    if symbol is not None:
        queries = [
            search.symbol(rsymbol)
            for rsymbol in reven_server.ossi.symbols(pattern=symbol, case_sensitive=case_sensitive)
        ]
        return reven2.util.collate(queries)

    raise ValueError("You must provide something to search")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="localhost", help='Reven host, as a string (default: "localhost")')
    parser.add_argument("-p", "--port", type=int, default="13370", help="Reven port, as an int (default: 13370)")
    parser.add_argument("-s", "--symbol", type=str, help="symbol pattern")
    parser.add_argument("-b", "--binary", type=str, help="binary pattern")
    parser.add_argument("-a", "--pc", type=lambda a: int(a, 0), help="pc address")
    parser.add_argument("--case-sensitive", action="store_true", help="case sensitive symbol search")
    args = parser.parse_args()

    reven_server = reven2.RevenServer(args.host, args.port)
    for ctx in search(
        reven_server, symbol=args.symbol, binary=args.binary, pc=args.pc, case_sensitive=args.case_sensitive
    ):
        try:
            tr = ctx.transition_after()
            print("#{}: {}".format(tr.id, ctx.ossi.location()))
        except IndexError:
            tr = ctx.transition_before()
            print("#{}: {}".format(tr.id + 1, ctx.ossi.location()))

Trace symbol coverage

Purpose

Display a list of executed binaries and symbols in a REVEN scenario, indicating how many transitions where spent in each symbol/binary.

How to use

usage: trace_coverage.py [-h] [--host HOST] [-p PORT] [-m MAX_TRANSITION]

optional arguments:
  -h, --help            show this help message and exit
  --host HOST           Reven host, as a string (default: "localhost")
  -p PORT, --port PORT  Reven port, as an int (default: 13370)
  -m MAX_TRANSITION, --max-transition MAX_TRANSITION
                        Maximum number of transitions

Known limitations

N/A

Supported versions

REVEN 2.2+

Supported perimeter

Any REVEN scenario.

Dependencies

None.

Source

import argparse

import reven2 as reven

"""
# Trace symbol coverage

## Purpose

Display a list of executed binaries and symbols in a REVEN scenario,
indicating how many transitions where spent in each symbol/binary.

## How to use

```bash
usage: trace_coverage.py [-h] [--host HOST] [-p PORT] [-m MAX_TRANSITION]

optional arguments:
  -h, --help            show this help message and exit
  --host HOST           Reven host, as a string (default: "localhost")
  -p PORT, --port PORT  Reven port, as an int (default: 13370)
  -m MAX_TRANSITION, --max-transition MAX_TRANSITION
                        Maximum number of transitions
```

# Known limitations

N/A

## Supported versions

REVEN 2.2+

## Supported perimeter

Any REVEN scenario.

## Dependencies

None.
"""


def trace_coverage(reven_server, max_transition=None):
    if max_transition is None:
        max_transition = reven_server.trace.transition_count
    else:
        max_transition = min(max_transition, reven_server.trace.transition_count)

    transition_id = 0
    coverages = {}
    while transition_id < max_transition:
        ctx = reven_server.trace.context_before(transition_id)
        transition_id += 1
        asid = ctx.read(reven.arch.x64.cr3)
        loc = ctx.ossi.location()
        unknown = True if loc is None else False
        binary = "unknown" if unknown else loc.binary.path
        symbol = "unknown" if unknown or loc.symbol is None else loc.symbol.name

        try:
            asid_coverage = coverages[asid]
        except KeyError:
            coverages[asid] = {}
            asid_coverage = coverages[asid]

        try:
            binary_coverage = asid_coverage[binary]
            binary_coverage[0] += 1
            if binary == "unknown":
                continue
        except KeyError:
            if binary == "unknown":
                asid_coverage[binary] = [1, None]
                continue
            asid_coverage[binary] = [1, {}]
            binary_coverage = asid_coverage[binary]

        try:
            binary_coverage[1][symbol] += 1
        except KeyError:
            binary_coverage[1][symbol] = 1
    return coverages


def print_coverages(coverages):
    for (asid, asid_coverage) in coverages.items():
        print("***** Coverage for CR3 = {:#x} *****\n".format(asid))
        for (binary, binary_coverage) in asid_coverage.items():
            print("- {}: {}".format(binary, binary_coverage[0]))
            if binary_coverage[1] is None:
                continue
            for (symbol, symbol_coverage) in binary_coverage[1].items():
                print("    - {}: {}".format(symbol, symbol_coverage))
            print("\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="localhost", help='Reven host, as a string (default: "localhost")')
    parser.add_argument("-p", "--port", type=int, default="13370", help="Reven port, as an int (default: 13370)")
    parser.add_argument("-m", "--max-transition", type=int, help="Maximum number of transitions")
    args = parser.parse_args()

    reven_server = reven.RevenServer(args.host, args.port)
    coverages = trace_coverage(reven_server, args.max_transition)
    print_coverages(coverages)

List created processes

Purpose

List all processes created in the trace.

How to use

usage: list_created_processes.py [-h] [--host HOST] [-p PORT]

optional arguments:
  -h, --help            show this help message and exit
  --host HOST           Reven host, as a string (default: "localhost")
  -p PORT, --port PORT  Reven port, as an int (default: 13370)

Known limitations

N/A

Supported versions

REVEN 2.2+

Supported perimeter

Any Windows x64 scenario.

Dependencies

The script requires that the target REVEN scenario have:

  • The Memory History feature replayed.
  • The Fast Search feature replayed.
  • The OSSI feature replayed.
  • An access to the binaries 'kernel32.dll' and 'kernelbase.dll' and their PDB files.

Source

import argparse

import reven2


"""
# List created processes

## Purpose

List all processes created in the trace.

## How to use

```bash
usage: list_created_processes.py [-h] [--host HOST] [-p PORT]

optional arguments:
  -h, --help            show this help message and exit
  --host HOST           Reven host, as a string (default: "localhost")
  -p PORT, --port PORT  Reven port, as an int (default: 13370)
```

## Known limitations

N/A

## Supported versions

REVEN 2.2+

## Supported perimeter

Any Windows x64 scenario.

## Dependencies

The script requires that the target REVEN scenario have:
  * The Memory History feature replayed.
  * The Fast Search feature replayed.
  * The OSSI feature replayed.
  * An access to the binaries 'kernel32.dll' and 'kernelbase.dll' and their PDB files.
"""


class Process(object):
    def __init__(self, name, pid, tid):
        self.name = name
        self.pid = pid
        self.tid = tid


def created_processes(reven):
    """
    Get all processes that were created during the trace (Windows 64 only).

    This is based on the call to the `CreateProcessInternalW` function of kernelbase.dll` and `kernel32.dll`.

    ```
    BOOL CreateProcessInternalW(
        LPCWSTR lpApplicationName, (rdx)
        LPWSTR lpCommandLine,      (r8)
        ...,
        LPPROCESS_INFORMATION lpProcessInformation (rsp + 0x58)
    )

    typedef struct _PROCESS_INFORMATION {
        HANDLE hProcess;   (+0x0)
        HANDLE hThread;    (+0x8)
        DWORD dwProcessId; (+0x10)
        DWORD dwThreadId;  (+0x14)
    } PROCESS_INFORMATION, *PPROCESS_INFORMATION, *LPPROCESS_INFORMATION;
    ```

    Dependencies
    ============

    The script requires that the target REVEN scenario have:
      * The Memory History feature replayed.
      * The Fast Search feature replayed.
      * The OSSI feature replayed.
      * An access to the binaries 'kernel32.dll' and 'kernelbase.dll' and their PDB files.
    """
    queries = [
        rvn.trace.search.symbol(symbol)
        for symbol in rvn.ossi.symbols(pattern="CreateProcessInternalW", binary_hint="kernelbase.dll")
    ]
    queries += [
        rvn.trace.search.symbol(symbol)
        for symbol in rvn.ossi.symbols(pattern="CreateProcessInternalW", binary_hint="kernel32.dll")
    ]

    for match in reven2.util.collate(queries):
        call_tr = match.transition_before()
        instruction = call_tr.instruction
        if instruction is None:
            # This case should not happen since the RIP register of the context after an exception transition
            # is generally pointing to exception handling code, not `kernel32!CreateProcessInternalW` or
            # `kernelbase!CreateProcessInternalW` code.
            continue
        if instruction.mnemonic != "call":
            # Certainly comes from a code page fault on the call instruction.
            # Never seen but possible.
            continue

        # Get process name from arguments lpApplicationName or lpCommandLine
        try:
            name = match.deref(
                reven2.arch.x64.rdx,
                reven2.types.Pointer(reven2.types.CString(encoding=reven2.types.Encoding.Utf16, max_size=256)),
            )
        except RuntimeError:
            name = match.deref(
                reven2.arch.x64.r8,
                reven2.types.Pointer(reven2.types.CString(encoding=reven2.types.Encoding.Utf16, max_size=256)),
            )

        # Get pointer to PROCESS_INFORMATION struct
        stack_pointer = match.read(reven2.arch.x64.rsp, reven2.types.Pointer(reven2.types.USize))
        process_info_pointer = match.read(stack_pointer + 0x58, reven2.types.Pointer(reven2.types.USize))

        # Go to the end of the function
        return_tr = call_tr.find_inverse()
        if return_tr is None:
            # Create process does not finish before the end of the trace
            continue

        return_value = return_tr.context_after().read(reven2.arch.x64.rax)
        if return_value == 0:
            # Create process failed
            continue

        # Get PID and TID from PROCESS_INFORMATION structure
        pid = return_tr.context_before().read(process_info_pointer + 0x10, 4)
        tid = return_tr.context_before().read(process_info_pointer + 0x14, 4)

        yield (call_tr, Process(name, pid, tid))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="localhost", help='Reven host, as a string (default: "localhost")')
    parser.add_argument("-p", "--port", type=int, default="13370", help="Reven port, as an int (default: 13370)")
    args = parser.parse_args()

    rvn = reven2.RevenServer(args.host, args.port)
    for transition, process in created_processes(rvn):
        print("#{}: name = {}, pid = {}, tid = {}".format(transition.id, process.name, process.pid, process.tid))

Binary Coverage

Purpose

This script is designed to build the coverage of a binary executed in a REVEN scenario.

How to use

usage: bin_coverage.py [-h] [--host HOST] [-p PORT] binary
positional arguments:
binary                Binary on which to compute coverage
optional arguments:
-h, --help            show this help message and exit
--host HOST           Reven host, as a string (default: "localhost")
-p PORT, --port PORT  Reven port, as an int (default: 13370)

Known limitations

N/A

Supported versions

REVEN 2.2+

Supported perimeter

Any REVEN scenario.

Dependencies

None.

Source

import argparse
import builtins
from collections import defaultdict

import reven2 as reven

# %% [markdown]
# # Binary Coverage
#
# ## Purpose
#
# This script is designed to build the coverage of a binary executed in a REVEN scenario.
#
# ## How to use
#
# ```bash
# usage: bin_coverage.py [-h] [--host HOST] [-p PORT] binary
#
# positional arguments:
#   binary                Binary on which to compute coverage
#
# optional arguments:
#   -h, --help            show this help message and exit
#   --host HOST           Reven host, as a string (default: "localhost")
#   -p PORT, --port PORT  Reven port, as an int (default: 13370)
# ```
#
# ## Known limitations
#
# N/A
#
# ## Supported versions
#
# REVEN 2.2+
#
# ## Supported perimeter
#
# Any REVEN scenario.
#
# ## Dependencies
#
# None.


def find_and_choose_binary(ossi, requested_binary):
    binaries = list(ossi.executed_binaries(requested_binary))
    if len(binaries) == 0:
        raise RuntimeError('Binary "{}" not executed in the trace.'.format(requested_binary))

    if len(binaries) == 1:
        return binaries[0]

    print('Multiple matches for "{}":'.format(requested_binary))
    for (index, binary) in enumerate(binaries):
        print("{}: {}".format(index, binary.path))
    answer = builtins.input("Please choose one binary: ")
    return binaries[int(answer)]


def compute_binary_coverages(trace, binary):
    coverages = {}
    for ctx in trace.search.binary(binary):
        asid = ctx.read(reven.arch.x64.cr3)
        loc = ctx.ossi.location()
        symbol = "unknown" if loc.symbol is None else loc.symbol.name
        asid_coverage = coverages.setdefault(asid, [loc.base_address, defaultdict(int)])[1]
        asid_coverage[symbol] += 1
    return coverages


def binary_coverage(reven_server, binary):
    binary = find_and_choose_binary(reven_server.ossi, binary)
    return compute_binary_coverages(reven_server.trace, binary)


def print_binary_coverages(coverages):
    for (asid, asid_coverage) in coverages.items():
        print("***** Coverage for CR3 = {:#x}: base address = {:#x} *****\n".format(asid, asid_coverage[0]))
        for (symbol, symbol_coverage) in asid_coverage[1].items():
            print("  {}: {}".format(symbol, symbol_coverage))
        print("\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="localhost", help='Reven host, as a string (default: "localhost")')
    parser.add_argument("-p", "--port", type=int, default="13370", help="Reven port, as an int (default: 13370)")
    parser.add_argument("binary", type=str, help="Binary on which to compute coverage")
    args = parser.parse_args()

    reven_server = reven.RevenServer(args.host, args.port)
    coverages = binary_coverage(reven_server, args.binary)
    print_binary_coverages(coverages)

Vulnerability detection

Examples in this section attempt to detect some vulnerabilities present in REVEN scenarios.

Searching for Buffer-Overflow vulnerabilities

This notebook allows to search for potential Buffer-Overflow vulnerabilities in a REVEN trace.

Prerequisites

  • This notebook should be run in a jupyter notebook server equipped with a REVEN 2 python kernel. REVEN comes with a jupyter notebook server accessible with the Open Python button in the Analyze page of any scenario.
  • This notebook depends on capstone being installed in the REVEN 2 python kernel. To install capstone in the current environment, please execute the capstone cell of this notebook.

Running the notebook

Fill out the parameters in the Parameters cell below, then run all the cells of this notebook.

Note

Although the script is designed to limit false positive results, you may still encounter a few ones. Please check the results and feel free to report any issue, be it a false positive or a false negative ;-).

Source

# -*- coding: utf-8 -*-
# ---
# jupyter:
#   jupytext:
#     formats: ipynb,py:percent
#     text_representation:
#       extension: .py
#       format_name: percent
#   kernelspec:
#     display_name: reven
#     language: python
#     name: reven-python3
# ---

# %% [markdown]
# # Searching for Buffer-Overflow vulnerabilities
#
# This notebook allows to search for potential Buffer-Overflow vulnerabilities in a REVEN trace.
#
# ## Prerequisites
#
# - This notebook should be run in a jupyter notebook server equipped with a REVEN 2 python kernel.
#   REVEN comes with a jupyter notebook server accessible with the `Open Python` button in the `Analyze` page of any
#   scenario.
#
# - This notebook depends on `capstone` being installed in the REVEN 2 python kernel.
#   To install capstone in the current environment, please execute the [capstone cell](#Capstone-Installation) of this
#   notebook.
#
#
# ## Running the notebook
#
# Fill out the parameters in the [Parameters cell](#Parameters) below, then run all the cells of this notebook.
#
#
# ## Note
#
# Although the script is designed to limit false positive results, you may still encounter a few ones. Please check
# the results and feel free to report any issue, be it a false positive or a false negative ;-).

# %% [markdown]
# # Capstone Installation
#
# Check for capstone's presence. If missing, attempt to get it from pip

# %%
try:
    import capstone
    print("capstone already installed")
except ImportError:
    print("Could not find capstone, attempting to install it from pip")
    import sys
    import subprocess

    command = [f"{sys.executable}", "-m", "pip", "install", "capstone"]
    p = subprocess.run(command)

    if int(p.returncode) != 0:
        raise RuntimeError("Error installing capstone")

    import capstone  # noqa
    print("Successfully installed capstone")

# %% [markdown]
# # Parameters

# %%
# Server connection

# Host of the REVEN server running the scenario.
# When running this notebook from the Project Manager, '127.0.0.1' should be the correct value.
reven_backend_host = '127.0.0.1'

# Port of the REVEN server running the scenario.
# After starting a REVEN server on your scenario, you can get its port on the Analyze page of that scenario.
reven_backend_port = 13370


# Range control

# First transition considered for the detection of allocation/deallocation pairs
# If set to None, then the first transition of the trace
bof_from_tr = None
# bof_from_tr = 5300000  # ex: CVE-2020-17087 scenario

# First transition **not** considered for the detection of allocation/deallocation pairs
# If set to None, then the last transition of the trace
bof_to_tr = None
# bof_to_tr = 6100000  # ex: CVE-2020-17087 scenario


# Filter control

# Beware that, when filtering, if an allocation happens in the specified process and/or binary,
# the script will fail if the deallocation happens in a different process and/or binary.
# This issue should only happen for allocations in the kernel.

# Specify on which PID the allocation/deallocation pairs should be kept.
# Set a value of None to not filter on the PID.
faulty_process_pid = None
# faulty_process_pid = 756  # ex: CVE-2020-17087 scenario
# faulty_process_pid = 466  # ex: CVE-2021-3156 scenario

# Specify on which process name the allocation/deallocation pairs should be kept.
# Set a value of None to not filter on the process name.
# If both a process PID and a process name are specified, please make sure that they both
# refer to the same process, otherwise all allocations will be filtered and no results
# will be produced.
faulty_process_name = None
# faulty_process_name = "cve-2020-17087.exe"  # ex: CVE-2020-17087 scenario
# faulty_process_name = "sudoedit"        # ex: CVE-2021-3156 scenario

# Specify on which binary name the allocation/deallocation pairs should be kept.
# Set a value of None to not filter on the binary name.
# Only allocation/deallocation taking place in the binaries whose filename,
# path or name contain the specified value are kept.
# If filtering on both a process and a binary, please make sure that there are
# allocations taking place in that binary in the selected process, otherwise all
# allocations will be filtered and no result will be produced.
faulty_binary = None
# faulty_binary = "cng.sys"     # ex: CVE-2020-17087 scenario
# faulty_binary = "sudoers.so"  # ex: CVE-2021-3156 scenario


# Address control

# Specify a **physical** address suspected of being faulty here,
# to only test BoF for this specific address, instead of all (filtered) allocations.
# The address should still be returned by an allocation/deallocation pair.
# To get a physical address from a virtual address, find a context where the address
# is mapped, then use `virtual_address.translate(ctx)`.
bof_faulty_physical_address = None
# bof_faulty_physical_address = 0x7c5a1450   # ex: CVE-2020-17087 scenario
# bof_faulty_physical_address = 0x132498cd0  # ex: CVE-2021-3156 scenario

# Allocator control

# The script can use two allocators to find allocation/deallocation pairs.
# The following booleans allow to enable the search for allocations by these
# allocators for a scenario.
# Generally it is expected to have only a single allocator enabled for a given
# scenario.

# To add your own allocator, please look at how the two provided allocators were
# added.

# Whether or not to look for windows malloc/free allocation/deallocation pairs.
search_windows_malloc = True
# search_windows_malloc = False  # ex: CVE-2020-17087 scenario
# search_windows_malloc = False  # ex: CVE-2021-3156 scenario

# Whether or not to look for ExAllocatePoolWithTag/ExFreePoolWithTag
# allocation/deallocation pairs.
# This allocator is used by the Windows kernel.
search_pool_allocation = True
# search_pool_allocation = True  # ex: CVE-2020-17087 scenario, kernel scenario
# search_pool_allocation = False  # ex: CVE-2021-3156 scenario

# Whether or not to look for linux malloc/free allocation/deallocation pairs.
search_linux_malloc = False
# search_linux_malloc = False  # ex: CVE-2020-17087 scenario, kernel scenario
# search_linux_malloc = True   # ex: CVE-2021-3156 scenario

# Taint control

# Technical parameter: number of accesses in a taint after which the script gives up on
# that particular taint.
# As long taints degrade the performance of the script significantly, it is recommended to
# give up on a taint after it exceeds a certain number of operations.
# If you experience missing results, try increasing the value.
# If you experience a very long runtime for the script, try decreasing the value.
# The default value should be appropriate for most cases.
taint_max_length = 100000

# Technical parameter: number of bytes that we consider around an allocated buffer to determine if an access if a BoF
# (or underflow).
# Adjust this value to limit the number of false positives.
bof_overflow_limit = 1024

# %%
from collections import OrderedDict  # noqa: E402
from typing import Dict, List  # noqa: E402

import reven2  # noqa: E402
import reven2.preview.taint  # noqa: E402


# %%
# Python script to connect to this scenario:
server = reven2.RevenServer(reven_backend_host, reven_backend_port)
print(server.trace.transition_count)


# %%
class MemoryRange:
    page_size = 4096
    page_mask = ~(page_size - 1)

    def __init__(self, logical_address, size):
        self.logical_address = logical_address
        self.size = size

        self.pages = [{
            'logical_address': self.logical_address,
            'size': self.size,
            'physical_address': None,
            'ctx_physical_address_mapped': None,
        }]

        # Compute the pages
        while (((self.pages[-1]['logical_address'] & ~MemoryRange.page_mask) + self.pages[-1]['size'] - 1)
                >= MemoryRange.page_size):
            # Compute the size of the new page
            new_page_size = ((self.pages[-1]['logical_address'] & ~MemoryRange.page_mask) + self.pages[-1]['size']
                             - MemoryRange.page_size)

            # Reduce the size of the previous page and create the new one
            self.pages[-1]['size'] -= new_page_size
            self.pages.append({
                'logical_address': self.pages[-1]['logical_address'] + self.pages[-1]['size'],
                'size': new_page_size,
                'physical_address': None,
                'ctx_physical_address_mapped': None,
            })

    def try_translate_first_page(self, ctx):
        if self.pages[0]['physical_address'] is not None:
            return True

        physical_address = reven2.address.LogicalAddress(self.pages[0]['logical_address']).translate(ctx)

        if physical_address is None:
            return False

        self.pages[0]['physical_address'] = physical_address.offset
        self.pages[0]['ctx_physical_address_mapped'] = ctx

        return True

    def try_translate_all_pages(self, ctx):
        return_value = True

        for page in self.pages:
            if page['physical_address'] is not None:
                continue

            physical_address = reven2.address.LogicalAddress(page['logical_address']).translate(ctx)

            if physical_address is None:
                return_value = False
                continue

            page['physical_address'] = physical_address.offset
            page['ctx_physical_address_mapped'] = ctx

        return return_value

    def is_physical_address_range_in_translated_pages(self, physical_address, size):
        for page in self.pages:
            if page['physical_address'] is None:
                continue

            if (
                physical_address >= page['physical_address']
                and physical_address + size <= page['physical_address'] + page['size']
            ):
                return True

        return False

    def __repr__(self):
        return "MemoryRange(0x%x, %d)" % (self.logical_address, self.size)

# Utils to translate the physical address of an address allocated just now
#     - ctx should be the ctx where the address is located in `rax`
#     - memory_range should be the range of memory of the newly allocated buffer
#
# We are using the translate API to translate it but sometimes just after the allocation
# the address isn't mapped yet. For that we are using the slicing and for all slice access
# we are trying to translate the address.


def translate_first_page_of_allocation(ctx, memory_range):
    if memory_range.try_translate_first_page(ctx):
        return

    tainter = reven2.preview.taint.Tainter(server.trace)
    taint = tainter.simple_taint(
        tag0="rax",
        from_context=ctx,
        to_context=None,
        is_forward=True
    )

    for access in taint.accesses(changes_only=False).all():
        if memory_range.try_translate_first_page(access.transition.context_after()):
            taint.cancel()
            return

    raise RuntimeError("Couldn't find the physical address of the first page")


# %%
class AllocEvent:
    def __init__(self, memory_range, tr_begin, tr_end):
        self.memory_range = memory_range
        self.tr_begin = tr_begin
        self.tr_end = tr_end


class FreeEvent:
    def __init__(self, logical_address, tr_begin, tr_end):
        self.logical_address = logical_address
        self.tr_begin = tr_begin
        self.tr_end = tr_end


def retrieve_events_for_symbol(
    alloc_dict,
    event_class,
    symbol,
    retrieve_event_info,
    event_filter=None,
):
    for ctx in server.trace.search.symbol(
        symbol,
        from_context=None if bof_from_tr is None else server.trace.context_before(bof_from_tr),
        to_context=None if bof_to_tr is None else server.trace.context_before(bof_to_tr)
    ):
        # We don't want hit on exception (code pagefault, hardware interrupts, etc)
        if ctx.transition_after().exception is not None:
            continue

        previous_location = (ctx - 1).ossi.location()
        previous_process = (ctx - 1).ossi.process()

        # Filter by process pid/process name/binary name

        # Filter by process pid
        if faulty_process_pid is not None and previous_process.pid != faulty_process_pid:
            continue

        # Filter by process name
        if faulty_process_name is not None and previous_process.name != faulty_process_name:
            continue

        # Filter by binary name / filename / path
        if faulty_binary is not None and faulty_binary not in [
            previous_location.binary.name,
            previous_location.binary.filename,
            previous_location.binary.path
        ]:
            continue

        # Filter the event with the argument filter
        if event_filter is not None:
            if event_filter(ctx.ossi.location(), previous_location):
                continue

        # Retrieve the call/ret
        # The heuristic is that the ret is the end of our function
        #   - If the call is inlined it should be at the end of the caller function, so the ret is the ret of our
        #     function
        #   - If the call isn't inlined, the ret should be the ret of our function
        ctx_call = next(ctx.stack.frames()).creation_transition.context_after()
        ctx_ret = ctx_call.transition_before().find_inverse().context_before()

        # Build the event by reading the needed registers
        if event_class == AllocEvent:
            current_address, size = retrieve_event_info(ctx, ctx_ret)

            # Filter the alloc failing
            if current_address == 0x0:
                continue

            memory_range = MemoryRange(current_address, size)
            try:
                translate_first_page_of_allocation(ctx_ret, memory_range)
            except RuntimeError:
                # If we can't translate the first page we assume that the buffer isn't used because
                # the heuristic to detect the call/ret failed
                continue

            if memory_range.pages[0]['physical_address'] not in alloc_dict:
                alloc_dict[memory_range.pages[0]['physical_address']] = []

            alloc_dict[memory_range.pages[0]['physical_address']].append(
                AllocEvent(
                    memory_range,
                    ctx.transition_after(), ctx_ret.transition_after()
                )
            )
        elif event_class == FreeEvent:
            current_address = retrieve_event_info(ctx, ctx_ret)

            # Filter the free of NULL
            if current_address == 0x0:
                continue

            current_physical_address = reven2.address.LogicalAddress(current_address).translate(ctx).offset

            if current_physical_address not in alloc_dict:
                alloc_dict[current_physical_address] = []

            alloc_dict[current_physical_address].append(
                FreeEvent(
                    current_address,
                    ctx.transition_after(), ctx_ret.transition_after()
                )
            )
        else:
            raise RuntimeError("Unknown event class: %s" % event_class.__name__)


# %%
# %%time

alloc_dict: Dict = {}


# Basic functions to retrieve the arguments
# They are working for the allocations/frees functions but won't work for all functions
# Particularly because on x86 we don't handle the size of the arguments
# nor if they are pushed left to right or right to left
def retrieve_first_argument(ctx):
    if ctx.is64b():
        return ctx.read(reven2.arch.x64.rcx)
    else:
        esp = ctx.read(reven2.arch.x64.esp)
        return ctx.read(reven2.address.LogicalAddress(esp + 4, reven2.arch.x64.ss), 4)


def retrieve_second_argument(ctx):
    if ctx.is64b():
        return ctx.read(reven2.arch.x64.rdx)
    else:
        esp = ctx.read(reven2.arch.x64.esp)
        return ctx.read(reven2.address.LogicalAddress(esp + 8, reven2.arch.x64.ss), 4)


def retrieve_first_argument_linux(ctx):
    if ctx.is64b():
        return ctx.read(reven2.arch.x64.rdi)
    else:
        raise NotImplementedError("Linux 32bits")


def retrieve_second_argument_linux(ctx):
    if ctx.is64b():
        return ctx.read(reven2.arch.x64.rsi)
    else:
        raise NotImplementedError("Linux 32bits")


def retrieve_return_value(ctx):
    if ctx.is64b():
        return ctx.read(reven2.arch.x64.rax)
    else:
        return ctx.read(reven2.arch.x64.eax)


def retrieve_alloc_info_with_size_as_first_argument(ctx_begin, ctx_end):
    return (
        retrieve_return_value(ctx_end),
        retrieve_first_argument(ctx_begin)
    )


def retrieve_alloc_info_with_size_as_first_argument_linux(ctx_begin, ctx_end):
    return (
        retrieve_return_value(ctx_end),
        retrieve_first_argument_linux(ctx_begin)
    )


def retrieve_alloc_info_with_size_as_second_argument(ctx_begin, ctx_end):
    return (
        retrieve_return_value(ctx_end),
        retrieve_second_argument(ctx_begin)
    )


def retrieve_alloc_info_with_size_as_second_argument_linux(ctx_begin, ctx_end):
    return (
        retrieve_return_value(ctx_end),
        retrieve_second_argument_linux(ctx_begin)
    )


def retrieve_alloc_info_for_calloc(ctx_begin, ctx_end):
    return (
        retrieve_return_value(ctx_end),
        retrieve_first_argument(ctx_begin) * retrieve_second_argument(ctx_begin)
    )


def retrieve_alloc_info_for_calloc_linux(ctx_begin, ctx_end):
    return (
        retrieve_return_value(ctx_end),
        retrieve_first_argument_linux(ctx_begin) * retrieve_second_argument_linux(ctx_begin)
    )


def retrieve_free_info_with_address_as_first_argument(ctx_begin, ctx_end):
    return retrieve_first_argument(ctx_begin)


def retrieve_free_info_with_address_as_first_argument_linux(ctx_begin, ctx_end):
    return retrieve_first_argument_linux(ctx_begin)


if search_windows_malloc:
    def filter_in_realloc(location, caller_location):
        return location.binary == caller_location.binary and caller_location.symbol.name == "realloc"

    # Search for allocations with malloc
    for symbol in server.ossi.symbols(r'^_?malloc$', binary_hint=r'msvcrt.dll'):
        retrieve_events_for_symbol(alloc_dict, AllocEvent, symbol, retrieve_alloc_info_with_size_as_first_argument,
                                   filter_in_realloc)

    # Search for allocations with calloc
    for symbol in server.ossi.symbols(r'^_?calloc(_crt)?$', binary_hint=r'msvcrt.dll'):
        retrieve_events_for_symbol(alloc_dict, AllocEvent, symbol, retrieve_alloc_info_for_calloc)

    # Search for deallocations with free
    for symbol in server.ossi.symbols(r'^_?free$', binary_hint=r'msvcrt.dll'):
        retrieve_events_for_symbol(alloc_dict, FreeEvent, symbol, retrieve_free_info_with_address_as_first_argument,
                                   filter_in_realloc)

    # Search for re-allocations with realloc
    for symbol in server.ossi.symbols(r'^_?realloc$', binary_hint=r'msvcrt.dll'):
        retrieve_events_for_symbol(alloc_dict, AllocEvent, symbol, retrieve_alloc_info_with_size_as_second_argument)
        retrieve_events_for_symbol(alloc_dict, FreeEvent, symbol, retrieve_free_info_with_address_as_first_argument)

if search_pool_allocation:
    # Search for allocations with ExAllocatePool...
    def filter_ex_allocate_pool(location, caller_location):
        return location.binary == caller_location.binary and caller_location.symbol.name.startswith("ExAllocatePool")

    for symbol in server.ossi.symbols(r'^ExAllocatePool', binary_hint=r'ntoskrnl.exe'):
        retrieve_events_for_symbol(alloc_dict, AllocEvent, symbol, retrieve_alloc_info_with_size_as_second_argument,
                                   filter_ex_allocate_pool)

    # Search for deallocations with ExFreePool...
    def filter_ex_free_pool(location, caller_location):
        return location.binary == caller_location.binary and caller_location.symbol.name == "ExFreePool"

    for symbol in server.ossi.symbols(r'^ExFreePool', binary_hint=r'ntoskrnl.exe'):
        retrieve_events_for_symbol(alloc_dict, FreeEvent, symbol, retrieve_free_info_with_address_as_first_argument,
                                   filter_ex_free_pool)

if search_linux_malloc:
    def filter_in_realloc(location, caller_location):
        return (location.binary == caller_location.binary
                and caller_location.symbol is not None
                and caller_location.symbol.name in ["realloc", "__GI___libc_realloc"])

    # Search for allocations with malloc
    for symbol in server.ossi.symbols(r'^((__GI___libc_malloc)|(__libc_malloc))$', binary_hint=r'libc-.*.so'):
        retrieve_events_for_symbol(alloc_dict, AllocEvent, symbol,
                                   retrieve_alloc_info_with_size_as_first_argument_linux, filter_in_realloc)

    # Search for allocations with calloc
    for symbol in server.ossi.symbols(r'^((__calloc)|(__libc_calloc))$', binary_hint=r'libc-.*.so'):
        retrieve_events_for_symbol(alloc_dict, AllocEvent, symbol, retrieve_alloc_info_for_calloc_linux)

    # Search for deallocations with free
    for symbol in server.ossi.symbols(r'^((__GI___libc_free)|(cfree))$', binary_hint=r'libc-.*.so'):
        retrieve_events_for_symbol(alloc_dict, FreeEvent, symbol,
                                   retrieve_free_info_with_address_as_first_argument_linux, filter_in_realloc)

    # Search for re-allocations with realloc
    for symbol in server.ossi.symbols(r'^((__GI___libc_realloc)|(realloc))$', binary_hint=r'libc-.*.so'):
        retrieve_events_for_symbol(alloc_dict, AllocEvent, symbol,
                                   retrieve_alloc_info_with_size_as_second_argument_linux)
        retrieve_events_for_symbol(alloc_dict, FreeEvent, symbol,
                                   retrieve_free_info_with_address_as_first_argument_linux)


# Sort the events per address and event type
for physical_address in alloc_dict.keys():
    alloc_dict[physical_address] = list(sorted(
        alloc_dict[physical_address],
        key=lambda event: (event.tr_begin.id, 0 if isinstance(event, FreeEvent) else 1)
    ))

# Sort the dict by address
alloc_dict = OrderedDict(sorted(alloc_dict.items()))


# %%
def get_alloc_free_pairs(events, errors=None):
    previous_event = None
    for event in events:
        if isinstance(event, AllocEvent):
            if previous_event is None:
                pass
            elif isinstance(previous_event, AllocEvent):
                if errors is not None:
                    errors.append("Two consecutives allocs found")
        elif isinstance(event, FreeEvent):
            if previous_event is None:
                continue
            elif isinstance(previous_event, FreeEvent):
                if errors is not None:
                    errors.append("Two consecutives frees found")
            elif isinstance(previous_event, AllocEvent):
                yield (previous_event, event)
        else:
            assert 0, ("Unknown event type: %s" % type(event))

        previous_event = event

    if isinstance(previous_event, AllocEvent):
        yield (previous_event, None)


# %%
# %%time

# Basic checks of the events

for physical_address, events in alloc_dict.items():
    for event in events:
        if not isinstance(event, AllocEvent) and not isinstance(event, FreeEvent):
            raise RuntimeError("Unknown event type: %s" % type(event))

    errors: List[str] = []
    for (alloc_event, free_event) in get_alloc_free_pairs(events, errors):
        # Check the uniformity of the logical address between the alloc and the free
        if free_event is not None and alloc_event.memory_range.logical_address != free_event.logical_address:
            errors.append("Phys:0x%x: Alloc #%d - Free #%d with different logical address: 0x%x != 0x%x" % (
                physical_address,
                alloc_event.tr_begin.id, free_event.tr_begin.id,
                alloc_event.memory_range.logical_address, free_event.logical_address))

        # Check size of 0x0
        if alloc_event.memory_range.size == 0x0 or alloc_event.memory_range.size is None:
            if free_event is None:
                errors.append("Phys:0x%x: Alloc #%d - Free N/A with weird size %s" % (
                    physical_address,
                    alloc_event.tr_begin.id,
                    alloc_event.memory_range.size))
            else:
                errors.append("Phys:0x%x: Alloc #%d - Free #%d with weird size %s" % (
                    physical_address,
                    alloc_event.tr_begin.id, free_event.tr_begin.id,
                    alloc_event.memory_range.size))

    if len(errors) > 0:
        print("Phys:0x%x: Error(s) detected:" % (physical_address))
        for error in errors:
            print("    - %s" % error)

# %%
# Print the events

for physical_address, events in alloc_dict.items():
    print("Phys:0x%x" % (physical_address))
    print("    Events:")
    for event in events:
        if isinstance(event, AllocEvent):
            print("        - Alloc at #%d (0x%x of size 0x%x)" % (event.tr_begin.id,
                  event.memory_range.logical_address, event.memory_range.size))
        elif isinstance(event, FreeEvent):
            print("        - Free at #%d (0x%x)" % (event.tr_begin.id, event.logical_address))

    print("    Pairs:")
    for (alloc_event, free_event) in get_alloc_free_pairs(events):
        if free_event is None:
            print("    - Allocated at #%d (0x%x of size 0x%x) and freed at N/A" % (alloc_event.tr_begin.id,
                  alloc_event.memory_range.logical_address, alloc_event.memory_range.size))
        else:
            print("    - Allocated at #%d (0x%x of size 0x%x) and freed at #%d (0x%x)" % (alloc_event.tr_begin.id,
                  alloc_event.memory_range.logical_address, alloc_event.memory_range.size, free_event.tr_begin.id,
                  free_event.logical_address))

    print()


# %%
# Capstone utilities
def get_reven_register_from_name(name):
    for reg in reven2.arch.helpers.x64_registers():
        if reg.name == name:
            return reg

    raise RuntimeError("Unknown register: %s" % name)


def compute_dereferenced_address(ctx, cs_insn, cs_op):
    dereferenced_address = 0

    if cs_op.value.mem.base != 0:
        dereferenced_address += ctx.read(get_reven_register_from_name(cs_insn.reg_name(cs_op.value.mem.base)))

    if cs_op.value.mem.index != 0:
        dereferenced_address += (cs_op.value.mem.scale
                                 * ctx.read(get_reven_register_from_name(cs_insn.reg_name(cs_op.value.mem.index))))

    dereferenced_address += cs_op.value.mem.disp

    return dereferenced_address & 0xFFFFFFFFFFFFFFFF


# %%
# Function to compute the range of the intersection between two ranges
def range_intersect(r1, r2):
    return range(max(r1.start, r2.start), min(r1.stop, r2.stop)) or None


def bof_analyze_function(physical_address, alloc_events):
    bof_count = 0

    # Setup capstone
    md_64 = capstone.Cs(capstone.CS_ARCH_X86, capstone.CS_MODE_64)
    md_64.detail = True

    md_32 = capstone.Cs(capstone.CS_ARCH_X86, capstone.CS_MODE_32)
    md_32.detail = True

    errors = []
    for (alloc_event, free_event) in get_alloc_free_pairs(alloc_events, errors):
        # Setup the first taint [alloc; free]
        tainter = reven2.preview.taint.Tainter(server.trace)
        taint = tainter.simple_taint(
            tag0="rax" if alloc_event.tr_end.context_before().is64b() else "eax",
            from_context=alloc_event.tr_end.context_before(),
            to_context=free_event.tr_begin.context_before() + 1 if free_event is not None else None,
            is_forward=True
        )

        # Iterate on the slice
        access_count = 0
        for access in taint.accesses(changes_only=False).all():
            access_count += 1

            if access_count > taint_max_length:
                if free_event is None:
                    print("Phys:0x%x: Allocated at #%d (0x%x of size 0x%x) and freed at N/A" % (physical_address,
                          alloc_event.tr_begin.id, alloc_event.memory_range.logical_address,
                          alloc_event.memory_range.size))
                else:
                    print("Phys:0x%x: Allocated at #%d (0x%x of size 0x%x) and freed at #%d (0x%x)"
                          % (physical_address, alloc_event.tr_begin.id, alloc_event.memory_range.logical_address,
                             alloc_event.memory_range.size, free_event.tr_begin.id, free_event.logical_address))
                print("    Warning: Allocation skipped: taint stopped after %d accesses" % access_count)
                print()
                break

            ctx = access.transition.context_before()
            md = md_64 if ctx.is64b() else md_32
            cs_insn = next(md.disasm(access.transition.instruction.raw, access.transition.instruction.size))

            # Skip `lea` instructions are they are not really memory read/write and the taint
            # will propagate the taint anyway so that we will see the dereference of the computed value
            if cs_insn.mnemonic == "lea":
                continue

            registers_in_state = {}
            for reg_slice, _ in access.state_before().tainted_registers():
                registers_in_state[reg_slice.register.name] = reg_slice

            for cs_op in cs_insn.operands:
                if cs_op.type != capstone.x86.X86_OP_MEM:
                    continue

                bof_reg = None

                if cs_op.value.mem.base != 0:
                    base_reg_name = cs_insn.reg_name(cs_op.value.mem.base)
                    if base_reg_name in registers_in_state:
                        bof_reg = registers_in_state[base_reg_name]

                if bof_reg is None and cs_op.value.mem.index != 0:
                    index_reg_name = cs_insn.reg_name(cs_op.value.mem.index)
                    if index_reg_name in registers_in_state:
                        bof_reg = registers_in_state[index_reg_name]

                if bof_reg is None:
                    continue

                dereferenced_address = compute_dereferenced_address(ctx, cs_insn, cs_op)

                # We only check on the translated pages as the taint won't return an access with a pagefault
                # so the dereferenced address should be translated

                dereferenced_physical_address = reven2.address.LogicalAddress(dereferenced_address).translate(ctx)

                if dereferenced_physical_address is None:
                    continue

                operand_range = range(dereferenced_address, dereferenced_address + cs_op.size)
                before_buffer_range = range(alloc_event.memory_range.logical_address - bof_overflow_limit,
                                            alloc_event.memory_range.logical_address)
                after_buffer_range = range(alloc_event.memory_range.logical_address + alloc_event.memory_range.size,
                                           alloc_event.memory_range.logical_address + alloc_event.memory_range.size
                                           + bof_overflow_limit)

                if (range_intersect(operand_range, before_buffer_range) is None
                   and range_intersect(operand_range, after_buffer_range) is None):
                    continue

                if free_event is None:
                    print("Phys:0x%x: Allocated at #%d (0x%x of size 0x%x) and freed at N/A" % (physical_address,
                          alloc_event.tr_begin.id, alloc_event.memory_range.logical_address,
                          alloc_event.memory_range.size))
                else:
                    print("Phys:0x%x: Allocated at #%d (0x%x of size 0x%x) and freed at #%d (0x%x)"
                          % (physical_address, alloc_event.tr_begin.id, alloc_event.memory_range.logical_address,
                             alloc_event.memory_range.size, free_event.tr_begin.id, free_event.logical_address))
                print("    BOF coming from reg %s[%d-%d] leading to dereferenced address = 0x%x"
                      % (bof_reg.register.name, bof_reg.begin, bof_reg.end, dereferenced_address))
                print("    ", end="")
                print(access.transition, end=" ")
                print(ctx.ossi.location())
                print()

                bof_count += 1

    if len(errors) > 0:
        print("Phys:0x%x: Error(s) detected:" % (physical_address))
        for error in errors:
            print("    - %s" % error)

    return bof_count


# %%
# %%time

bof_count = 0

if bof_faulty_physical_address is None:
    for physical_address, alloc_events in alloc_dict.items():
        bof_count += bof_analyze_function(physical_address, alloc_events)

else:
    if bof_faulty_physical_address not in alloc_dict:
        raise KeyError("The passed physical address was not detected during the allocation search")
    bof_count += bof_analyze_function(bof_faulty_physical_address, alloc_dict[bof_faulty_physical_address])


print("---------------------------------------------------------------------------------")
bof_begin_range = "the beginning of the trace" if bof_from_tr is None else "#{}".format(bof_to_tr)
bof_end_range = "the end of the trace" if bof_to_tr is None else "#{}".format(bof_to_tr)
bof_range = ("on the whole trace" if bof_from_tr is None and bof_to_tr is None else
             "between {} and {}".format(bof_begin_range, bof_end_range))

bof_range_size = server.trace.transition_count
if bof_from_tr is not None:
    bof_range_size -= bof_from_tr
if bof_to_tr is not None:
    bof_range_size -= server.trace.transition_count - bof_to_tr

if bof_faulty_physical_address is None:
    searched_memory_addresses = "with {} searched memory addresses".format(len(alloc_dict))
else:
    searched_memory_addresses = "on {:#x}".format(bof_faulty_physical_address)

print("{} BOF(s) found {} ({} transitions) {}".format(
    bof_count, bof_range, bof_range_size, searched_memory_addresses
))
print("---------------------------------------------------------------------------------")

Detecting critical section deadlocks

This notebook checks for potential deadlocks that may occur when using critical sections as the synchronization primitive.

Prerequisites

  • This notebook should be run in a jupyter notebook server equipped with a REVEN 2 python kernel. REVEN comes with a jupyter notebook server accessible with the Open Python button in the Analyze page of any scenario.
  • The following resources must be replayed for the analyzed scenario:
  • Trace
  • OSSI
  • Fast Search

Limits

  • Only support RtlEnterCriticalSection and RtlLeaveCriticalSection as the lock (resp. unlock) primitive.
  • Locks and unlocks must be not nested: a critical section can be locked and (or) unlocked then, but it must not be relocked multiple times, e.g.
  • lock A, lock B, unlock A, lock A => OK
  • lock A, lock B, lock A => not OK (since A is locked again)

Running

Fill out the parameters in the Parameters cell below, then run all the cells of this notebook.

Source

# ---
# jupyter:
#   jupytext:
#     formats: ipynb,py:percent
#     text_representation:
#       extension: .py
#       format_name: percent
#   kernelspec:
#     display_name: reven
#     language: python
#     name: reven-python3
# ---

# %% [markdown]
# # Detecting critical section deadlocks
# This notebook checks for potential deadlocks that may occur when using critical sections as the synchronization
# primitive.
#
# ## Prerequisites
#  - This notebook should be run in a jupyter notebook server equipped with a REVEN 2 python kernel.
# REVEN comes with a jupyter notebook server accessible with the `Open Python` button
# in the `Analyze` page of any scenario.
#  - The following resources must be replayed for the analyzed scenario:
#     - Trace
#     - OSSI
#     - Fast Search
#
# ## Limits
#  - Only support `RtlEnterCriticalSection` and `RtlLeaveCriticalSection` as the lock (resp. unlock) primitive.
#  - Locks and unlocks must be not nested: a critical section can be locked and (or) unlocked then, but it must not be
# relocked multiple times, e.g.
#     - lock A, lock B, unlock A, lock A => OK
#     - lock A, lock B, lock A => not OK (since A is locked again)
#
# ## Running
# Fill out the parameters in the [Parameters cell](#Parameters) below, then run all the cells of this notebook.

# %%
# For Python's type annotation
from typing import Dict, Iterator, List, Optional, Set, Tuple

# Reven specific
import reven2
from reven2.address import LogicalAddress
from reven2.arch import x64
from reven2.trace import Context, Trace, Transition
from reven2.util import collate

# %% [markdown]
# # Parameters

# %%
# Host and port of the running the scenario
host = '127.0.0.1'
port = 35083


# The PID and (or) the name of the binary of interest: if the binary name is given (not None),
# only locks and unlocks called directly from the binary are counted.
pid = 2044
binary = None

# The begin and end transition numbers between them the deadlock detection processes.
begin_trans_id = None  # if None, start from the first transition of the trace
end_trans_id = None    # if None, stop at the last transition


# %%
# Helper class which wraps Reven's runtime objects and give methods helping get information about calls to
# RtlEnterCriticalSection and RtlLeaveCriticalSection
class RuntimeHelper:
    def __init__(self, host: str, port: int):
        try:
            server = reven2.RevenServer(host, port)
        except RuntimeError:
            raise RuntimeError(f'Cannot connect to the scenario at {host}:{port}')

        self.trace = server.trace
        self.ossi = server.ossi
        self.search_symbol = self.trace.search.symbol
        self.search_binary = self.trace.search.binary

        try:
            ntdll = next(self.ossi.executed_binaries('^c:/windows/system32/ntdll.dll$'))
        except StopIteration:
            raise RuntimeError('ntdll.dll not found')

        try:
            self.__rtl_enter_critical_section = next(ntdll.symbols("^RtlEnterCriticalSection$"))
            self.__rtl_leave_critical_section = next(ntdll.symbols("^RtlLeaveCriticalSection$"))
        except StopIteration:
            raise RuntimeError('Rtl(Enter|Leave)CriticalSection symbol not found')

    def get_critical_section_locks(self, from_context: Context, to_context: Context) -> Iterator[Context]:
        # For ease, if the from_context is the first context of the trace and also the beginning of a lock,
        # then omit it since the caller is unknown
        if from_context == self.trace.first_context:
            if from_context == self.trace.last_context:
                return
            else:
                from_context = from_context + 1

        for ctxt in self.search_symbol(self.__rtl_enter_critical_section, from_context, to_context):
            # Any context correspond to the entry of RtlEnterCriticalSection, since the first context
            # does not count, decrease by 1 is safe.
            yield ctxt - 1

    def get_critical_section_unlocks(self, from_context: Context, to_context: Context) -> Iterator[Context]:
        # For ease, if the from_context is the first context of the trace and also the beginning of an unlock,
        # then omit it since the caller is unknown
        if from_context == self.trace.first_context:
            if from_context == self.trace.last_context:
                return
            else:
                from_context = from_context + 1

        for ctxt in self.search_symbol(self.__rtl_leave_critical_section, from_context, to_context):
            # Any context correspond to the entry of RtlLeaveCriticalSection, since the first context
            # does not count, decrease by 1 is safe.
            yield ctxt - 1

    @staticmethod
    def get_critical_section_handle(ctxt: Context) -> int:
        return ctxt.read(x64.rcx)

    @staticmethod
    def thread_id(ctxt: Context) -> int:
        return ctxt.read(LogicalAddress(0x48, x64.gs), 4)

    @staticmethod
    def is_kernel_mode(ctxt: Context) -> bool:
        return ctxt.read(x64.cs) & 0x3 == 0


# %%
# Find the lower/upper bound of contexts on which the deadlock detection processes
def find_begin_end_context(sco: RuntimeHelper, pid: int, binary: Optional[str],
                           begin_id: Optional[int], end_id: Optional[int]) -> Tuple[Context, Context]:
    begin_ctxt = None
    if begin_id is not None:
        try:
            begin_trans = sco.trace.transition(begin_id)
            begin_ctxt = begin_trans.context_after()
        except IndexError:
            begin_ctxt = None

    if begin_ctxt is None:
        if binary is not None:
            for name in sco.ossi.executed_binaries(binary):
                for ctxt in sco.search_binary(name):
                    ctx_process = ctxt.ossi.process()
                    assert ctx_process is not None

                    if ctx_process.pid == pid:
                        begin_ctxt = ctxt
                        break
                if begin_ctxt is not None:
                    break

    if begin_ctxt is None:
        begin_ctxt = sco.trace.first_context

    end_ctxt = None
    if end_id is not None:
        try:
            end_trans = sco.trace.transition(end_id)
            end_ctxt = end_trans.context_before()
        except IndexError:
            end_ctxt = None

    if end_ctxt is None:
        end_ctxt = sco.trace.last_context

    if (end_ctxt <= begin_ctxt):
        raise RuntimeError("The begin transition must be smaller than the end.")

    return (begin_ctxt, end_ctxt)


# Get all execution contexts of a given process
def find_process_ranges(sco: RuntimeHelper, pid: int, first_ctxt: Context, last_context: Context) \
        -> Iterator[Tuple[Context, Context]]:
    ctxt_low = first_ctxt
    ctxt_high: Optional[Context] = None

    while True:
        current_process = ctxt_low.ossi.process()
        assert current_process is not None

        current_pid = current_process.pid
        ctxt_high = ctxt_low.find_register_change(x64.cr3, is_forward=True)

        if ctxt_high is None:
            if current_pid == pid:
                if ctxt_low < last_context - 1:
                    yield (ctxt_low, last_context - 1)
            break

        if ctxt_high >= last_context:
            break

        if current_pid == pid:
            yield (ctxt_low, ctxt_high)

        ctxt_low = ctxt_high + 1


# Start from a transition, return the first transition that is not a non-instruction, or None if there isn't one.
def ignore_non_instructions(trans: Transition, trace: Trace) -> Optional[Transition]:
    while trans.instruction is None:
        if trans == trace.last_transition:
            return None
        trans = trans + 1

    return trans


# Extract user mode only context ranges from a context range (which may include also kernel mode ranges)
def find_usermode_ranges(sco: RuntimeHelper, ctxt_low: Context, ctxt_high: Context) \
        -> Iterator[Tuple[Context, Context]]:
    trans = ignore_non_instructions(ctxt_low.transition_after(), sco.trace)
    if trans is None:
        return
    ctxt_current = trans.context_before()

    while ctxt_current < ctxt_high:
        ctxt_next = ctxt_current.find_register_change(x64.cs, is_forward=True)

        if not RuntimeHelper.is_kernel_mode(ctxt_current):
            if ctxt_next is None or ctxt_next > ctxt_high:
                yield (ctxt_current, ctxt_high)
                break
            else:
                # It's safe to decrease ctxt_next by 1 because it was obtained from a forward find_register_change
                yield (ctxt_current, ctxt_next - 1)

        if ctxt_next is None:
            break

        ctxt_current = ctxt_next


# Get user mode only execution contexts of a given process
def find_process_usermode_ranges(trace: RuntimeHelper, pid: int, first_ctxt: Context, last_ctxt: Context) \
        -> Iterator[Tuple[Context, Context]]:
    for (ctxt_low, ctxt_high) in find_process_ranges(trace, pid, first_ctxt, last_ctxt):
        usermode_ranges = find_usermode_ranges(trace, ctxt_low, ctxt_high)
        for usermode_range in usermode_ranges:
            yield usermode_range


def context_is_in_binary(ctxt: Optional[Context], binary: Optional[str]) -> bool:
    if ctxt is None:
        return False

    if binary is None:
        return True

    ctxt_loc = ctxt.ossi.location()
    if ctxt_loc is None:
        return False

    ctxt_binary = ctxt_loc.binary
    if ctxt_binary is not None:
        return (binary in [ctxt_binary.name, ctxt_binary.filename, ctxt_binary.path])

    return False


# Get locks (i.e. RtlEnterCriticalSection) called by the binary in a range of context (defined by ctxt_low and
# ctxt_high).
# Note that for a process (of a binary), there are also locks called by libraries loaded by PE loader, such calls
# are considered "uninteresting" if the binary name is given.
def get_in_binary_locks(sco: RuntimeHelper, ctxt_low: Context, ctxt_high: Context, binary: Optional[str]) \
        -> Iterator[Context]:
    for ctxt in sco.get_critical_section_locks(ctxt_low, ctxt_high):
        if context_is_in_binary(ctxt, binary):
            yield ctxt


# Get unlocks (i.e. RtlLeaveCriticalSection) called by the binary in a range of context (defined by ctxt_low and
# ctxt_high).
# Note that for a process (of a binary), there are also unlocks called by libraries loaded by PE loader, such calls
# are considered "uninteresting" if the binary name is given.
def get_in_binary_unlocks(sco: RuntimeHelper, ctxt_low: Context, ctxt_high: Context, binary: Optional[str]) \
        -> Iterator[Context]:
    for ctxt in sco.get_critical_section_unlocks(ctxt_low, ctxt_high):
        if context_is_in_binary(ctxt, binary):
            yield ctxt


# Sort lock and unlock contexts (called in a range of contexts) in a correct order.
# Return a generator of pairs (bool, Context): True for lock, False for unlock.
def get_in_binary_locks_unlocks(sco: RuntimeHelper, ctxt_low: Context, ctxt_high: Context, binary: Optional[str]) \
        -> Iterator[Tuple[bool, Context]]:
    def generate_locks():
        for ctxt in get_in_binary_locks(sco, ctxt_low, ctxt_high, binary):
            yield (True, ctxt)

    def generate_unlocks():
        for ctxt in get_in_binary_unlocks(sco, ctxt_low, ctxt_high, binary):
            yield (False, ctxt)

    return collate([generate_locks(), generate_unlocks()], key=lambda bool_context: bool_context[1])


# Generate all locks and unlocks called by the binary
def get_thread_usermode_in_binary_locks_unlocks(sco: RuntimeHelper, ranges: List[Tuple[Context, Context]],
                                                binary: Optional[str]) -> Iterator[Tuple[bool, Context]]:
    for (ctxt_low, ctxt_high) in ranges:
        for lock_unlock in get_in_binary_locks_unlocks(sco, ctxt_low, ctxt_high, binary):
            yield lock_unlock


# %%
# Check for RAII in locks/unlocks (i.e. critical sections are unlocked in reverse order of lock), warn if RAII is
# violated as it is good practice.
#
# In synchronization with critical sections, RAII is an idiomatic technique used to make the lock order consistent,
# then avoid deadlock in case of there is a total order of locks between threads. For example: the following threads
# where the lock/unlock of critical sections A and B follows RAII
#  - thread 0: lock A, lock B, unlock B, unlock A, lock A, lock B
#  - thread 1: lock A, lock B, unlock B, unlock A, lock A, lock B
# are deadlock free. But
#  - thread 0: lock A, lock B, unlock A, lock A, unlock B
#  - thread 1: lock A, lock B, unlock A, lock A, unlock B
# have deadlock. Let's consider the interleaving:
#  - lock A, lock B, unlock A (thread 0), lock A (thread 1)
# now thread 1 try to lock B (but it cannot since B is already locked by thread 0), thread 0 cannot unlock B neither
# since it needs to lock A first (but it cannot since A is already locked by thread 1).
#
# Note that RAII cannot guarantee the deadlock free synchronization if the condition about single total order of
# critical sections is not satisfied.
def check_intrathread_lock_unlock_matching(trace: RuntimeHelper, thread_ranges: List[Tuple[Context, Context]],
                                           binary: Optional[str]):
    locks_unlocks = get_thread_usermode_in_binary_locks_unlocks(trace, thread_ranges, binary)

    # lock/unlock should (though not obliged) follow RAII
    corresponding_lock_stack: List[Context] = []
    mismatch_lock_unlock_pcs: Set[Tuple[int, int]] = set()
    mismatch_unlocks_pcs: Set[int] = set()
    ok = True
    for (is_lock, ctxt) in locks_unlocks:
        if is_lock:
            # push lock context
            corresponding_lock_stack.append(ctxt)
        else:
            if corresponding_lock_stack:
                last_lock_ctxt = corresponding_lock_stack[-1]

                last_lock_handle = RuntimeHelper.get_critical_section_handle(last_lock_ctxt)
                current_unlock_handle = RuntimeHelper.get_critical_section_handle(ctxt)

                if last_lock_handle == current_unlock_handle:
                    # lock and unlock on the same critical section
                    corresponding_lock_stack.pop()
                else:
                    # It's safe to decrease by 1 since the first context of the trace is never counted as a lock
                    # nor unlock (c.f. RuntimeHelper::get_critical_section_(locks|unlocks)).
                    in_binary_lock_pc = last_lock_ctxt.read(x64.rip)
                    in_binary_unlock_pc = ctxt.read(x64.rip)

                    if (in_binary_lock_pc, in_binary_unlock_pc) in mismatch_lock_unlock_pcs:
                        continue

                    mismatch_lock_unlock_pcs.add((in_binary_lock_pc, in_binary_unlock_pc))

                    print(f'Warning:\n\t#{last_lock_ctxt.transition_after().id}: lock at 0x{in_binary_lock_pc:x} \
(on critical section handle 0x{last_lock_handle:x}) followed by\n\t#{ctxt.transition_after().id}: \
unlock at 0x{in_binary_unlock_pc:x} (on different critical section handle 0x{current_unlock_handle:x})')

                    ok = False
            else:
                in_binary_unlock_pc = ctxt.read(x64.rip)
                if in_binary_unlock_pc in mismatch_unlocks_pcs:
                    continue

                mismatch_unlocks_pcs.add(in_binary_unlock_pc)

                print(f'Warning:\n\t#{ctxt.transition_after().id}: unlock at  \
(on 0x{current_unlock_handle:x}) without any lock')
                ok = False
    if ok:
        print('OK')


# Build a dependency graph of locks: a lock A is followed by a lock B if A is still locked when B is locked.
# For example:
#  - lock A, lock B => A followed by B
#  - lock A, unlock A, lock B => A is not followed by B (since A is already unlocked when B is locked)
def build_locks_unlocks_order_graph_next(locks_unlocks: Iterator[Tuple[bool, Context]]) \
        -> Tuple[Dict[int, List[int]], Dict[Tuple[int, int], List[Tuple[Context, Context]]]]:

    order_graph: Dict[int, List[int]] = {}
    order_graph_label: Dict[Tuple[int, int], List[Tuple[Context, Context]]] = {}

    lock_stack: List[Context] = []
    for (is_lock, ctxt) in locks_unlocks:
        if not lock_stack:
            if is_lock:
                lock_stack.append(ctxt)
            continue

        current_lock_unlock_handle = RuntimeHelper.get_critical_section_handle(ctxt)
        current_lock_unlock_threadid = RuntimeHelper.thread_id(ctxt)
        if not is_lock:
            # looking for the last lock in the stack
            i = len(lock_stack) - 1
            while True:
                lock_handle_i = RuntimeHelper.get_critical_section_handle(lock_stack[i])
                lock_threadid_i = RuntimeHelper.thread_id(lock_stack[i])
                if lock_handle_i == current_lock_unlock_handle and lock_threadid_i == current_lock_unlock_threadid:
                    del lock_stack[i]
                    break
                if i == 0:
                    break
                i -= 1
            continue

        last_lock_ctxt = lock_stack[-1]
        # check of the last lock and the current lock are in the same thread
        if RuntimeHelper.thread_id(last_lock_ctxt) == RuntimeHelper.thread_id(ctxt):
            # create the edge: last_lock -> current_lock
            last_lock_handle = RuntimeHelper.get_critical_section_handle(last_lock_ctxt)
            if last_lock_handle not in order_graph:
                order_graph[last_lock_handle] = []
            order_graph[last_lock_handle].append(current_lock_unlock_handle)

            # create (or update) the label of the edge
            if (last_lock_handle, current_lock_unlock_handle) not in order_graph_label:
                order_graph_label[(last_lock_handle, current_lock_unlock_handle)] = []
            order_graph_label[(last_lock_handle, current_lock_unlock_handle)].append((last_lock_ctxt, ctxt))

        lock_stack.append(ctxt)

    return (order_graph, order_graph_label)


# Check if there are cycles in the lock dependency graph, such a cycle is considered a potential deadlock.
def check_order_graph_cycle(graph: Dict[int, List[int]], labels: Dict[Tuple[int, int], List[Tuple[Context, Context]]]):
    def dfs(path: List[Tuple[int, int, int]], starting_node: int, visited_nodes: Set[int]) \
            -> Optional[List[List[Tuple[int, int, int]]]]:
        if starting_node not in graph:
            return None

        next_nodes_to_visit = set(graph[starting_node]) - visited_nodes
        if not next_nodes_to_visit:
            return None

        nodes_on_path = set()
        tids_on_path = set()
        for (hd, tl, tid) in path:
            nodes_on_path.add(hd)
            nodes_on_path.add(tl)
            tids_on_path.add(tid)

        # check if we can build a cycle of locks by trying visiting a node
        back_nodes = next_nodes_to_visit & nodes_on_path
        for node in back_nodes:
            back_node = node
            tids_starting_back = set()
            for (ctxt_hd, _) in labels[(starting_node, back_node)]:
                tids_starting_back.add(RuntimeHelper.thread_id(ctxt_hd))

            # sub-path starting from the back-node to the starting-node
            sub_path = []
            sub_path_tids = set()
            for (hd, tl, tid) in path:
                if hd == node:
                    sub_path.append((hd, tl, tid))
                    sub_path_tids.add(tid)
                    node = tl

                if tl == starting_node:
                    diff_tids = tids_starting_back - sub_path_tids
                    # there is an edge whose TID is not on the sub-path yet
                    cycles = []
                    if diff_tids:
                        for tid in diff_tids:
                            for (ctxt_hd, _) in labels[(starting_node, back_node)]:
                                if RuntimeHelper.thread_id(ctxt_hd) == tid:
                                    sub_path.append((starting_node, back_node, tid))
                                    cycles.append(sub_path)
                                    break
                        return cycles
                    else:
                        return None

        for next_node in next_nodes_to_visit:
            tids = set()
            for (ctxt_hd, _,) in labels[(starting_node, next_node)]:
                tid = RuntimeHelper.thread_id(ctxt_hd)
                tids.add(tid)

            tids_to_visit = tids - tids_on_path
            if tids_to_visit:
                for tid_to_visit in tids_to_visit:
                    for (ctxt_hd, _) in labels[(starting_node, next_node)]:
                        if RuntimeHelper.thread_id(ctxt_hd) == tid_to_visit:
                            next_path = path
                            next_path.append((starting_node, next_node, tid))
                            visited_nodes.add(next_node)
                            some_cycles = dfs(next_path, next_node, visited_nodes)
                            if some_cycles is not None:
                                return some_cycles

        return None

    def compare_cycles(c0: List[Tuple[int, int, int]], c1: List[Tuple[int, int, int]]) -> bool:
        if len(c0) != len(c1):
            return False

        if len(c0) == 0:
            return True

        def circular_generator(c: List[Tuple[int, int, int]], e: Tuple[int, int, int]) \
                -> Optional[Iterator[Tuple[int, int, int]]]:
            clen = len(c)
            i = 0
            for elem in c:
                if elem == e:
                    while True:
                        yield c[i]
                        i = i + 1
                        if i == clen:
                            i = 0
                i = i + 1

            return None

        c0_gen = circular_generator(c0, c0[0])
        c1_gen = circular_generator(c1, c0[0])

        if c1_gen is None or c0_gen is None:
            return False

        i = 0
        while i < len(c0):
            e0 = next(c0_gen)
            e1 = next(c1_gen)
            if e0 != e1:
                return False
            i = i + 1

        return True

    ok = True
    distinct_cycles: List[List[Tuple[int, int, int]]] = []
    for node in graph:
        cycles = dfs([], node, set())
        if cycles is None or not cycles:
            continue

        for cycle in cycles:
            duplicated = False
            if not distinct_cycles:
                duplicated = False
            else:
                for ec in distinct_cycles:
                    if compare_cycles(ec, cycle):
                        duplicated = True
                        break
            if duplicated:
                continue

            distinct_cycles.append(cycle)

            print('Potential deadlock(s):')
            for (node, next_node, tid) in cycle:
                label = labels[(node, next_node)]
                distinct_labels: Set[Tuple[int, int]] = set()
                for (node_ctxt, next_node_ctxt) in label:
                    in_binary_lock_pc = (node_ctxt - 1).read(x64.rip)
                    in_binary_next_lock_pc = (next_node_ctxt - 1).read(x64.rip)

                    if (in_binary_lock_pc, in_binary_next_lock_pc) in distinct_labels:
                        continue

                    distinct_labels.add((in_binary_lock_pc, in_binary_next_lock_pc))

                    print(f'\t#{node_ctxt.transition_after().id}: lock at 0x{in_binary_lock_pc:x} \
(on thread {RuntimeHelper.thread_id(node_ctxt)}, critical section handle \
0x{RuntimeHelper.get_critical_section_handle(node_ctxt):x}) followed by\n\t\
#{next_node_ctxt.transition_after().id}: lock at 0x{in_binary_next_lock_pc:x} \
(on thread {RuntimeHelper.thread_id(next_node_ctxt)}, \
critical section handle 0x{RuntimeHelper.get_critical_section_handle(next_node_ctxt):x})')
                    ok = False
                print(
                    '\t=============================================================='
                )
    if ok:
        print('Not found.')
    return None


# Get user mode locks (and unlocks) of called by the binary, build the dependency graph, then check the cycles.
def check_lock_cycle(trace: RuntimeHelper, threads_ranges: List[Tuple[Context, Context]], binary: Optional[str]):
    locks_unlocks = get_thread_usermode_in_binary_locks_unlocks(trace, threads_ranges, binary)
    (order_graph, order_graph_labels) = build_locks_unlocks_order_graph_next(locks_unlocks)
    check_order_graph_cycle(order_graph, order_graph_labels)


# Combination of checking lock/unlock RAII and deadlock
def detect_deadlocks(trace: RuntimeHelper, pid: int, binary: Optional[str],
                     first_id: Optional[int], last_id: Optional[int]):
    (first_context, last_context) = find_begin_end_context(trace, pid, binary, first_id, last_id)
    process_ranges = list(find_process_usermode_ranges(trace, pid, first_context, last_context))
    thread_ranges: Dict[int, List[Tuple[Context, Context]]] = {}

    for (ctxt_low, ctxt_high) in process_ranges:
        tid = RuntimeHelper.thread_id(ctxt_low)
        if tid not in thread_ranges:
            thread_ranges[tid] = []
        thread_ranges[tid].append((ctxt_low, ctxt_high))

    for (tid, ranges) in thread_ranges.items():
        print('\n============ checking lock/unlock matching on thread {} ============'.format(tid))
        check_intrathread_lock_unlock_matching(trace, ranges, binary)

    print('\n\n============ checking potential deadlocks on process ============')
    check_lock_cycle(trace, process_ranges, binary)


# %%
trace = RuntimeHelper(host, port)
detect_deadlocks(trace, pid, binary, begin_trans_id, end_trans_id)

Searching for Use-of-Uninitialized-Memory vulnerabilities

This notebook allows to search for potential Use-of-Uninitialized-Memory vulnerabilities in a REVEN trace.

Prerequisites

  • This notebook should be run in a jupyter notebook server equipped with a REVEN 2 python kernel. REVEN comes with a jupyter notebook server accessible with the Open Python button in the Analyze page of any scenario.
  • This notebook depends on capstone being installed in the REVEN 2 python kernel. To install capstone in the current environment, please execute the capstone cell of this notebook.
  • This notebook requires the Memory History resource for your target scenario.

Running the notebook

Fill out the parameters in the Parameters cell below, then run all the cells of this notebook.

Source

# -*- coding: utf-8 -*-
# ---
# jupyter:
#   jupytext:
#     formats: ipynb,py:percent
#     text_representation:
#       extension: .py
#       format_name: percent
#   kernelspec:
#     display_name: reven
#     language: python
#     name: reven-python3
# ---

# %% [markdown]
# # Searching for Use-of-Uninitialized-Memory vulnerabilities
#
# This notebook allows to search for potential Use-of-Uninitialized-Memory vulnerabilities in a REVEN trace.
#
# ## Prerequisites
#
# - This notebook should be run in a jupyter notebook server equipped with a REVEN 2 python kernel.
#   REVEN comes with a jupyter notebook server accessible with the `Open Python` button
#   in the `Analyze` page of any scenario.
# - This notebook depends on capstone being installed in the REVEN 2 python kernel.
#   To install capstone in the current environment, please execute the capstone cell of this notebook.
# - This notebook requires the Memory History resource for your target scenario.
#
#
# ## Running the notebook
#
# Fill out the parameters in the [Parameters cell](#Parameters) below, then run all the cells of this notebook.

# %% [markdown]
# # Capstone Installation
#
# Check for capstone's presence. If missing, attempt to get it from pip

# %%
try:
    import capstone
    print("capstone already installed")
except ImportError:
    print("Could not find capstone, attempting to install it from pip")
    import sys
    import subprocess

    command = [f"{sys.executable}", "-m", "pip", "install", "capstone"]
    p = subprocess.run(command)

    if int(p.returncode) != 0:
        raise RuntimeError("Error installing capstone")

    import capstone  # noqa
    print("Successfully installed capstone")

# %% [markdown]
# # Parameters

# %%
# Server connection

# Host of the REVEN server running the scenario.
# When running this notebook from the Project Manager, '127.0.0.1' should be the correct value.
reven_backend_host = '127.0.0.1'

# Port of the REVEN server running the scenario.
# After starting a REVEN server on your scenario, you can get its port on the Analyze page of that scenario.
reven_backend_port = 13370


# Range control

# First transition considered for the detection of allocation/deallocation pairs
# If set to None, then the first transition of the trace
from_tr = None

# First transition **not** considered for the detection of allocation/deallocation pairs
# If set to None, then the last transition of the trace
to_tr = None


# Filter control

# Beware that, when filtering, if an allocation happens in the specified process and/or binary,
# the script will fail if the deallocation happens in a different process and/or binary.
# This issue should only happen for allocations in the kernel.

# Specify on which PID the allocation/deallocation pairs should be kept.
# Set a value of None to not filter on the PID.
faulty_process_pid = None

# Specify on which process name the allocation/deallocation pairs should be kept.
# Set a value of None to not filter on the process name.
# If both a process PID and a process name are specified, please make sure that they both
# refer to the same process, otherwise all allocations will be filtered and no results
# will be produced.
faulty_process_name = None

# Specify on which binary name the allocation/deallocation pairs should be kept.
# Set a value of None to not filter on the binary name.
# Only allocation/deallocation taking place in the binaries whose filename,
# path or name contain the specified value are kept.
# If filtering on both a process and a binary, please make sure that there are
# allocations taking place in that binary in the selected process, otherwise all
# allocations will be filtered and no result will be produced.
faulty_binary = None


# Address control

# Specify a **physical** address suspected of being faulty here,
# to only test the script for this specific address, instead of all (filtered) allocations.
# The address should still be returned by an allocation/deallocation pair.
# To get a physical address from a virtual address, find a context where the address
# is mapped, then use `virtual_address.translate(ctx)`.
faulty_physical_address = None

# Allocator control

# The script can use two allocators to find allocation/deallocation pairs.
# The following booleans allow to enable the search for allocations by these
# allocators for a scenario.
# Generally it is expected to have only a single allocator enabled for a given
# scenario.

# To add your own allocator, please look at how the two provided allocators were
# added.

# Whether or not to look for windows malloc/free allocation/deallocation pairs.
search_windows_malloc = True

# Whether or not to look for ExAllocatePoolWithTag/ExFreePoolWithTag
# allocation/deallocation pairs.
# This allocator is used by the Windows kernel.
search_pool_allocation = False

# Whether or not to look for linux malloc/free allocation/deallocation pairs.
search_linux_malloc = False


# Analysis control

# Whether or not the display should be restricted to UUMs impacting the control flow
only_display_uum_changing_control_flow = False

# %%
import struct  # noqa: E402
from collections import OrderedDict  # noqa: E402
from typing import Dict, List  # noqa: E402

import reven2  # noqa: E402
import reven2.preview.taint  # noqa: E402

# %%
# Python script to connect to this scenario:
server = reven2.RevenServer(reven_backend_host, reven_backend_port)
print(server.trace.transition_count)


# %%
class MemoryRange:
    page_size = 4096
    page_mask = ~(page_size - 1)

    def __init__(self, logical_address, size):
        self.logical_address = logical_address
        self.size = size

        self.pages = [{
            'logical_address': self.logical_address,
            'size': self.size,
            'physical_address': None,
            'ctx_physical_address_mapped': None,
        }]

        # Compute the pages
        while (((self.pages[-1]['logical_address'] & ~MemoryRange.page_mask) + self.pages[-1]['size'] - 1)
                >= MemoryRange.page_size):
            # Compute the size of the new page
            new_page_size = ((self.pages[-1]['logical_address'] & ~MemoryRange.page_mask) + self.pages[-1]['size']
                             - MemoryRange.page_size)

            # Reduce the size of the previous page and create the new one
            self.pages[-1]['size'] -= new_page_size
            self.pages.append({
                'logical_address': self.pages[-1]['logical_address'] + self.pages[-1]['size'],
                'size': new_page_size,
                'physical_address': None,
                'ctx_physical_address_mapped': None,
            })

    def try_translate_first_page(self, ctx):
        if self.pages[0]['physical_address'] is not None:
            return True

        physical_address = reven2.address.LogicalAddress(self.pages[0]['logical_address']).translate(ctx)

        if physical_address is None:
            return False

        self.pages[0]['physical_address'] = physical_address.offset
        self.pages[0]['ctx_physical_address_mapped'] = ctx

        return True

    def try_translate_all_pages(self, ctx):
        return_value = True

        for page in self.pages:
            if page['physical_address'] is not None:
                continue

            physical_address = reven2.address.LogicalAddress(page['logical_address']).translate(ctx)

            if physical_address is None:
                return_value = False
                continue

            page['physical_address'] = physical_address.offset
            page['ctx_physical_address_mapped'] = ctx

        return return_value

    def is_physical_address_range_in_translated_pages(self, physical_address, size):
        for page in self.pages:
            if page['physical_address'] is None:
                continue

            if (
                physical_address >= page['physical_address']
                and physical_address + size <= page['physical_address'] + page['size']
            ):
                return True

        return False

    def __repr__(self):
        return "MemoryRange(0x%x, %d)" % (self.logical_address, self.size)


# Utils to translate the physical address of an address allocated just now
#     - ctx should be the ctx where the address is located in `rax`
#     - memory_range should be the range of memory of the newly allocated buffer
#
# We are using the translate API to translate it but sometimes just after the allocation
# the address isn't mapped yet. For that we are using the slicing and for all slice access
# we are trying to translate the address.
def translate_first_page_of_allocation(ctx, memory_range):
    if memory_range.try_translate_first_page(ctx):
        return

    tainter = reven2.preview.taint.Tainter(server.trace)
    taint = tainter.simple_taint(
        tag0="rax",
        from_context=ctx,
        to_context=None,
        is_forward=True
    )

    for access in taint.accesses(changes_only=False).all():
        if memory_range.try_translate_first_page(access.transition.context_after()):
            taint.cancel()
            return

    raise RuntimeError("Couldn't find the physical address of the first page")


# %%
class AllocEvent:
    def __init__(self, memory_range, tr_begin, tr_end):
        self.memory_range = memory_range
        self.tr_begin = tr_begin
        self.tr_end = tr_end


class FreeEvent:
    def __init__(self, logical_address, tr_begin, tr_end):
        self.logical_address = logical_address
        self.tr_begin = tr_begin
        self.tr_end = tr_end


def retrieve_events_for_symbol(
    alloc_dict,
    event_class,
    symbol,
    retrieve_event_info,
    event_filter=None,
):
    for ctx in server.trace.search.symbol(
        symbol,
        from_context=None if from_tr is None else server.trace.context_before(from_tr),
        to_context=None if to_tr is None else server.trace.context_before(to_tr)
    ):
        # We don't want hit on exception (code pagefault, hardware interrupts, etc)
        if ctx.transition_after().exception is not None:
            continue

        previous_location = (ctx - 1).ossi.location()
        previous_process = (ctx - 1).ossi.process()

        # Filter by process pid/process name/binary name

        # Filter by process pid
        if faulty_process_pid is not None and previous_process.pid != faulty_process_pid:
            continue

        # Filter by process name
        if faulty_process_name is not None and previous_process.name != faulty_process_name:
            continue

        # Filter by binary name / filename / path
        if faulty_binary is not None and faulty_binary not in [
            previous_location.binary.name,
            previous_location.binary.filename,
            previous_location.binary.path
        ]:
            continue

        # Filter the event with the argument filter
        if event_filter is not None:
            if event_filter(ctx.ossi.location(), previous_location):
                continue

        # Retrieve the call/ret
        # The heuristic is that the ret is the end of our function
        #   - If the call is inlined it should be at the end of the caller function, so the ret is the ret of our
        #     function
        #   - If the call isn't inlined, the ret should be the ret of our function
        ctx_call = next(ctx.stack.frames()).creation_transition.context_after()
        ctx_ret = ctx_call.transition_before().find_inverse().context_before()

        # Build the event by reading the needed registers
        if event_class == AllocEvent:
            current_address, size = retrieve_event_info(ctx, ctx_ret)

            # Filter the alloc failing
            if current_address == 0x0:
                continue

            memory_range = MemoryRange(current_address, size)
            try:
                translate_first_page_of_allocation(ctx_ret, memory_range)
            except RuntimeError:
                # If we can't translate the first page we assume that the buffer isn't used because
                # the heuristic to detect the call/ret failed
                continue

            if memory_range.pages[0]['physical_address'] not in alloc_dict:
                alloc_dict[memory_range.pages[0]['physical_address']] = []

            alloc_dict[memory_range.pages[0]['physical_address']].append(
                AllocEvent(
                    memory_range,
                    ctx.transition_after(), ctx_ret.transition_after()
                )
            )
        elif event_class == FreeEvent:
            current_address = retrieve_event_info(ctx, ctx_ret)

            # Filter the free of NULL
            if current_address == 0x0:
                continue

            current_physical_address = reven2.address.LogicalAddress(current_address).translate(ctx).offset

            if current_physical_address not in alloc_dict:
                alloc_dict[current_physical_address] = []

            alloc_dict[current_physical_address].append(
                FreeEvent(
                    current_address,
                    ctx.transition_after(), ctx_ret.transition_after()
                )
            )
        else:
            raise RuntimeError("Unknown event class: %s" % event_class.__name__)


# %%
# %%time

alloc_dict: Dict = {}


# Basic functions to retrieve the arguments
# They are working for the allocations/frees functions but won't work for all functions
# Particularly because on x86 we don't handle the size of the arguments
# nor if they are pushed left to right or right to left
def retrieve_first_argument(ctx):
    if ctx.is64b():
        return ctx.read(reven2.arch.x64.rcx)
    else:
        esp = ctx.read(reven2.arch.x64.esp)
        return ctx.read(reven2.address.LogicalAddress(esp + 4, reven2.arch.x64.ss), 4)


def retrieve_second_argument(ctx):
    if ctx.is64b():
        return ctx.read(reven2.arch.x64.rdx)
    else:
        esp = ctx.read(reven2.arch.x64.esp)
        return ctx.read(reven2.address.LogicalAddress(esp + 8, reven2.arch.x64.ss), 4)


def retrieve_first_argument_linux(ctx):
    if ctx.is64b():
        return ctx.read(reven2.arch.x64.rdi)
    else:
        raise NotImplementedError("Linux 32bits")


def retrieve_second_argument_linux(ctx):
    if ctx.is64b():
        return ctx.read(reven2.arch.x64.rsi)
    else:
        raise NotImplementedError("Linux 32bits")


def retrieve_return_value(ctx):
    if ctx.is64b():
        return ctx.read(reven2.arch.x64.rax)
    else:
        return ctx.read(reven2.arch.x64.eax)


def retrieve_alloc_info_with_size_as_first_argument(ctx_begin, ctx_end):
    return (
        retrieve_return_value(ctx_end),
        retrieve_first_argument(ctx_begin)
    )


def retrieve_alloc_info_with_size_as_first_argument_linux(ctx_begin, ctx_end):
    return (
        retrieve_return_value(ctx_end),
        retrieve_first_argument_linux(ctx_begin)
    )


def retrieve_alloc_info_with_size_as_second_argument(ctx_begin, ctx_end):
    return (
        retrieve_return_value(ctx_end),
        retrieve_second_argument(ctx_begin)
    )


def retrieve_alloc_info_with_size_as_second_argument_linux(ctx_begin, ctx_end):
    return (
        retrieve_return_value(ctx_end),
        retrieve_second_argument_linux(ctx_begin)
    )


def retrieve_alloc_info_for_calloc(ctx_begin, ctx_end):
    return (
        retrieve_return_value(ctx_end),
        retrieve_first_argument(ctx_begin) * retrieve_second_argument(ctx_begin)
    )


def retrieve_alloc_info_for_calloc_linux(ctx_begin, ctx_end):
    return (
        retrieve_return_value(ctx_end),
        retrieve_first_argument_linux(ctx_begin) * retrieve_second_argument_linux(ctx_begin)
    )


def retrieve_free_info_with_address_as_first_argument(ctx_begin, ctx_end):
    return retrieve_first_argument(ctx_begin)


def retrieve_free_info_with_address_as_first_argument_linux(ctx_begin, ctx_end):
    return retrieve_first_argument_linux(ctx_begin)


if search_windows_malloc:
    def filter_in_realloc(location, caller_location):
        return location.binary == caller_location.binary and caller_location.symbol.name == "realloc"

    # Search for allocations with malloc
    for symbol in server.ossi.symbols(r'^_?malloc$', binary_hint=r'msvcrt.dll'):
        retrieve_events_for_symbol(alloc_dict, AllocEvent, symbol, retrieve_alloc_info_with_size_as_first_argument,
                                   filter_in_realloc)

    # Search for allocations with calloc
    for symbol in server.ossi.symbols(r'^_?calloc(_crt)?$', binary_hint=r'msvcrt.dll'):
        retrieve_events_for_symbol(alloc_dict, AllocEvent, symbol, retrieve_alloc_info_for_calloc)

    # Search for deallocations with free
    for symbol in server.ossi.symbols(r'^_?free$', binary_hint=r'msvcrt.dll'):
        retrieve_events_for_symbol(alloc_dict, FreeEvent, symbol, retrieve_free_info_with_address_as_first_argument,
                                   filter_in_realloc)

    # Search for re-allocations with realloc
    for symbol in server.ossi.symbols(r'^_?realloc$', binary_hint=r'msvcrt.dll'):
        retrieve_events_for_symbol(alloc_dict, AllocEvent, symbol, retrieve_alloc_info_with_size_as_second_argument)
        retrieve_events_for_symbol(alloc_dict, FreeEvent, symbol, retrieve_free_info_with_address_as_first_argument)

if search_pool_allocation:
    # Search for allocations with ExAllocatePool...
    def filter_ex_allocate_pool(location, caller_location):
        return location.binary == caller_location.binary and caller_location.symbol.name.startswith("ExAllocatePool")

    for symbol in server.ossi.symbols(r'^ExAllocatePool', binary_hint=r'ntoskrnl.exe'):
        retrieve_events_for_symbol(alloc_dict, AllocEvent, symbol, retrieve_alloc_info_with_size_as_second_argument,
                                   filter_ex_allocate_pool)

    # Search for deallocations with ExFreePool...
    def filter_ex_free_pool(location, caller_location):
        return location.binary == caller_location.binary and caller_location.symbol.name == "ExFreePool"

    for symbol in server.ossi.symbols(r'^ExFreePool', binary_hint=r'ntoskrnl.exe'):
        retrieve_events_for_symbol(alloc_dict, FreeEvent, symbol, retrieve_free_info_with_address_as_first_argument,
                                   filter_ex_free_pool)

if search_linux_malloc:
    def filter_in_realloc(location, caller_location):
        return (
            location.binary == caller_location.binary
            and (
                caller_location.symbol is not None
                and caller_location.symbol.name in ["realloc", "__GI___libc_realloc"]
            )
        )

    # Search for allocations with malloc
    for symbol in server.ossi.symbols(r'^((__GI___libc_malloc)|(__libc_malloc))$', binary_hint=r'libc-.*.so'):
        retrieve_events_for_symbol(alloc_dict, AllocEvent, symbol,
                                   retrieve_alloc_info_with_size_as_first_argument_linux, filter_in_realloc)

    # Search for allocations with calloc
    for symbol in server.ossi.symbols(r'^((__calloc)|(__libc_calloc))$', binary_hint=r'libc-.*.so'):
        retrieve_events_for_symbol(alloc_dict, AllocEvent, symbol, retrieve_alloc_info_for_calloc_linux)

    # Search for deallocations with free
    for symbol in server.ossi.symbols(r'^((__GI___libc_free)|(cfree))$', binary_hint=r'libc-.*.so'):
        retrieve_events_for_symbol(alloc_dict, FreeEvent, symbol,
                                   retrieve_free_info_with_address_as_first_argument_linux, filter_in_realloc)

    # Search for re-allocations with realloc
    for symbol in server.ossi.symbols(r'^((__GI___libc_realloc)|(realloc))$', binary_hint=r'libc-.*.so'):
        retrieve_events_for_symbol(alloc_dict, AllocEvent, symbol,
                                   retrieve_alloc_info_with_size_as_second_argument_linux)
        retrieve_events_for_symbol(alloc_dict, FreeEvent, symbol,
                                   retrieve_free_info_with_address_as_first_argument_linux)

# Sort the events per address and event type
for physical_address in alloc_dict.keys():
    alloc_dict[physical_address] = list(sorted(
        alloc_dict[physical_address],
        key=lambda event: (event.tr_begin.id, 0 if isinstance(event, FreeEvent) else 1)
    ))

# Sort the dict by address
alloc_dict = OrderedDict(sorted(alloc_dict.items()))


# %%
def get_alloc_free_pairs(events, errors=None):
    previous_event = None
    for event in events:
        if isinstance(event, AllocEvent):
            if previous_event is None:
                pass
            elif isinstance(previous_event, AllocEvent):
                if errors is not None:
                    errors.append("Two consecutives allocs found")
        elif isinstance(event, FreeEvent):
            if previous_event is None:
                continue
            elif isinstance(previous_event, FreeEvent):
                if errors is not None:
                    errors.append("Two consecutives frees found")
            elif isinstance(previous_event, AllocEvent):
                yield (previous_event, event)
        else:
            assert 0, ("Unknown event type: %s" % type(event))

        previous_event = event

    if isinstance(previous_event, AllocEvent):
        yield (previous_event, None)


# %%
# %%time

# Basic checks of the events

for physical_address, events in alloc_dict.items():
    for event in events:
        if not isinstance(event, AllocEvent) and not isinstance(event, FreeEvent):
            raise RuntimeError("Unknown event type: %s" % type(event))

    errors: List[str] = []
    for (alloc_event, free_event) in get_alloc_free_pairs(events, errors):
        # Check the uniformity of the logical address between the alloc and the free
        if free_event is not None and alloc_event.memory_range.logical_address != free_event.logical_address:
            errors.append(
                "Phys:0x%x: Alloc #%d - Free #%d with different logical address: 0x%x != 0x%x" % (
                    physical_address,
                    alloc_event.tr_begin.id,
                    free_event.tr_begin.id,
                    alloc_event.memory_range.logical_address,
                    free_event.logical_address
                )
            )

        # Check size of 0x0
        if alloc_event.memory_range.size == 0x0 or alloc_event.memory_range.size is None:
            if free_event is None:
                errors.append("Phys:0x%x: Alloc #%d - Free N/A with weird size %s" % (
                    physical_address, alloc_event.tr_begin.id, alloc_event.memory_range.size
                ))
            else:
                errors.append("Phys:0x%x: Alloc #%d - Free #%d with weird size %s" % (
                    physical_address, alloc_event.tr_begin.id, free_event.tr_begin.id, alloc_event.memory_range.size
                ))

    if len(errors) > 0:
        print("Phys:0x%x: Error(s) detected:" % (physical_address))
        for error in errors:
            print("    - %s" % error)

# %%
# Print the events

for physical_address, events in alloc_dict.items():
    print("Phys:0x%x" % (physical_address))
    print("    Events:")
    for event in events:
        if isinstance(event, AllocEvent):
            print("        - Alloc at #%d (0x%x of size 0x%x)" % (
                event.tr_begin.id, event.memory_range.logical_address, event.memory_range.size
            ))
        elif isinstance(event, FreeEvent):
            print("        - Free at #%d (0x%x)" % (event.tr_begin.id, event.logical_address))

    print("    Pairs:")
    for (alloc_event, free_event) in get_alloc_free_pairs(events):
        if free_event is None:
            print("    - Allocated at #%d (0x%x of size 0x%x) and freed at N/A" % (
                alloc_event.tr_begin.id, alloc_event.memory_range.logical_address, alloc_event.memory_range.size
            ))
        else:
            print("    - Allocated at #%d (0x%x of size 0x%x) and freed at #%d (0x%x)" % (
                alloc_event.tr_begin.id, alloc_event.memory_range.logical_address, alloc_event.memory_range.size,
                free_event.tr_begin.id, free_event.logical_address
            ))

    print()

# %%
# Setup capstone
md_64 = capstone.Cs(capstone.CS_ARCH_X86, capstone.CS_MODE_64)
md_64.detail = True

md_32 = capstone.Cs(capstone.CS_ARCH_X86, capstone.CS_MODE_32)
md_32.detail = True


# Retrieve a `bytes` array from the capstone operand
def get_mask_from_cs_op(mask_op):
    mask_formats = [
        None,
        'B',  # 1
        'H',  # 2
        None,
        'I',  # 4
        None,
        None,
        None,
        'Q',  # 8
    ]

    return struct.pack(
        mask_formats[mask_op.size],
        mask_op.imm if mask_op.imm >= 0 else ((1 << (mask_op.size * 8)) + mask_op.imm)
    )


# This function will return an array containing either `True` or `False` for each byte
# of the memory access to know which one should be considered for an UUM
def filter_and_bytes(cs_insn, mem_access):
    # `and` instruction could be used to set some bytes to 0 with an immediate mask
    # Bytes in the mask tell us what to do
    #  - with 0x00 we should consider the write and not the read
    #  - with 0xFF we should consider neither of them
    #  - with everything else we should consider the reads and the writes
    filtered_bytes = [False] * mem_access.size

    dest_op = cs_insn.operands[0]
    mask_op = cs_insn.operands[1]

    if dest_op.type != capstone.x86.X86_OP_MEM or mask_op.type != capstone.x86.X86_OP_IMM:
        return filtered_bytes

    mask = get_mask_from_cs_op(mask_op)

    for i in range(0, mask_op.size):
        if mask[i] == 0x00 and mem_access.operation == reven2.memhist.MemoryAccessOperation.Read:
            filtered_bytes[i] = True
        elif mask[i] == 0xFF:
            filtered_bytes[i] = True

    return filtered_bytes


# This function will return an array containing either `True` or `False` for each byte
# of the memory access to know which one should be considered for an UUM
def filter_or_bytes(cs_insn, mem_access):
    # `or` instruction could be used to set some bytes to 0 with an immediate mask
    # Bytes in the mask tell us what to do
    #  - with 0x00 we should consider neither of them
    #  - with 0xFF we should consider the write and not the read
    #  - with everything else we should consider the reads and the writes

    filtered_bytes = [False] * mem_access.size

    dest_op = cs_insn.operands[0]
    mask_op = cs_insn.operands[1]

    if dest_op.type != capstone.x86.X86_OP_MEM or mask_op.type != capstone.x86.X86_OP_IMM:
        return filtered_bytes

    mask = get_mask_from_cs_op(mask_op)

    for i in range(0, mask_op.size):
        if mask[i] == 0x00:
            filtered_bytes[i] = True
        elif mask[i] == 0xFF and mem_access.operation == reven2.memhist.MemoryAccessOperation.Read:
            filtered_bytes[i] = True

    return filtered_bytes


# This function will return an array containing either `True` or `False` for each byte
# of the memory access to know which one should be considered for an UUM.
# Only bytes whose index returns `False` will be considered for potential UUM
def filter_bts_bytes(cs_insn, mem_access):
    # `bts` instruction with an immediate only access one byte in the memory
    # but could be written with a bigger access (e.g `dword`)
    # We only consider the byte accessed by the `bts` instruction in this case
    filtered_bytes = [False] * mem_access.size

    dest_op = cs_insn.operands[0]
    bit_nb_op = cs_insn.operands[1]

    if dest_op.type != capstone.x86.X86_OP_MEM or bit_nb_op.type != capstone.x86.X86_OP_IMM:
        return filtered_bytes

    filtered_bytes = [True] * mem_access.size
    filtered_bytes[bit_nb_op.imm // 8] = False

    return filtered_bytes


# This function will return an array containing either `True` or `False` for each byte
# of the memory access to know which one should be considered for an UUM
def get_filtered_bytes(cs_insn, mem_access):
    if cs_insn.mnemonic in ["and", "lock and"]:
        return filter_and_bytes(cs_insn, mem_access)
    elif cs_insn.mnemonic in ["or", "lock or"]:
        return filter_or_bytes(cs_insn, mem_access)
    elif cs_insn.mnemonic in ["bts", "lock bts"]:
        return filter_bts_bytes(cs_insn, mem_access)

    return [False] * mem_access.size


class UUM:
    # This array contains the relation between the capstone flag and
    # the reven register to check
    test_eflags = {
        capstone.x86.X86_EFLAGS_TEST_OF: reven2.arch.x64.of,
        capstone.x86.X86_EFLAGS_TEST_SF: reven2.arch.x64.sf,
        capstone.x86.X86_EFLAGS_TEST_ZF: reven2.arch.x64.zf,
        capstone.x86.X86_EFLAGS_TEST_PF: reven2.arch.x64.pf,
        capstone.x86.X86_EFLAGS_TEST_CF: reven2.arch.x64.cf,
        capstone.x86.X86_EFLAGS_TEST_NT: reven2.arch.x64.nt,
        capstone.x86.X86_EFLAGS_TEST_DF: reven2.arch.x64.df,
        capstone.x86.X86_EFLAGS_TEST_RF: reven2.arch.x64.rf,
        capstone.x86.X86_EFLAGS_TEST_IF: reven2.arch.x64.if_,
        capstone.x86.X86_EFLAGS_TEST_TF: reven2.arch.x64.tf,
        capstone.x86.X86_EFLAGS_TEST_AF: reven2.arch.x64.af,
    }

    def __init__(self, alloc_event, free_event, memaccess, uum_bytes):
        self.alloc_event = alloc_event
        self.free_event = free_event
        self.memaccess = memaccess
        self.bytes = uum_bytes

        # Store conditionals depending on uninitialized memory
        # - 'transition': the transition
        # - 'reg': the flag which is uninitialized
        self.conditionals = None

    @property
    def nb_uum_bytes(self):
        return len(list(filter(lambda byte: byte, self.bytes)))

    def analyze_usage(self):
        # Initialize an array of what to taint based on the uninitialized bytes
        taint_tags = []
        for i in range(0, self.memaccess.size):
            if not self.bytes[i]:
                continue

            taint_tags.append(reven2.preview.taint.TaintedMemories(self.memaccess.physical_address + i, 1))

        # Start a taint of just the first instruction (the memory access)
        # We don't want to keep the memory tainted as if the memory is accessed later we
        # will have another UUM anyway. So we are using the state of this first taint
        # and we remove the initial tainted memory to start a new taint from after the first
        # instruction to the end of the trace
        tainter = reven2.preview.taint.Tainter(server.trace)
        taint = tainter.simple_taint(
            tag0=taint_tags,
            from_context=self.memaccess.transition.context_before(),
            to_context=self.memaccess.transition.context_after() + 1,
            is_forward=True
        )

        state_after_first_instruction = taint.state_at(self.memaccess.transition.context_after())

        # We assume that we won't have other tainted memories than the uninitialized memories
        # after the first instruction, so we can just keep the registers from
        # `state_after_first_instruction` and not the memories
        # In the future, we should keep the inverse of the intersection of the uninitialized memories
        # and the memories in the `state_after_first_instruction`
        taint = tainter.simple_taint(
            tag0=list(map(
                lambda x: x[0],
                state_after_first_instruction.tainted_registers()
            )),
            from_context=self.memaccess.transition.context_after(),
            to_context=None,
            is_forward=True
        )

        conditionals = []
        for access in taint.accesses(changes_only=False).all():
            ctx = access.transition.context_before()
            md = md_64 if ctx.is64b() else md_32
            cs_insn = next(md.disasm(access.transition.instruction.raw, access.transition.instruction.size))

            # Test conditional jump & move
            for flag, reg in self.test_eflags.items():
                if not cs_insn.eflags & flag:
                    continue

                if not UUM._is_register_tainted_in_taint_state(
                    taint.state_at(access.transition.context_after()),
                    reg
                ):
                    continue

                conditionals.append({
                    'transition': access.transition,
                    'flag': reg,
                })

        self.conditionals = conditionals

    def _is_register_tainted_in_taint_state(taint_state, reg):
        for tainted_reg, _ in taint_state.tainted_registers():
            if tainted_reg.register == reg:
                return True
        return False

    def __str__(self):
        desc = ""

        if self.free_event is None:
            desc += "Phys:0x%x: Allocated at #%d (0x%x of size 0x%x) and freed at N/A\n" % (
                self.alloc_event.memory_range.pages[0]['physical_address'],
                self.alloc_event.tr_begin.id, self.alloc_event.memory_range.logical_address,
                self.alloc_event.memory_range.size,
            )
            desc += "\tAlloc in: %s / %s\n\n" % (
                (self.alloc_event.tr_begin - 1).context_before().ossi.location(),
                (self.alloc_event.tr_begin - 1).context_before().ossi.process(),
            )
        else:
            desc += "Phys:0x%x: Allocated at #%d (0x%x of size 0x%x) ad freed at #%d (0x%x)\n" % (
                self.alloc_event.memory_range.pages[0]['physical_address'],
                self.alloc_event.tr_begin.id, self.alloc_event.memory_range.logical_address,
                self.alloc_event.memory_range.size,
                self.free_event.tr_begin.id, self.free_event.logical_address,
            )
            desc += "\tAlloc in: %s / %s\n" % (
                (self.alloc_event.tr_begin - 1).context_before().ossi.location(),
                (self.alloc_event.tr_begin - 1).context_before().ossi.process(),
            )
            desc += "\tFree in: %s / %s\n\n" % (
                (self.free_event.tr_begin - 1).context_before().ossi.location(),
                (self.free_event.tr_begin - 1).context_before().ossi.process(),
            )
        desc += "\tUUM of %d byte(s) first read at:\n" % self.nb_uum_bytes
        desc += "\t\t%s / %s\n" % (
            self.memaccess.transition.context_before().ossi.location(),
            self.memaccess.transition.context_before().ossi.process(),
        )
        desc += "\t\t%s" % (self.memaccess.transition)

        if self.conditionals is None:
            return desc
        elif len(self.conditionals) == 0:
            desc += "\n\n\tNot impacting the control flow"
            return desc

        desc += "\n\n\tThe control flow depends on uninitialized value(s):"

        conditionals = []
        for conditional in self.conditionals:
            conditional_str = "\n\t\tFlag '%s' depends on uninitialized memory\n" % conditional['flag'].name
            conditional_str += "\t\t%s / %s\n" % (
                conditional['transition'].context_before().ossi.location(),
                conditional['transition'].context_before().ossi.process()
            )
            conditional_str += "\t\t%s" % (conditional['transition'])

            conditionals.append(conditional_str)

        desc += "\n".join(conditionals)

        return desc


def analyze_one_memaccess(alloc_event, free_event, pages, pages_bytes_written, memaccess):
    if (
        memaccess.transition > alloc_event.tr_begin
        and memaccess.transition < alloc_event.tr_end
        and memaccess.operation == reven2.memhist.MemoryAccessOperation.Read
    ):
        # We assume that read accesses during the allocator are okay
        # as the allocator should know what it is doing
        return None

    ctx = memaccess.transition.context_before()
    md = md_64 if ctx.is64b() else md_32
    cs_insn = next(md.disasm(memaccess.transition.instruction.raw, memaccess.transition.instruction.size))

    filtered_bytes = get_filtered_bytes(cs_insn, memaccess)
    uum_bytes = [False] * memaccess.size

    for i in range(0, memaccess.size):
        if filtered_bytes[i]:
            continue

        possible_pages = list(filter(
            lambda page: (
                memaccess.physical_address.offset + i >= page['physical_address']
                and memaccess.physical_address.offset + i < page['physical_address'] + page['size']
            ),
            pages
        ))

        if len(possible_pages) > 1:
            # Should not be possible to have a byte in multiple pages
            raise AssertionError("Single byte access accross multiple pages")
        elif len(possible_pages) == 0:
            # Access partially outside the buffer
            continue

        phys_addr = possible_pages[0]['physical_address']
        byte_offset_in_page = memaccess.physical_address.offset + i - possible_pages[0]['physical_address']

        if memaccess.operation == reven2.memhist.MemoryAccessOperation.Read:
            byte_written = pages_bytes_written[phys_addr][byte_offset_in_page]

            if not byte_written:
                uum_bytes[i] = True

        elif memaccess.operation == reven2.memhist.MemoryAccessOperation.Write:
            pages_bytes_written[phys_addr][byte_offset_in_page] = True

    if any(uum_bytes):
        return UUM(alloc_event, free_event, memaccess, uum_bytes)

    return None


def uum_analyze_function(physical_address, alloc_events):
    uum_count = 0

    for (alloc_event, free_event) in get_alloc_free_pairs(alloc_events, errors):
        # We are trying to translate all the pages and will construct
        # an array of translated pages.
        # We don't check UUM on pages we couldn't translate
        alloc_event.memory_range.try_translate_all_pages(
            free_event.tr_begin.context_before()
            if free_event is not None else
            alloc_event.tr_end.context_before()
        )

        pages = list(filter(
            lambda page: page['physical_address'] is not None,
            alloc_event.memory_range.pages
        ))

        # An iterator of all the memory accesses of all the translated pages
        # The from is the start of the alloc and not the end of the alloc as in some
        # cases we want the accesses in it. For example a `calloc` will write the memory
        # during its execution. That's also why we are ignoring the read memory accesses
        # during the alloc function.
        mem_accesses = reven2.util.collate(map(
            lambda page: server.trace.memory_accesses(
                reven2.address.PhysicalAddress(page['physical_address']),
                page['size'],
                from_transition=alloc_event.tr_begin,
                to_transition=free_event.tr_begin if free_event is not None else None,
                is_forward=True,
                operation=None
            ),
            pages
        ), key=lambda access: access.transition)

        # This will contain for each page an array of booleans representing
        # if the byte have been written before or not
        pages_bytes_written = {}
        for page in pages:
            pages_bytes_written[page['physical_address']] = [False] * page['size']

        for memaccess in mem_accesses:
            if all([all(bytes_written) for bytes_written in pages_bytes_written.values()]):
                # All the bytes have been set in the memory
                # we no longer need to track the memory accesses
                break

            # Do we have a UUM on this memory access?
            uum = analyze_one_memaccess(alloc_event, free_event, pages, pages_bytes_written, memaccess)
            if uum is None:
                continue

            uum.analyze_usage()

            if only_display_uum_changing_control_flow and len(uum.conditionals) == 0:
                continue

            print(str(uum))
            print()

            uum_count += 1

    return uum_count


# %%
# %%time

count = 0

if faulty_physical_address is None:
    for physical_address, alloc_events in alloc_dict.items():
        count += uum_analyze_function(physical_address, alloc_events)
else:
    if faulty_physical_address not in alloc_dict:
        raise KeyError("The passed physical address was not detected during the allocation search")
    count += uum_analyze_function(faulty_physical_address, alloc_dict[faulty_physical_address])

print("---------------------------------------------------------------------------------")
begin_range = "the beginning of the trace" if from_tr is None else "#{}".format(to_tr)
end_range = "the end of the trace" if to_tr is None else "#{}".format(to_tr)
final_range = ("on the whole trace" if from_tr is None and to_tr is None else
               "between {} and {}".format(begin_range, end_range))

range_size = server.trace.transition_count
if from_tr is not None:
    range_size -= from_tr
if to_tr is not None:
    range_size -= server.trace.transition_count - to_tr

if faulty_physical_address is None:
    searched_memory_addresses = "with {} searched memory addresses".format(len(alloc_dict))
else:
    searched_memory_addresses = "on {:#x}".format(faulty_physical_address)

print("{} UUM(s) found {} ({} transitions) {}".format(
    count, final_range, range_size, searched_memory_addresses
))
print("---------------------------------------------------------------------------------")

Detect data race

This notebook checks for possible data race that may occur when using critical sections as the synchronization primitive.

Prerequisites

  • Supported versions:
  • REVEN 2.9+
  • This notebook should be run in a jupyter notebook server equipped with a REVEN 2 Python kernel. REVEN comes with a Jupyter notebook server accessible with the Open Python button in the Analyze page of any scenario.
  • The following resources are needed for the analyzed scenario:
  • Trace
  • OSSI
  • Memory History
  • Backtrace (Stack Events included)
  • Fast Search

Perimeter:

  • Windows 10 64-bit

Limits

  • Only support the following critical section API(s) as the lock/unlock operations.
  • RtlEnterCriticalSection, RtlTryEnterCriticalSection
  • RtlLeaveCriticalSection

Running

Fill out the parameters in the Parameters cell below, then run all the cells of this notebook.

Source

# ---
# jupyter:
#   jupytext:
#     formats: ipynb,py:percent
#     text_representation:
#       extension: .py
#       format_name: percent
#   kernelspec:
#     display_name: reven
#     language: python
#     name: reven-python3
# ---

# %% [markdown]
# # Detect data race
# This notebook checks for possible data race that may occur when using critical sections as the synchronization
# primitive.
#
# ## Prerequisites
#  - Supported versions:
#     - REVEN 2.9+
#  - This notebook should be run in a jupyter notebook server equipped with a REVEN 2 Python kernel.
# REVEN comes with a Jupyter notebook server accessible with the `Open Python` button in the `Analyze`
# page of any scenario.
#  - The following resources are needed for the analyzed scenario:
#     - Trace
#     - OSSI
#     - Memory History
#     - Backtrace (Stack Events included)
#     - Fast Search
#
# ## Perimeter:
#  - Windows 10 64-bit
#
# ## Limits
#  - Only support the following critical section API(s) as the lock/unlock operations.
#     - `RtlEnterCriticalSection`, `RtlTryEnterCriticalSection`
#     - `RtlLeaveCriticalSection`
#
#
# ## Running
# Fill out the parameters in the [Parameters cell](#Parameters) below, then run all the cells of this notebook.


# %%
from dataclasses import dataclass
from enum import Enum
from typing import Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union


# Reven specific
import reven2
from reven2.address import LinearAddress, LogicalAddress
from reven2.arch import x64
from reven2.memhist import MemoryAccess, MemoryAccessOperation
from reven2.ossi.ossi import Symbol
from reven2.trace import Context, Trace, Transition
from reven2.util import collate

# %% [markdown]
# # Parameters

# %%
# Host and port of the running the scenario.
host = '127.0.0.1'
port = 41309

# The PID and the name of the binary of interest (optional): if the binary name is given (i.e. not None), then only
# locks and unlocks called directly from the binary are counted.
pid = 2460
binary = None

begin_trans_id = None
end_trans_id = None

# Do not show common memory accesses (from different threads) which are synchronized (i.e. mutually excluded) by some
# critical section(s).
hide_synchronized_accesses = True

# Do not show common memory accesses (from different threads) which are dynamically free from critical section(s).
suppress_unknown_primitives = False


# %%
# Helper class which wraps Reven's runtime objects and give methods helping get information about calls to
# RtlEnterCriticalSection and RtlLeaveCriticalSection
class RuntimeHelper:
    def __init__(self, host: str, port: int):
        try:
            server = reven2.RevenServer(host, port)
        except RuntimeError:
            raise RuntimeError(f'Cannot connect to the scenario at {host}:{port}')

        self.trace = server.trace
        self.ossi = server.ossi
        self.search_symbol = self.trace.search.symbol
        self.search_binary = self.trace.search.binary

        # basic OSSI
        bin_symbol_names = {
            'c:/windows/system32/ntoskrnl.exe': {
                'KiSwapContext'
            },

            'c:/windows/system32/ntdll.dll': {
                # critical section
                'RtlInitializeCriticalSection',
                'RtlInitializeCriticalSectionEx',
                'RtlInitializeCriticalSectionAndSpinCount',
                'RtlEnterCriticalSection',
                'RtlTryEnterCriticalSection',
                'RtlLeaveCriticalSection',

                # initialize/shutdown thread
                'LdrpInitializeThread',
                'LdrShutdownThread'
            },

            'c:/windows/system32/basesrv.dll': {
                'BaseSrvCreateThread'
            },

            'c:/windows/system32/csrsrv.dll': {
                'CsrCreateThread',
                'CsrThreadRefcountZero',
                'CsrDereferenceThread'
            }
        }

        self.symbols: Dict[str, Optional[Symbol]] = {}
        for (bin, symbol_names) in bin_symbol_names.items():
            try:
                exec_bin = next(self.ossi.executed_binaries(f'^{bin}$'))
            except StopIteration:
                if bin == 'c:/windows/system32/ntoskrnl.exe':
                    raise RuntimeError(f'{bin} not found')
                exec_bin = None

            for name in symbol_names:
                if exec_bin is None:
                    self.symbols[name] = None
                    continue

                try:
                    sym = next(exec_bin.symbols(f'^{name}$'))
                    self.symbols[name] = sym
                except StopIteration:
                    msg = f'{name} not found in {bin}'

                    if name in {
                        'KiSwapContext',
                        'RtlEnterCriticalSection', 'RtlTryEnterCriticalSection', 'RtlLeaveCriticalSection'
                    }:
                        raise RuntimeError(msg)
                    else:
                        self.symbols[name] = None
                        print(f'Warning: {msg}')

        self.has_debugger = True
        try:
            self.trace.first_transition.step_over()
        except RuntimeError:
            print('Warning: the debugger interface is not available, so the script cannot determine \
function return values.\nMake sure the stack events and PC range resources are replayed for this scenario.')
            self.has_debugger = False

    def get_memory_accesses(self, from_context: Context, to_context: Context) -> Iterator[MemoryAccess]:
        try:
            from_trans = from_context.transition_after()
            to_trans = to_context.transition_before()
        except IndexError:
            return

        accesses = self.trace.memory_accesses(from_transition=from_trans, to_transition=to_trans)
        for access in accesses:
            # Skip the access without virtual address
            if access.virtual_address is None:
                continue

            # Skip the access without instruction
            if access.transition.instruction is None:
                continue

            # Skip the access of `lock` prefixed instruction
            ins_bytes = access.transition.instruction.raw
            if ins_bytes[0] == 0xf0:
                continue

            yield access

    @staticmethod
    def get_lock_handle(ctxt: Context) -> int:
        return ctxt.read(x64.rcx)

    @staticmethod
    def thread_id(ctxt: Context) -> int:
        return ctxt.read(LogicalAddress(0x48, x64.gs), 4)

    @staticmethod
    def is_kernel_mode(ctxt: Context) -> bool:
        return ctxt.read(x64.cs) & 0x3 == 0


# %%
# Look for a possible first execution context of a binary
# Find the lower/upper bound of contexts on which the deadlock detection processes
def find_begin_end_context(sco: RuntimeHelper, pid: int, binary: Optional[str],
                           begin_id: Optional[int], end_id: Optional[int]) -> Tuple[Context, Context]:
    begin_ctxt = None
    if begin_id is not None:
        try:
            begin_trans = sco.trace.transition(begin_id)
            begin_ctxt = begin_trans.context_after()
        except IndexError:
            begin_ctxt = None

    if begin_ctxt is None:
        if binary is not None:
            for name in sco.ossi.executed_binaries(binary):
                for ctxt in sco.search_binary(name):
                    ctx_process = ctxt.ossi.process()
                    assert ctx_process is not None
                    if ctx_process.pid == pid:
                        begin_ctxt = ctxt
                        break
                if begin_ctxt is not None:
                    break

    if begin_ctxt is None:
        begin_ctxt = sco.trace.first_context

    end_ctxt = None
    if end_id is not None:
        try:
            end_trans = sco.trace.transition(end_id)
            end_ctxt = end_trans.context_before()
        except IndexError:
            end_ctxt = None

    if end_ctxt is None:
        end_ctxt = sco.trace.last_context

    if (end_ctxt <= begin_ctxt):
        raise RuntimeError("The begin transition must be smaller than the end.")

    return (begin_ctxt, end_ctxt)


# Get all execution contexts of a given process
def find_process_ranges(sco: RuntimeHelper, pid: int, first_context: Context, last_context: Context) \
        -> Iterator[Tuple[Context, Optional[Context]]]:
    if last_context <= first_context:
        return iter(())

    ctxt_low = first_context
    assert sco.symbols['KiSwapContext'] is not None
    ki_swap_context = sco.symbols['KiSwapContext']
    for ctxt in sco.search_symbol(ki_swap_context, from_context=first_context,
                                  to_context=None if last_context == sco.trace.last_context else last_context):
        ctx_process = ctxt_low.ossi.process()
        assert ctx_process is not None

        if ctx_process.pid == pid:
            yield (ctxt_low, ctxt)
        ctxt_low = ctxt

    ctx_process = ctxt_low.ossi.process()
    assert ctx_process is not None

    if ctx_process.pid == pid:
        if ctxt_low < last_context:
            yield (ctxt_low, last_context)
        else:
            # So ctxt_low == last_context, using None for the upper bound.
            # This happens only when last_context is in the process, and is also the last context of the trace.
            yield (ctxt_low, None)


# Start from a transition, return the first transition that is not a non-instruction, or None if there isn't one.
def ignore_non_instructions(trans: Transition, trace: Trace) -> Optional[Transition]:
    while trans.instruction is None:
        if trans == trace.last_transition:
            return None
        trans = trans + 1

    return trans


# Extract user mode only context ranges from a context range (which may include also kernel mode ranges)
def find_usermode_ranges(sco: RuntimeHelper, ctxt_low: Context, ctxt_high: Optional[Context]) \
        -> Iterator[Tuple[Context, Context]]:
    if ctxt_high is None:
        return

    trans = ignore_non_instructions(ctxt_low.transition_after(), sco.trace)
    if trans is None:
        return
    ctxt_current = trans.context_before()

    while ctxt_current < ctxt_high:
        ctxt_next = ctxt_current.find_register_change(x64.cs, is_forward=True)

        if not RuntimeHelper.is_kernel_mode(ctxt_current):
            if ctxt_next is None or ctxt_next > ctxt_high:
                yield (ctxt_current, ctxt_high)
                break
            else:
                # It's safe to decrease ctxt_next by 1 because it was obtained from a forward find_register_change
                yield (ctxt_current, ctxt_next - 1)

        if ctxt_next is None:
            break

        ctxt_current = ctxt_next


# Get user mode only execution contexts of a given process
def find_process_usermode_ranges(trace: RuntimeHelper, pid: int, first_ctxt: Context, last_ctxt: Context) \
        -> Iterator[Tuple[Context, Context]]:
    for (ctxt_low, ctxt_high) in find_process_ranges(trace, pid, first_ctxt, last_ctxt):
        usermode_ranges = find_usermode_ranges(trace, ctxt_low, ctxt_high)
        for usermode_range in usermode_ranges:
            yield usermode_range


def build_ordered_api_calls(sco: RuntimeHelper, binary: Optional[str],
                            ctxt_low: Context, ctxt_high: Context, apis: List[str]) -> Iterator[Tuple[str, Context]]:
    def gen(api):
        return (
            (api, ctxt)
            for ctxt in sco.search_symbol(sco.symbols[api], from_context=ctxt_low,
                                          to_context=None if ctxt_high == sco.trace.last_context else ctxt_high)
        )

    api_contexts = (
        gen(api)
        for api in apis if api in sco.symbols
    )

    if binary is None:
        for api_ctxt in collate(api_contexts, key=lambda name_ctxt: name_ctxt[1]):
            yield api_ctxt
    else:
        for (api, ctxt) in collate(api_contexts, key=lambda name_ctxt: name_ctxt[1]):
            try:
                caller_ctxt = ctxt - 1
            except IndexError:
                continue

            caller_location = caller_ctxt.ossi.location()
            if caller_location is None:
                continue

            caller_binary = caller_location.binary
            if caller_binary is None:
                continue

            if binary in [caller_binary.name, caller_binary.filename, caller_binary.path]:
                yield (api, ctxt)


def get_return_value(sco: RuntimeHelper, ctxt: Context) -> Optional[int]:
    if not sco.has_debugger:
        return None

    try:
        trans_after = ctxt.transition_after()
    except IndexError:
        return None

    if trans_after is None:
        return None

    trans_ret = trans_after.step_out()
    if trans_ret is None:
        return None

    ctxt_ret = trans_ret.context_after()
    return ctxt_ret.read(x64.rax)


class SynchronizationAction(Enum):
    LOCK = 1
    UNLOCK = 0


def get_locks_unlocks(sco: RuntimeHelper, binary: Optional[str], ranges: List[Tuple[Context, Context]]) \
        -> List[Tuple[SynchronizationAction, Context]]:

    lock_unlock_apis = {
        'RtlEnterCriticalSection': SynchronizationAction.LOCK,
        'RtlLeaveCriticalSection': SynchronizationAction.UNLOCK,
        'RtlTryEnterCriticalSection': None,
    }

    critical_section_actions = []

    for ctxt_low, ctxt_high in ranges:
        for name, ctxt in build_ordered_api_calls(sco, binary, ctxt_low, ctxt_high, list(lock_unlock_apis.keys())):
            # either lock, unlock, or none
            action = lock_unlock_apis[name]

            if action is None:
                # need to check the return value of the API to get the action
                api_ret_val = get_return_value(sco, ctxt)
                if api_ret_val is None:
                    print(f'Warning: failed to get the return value, {name} is omitted')
                    continue
                else:
                    # RtlTryEnterCriticalSection: the return value is nonzero if the thread is success
                    # to enter the critical section
                    if api_ret_val != 0:
                        action = SynchronizationAction.LOCK
                    else:
                        continue

            critical_section_actions.append((action, ctxt))

    return critical_section_actions


# Get locks which are effective at a transition
def get_live_locks(sco: RuntimeHelper,
                   locks_unlocks: List[Tuple[SynchronizationAction, Context]], trans: Transition) -> List[Context]:
    protection_locks: List[Context] = []
    trans_ctxt = trans.context_before()

    for (action, ctxt) in locks_unlocks:
        if ctxt > trans_ctxt:
            return protection_locks

        if action == SynchronizationAction.LOCK:
            protection_locks.append(ctxt)
            continue

        unlock_handle = RuntimeHelper.get_lock_handle(ctxt)
        # look for the lastest corresponding lock
        for (idx, lock_ctxt) in reversed(list(enumerate(protection_locks))):
            lock_handle = RuntimeHelper.get_lock_handle(lock_ctxt)
            if lock_handle == unlock_handle:
                del protection_locks[idx]
                break

    return protection_locks


# %%
AccessType = TypeVar("AccessType")


@dataclass
class MemorySegment:
    address: Union[LinearAddress, LogicalAddress]
    size: int


@dataclass
class MemorySegmentAccess(Generic[AccessType]):
    segment: MemorySegment
    accesses: List[AccessType]


# Insert a segment into a current list of segments, start at position; if the position is None then start from the
# head of the list.
# Precondition: `new_segment.segment.size > 0`
# Return the position where the segment is inserted.
# The complexity of the insertion is about O(n).
def insert_memory_segment_access(new_segment: MemorySegmentAccess, segments: List[MemorySegmentAccess],
                                 position: Optional[int]) -> int:
    new_address = new_segment.segment.address
    new_size = new_segment.segment.size
    new_accesses = new_segment.accesses

    if not segments:
        segments.append(MemorySegmentAccess(MemorySegment(new_address, new_size), new_accesses))
        return 0

    if position is None:
        position = 0

    index = position
    first_found_index = None
    while index < len(segments):
        # Loop invariant `new_size > 0`
        #  - True at the first iteration
        #  - In the subsequent iterations, either
        #     - the invariant is always kept through cases (2), (5), (6), (7)
        #     - the loop returns directly in cases (1), (3), (4)

        address = segments[index].segment.address
        size = segments[index].segment.size
        accesses = segments[index].accesses

        if new_address < address:
            # Case 1
            #                   |--------------|
            # |--------------|
            if new_address + new_size <= address:
                segments.insert(index,
                                MemorySegmentAccess(MemorySegment(new_address, new_size), new_accesses))

                return first_found_index if first_found_index is not None else index
            # Case 2
            #              |--------------|
            # |--------------|
            else:
                # Insert the different part of the new segment
                segments.insert(index,
                                MemorySegmentAccess(
                                    MemorySegment(new_address, address.offset - new_address.offset), new_accesses))

                if first_found_index is None:
                    first_found_index = index

                # The common part will be handled in the next iteration by either case (3), (4), or (5)
                #
                # Since
                #  - `new_address + new_size > address` (else condition of case 2), so
                #  - `new_size  > address.offset - new_address.offset`
                # then invariant `new_size > 0`
                new_size -= (address.offset - new_address.offset)
                new_address = address
                index += 1
        elif new_address == address:
            # case 3
            # |--------------|
            # |--------|
            if new_address + new_size < address + size:
                segments.insert(index,
                                MemorySegmentAccess(MemorySegment(new_address, new_size), accesses + new_accesses))
                segments[index + 1] = MemorySegmentAccess(
                    MemorySegment(new_address + new_size, size - new_size), accesses)

                return first_found_index if first_found_index is not None else index
            # case 4
            # |--------------|
            # |--------------|
            elif new_address + new_size == address + size:
                segments[index] = MemorySegmentAccess(MemorySegment(new_address, new_size), accesses + new_accesses)

                return first_found_index if first_found_index is not None else index
            # case 5
            # |--------------|
            # |------------------|
            # new_address + new_size > address + size
            else:
                # Current segment's accesses are augmented by new segment's accesses
                segments[index] = MemorySegmentAccess(MemorySegment(address, size), accesses + new_accesses)

                if first_found_index is None:
                    first_found_index = index

                # The different part of the new segment will be handled in the next iteration
                #
                # Since:
                #  - `new_address == address` and `new_address + new_size > address + size`, so
                #  - `new_size > size`
                # then invariant `new_size > 0`
                new_address = address + size
                new_size = new_size - size
                index += 1
        # new_address > address
        else:
            # case 6
            # |--------------|
            #                    |-----------|
            if new_address >= address + size:
                index += 1
            # case 7
            # |--------------|
            #            |-----------|
            else:
                # Split the current segment into:
                #  - the different part
                segments[index] = MemorySegmentAccess(
                    MemorySegment(address, new_address.offset - address.offset), accesses)
                #  - the common part
                segments.insert(index + 1, MemorySegmentAccess(
                    MemorySegment(new_address, address.offset + size - new_address.offset), accesses))

                # The `new_segment` will be handled in the next iteration by either case (3), (4) or (5)
                index += 1

    segments.append(MemorySegmentAccess(MemorySegment(new_address, new_size), new_accesses))
    index = len(segments) - 1

    return first_found_index if first_found_index is not None else index


ThreadId = int


# Get the common memory accesses of two threads of the same process.
# Preconditions on both parameters:
#  1. The list of MemorySegmentAccess is sorted by address and size
#  2. For each MemorySegmentAccess access, `access.segment.size > 0`
# Each thread has a sorted (by address and size) memory segment access list; segments are not empty (i.e `size > 0`).
def get_common_memory_accesses(
    first_thread_segment_accesses: Tuple[ThreadId,
                                         List[MemorySegmentAccess[Tuple[Transition, MemoryAccessOperation]]]],
    second_thread_segment_accesses: Tuple[ThreadId,
                                          List[MemorySegmentAccess[Tuple[Transition, MemoryAccessOperation]]]]) \
        -> List[Tuple[MemorySegment,
                Tuple[List[Tuple[Transition, MemoryAccessOperation]],
                      List[Tuple[Transition, MemoryAccessOperation]]]]]:

    (first_thread_id, first_segment_accesses) = first_thread_segment_accesses
    (second_thread_id, second_segment_accesses) = second_thread_segment_accesses

    # Merge these lists into a new list
    merged_accesses: List[MemorySegmentAccess[Tuple[ThreadId, Transition, MemoryAccessOperation]]] = []
    i = j = 0
    while i < len(first_segment_accesses) and j < len(second_segment_accesses):
        (first_segment, first_accesses) = (first_segment_accesses[i].segment, first_segment_accesses[i].accesses)
        (second_segment, second_accesses) = (second_segment_accesses[j].segment, second_segment_accesses[j].accesses)

        first_threaded_accesses = [
            (first_thread_id, trans, mem_operation) for (trans, mem_operation) in first_accesses]
        second_threaded_accesses = [
            (second_thread_id, trans, mem_operation) for (trans, mem_operation) in second_accesses]

        (first_address, first_size) = (first_segment.address, first_segment.size)
        (second_address, second_size) = (second_segment.address, second_segment.size)

        if (first_address, first_size) < (second_address, second_size):
            merged_accesses.append(MemorySegmentAccess(first_segment, first_threaded_accesses))
            i += 1
        elif (first_address, first_size) > (second_address, second_size):
            merged_accesses.append(MemorySegmentAccess(second_segment, second_threaded_accesses))
            j += 1
        else:
            merged_accesses.append(MemorySegmentAccess(
                first_segment, first_threaded_accesses + second_threaded_accesses))
            i += 1
            j += 1

    while i < len(first_segment_accesses):
        (first_segment, first_accesses) = (first_segment_accesses[i].segment, first_segment_accesses[i].accesses)
        first_threaded_accesses = [
            (first_thread_id, trans, mem_operation) for (trans, mem_operation) in first_accesses]
        merged_accesses.append(MemorySegmentAccess(first_segment, first_threaded_accesses))
        i += 1

    while j < len(second_segment_accesses):
        (second_segment, second_accesses) = (second_segment_accesses[j].segment, second_segment_accesses[j].accesses)
        second_threaded_accesses = [
            (second_thread_id, trans, mem_operation) for (trans, mem_operation) in second_accesses]
        merged_accesses.append(MemorySegmentAccess(second_segment, second_threaded_accesses))
        j += 1

    # The merged list needs to be segmented again to handle the case of overlapping segments coming from different
    # threads.
    # We start from an empty list `refined_accesses`, and gradually insert new segment into it.
    #
    # The list merged_accesses is sorted already, then the position where a segment is inserted will be always larger
    # or equal the inserting position of the previous segment. We can use the inserting position of an segment as
    # the starting position when looking for the position to insert the next segment.
    #
    # Though the complexity of `insert_memory_segment_access` is O(n) in average case, the insertion in the loop
    # below happens mostly at the end of the list (with the complexity O(1)). The complexity of the loop is still O(n).
    refined_accesses: List[MemorySegmentAccess[Tuple[ThreadId, Transition, MemoryAccessOperation]]] = []
    last_inserted_index = None
    for refined_access in merged_accesses:
        last_inserted_index = insert_memory_segment_access(refined_access, refined_accesses, last_inserted_index)

    common_accesses = []
    for refined_access in refined_accesses:
        first_thread_rws = []
        second_thread_rws = []

        for (thread_id, transition, operation) in refined_access.accesses:
            if thread_id == first_thread_id:
                first_thread_rws.append((transition, operation))
            else:
                second_thread_rws.append((transition, operation))

        if first_thread_rws and second_thread_rws:
            common_accesses.append((refined_access.segment, (first_thread_rws, second_thread_rws)))

    return common_accesses


# Return True if the transition should be excluded by the caller for data race detection.
def transition_excluded_by_heuristics(sco: RuntimeHelper, trans: Transition, blacklist: Set[int]) -> bool:
    if trans.id in blacklist:
        return True

    # Memory accesses of the first transition are considered non-protected
    try:
        trans_ctxt = trans.context_before()
    except IndexError:
        return False

    trans_location = trans_ctxt.ossi.location()
    if trans_location is None:
        return False

    if trans_location.binary is None:
        return False

    trans_binary_path = trans_location.binary.path
    process_thread_core_dlls = [
        'c:/windows/system32/ntdll.dll', 'c:/windows/system32/csrsrv.dll', 'c:/windows/system32/basesrv.dll'
    ]
    if trans_binary_path not in process_thread_core_dlls:
        return False

    if trans_location.symbol is None:
        return False

    symbols = sco.symbols

    # Accesses of the synchronization APIs themselves are not counted.
    thread_sync_apis = [
        symbols[name] for name in [
            'RtlInitializeCriticalSection', 'RtlInitializeCriticalSectionEx',
            'RtlInitializeCriticalSectionAndSpinCount',
            'RtlEnterCriticalSection', 'RtlTryEnterCriticalSection', 'RtlLeaveCriticalSection'
        ]
        if symbols[name] is not None
    ]
    if trans_location.symbol in thread_sync_apis:
        blacklist.add(trans.id)
        return True

    trans_frames = trans_ctxt.stack.frames()

    # Accesses of thread create/shutdown are not counted
    thread_create_apis = [
        api for api in [
            symbols['LdrpInitializeThread'], symbols['BaseSrvCreateThread'], symbols['CsrCreateThread']
        ]
        if api is not None
    ]
    thread_shutdown_apis = [
        api for api in [
            symbols['CsrDereferenceThread'], symbols['CsrThreadRefcountZero'], symbols['LdrShutdownThread']
        ]
        if api is not None
    ]
    for frame in trans_frames:
        frame_ctxt = frame.first_context
        frame_location = frame_ctxt.ossi.location()

        if frame_location is None:
            continue

        if frame_location.binary is None:
            continue

        if frame_location.binary.path not in process_thread_core_dlls:
            continue

        if frame_location.symbol is None:
            continue

        if (
            frame_location.symbol in thread_create_apis
            or frame_location.symbol in thread_shutdown_apis
            or frame_location.symbol in thread_sync_apis
        ):
            blacklist.add(trans.id)
            return True

    return False


def get_threads_locks_unlocks(trace: RuntimeHelper, binary: Optional[str],
                              process_ranges: List[Tuple[Context, Context]]) \
        -> Dict[ThreadId, List[Tuple[SynchronizationAction, Context]]]:

    threads_ranges: Dict[int, List[Tuple[Context, Context]]] = {}
    for (ctxt_lo, ctxt_hi) in process_ranges:
        tid = RuntimeHelper.thread_id(ctxt_lo)
        if tid not in threads_ranges:
            threads_ranges[tid] = []

        threads_ranges[tid].append((ctxt_lo, ctxt_hi))

    threads_locks_unlocks: Dict[ThreadId, List[Tuple[SynchronizationAction, Context]]] = {}
    for tid, ranges in threads_ranges.items():
        threads_locks_unlocks[tid] = get_locks_unlocks(trace, binary, ranges)

    return threads_locks_unlocks


# Get all segmented memory accesses of each thread given a list of context ranges.
# Return a map: thread_id -> list(((mem_addr, mem_size), list((thread_id, transition, mem_operation))))
def get_threads_segmented_memory_accesses(trace: RuntimeHelper, process_ranges: List[Tuple[Context, Context]]) \
        -> Dict[ThreadId, List[MemorySegmentAccess[Tuple[Transition, MemoryAccessOperation]]]]:
    sorted_threads_segmented_memory_accesses: Dict[
        ThreadId,
        List[MemorySegmentAccess[Tuple[Transition, MemoryAccessOperation]]]
    ] = {}
    for (ctxt_lo, ctxt_hi) in process_ranges:
        tid = RuntimeHelper.thread_id(ctxt_lo)
        if tid not in sorted_threads_segmented_memory_accesses:
            sorted_threads_segmented_memory_accesses[tid] = []

        for access in trace.get_memory_accesses(ctxt_lo, ctxt_hi):
            access_virtual_address = access.virtual_address
            if access_virtual_address is None:
                # DMA access
                continue

            sorted_threads_segmented_memory_accesses[tid].append(MemorySegmentAccess(
                MemorySegment(access_virtual_address, access.size),
                [(access.transition, access.operation)]))

    # The memory segment accesses of each thread are sorted by address and size of segments, that will help to improve
    # the performance of looking for common accesses of two threads.
    # The complexity of sorting is O(n * logn)
    # Note that segments can be still overlapped.
    for tid in sorted_threads_segmented_memory_accesses.keys():
        sorted_threads_segmented_memory_accesses[tid].sort(key=lambda x: (x.segment.address, x.segment.size))

    threads_segmented_memory_accesses: Dict[
        ThreadId,
        List[MemorySegmentAccess[Tuple[Transition, MemoryAccessOperation]]]
    ] = {}
    # The non-overlapped memory segment access list of each thread is built by:
    #   - start from an empty list
    #   - gradually insert segment into the list using `insert_memory_segment_access`
    # Since the original segment lists are sorted, the complexity of the build is O(n). The total complexity of
    # constructing the non-overlapped memory segment accesses of each thread is 0(n * logn) then.
    for tid in sorted_threads_segmented_memory_accesses.keys():
        threads_segmented_memory_accesses[tid] = []
        last_mem_acc_index = None
        for seg_mem_acc in sorted_threads_segmented_memory_accesses[tid]:
            last_mem_acc_index = insert_memory_segment_access(
                seg_mem_acc, threads_segmented_memory_accesses[tid], last_mem_acc_index
            )

    return threads_segmented_memory_accesses


TransitionId = int
InstructionPointer = int
LockHandle = int


def detect_data_race(trace: RuntimeHelper, pid: int, binary: Optional[str],
                     begin_transition_id: Optional[TransitionId], end_transition_id: Optional[TransitionId]):

    def handlers_to_string(handlers: Iterable[int]):
        msgs = [f'{handle:#x}' for handle in handlers]
        message = ', '.join(msgs)
        return message

    def get_live_lock_handles(
        sco: RuntimeHelper, locks_unlocks: List[Tuple[SynchronizationAction, Context]], trans: Transition
    ) -> Tuple[List[Context], Set[LockHandle]]:
        locks = get_live_locks(sco, locks_unlocks, trans)
        return (locks, {RuntimeHelper.get_lock_handle(ctxt) for ctxt in locks})

    (first_context, last_context) = find_begin_end_context(trace, pid, binary, begin_transition_id, end_transition_id)
    process_ranges = list(find_process_usermode_ranges(trace, pid, first_context, last_context))

    threads_memory_accesses = get_threads_segmented_memory_accesses(trace, process_ranges)
    threads_locks_unlocks = get_threads_locks_unlocks(trace, binary, process_ranges)

    thread_ids: Set[ThreadId] = set()

    # map from a transition (of a given thread) to a tuple whose
    #  - the first is a list of contexts of calls (e.g. RtlEnterCriticalSection) to lock
    #  - the second is a set of handles used by these calls
    cached_live_locks_handles: Dict[Tuple[ThreadId, TransitionId], Tuple[List[Context], Set[LockHandle]]] = {}

    cached_pcs: Dict[TransitionId, InstructionPointer] = {}

    filtered_trans_ids: Set[TransitionId] = set()

    for first_tid, first_thread_accesses in threads_memory_accesses.items():
        thread_ids.add(first_tid)
        for second_tid, second_thread_accesses in threads_memory_accesses.items():
            if second_tid in thread_ids:
                continue

            common_accesses = get_common_memory_accesses(
                (first_tid, first_thread_accesses), (second_tid, second_thread_accesses)
            )

            for (segment, (first_accesses, second_accesses)) in common_accesses:
                (segment_address, segment_size) = (segment.address, segment.size)

                race_pcs: Set[Tuple[InstructionPointer, InstructionPointer]] = set()

                for (first_transition, first_rw) in first_accesses:
                    if first_transition.id in cached_pcs:
                        first_pc = cached_pcs[first_transition.id]
                    else:
                        first_pc = first_transition.context_before().read(x64.rip)

                    if (first_tid, first_transition.id) not in cached_live_locks_handles:
                        critical_section_locks_unlocks = threads_locks_unlocks[first_tid]

                        cached_live_locks_handles[(first_tid, first_transition.id)] = get_live_lock_handles(
                            trace, critical_section_locks_unlocks, first_transition)

                    for (second_transition, second_rw) in second_accesses:
                        if second_transition.id in cached_pcs:
                            second_pc = cached_pcs[second_transition.id]
                        else:
                            second_pc = second_transition.context_before().read(x64.rip)
                            cached_pcs[second_transition.id] = second_pc

                        # the execution of an instruction is considered atomic
                        if first_pc == second_pc:
                            continue

                        if (second_tid, second_transition.id) not in cached_live_locks_handles:
                            critical_section_locks_unlocks = threads_locks_unlocks[second_tid]

                            cached_live_locks_handles[(second_tid, second_transition.id)] = get_live_lock_handles(
                                trace, critical_section_locks_unlocks, second_transition)

                        # Skip if both are the same operation (i.e. read/read or write/write)
                        if first_rw == second_rw:
                            continue

                        first_is_write = first_rw == MemoryAccessOperation.Write
                        second_is_write = second_rw == MemoryAccessOperation.Write

                        if first_transition < second_transition:
                            before_mem_opr_str = 'write' if first_is_write else 'read'
                            after_mem_opr_str = 'write' if second_is_write else 'read'

                            before_trans = first_transition
                            after_trans = second_transition

                            before_tid = first_tid
                            after_tid = second_tid

                        else:
                            before_mem_opr_str = 'write' if second_is_write else 'read'
                            after_mem_opr_str = 'write' if first_is_write else 'read'

                            before_trans = second_transition
                            after_trans = first_transition

                            before_tid = second_tid
                            after_tid = first_tid

                        # Duplicated
                        if (first_pc, second_pc) in race_pcs:
                            continue

                        race_pcs.add((first_pc, second_pc))

                        # ===== No data race

                        message = f'No data race during {after_mem_opr_str} at \
[{segment_address.offset:#x}, {segment_size}] (transition #{after_trans.id}) by thread {after_tid}\n\
with a previous {before_mem_opr_str} (transition #{before_trans.id}) by thread {before_tid}'

                        # There is a common critical section synchronizing the accesses
                        live_critical_sections: Dict[ThreadId, Set[LockHandle]] = {
                            first_tid: cached_live_locks_handles[(first_tid, first_transition.id)][1],
                            second_tid: cached_live_locks_handles[(second_tid, second_transition.id)][1]
                        }

                        common_critical_sections = live_critical_sections[first_tid].intersection(
                            live_critical_sections[second_tid]
                        )
                        if common_critical_sections:
                            if not hide_synchronized_accesses:
                                handlers_str = handlers_to_string(common_critical_sections)
                                print(f'{message}: synchronized by critical section(s) ({handlers_str}).\n')
                            continue

                        # Accesses are excluded since they are in thread create/shutdown
                        are_excluded = transition_excluded_by_heuristics(
                            trace, first_transition, filtered_trans_ids
                        ) and transition_excluded_by_heuristics(trace, second_transition, filtered_trans_ids)
                        if are_excluded:
                            if not hide_synchronized_accesses:
                                print(f'{message}: shared accesses in create/shutdown threads')
                            continue

                        # ===== Data race
                        # If there is no data race, then show the un-synchronized accesses in the following order
                        #   1. one of them are protected by one or several locks
                        #   2. both of them are not protected by any lock

                        message = f'Possible data race during {after_mem_opr_str} at \
[{segment_address.offset:#x}, {segment_size}] (transition #{after_trans.id}) by thread {after_tid}\n\
with a previous {before_mem_opr_str} (transition #{before_trans.id}) by thread {before_tid}'

                        if live_critical_sections[first_tid] or live_critical_sections[second_tid]:
                            print(f'{message}.')
                            for tid in {first_tid, second_tid}:
                                if live_critical_sections[tid]:
                                    handlers_str = handlers_to_string(live_critical_sections[tid])
                                    print(f'\tCritical section handles(s) used by {tid}: {handlers_str}\n')
                        elif not suppress_unknown_primitives:
                            print(f'{message}: no critical section used.\n')


# %%
trace = RuntimeHelper(host, port)
detect_data_race(trace, pid, binary, begin_trans_id, end_trans_id)

Use-after-Free vulnerabilities detection

This notebook allows to search for potential Use-after-Free vulnerabilities in a REVEN trace.

Prerequisites

  • This notebook should be run in a jupyter notebook server equipped with a REVEN 2 python kernel. REVEN comes with a jupyter notebook server accessible with the Open Python button in the Analyze page of any scenario.
  • This notebook depends on capstone being installed in the REVEN 2 python kernel. To install capstone in the current environment, please execute the capstone cell of this notebook.
  • This notebook requires the Memory History resource for your target scenario.

Running the notebook

Fill out the parameters in the Parameters cell below, then run all the cells of this notebook.

Note

Although the script is designed to limit false positive results, you may still encounter a few ones. Please check the results and feel free to report any issue, be it a false positive or a false negative ;-).

Source

# -*- coding: utf-8 -*-
# ---
# jupyter:
#   jupytext:
#     formats: ipynb,py:percent
#     text_representation:
#       extension: .py
#       format_name: percent
#   kernelspec:
#     display_name: reven
#     language: python
#     name: reven-python3
# ---

# %% [markdown]
# # Use-after-Free vulnerabilities detection
#
# This notebook allows to search for potential Use-after-Free vulnerabilities in a REVEN trace.
#
# ## Prerequisites
#
# - This notebook should be run in a jupyter notebook server equipped with a REVEN 2 python kernel.
#   REVEN comes with a jupyter notebook server accessible with the `Open Python` button in the `Analyze` page of any
#   scenario.
#
# - This notebook depends on `capstone` being installed in the REVEN 2 python kernel.
#   To install capstone in the current environment, please execute the [capstone cell](#Capstone-Installation) of this
#   notebook.
#
# - This notebook requires the Memory History resource for your target scenario.
#
#
# ## Running the notebook
#
# Fill out the parameters in the [Parameters cell](#Parameters) below, then run all the cells of this notebook.
#
#
# ## Note
#
# Although the script is designed to limit false positive results, you may still encounter a few ones. Please check
# the results and feel free to report any issue, be it a false positive or a false negative ;-).

# %% [markdown]
# # Capstone Installation
#
# Check for capstone's presence. If missing, attempt to get it from pip

# %%
try:
    import capstone
    print("capstone already installed")
except ImportError:
    print("Could not find capstone, attempting to install it from pip")
    import sys
    import subprocess

    command = [f"{sys.executable}", "-m", "pip", "install", "capstone"]
    p = subprocess.run(command)

    if int(p.returncode) != 0:
        raise RuntimeError("Error installing capstone")

    import capstone  # noqa
    print("Successfully installed capstone")

# %% [markdown]
# # Parameters

# %%
# Server connection

# Host of the REVEN server running the scenario.
# When running this notebook from the Project Manager, '127.0.0.1' should be the correct value.
reven_backend_host = '127.0.0.1'

# Port of the REVEN server running the scenario.
# After starting a REVEN server on your scenario, you can get its port on the Analyze page of that scenario.
reven_backend_port = 13370


# Range control

# First transition considered for the detection of allocation/deallocation pairs
# If set to None, then the first transition of the trace
uaf_from_tr = None
# uaf_from_tr = 3984021008  # ex: vlc scenario
# uaf_from_tr = 8655210  # ex: bksod scenario

# First transition **not** considered for the detection of allocation/deallocation pairs
# If set to None, then the last transition of the trace
uaf_to_tr = None
# uaf_to_tr = 3986760400  # ex: vlc scenario
# uaf_to_tr = 169673257  # ex: bksod scenario


# Filter control

# Beware that, when filtering, if an allocation happens in the specified process and/or binary,
# the script will fail if the deallocation happens in a different process and/or binary.
# This issue should only happen for allocations in the kernel.

# Specify on which PID the allocation/deallocation pairs should be kept.
# Set a value of None to not filter on the PID.
faulty_process_pid = None
# faulty_process_pid = 2620  # ex: vlc scenario
# faulty_process_pid = None  # ex: bksod
# We can't filter on process for BKSOD as we have alloc in a process and free in another process
# As termdd.sys and the process svchost.exe (pid: 1004) is using kernel workers to free some resources
# So the allocs are done in the process svchost.exe and some of the frees in the "System" process (pid: 4)

# Specify on which process name the allocation/deallocation pairs should be kept.
# Set a value of None to not filter on the process name.
# If both a process PID and a process name are specified, please make sure that they both
# refer to the same process, otherwise all allocations will be filtered and no results
# will be produced.
faulty_process_name = None
# faulty_process_name = "vlc.exe"  # ex: vlc scenario
# faulty_process_name = None  # ex: bksod scenario

# Specify on which binary name the allocation/deallocation pairs should be kept.
# Set a value of None to not filter on the binary name.
# Only allocation/deallocation taking place in the binaries whose filename,
# path or name contain the specified value are kept.
# If filtering on both a process and a binary, please make sure that there are
# allocations taking place in that binary in the selected process, otherwise all
# allocations will be filtered and no result will be produced.
faulty_binary = None
# faulty_binary = "libmkv_plugin.dll"  # ex: vlc scenario
# faulty_binary = "termdd.sys"  # ex: bksod scenario


# Address control

# Specify a **physical** address suspected of being faulty here,
# to only test UaF for this specific address, instead of all (filtered) allocations.
# The address should still be returned by an allocation/deallocation pair.
# To get a physical address from a virtual address, find a context where the address
# is mapped, then use `virtual_address.translate(ctx)`.
uaf_faulty_physical_address = None
# uaf_faulty_physical_address = 0x5a36cd20  # ex: vlc scenario
# uaf_faulty_physical_address = 0x7fb65010  # ex: bksod scenario


# Allocator control

# The script can use two allocators to find allocation/deallocation pairs.
# The following booleans allow to enable the search for allocations by these
# allocators for a scenario.
# Generally it is expected to have only a single allocator enabled for a given
# scenario.

# To add your own allocator, please look at how the two provided allocators were
# added.

# Whether or not to look for malloc/free allocation/deallocation pairs.
search_malloc = True
# search_malloc = True  # ex: vlc scenario, user space scenario
# search_malloc = False  # ex: bksod scenario

# Whether or not to look for ExAllocatePoolWithTag/ExFreePoolWithTag
# allocation/deallocation pairs.
# This allocator is used by the Windows kernel.
search_pool_allocation = False
# search_pool_allocation = False  # ex: vlc scenario
# search_pool_allocation = True  # ex: bksod scenario, kernel scenario


# Taint control

# Technical parameter: number of accesses in a taint after which the script gives up on
# that particular taint.
# As long taints degrade the performance of the script significantly, it is recommended to
# give up on a taint after it exceeds a certain number of operations.
# If you experience missing results, try increasing the value.
# If you experience a very long runtime for the script, try decreasing the value.
# The default value should be appropriate for most cases.
taint_max_length = 100000

# %%
import itertools  # noqa: E402
from collections import OrderedDict  # noqa: E402
from typing import Dict, List  # noqa: E402

import reven2  # noqa: E402
import reven2.preview.taint  # noqa: E402

# %%
# Python script to connect to this scenario:
server = reven2.RevenServer(reven_backend_host, reven_backend_port)
print(server.trace.transition_count)


# %%
class MemoryRange:
    page_size = 4096
    page_mask = ~(page_size - 1)

    def __init__(self, logical_address, size):
        self.logical_address = logical_address
        self.size = size

        self.pages = [{
            'logical_address': self.logical_address,
            'size': self.size,
            'physical_address': None,
            'ctx_physical_address_mapped': None,
        }]

        # Compute the pages
        while (((self.pages[-1]['logical_address'] & ~MemoryRange.page_mask) + self.pages[-1]['size'] - 1)
                >= MemoryRange.page_size):
            # Compute the size of the new page
            new_page_size = ((self.pages[-1]['logical_address'] & ~MemoryRange.page_mask) + self.pages[-1]['size']
                             - MemoryRange.page_size)

            # Reduce the size of the previous page and create the new one
            self.pages[-1]['size'] -= new_page_size
            self.pages.append({
                'logical_address': self.pages[-1]['logical_address'] + self.pages[-1]['size'],
                'size': new_page_size,
                'physical_address': None,
                'ctx_physical_address_mapped': None,
            })

    def try_translate_first_page(self, ctx):
        if self.pages[0]['physical_address'] is not None:
            return True

        physical_address = reven2.address.LogicalAddress(self.pages[0]['logical_address']).translate(ctx)

        if physical_address is None:
            return False

        self.pages[0]['physical_address'] = physical_address.offset
        self.pages[0]['ctx_physical_address_mapped'] = ctx

        return True

    def try_translate_all_pages(self, ctx):
        return_value = True

        for page in self.pages:
            if page['physical_address'] is not None:
                continue

            physical_address = reven2.address.LogicalAddress(page['logical_address']).translate(ctx)

            if physical_address is None:
                return_value = False
                continue

            page['physical_address'] = physical_address.offset
            page['ctx_physical_address_mapped'] = ctx

        return return_value

    def is_physical_address_range_in_translated_pages(self, physical_address, size):
        for page in self.pages:
            if page['physical_address'] is None:
                continue

            if (
                physical_address >= page['physical_address']
                and physical_address + size <= page['physical_address'] + page['size']
            ):
                return True

        return False

    def __repr__(self):
        return "MemoryRange(0x%x, %d)" % (self.logical_address, self.size)

# Utils to translate the physical address of an address allocated just now
#     - ctx should be the ctx where the address is located in `rax`
#     - memory_range should be the range of memory of the newly allocated buffer
#
# We are using the translate API to translate it but sometimes just after the allocation
# the address isn't mapped yet. For that we are using the slicing and for all slice access
# we are trying to translate the address.


def translate_first_page_of_allocation(ctx, memory_range):
    if memory_range.try_translate_first_page(ctx):
        return

    tainter = reven2.preview.taint.Tainter(server.trace)
    taint = tainter.simple_taint(
        tag0="rax",
        from_context=ctx,
        to_context=None,
        is_forward=True
    )

    for access in taint.accesses(changes_only=False).all():
        if memory_range.try_translate_first_page(access.transition.context_after()):
            taint.cancel()
            return

    raise RuntimeError("Couldn't find the physical address of the first page")


# %%
class AllocEvent:
    def __init__(self, memory_range, tr_begin, tr_end):
        self.memory_range = memory_range
        self.tr_begin = tr_begin
        self.tr_end = tr_end


class FreeEvent:
    def __init__(self, logical_address, tr_begin, tr_end):
        self.logical_address = logical_address
        self.tr_begin = tr_begin
        self.tr_end = tr_end


def retrieve_events_for_symbol(
    alloc_dict,
    event_class,
    symbol,
    retrieve_event_info,
    event_filter=None,
):
    for ctx in server.trace.search.symbol(
        symbol,
        from_context=None if uaf_from_tr is None else server.trace.context_before(uaf_from_tr),
        to_context=None if uaf_to_tr is None else server.trace.context_before(uaf_to_tr)
    ):
        # We don't want hit on exception (code pagefault, hardware interrupts, etc)
        if ctx.transition_after().exception is not None:
            continue

        previous_location = (ctx - 1).ossi.location()
        previous_process = (ctx - 1).ossi.process()

        # Filter by process pid/process name/binary name

        # Filter by process pid
        if faulty_process_pid is not None and previous_process.pid != faulty_process_pid:
            continue

        # Filter by process name
        if faulty_process_name is not None and previous_process.name != faulty_process_name:
            continue

        # Filter by binary name / filename / path
        if faulty_binary is not None and faulty_binary not in [
            previous_location.binary.name,
            previous_location.binary.filename,
            previous_location.binary.path
        ]:
            continue

        # Filter the event with the argument filter
        if event_filter is not None:
            if event_filter(ctx.ossi.location(), previous_location):
                continue

        # Retrieve the call/ret
        # The heuristic is that the ret is the end of our function
        #   - If the call is inlined it should be at the end of the caller function, so the ret is the ret of our
        #     function
        #   - If the call isn't inlined, the ret should be the ret of our function
        ctx_call = next(ctx.stack.frames()).creation_transition.context_after()

        tr_ret = ctx_call.transition_before().find_inverse()
        # Finding the inverse operation can fail
        # for instance if the end of the allocation function was not recorded in the trace
        if tr_ret is None:
            continue
        ctx_ret = tr_ret.context_before()

        # Build the event by reading the needed registers
        if event_class == AllocEvent:
            current_address, size = retrieve_event_info(ctx, ctx_ret)

            # Filter the alloc failing
            if current_address == 0x0:
                continue

            memory_range = MemoryRange(current_address, size)
            translate_first_page_of_allocation(ctx_ret, memory_range)

            if memory_range.pages[0]['physical_address'] not in alloc_dict:
                alloc_dict[memory_range.pages[0]['physical_address']] = []

            alloc_dict[memory_range.pages[0]['physical_address']].append(
                AllocEvent(
                    memory_range,
                    ctx.transition_after(), ctx_ret.transition_after()
                )
            )
        elif event_class == FreeEvent:
            current_address = retrieve_event_info(ctx, ctx_ret)

            # Filter the free of NULL
            if current_address == 0x0:
                continue

            current_physical_address = reven2.address.LogicalAddress(current_address).translate(ctx).offset

            if current_physical_address not in alloc_dict:
                alloc_dict[current_physical_address] = []

            alloc_dict[current_physical_address].append(
                FreeEvent(
                    current_address,
                    ctx.transition_after(), ctx_ret.transition_after()
                )
            )
        else:
            raise RuntimeError("Unknown event class: %s" % event_class.__name__)

# %%
# %%time


alloc_dict: Dict = {}


# Basic functions to retrieve the arguments
# They are working for the allocations/frees functions but won't work for all functions
# Particularly because on x86 we don't handle the size of the arguments
# nor if they are pushed left to right or right to left
def retrieve_first_argument(ctx):
    if ctx.is64b():
        return ctx.read(reven2.arch.x64.rcx)
    else:
        esp = ctx.read(reven2.arch.x64.esp)
        return ctx.read(reven2.address.LogicalAddress(esp + 4, reven2.arch.x64.ss), 4)


def retrieve_second_argument(ctx):
    if ctx.is64b():
        return ctx.read(reven2.arch.x64.rdx)
    else:
        esp = ctx.read(reven2.arch.x64.esp)
        return ctx.read(reven2.address.LogicalAddress(esp + 8, reven2.arch.x64.ss), 4)


def retrieve_return_value(ctx):
    if ctx.is64b():
        return ctx.read(reven2.arch.x64.rax)
    else:
        return ctx.read(reven2.arch.x64.eax)


def retrieve_alloc_info_with_size_as_first_argument(ctx_begin, ctx_end):
    return (
        retrieve_return_value(ctx_end),
        retrieve_first_argument(ctx_begin)
    )


def retrieve_alloc_info_with_size_as_second_argument(ctx_begin, ctx_end):
    return (
        retrieve_return_value(ctx_end),
        retrieve_second_argument(ctx_begin)
    )


def retrieve_alloc_info_for_calloc(ctx_begin, ctx_end):
    return (
        retrieve_return_value(ctx_end),
        retrieve_first_argument(ctx_begin) * retrieve_second_argument(ctx_begin)
    )


def retrieve_free_info_with_address_as_first_argument(ctx_begin, ctx_end):
    return retrieve_first_argument(ctx_begin)


if search_malloc:
    def filter_in_realloc(location, caller_location):
        return location.binary == caller_location.binary and caller_location.symbol.name == "realloc"

    # Search for allocations with malloc
    for symbol in server.ossi.symbols(r'^_?malloc$', binary_hint=r'msvcrt.dll'):
        retrieve_events_for_symbol(alloc_dict, AllocEvent, symbol, retrieve_alloc_info_with_size_as_first_argument,
                                   filter_in_realloc)

    # Search for allocations with calloc
    for symbol in server.ossi.symbols(r'^_?calloc(_crt)?$', binary_hint=r'msvcrt.dll'):
        retrieve_events_for_symbol(alloc_dict, AllocEvent, symbol, retrieve_alloc_info_for_calloc)

    # Search for deallocations with free
    for symbol in server.ossi.symbols(r'^_?free$', binary_hint=r'msvcrt.dll'):
        retrieve_events_for_symbol(alloc_dict, FreeEvent, symbol, retrieve_free_info_with_address_as_first_argument,
                                   filter_in_realloc)

    # Search for re-allocations with realloc
    for symbol in server.ossi.symbols(r'^_?realloc$', binary_hint=r'msvcrt.dll'):
        retrieve_events_for_symbol(alloc_dict, AllocEvent, symbol, retrieve_alloc_info_with_size_as_second_argument)
        retrieve_events_for_symbol(alloc_dict, FreeEvent, symbol, retrieve_free_info_with_address_as_first_argument)

if search_pool_allocation:
    # Search for allocations with ExAllocatePool...
    def filter_ex_allocate_pool(location, caller_location):
        return location.binary == caller_location.binary and caller_location.symbol.name.startswith("ExAllocatePool")

    for symbol in server.ossi.symbols(r'^ExAllocatePool', binary_hint=r'ntoskrnl.exe'):
        retrieve_events_for_symbol(alloc_dict, AllocEvent, symbol, retrieve_alloc_info_with_size_as_second_argument,
                                   filter_ex_allocate_pool)

    # Search for deallocations with ExFreePool...
    def filter_ex_free_pool(location, caller_location):
        return location.binary == caller_location.binary and caller_location.symbol.name == "ExFreePool"

    for symbol in server.ossi.symbols(r'^ExFreePool', binary_hint=r'ntoskrnl.exe'):
        retrieve_events_for_symbol(alloc_dict, FreeEvent, symbol, retrieve_free_info_with_address_as_first_argument,
                                   filter_ex_free_pool)

# Sort the events per address and event type
for physical_address in alloc_dict.keys():
    alloc_dict[physical_address] = list(sorted(
        alloc_dict[physical_address],
        key=lambda event: (event.tr_begin.id, 0 if isinstance(event, FreeEvent) else 1)
    ))

# Sort the dict by address
alloc_dict = OrderedDict(sorted(alloc_dict.items()))


# %%
def get_alloc_free_pairs(events, errors=None):
    previous_event = None
    for event in events:
        if isinstance(event, AllocEvent):
            if previous_event is None:
                pass
            elif isinstance(previous_event, AllocEvent):
                if errors is not None:
                    errors.append("Two consecutives allocs found")
        elif isinstance(event, FreeEvent):
            if previous_event is None:
                continue
            elif isinstance(previous_event, FreeEvent):
                if errors is not None:
                    errors.append("Two consecutives frees found")
            elif isinstance(previous_event, AllocEvent):
                yield (previous_event, event)
        else:
            assert 0, ("Unknown event type: %s" % type(event))

        previous_event = event


# %%
# %%time

# Basic checks of the events

for physical_address, events in alloc_dict.items():
    for event in events:
        if not isinstance(event, AllocEvent) and not isinstance(event, FreeEvent):
            raise RuntimeError("Unknown event type: %s" % type(event))

    errors: List[str] = []
    for (alloc_event, free_event) in get_alloc_free_pairs(events, errors):
        # Check the uniformity of the logical address between the alloc and the free
        if alloc_event.memory_range.logical_address != free_event.logical_address:
            errors.append("Phys:0x%x: Alloc #%d - Free #%d with different logical address: 0x%x != 0x%x" % (
                physical_address,
                alloc_event.tr_begin.id, free_event.tr_begin.id,
                alloc_event.memory_range.logical_address, free_event.logical_address))

        # Check size of 0x0
        if alloc_event.memory_range.size == 0x0 or alloc_event.memory_range.size is None:
            errors.append("Phys:0x%x: Alloc #%d - Free #%d with weird size %s" % (
                physical_address,
                alloc_event.tr_begin.id, free_event.tr_begin.id,
                alloc_event.memory_range.size))

    if len(errors) > 0:
        print("Phys:0x%x: Error(s) detected:" % (physical_address))
        for error in errors:
            print("    - %s" % error)

# %%
# Print the events

for physical_address, events in alloc_dict.items():
    print("Phys:0x%x" % (physical_address))
    print("    Events:")
    for event in events:
        if isinstance(event, AllocEvent):
            print("        - Alloc at #%d (0x%x of size 0x%x)" % (event.tr_begin.id,
                  event.memory_range.logical_address, event.memory_range.size))
        elif isinstance(event, FreeEvent):
            print("        - Free at #%d (0x%x)" % (event.tr_begin.id, event.logical_address))

    print("    Pairs:")
    for (alloc_event, free_event) in get_alloc_free_pairs(events):
        print("    - Allocated at #%d (0x%x of size 0x%x) and freed at #%d (0x%x)" % (alloc_event.tr_begin.id,
              alloc_event.memory_range.logical_address, alloc_event.memory_range.size, free_event.tr_begin.id,
              free_event.logical_address))

    print()


# %%
# This function is used to ignore the changes from the `free` in the taint of the address
# as the `free` will store the address on some internal structures re-used to malloc
# others addresses leading to false positives

def start_alloc_address_taint(server, alloc_event, free_event):
    # Setup the first taint [alloc; free]
    tainter = reven2.preview.taint.Tainter(server.trace)
    taint = tainter.simple_taint(
        tag0="rax" if alloc_event.tr_end.mode == reven2.trace.Mode.X86_64 else "eax",
        from_context=alloc_event.tr_end.context_before(),
        to_context=free_event.tr_begin.context_before() + 1,
        is_forward=True
    )

    # Setup the second taint [free; [
    state_before_free = taint.state_at(free_event.tr_begin.context_before())

    # `rcx` is lost during the execution of the free in theory
    # `rsp` is non-useful if we have it
    tag0_regs = filter(
        lambda x: x[0].register.name not in ['rcx', 'rsp'],
        state_before_free.tainted_registers()
    )

    # We don't want to keep memory inside the allocated object as accessing them will trigger a UAF anyway
    # It is also used to remove some false-positive because the memory will be used by the alloc/free functions
    # TODO: Handle addresses different than PhysicalAddress (is that even possible?)
    tag0_mems = filter(
        lambda x: not alloc_event.memory_range.is_physical_address_range_in_translated_pages(
            x[0].address.offset, x[0].size),
        state_before_free.tainted_memories()
    )

    # Only keep the slices
    tag0 = map(
        lambda x: x[0],
        itertools.chain(tag0_regs, tag0_mems)
    )

    tainter = reven2.preview.taint.Tainter(server.trace)
    return tainter.simple_taint(
        tag0=list(tag0),
        from_context=free_event.tr_end.context_before(),
        to_context=None,
        is_forward=True
    )


# %%
# Capstone utilities
def get_reven_register_from_name(name):
    for reg in reven2.arch.helpers.x64_registers():
        if reg.name == name:
            return reg

    raise RuntimeError("Unknown register: %s" % name)


def read_reg(tr, reg):
    if reg in [reven2.arch.x64.rip, reven2.arch.x64.eip]:
        return tr.pc
    else:
        return tr.context_before().read(reg)


def compute_dereferenced_address(tr, cs_insn, cs_op):
    dereferenced_address = 0

    if cs_op.value.mem.base != 0:
        base_reg = get_reven_register_from_name(cs_insn.reg_name(cs_op.value.mem.base))
        dereferenced_address += read_reg(tr, base_reg)

    if cs_op.value.mem.index != 0:
        index_reg = get_reven_register_from_name(cs_insn.reg_name(cs_op.value.mem.index))
        dereferenced_address += (cs_op.value.mem.scale * read_reg(tr, index_reg))

    dereferenced_address += cs_op.value.mem.disp

    return dereferenced_address & 0xFFFFFFFFFFFFFFFF


# %%
def uaf_analyze_function(physical_address, alloc_events):
    uaf_count = 0

    # Setup capstone
    md_64 = capstone.Cs(capstone.CS_ARCH_X86, capstone.CS_MODE_64)
    md_64.detail = True

    md_32 = capstone.Cs(capstone.CS_ARCH_X86, capstone.CS_MODE_32)
    md_32.detail = True

    errors = []
    for (alloc_event, free_event) in get_alloc_free_pairs(alloc_events, errors):
        # Get the memory accesses access of the allocated block after the free
        # The optimization is disabled if we can't translate all the pages of the memory range
        mem_access = None
        mem_accesses = None

        if alloc_event.memory_range.try_translate_all_pages(free_event.tr_begin.context_before()):
            mem_accesses = reven2.util.collate(map(
                lambda page: server.trace.memory_accesses(
                    reven2.address.PhysicalAddress(page['physical_address']),
                    page['size'],
                    from_transition=free_event.tr_end,
                    to_transition=None,
                    is_forward=True,
                    operation=None
                ),
                alloc_event.memory_range.pages
            ), key=lambda access: access.transition)

            try:
                mem_access = next(mem_accesses)
            except StopIteration:
                continue

        else:
            print("Phys:0x%x: Allocated at #%d (0x%x of size 0x%x) and freed at #%d (0x%x)" % (physical_address,
                  alloc_event.tr_begin.id, alloc_event.memory_range.logical_address, alloc_event.memory_range.size,
                  free_event.tr_begin.id, free_event.logical_address))
            print("    Warning: Memory history optimization disabled because we couldn't "
                  "translate all the memory pages")
            print()

        # Setup the slicing
        taint = start_alloc_address_taint(server, alloc_event, free_event)

        # Iterate on the slice
        access_count = 0
        for access in taint.accesses(changes_only=False).all():
            access_count += 1
            access_transition = access.transition

            if access_count > taint_max_length:
                print("Phys:0x%x: Allocated at #%d (0x%x of size 0x%x) and freed at #%d (0x%x)" % (physical_address,
                      alloc_event.tr_begin.id, alloc_event.memory_range.logical_address, alloc_event.memory_range.size,
                      free_event.tr_begin.id, free_event.logical_address))
                print("    Warning: Allocation skipped: post-free taint stopped after %d accesses" % access_count)
                print()
                break

            if mem_accesses is not None:
                # Check that we have an access on the same transition as the taint access
                # if not the memory operand won't be an UAF anyway so we can skip it

                while mem_access.transition < access_transition:
                    try:
                        mem_access = next(mem_accesses)
                    except StopIteration:
                        break

                if mem_access.transition != access_transition:
                    continue

            md = md_64 if access_transition.mode == reven2.trace.Mode.X86_64 else md_32
            cs_insn = next(md.disasm(access_transition.instruction.raw, access_transition.instruction.size))

            # Skip `lea` instructions are they are not really memory read/write and the taint
            # will propagate the taint anyway so that we will see the dereference of the computed value
            if cs_insn.mnemonic == "lea":
                continue

            registers_in_state = {}
            for reg_slice, _ in access.state_before().tainted_registers():
                registers_in_state[reg_slice.register.name] = reg_slice

            for cs_op in cs_insn.operands:
                if cs_op.type != capstone.x86.X86_OP_MEM:
                    continue

                uaf_reg = None

                if cs_op.value.mem.base != 0:
                    base_reg_name = cs_insn.reg_name(cs_op.value.mem.base)
                    if base_reg_name in registers_in_state:
                        uaf_reg = registers_in_state[base_reg_name]

                if uaf_reg is None and cs_op.value.mem.index != 0:
                    index_reg_name = cs_insn.reg_name(cs_op.value.mem.index)
                    if index_reg_name in registers_in_state:
                        uaf_reg = registers_in_state[index_reg_name]

                if uaf_reg is None:
                    continue

                dereferenced_address = compute_dereferenced_address(access_transition, cs_insn, cs_op)

                if mem_accesses is None:
                    # As we don't have the memory access optimization we need to check if the dereferenced address
                    # is in the allocated buffer
                    # We only check on the translated pages as the taint won't return an access with a pagefault
                    # so the dereferenced address should be translated

                    dereferenced_physical_address = reven2.address.LogicalAddress(dereferenced_address).translate(
                        access_transition.context_before())

                    if dereferenced_physical_address is None:
                        continue

                    if not alloc_event.memory_range.is_physical_address_range_in_translated_pages(
                            dereferenced_physical_address.offset, 1):
                        continue

                print("Phys:0x%x: Allocated at #%d (0x%x of size 0x%x) and freed at #%d (0x%x)" % (physical_address,
                      alloc_event.tr_begin.id, alloc_event.memory_range.logical_address, alloc_event.memory_range.size,
                      free_event.tr_begin.id, free_event.logical_address))
                print("    UAF coming from reg %s[%d-%d] leading to dereferenced address = 0x%x" % (
                      uaf_reg.register.name, uaf_reg.begin, uaf_reg.end, dereferenced_address))
                print("    ", end="")
                print(access_transition, end=" ")
                print(access_transition.context_before().ossi.location())
                print("    Accessed %d transitions after the free" % (access_transition.id - free_event.tr_end.id))
                print()

                uaf_count += 1

    if len(errors) > 0:
        print("Phys:0x%x: Error(s) detected:" % (physical_address))
        for error in errors:
            print("    - %s" % error)

    return uaf_count


# %%
# %%time

uaf_count = 0

if uaf_faulty_physical_address is None:
    for physical_address, alloc_events in alloc_dict.items():
        uaf_count += uaf_analyze_function(physical_address, alloc_events)

else:
    if uaf_faulty_physical_address not in alloc_dict:
        raise KeyError("The passed physical address was not detected during the allocation search")
    uaf_count += uaf_analyze_function(uaf_faulty_physical_address, alloc_dict[uaf_faulty_physical_address])


print("---------------------------------------------------------------------------------")
uaf_begin_range = "the beginning of the trace" if uaf_from_tr is None else "#{}".format(uaf_to_tr)
uaf_end_range = "the end of the trace" if uaf_to_tr is None else "#{}".format(uaf_to_tr)
uaf_range = ("on the whole trace" if uaf_from_tr is None and uaf_to_tr is None else
             "between {} and {}".format(uaf_begin_range, uaf_end_range))

uaf_range_size = server.trace.transition_count
if uaf_from_tr is not None:
    uaf_range_size -= uaf_from_tr
if uaf_to_tr is not None:
    uaf_range_size -= server.trace.transition_count - uaf_to_tr

if uaf_faulty_physical_address is None:
    searched_memory_addresses = "with {} searched memory addresses".format(len(alloc_dict))
else:
    searched_memory_addresses = "on {:#x}".format(uaf_faulty_physical_address)

print("{} UAF(s) found {} ({} transitions) {}".format(
    uaf_count, uaf_range, uaf_range_size, searched_memory_addresses
))
print("---------------------------------------------------------------------------------")

# %%

Find symbols that access a specific memory range

Purpose

This notebook and script are designed to find all symbols that access a specific memory range. This script searches a Reven trace for all symbols that accessed a specific memory range. The script can filter the results by processes, ring, included binaries, excluded binaries, excluded symbols, context range and memory access operation. The script can generate two kinds of results:

  • process, binary and symbol information for each memory access.
  • for each symbol, all the memory accesses that occurred in that symbol. Note that this option can take long time to start showing results, especially when there is many nested functions or many functions that don't end in the trace. Note that:
  • accesses will be reported as belonging to the innermost symbol that has not been excluded and whose binary has not been excluded in the configuration.
  • we consider that we are "in a symbol" when the corresponding context.location.symbol returns this symbol. REVEN returns the closest symbol with an rva lower than ours. Note that we are not trying to determine the exact bounds of the function with that symbol for name. In particular, when there are missing symbols, this may report a symbol we saw a long time ago rather than

How to use

Results can be generated from this notebook or from the command line. The script can also be imported as a module for use from your own script or notebook.

From the notebook

  1. Upload the symbols_access_memory_range.ipynb file in Jupyter.
  2. Fill out the parameters cell of this notebook according to your scenario and desired output.
  3. Run the full notebook.

From the command line

  1. Make sure that you are in an environment that can run REVEN scripts.
  2. Run python symbols_access_memory_range.py --help to get a tour of available arguments.
  3. Run python symbols_access_memory_range.py --host <your_host> --port <your_port> [<other_option>] with your arguments of choice.

Imported in your own script or notebook

  1. Make sure that you are in an environment that can run REVEN scripts.
  2. Make sure that symbols_access_memory_range.py is in the same directory as your script or notebook.
  3. Add import symbols_access_memory_range to your script or notebook. You can access the various functions and classes exposed by the module from the symbols_access_memory_range namespace.
  4. Refer to the Argument parsing cell for an example of use in a script, and to the Parameters cell and below for an example of use in a notebook (you just need to preprend symbols_access_memory_range in front of the functions and classes from the script).

Known limitations

N/A.

Supported versions

REVEN 2.10+

Supported perimeter

Any REVEN scenario.

Dependencies

The script requires that the target REVEN scenario have:

  • The OSSI feature replayed.
  • The memory history feature replayed.
  • pandas python module

Source

# ---
# jupyter:
#   jupytext:
#     formats: ipynb,py:percent
#     text_representation:
#       extension: .py
#       format_name: percent
#   kernelspec:
#     display_name: reven
#     language: python
#     name: reven-python3
# ---

# %% [markdown]
# # Find symbols that access a specific memory range
#
# ## Purpose
#
# This notebook and script are designed to find all symbols that access a specific memory range.
#
# This script searches a Reven trace for all symbols that accessed a specific memory range.
# The script can filter the results by processes, ring, included binaries, excluded binaries, excluded
# symbols, context range and memory access operation.
#
# The script can generate two kinds of results:
# - process, binary and symbol information for each memory access.
# - for each symbol, all the memory accesses that occurred in that symbol.
#   Note that this option can take long time to start showing results,
#   especially when there is many nested functions or many functions that don't end in the trace.
#
# Note that:
# - accesses will be reported as belonging to the innermost symbol that has not been excluded
#   and whose binary has not been excluded in the configuration.
# - we consider that we are "in a symbol" when the corresponding context.location.symbol returns this symbol.
#   REVEN returns the closest symbol with an `rva` lower than ours. Note that we are not trying to determine the
#   exact bounds of the function with that symbol for name. In particular, when there are missing symbols,
#   this may report a symbol we saw a long time ago rather than <unknown>
#
#
#
# ## How to use
#
# Results can be generated from this notebook or from the command line.
# The script can also be imported as a module for use from your own script or notebook.
#
#
# ### From the notebook
#
# 1. Upload the `symbols_access_memory_range.ipynb` file in Jupyter.
# 2. Fill out the [parameters](#Parameters) cell of this notebook according to your scenario and desired output.
# 3. Run the full notebook.
#
#
# ### From the command line
#
# 1. Make sure that you are in an
#    [environment](http://doc.tetrane.com/professional/latest/Python-API/Installation.html#on-the-reven-server)
#    that can run REVEN scripts.
# 2. Run `python symbols_access_memory_range.py --help` to get a tour of available arguments.
# 3. Run `python symbols_access_memory_range.py --host <your_host> --port <your_port> [<other_option>]` with your
# arguments of choice.
#
# ### Imported in your own script or notebook
#
# 1. Make sure that you are in an
#    [environment](http://doc.tetrane.com/professional/latest/Python-API/Installation.html#on-the-reven-server)
#    that can run REVEN scripts.
# 2. Make sure that `symbols_access_memory_range.py` is in the same directory as your script or notebook.
# 3. Add `import symbols_access_memory_range` to your script or notebook. You can access the various functions and
#    classes exposed by the module from the `symbols_access_memory_range` namespace.
# 4. Refer to the [Argument parsing](#Argument-parsing) cell for an example of use in a script, and to the
#    [Parameters](#Parameters) cell and below for an example of use in a notebook (you just need to preprend
#    `symbols_access_memory_range` in front of the functions and classes from the script).
#
# ## Known limitations
#
# N/A.
#
# ## Supported versions
#
# REVEN 2.10+
#
# ## Supported perimeter
#
# Any REVEN scenario.
#
# ## Dependencies
#
# The script requires that the target REVEN scenario have:
#
# * The OSSI feature replayed.
# * The memory history feature replayed.
# * pandas python module

# %% [markdown]
# ### Package imports

# %%
import argparse
from enum import Enum
from typing import Iterable as _Iterable, List
from typing import Optional as _Optional
from typing import cast as _cast

from IPython.core.display import display  # type: ignore

import pandas

import reven2.address as _address
import reven2.arch as _arch
from reven2.filter import RingPolicy
from reven2.memhist import MemoryAccess, MemoryAccessOperation
from reven2.memory_range import MemoryRange
from reven2.ossi import Binary, Process, Symbol
from reven2.ossi.thread import Thread
from reven2.prelude import RevenServer
from reven2.stack import Stack
from reven2.trace import Context, Trace
from reven2.util import collate as _collate


# %% [markdown]
# ### Utility functions


# %%
# Detect if we are currently running a Jupyter notebook.
#
# This is used e.g. to display rendered results inline in Jupyter when we are executing in the context of a Jupyter
# notebook, or to display raw results on the standard output when we are executing in the context of a script.
def in_notebook():
    try:
        from IPython import get_ipython  # type: ignore

        if get_ipython() is None or ("IPKernelApp" not in get_ipython().config):
            return False
    except ImportError:
        return False
    return True


# %% [markdown]
# ### Helper classes for results

# %%
class CallSymbol:
    r"""
    CallSymbol is a helper class used to represent a symbol with its start and end context
    """

    def __init__(self, symbol: _Optional[Symbol], start: Context, end: _Optional[Context] = None) -> None:
        self._symbol = symbol
        self._start = start
        self._end = end

    @property
    def symbol(self) -> _Optional[Symbol]:
        r"""
        B{Property:} The symbol of the call symbol. None if the symbol is unknown.
        """
        return self._symbol

    @property
    def start_context(self) -> Context:
        r"""
        B{Property:} The start context of the call symbol.
        """
        return self._start

    @property
    def end_context(self) -> _Optional[Context]:
        r"""
        B{Property:} The end excluded context of the call symbol. None if the end context isn't in the trace.
        """
        return self._end

    def __eq__(self, other: "CallSymbol") -> bool:  # type: ignore
        return self._symbol == other._symbol and self._start == other._start and self._end == other._end

    def __ne__(self, other: "CallSymbol") -> bool:  # type: ignore
        return not (self == other)


class MemoryRangeSymbolResult:
    r"""
    MemoryRangeSymbolResult is a helper class that represents one result of the search.
    """

    def __init__(
        self,
        call_symbol: CallSymbol,
        memory_access: _Optional[MemoryAccess],
        ring: int,
        process: _Optional[Process],
        thread: _Optional[Thread],
        binary: _Optional[Binary],
    ) -> None:
        self._call_symbol = call_symbol
        self._memory_accesses = [] if memory_access is None else [memory_access]
        self._ring = ring
        self._process = process
        self._thread = thread
        self._binary = binary

    def copy(self) -> "MemoryRangeSymbolResult":
        r"""
        return a copy of this object

        it makes a shallow copy of all attributes except for memory accesses where the list is deeply copied
        """
        new_obj = MemoryRangeSymbolResult(
            call_symbol=self._call_symbol,
            memory_access=None,
            ring=self._ring,
            process=self._process,
            thread=self._thread,
            binary=self._binary,
        )
        if self._memory_accesses is not None:
            new_obj._memory_accesses += self._memory_accesses
        return new_obj

    @property
    def call_symbol(self) -> CallSymbol:
        r"""
        B{Property:} The call symbol of the result.
        """
        return self._call_symbol

    @property
    def memory_accesses(self) -> List[MemoryAccess]:
        r"""
        B{Property:} The memory accesses of the result.
        """
        return self._memory_accesses

    @property
    def ring(self) -> int:
        r"""
        B{Property:} The ring of the result.
        """
        return self._ring

    @property
    def process(self) -> _Optional[Process]:
        r"""
        B{Property:} The process of the result.
        """
        return self._process

    @property
    def binary(self) -> _Optional[Binary]:
        r"""
        B{Property:} The binary of the result. None if the binary is unknown.
        """
        return self._binary

    @property
    def thread(self) -> _Optional[Thread]:
        r"""
        B{Property:} The thread of the result.
        """
        return self._thread

    def __eq__(self, other: "MemoryRangeSymbolResult") -> bool:  # type: ignore
        return (
            self._ring == other._ring
            and self._process is not None
            and other._process is not None
            and self._process.name == other._process.name
            and self._process.pid == other._process.pid
            and self._process.ppid == other._process.ppid
            and self._thread is not None
            and other._thread is not None
            and self._thread.id == other._thread.id
            and self._thread.owner_process_id == other._thread.owner_process_id
            and (
                (self._binary is None and other._binary is None)
                or (self._binary is not None and other._binary is not None and self._binary.path == other._binary.path)
            )
            and self._call_symbol == other._call_symbol
        )

    def __ne__(self, other: "MemoryRangeSymbolResult") -> bool:  # type: ignore
        return not (self == other)

    def __str__(self) -> str:
        memory_accesses = "\nmemory accesses:"
        for m in self._memory_accesses:
            memory_accesses += f"\n\t{m}, "
        memory_accesses += "\n"

        return (
            f"ring: {self._ring}, process: {self._process}, "
            f"thread: {self._thread}, binary: {self._binary}, "
            f"symbol: {self._call_symbol.symbol}[{self._call_symbol.start_context}, "
            f"{self._call_symbol.end_context}[ {memory_accesses}"
        )

    def format_as_html(self):
        r"""
        This method gets an html formatting string representation for this class instance.

        Information
        ===========
        @returns: C{String}
        """
        memory_accesses = "<p>memory accesses:</p><ol>"
        for m in self._memory_accesses:
            memory_accesses += f"<li>{m.format_as_html()}</li>"
        memory_accesses += "</ol>"

        return (
            f"ring: {self._ring}, process: {self._process if self._process is not None else 'unknown'}, "
            f" thread: {self._thread if self._thread is not None else 'unknown'}, binary: {self._binary}, "
            f"symbol: {self._call_symbol.symbol}[{self._call_symbol.start_context}, "
            f"{self._call_symbol.end_context}[ {memory_accesses}"
        )

    def _repr_html_(self):
        r"""
        Representation used by Jupyter Notebook when an instance of the this class is displayed in a cell.
        """
        return "<p>{}</p>".format(self.format_as_html())


# %% [markdown]
# ### MemoryRangeSymbolFinder
#
# This class represents the main logic of this script

# %%
class MemoryRangeSymbolFinder(object):
    r"""
        This class is a helper class to search for all symbols that access a specific memory range.
        Results can be filtered by processes, ring, binaries, excluded binaries, excluded symbols
        and a context range.

        The symbols that access this memory range are returned.

        Examples
        ========

        >>> # Search all symbols that access the memory range [ds:0xfffff8800115e180 ; 128]
        >>> # filtered by the process `svchost.exe` at the context #410545055.
        >>> processes = server.ossi.executed_processes('svchost.exe')
        >>> memory_range = MemoryRange::from_string("[ds:0xfffff8800115e180 ; 128]")
        >>> context = server.trace.context_before(410545055)
        >>> symbol_mem_finder = MemoryRangeSymbolFinder(
        ...     trace=server.trace, memory_range=memory_range,
        ...     context=context, processes=processes)
        >>> for r in symbol_mem_finder.query():
        ...     print(r)
        ring: 0, process: svchost.exe (1004), thread: 2256, binary: c:/windows/system32/drivers/cng.sys,
    symbol: cng!AesCbcDecrypt[Context before #25208343, Context before #25212457[
    memory accesses:
        [#25208488 xor r8d, dword ptr ds:[r11+rax*4+0x800]]Read access at
    @phy:0x36411e0 (virtual address: lin:0xfffff8800115e1e0) of size 4,
        ...

    """

    def __init__(
        self,
        trace: Trace,
        memory_range: MemoryRange,
        translation_context: _Optional[Context] = None,
        from_context: _Optional[Context] = None,
        to_context: _Optional[Context] = None,
        ring_policy: RingPolicy = RingPolicy.All,
        processes: _Optional[_Iterable[Process]] = None,
        included_binaries: _Optional[_Iterable[Binary]] = None,
        excluded_binaries: _Optional[_Iterable[Binary]] = None,
        excluded_symbols: _Optional[_Iterable[Symbol]] = None,
        operation: _Optional[MemoryAccessOperation] = None,
    ) -> None:
        r"""
        Initialize a C{MemoryRangeSymbolFinder}

        Information
        ===========

        @param trace: the trace where symbols will be looked for.
        @param memory_range: the memory range that are accessed by the returned symbols.
        @param translation_context: context used to translate the memory range when it is virtual.
        @param to_context: the context where the search will be ended.
        @param ring_policy: ring policy to search for.
        @param processes: processes to limit the search in it. If None, all processes will be filtered.
        @param included_binaries: binaries that must be included in the search.
                                  If None, all binaries will be included.
                                  When binary is not included, all its symbols are ignored with its memory accesses
        @param excluded_binaries: binaries that must be excluded from the search. If None nothing will be excluded.
                                  Accesses performed in this binary are reported, but using the first caller
                                  binary that is not excluded. Note that inclusion is applied before exclusion.
        @param excluded_symbols: symbols that must be excluded from the search. If None nothing will be excluded.
                                 Accesses performed in this symbol are reported, but using the first caller
                                 symbol that is not excluded.
        @param operation: limit results to accesses performing the specified operation.


        @raises TypeError: if trace is not a C{reven2.trace.Trace}.
        @raises ValueError: If provided memory range is virtual and the translation_context is None.
        """
        if not isinstance(trace, Trace):
            raise TypeError("You must provide a valid trace")
        self._trace = trace

        if isinstance(memory_range.address, _address.PhysicalAddress):
            self._physical_memory_ranges = [_cast(MemoryRange[_address.PhysicalAddress], memory_range)]
        elif translation_context is None:
            raise ValueError("You must provide a context for the translation if the memory range is virtual")
        else:
            self._physical_memory_ranges = [mem_range for mem_range in memory_range.translate(translation_context)]

        self._from_context = from_context
        self._to_context = to_context

        self._ring_policy = ring_policy
        self._processes = None if processes is None else [process for process in processes]
        self._included_binaries = None if included_binaries is None else {binary.name for binary in included_binaries}
        self._excluded_binaries = set() if excluded_binaries is None else {binary.name for binary in excluded_binaries}
        self._excluded_symbols = set() if excluded_symbols is None else {symbol.name for symbol in excluded_symbols}
        self._operation = operation

    def filter_by_processes(self, processes: _Iterable[Process]) -> "MemoryRangeSymbolFinder":
        r"""
        Extend the list of processes to limit the search in, and return the self object.

        Information
        ===========

        @param processes: processes to limit the search in.
        @returns : self object
        """
        if self._processes is None:
            self._processes = []
        self._processes += [process for process in processes]

        return self

    def filter_by_ring(self, ring_policy: RingPolicy) -> "MemoryRangeSymbolFinder":
        r"""
        Update the ring policy to search for and return the `self` object.

        Information
        ===========

        @param ring_policy: ring policy to search for.
        @returns : self object
        """
        self._ring_policy = ring_policy
        return self

    def from_context(self, context: Context) -> "MemoryRangeSymbolFinder":
        r"""
        Update the context where the search will be started and return the `self` object.

        Information
        ===========

        @param context: context where the search will be started.
        @returns : self object
        """
        self._from_context = context
        return self

    def to_context(self, context: Context) -> "MemoryRangeSymbolFinder":
        r"""
        Update the context where the search will be ended and return the `self` object.

        Information
        ===========

        @param context: context where the search will be ended.
        @returns : self object
        """
        self._to_context = context
        return self

    def include_bnaries(self, binaries: _Iterable[Binary]) -> "MemoryRangeSymbolFinder":
        r"""
        Extend the list of binaries that must be included in the search and return the `self` object.

        Information
        ===========

        @param binaries: binaries that must be included in the search.
        @returns : self object
        """
        if self._included_binaries is None:
            self._included_binaries = {binary.name for binary in binaries}
        else:
            self._included_binaries.update([binary.name for binary in binaries])
        return self

    def exclude_bnaries(self, binaries: _Iterable[Binary]) -> "MemoryRangeSymbolFinder":
        r"""
        Extend the list of binaries that must be excluded from the search and return the `self` object.

        Information
        ===========

        @param binaries: binaries that must be excluded from the search.
        @returns : self object
        """
        self._excluded_binaries.update([binary.name for binary in binaries])
        return self

    def exclude_symbols(self, symbols: _Iterable[Symbol]) -> "MemoryRangeSymbolFinder":
        r"""
        Extend the list of symbols that must be excluded from the search and return the `self` object.

        Information
        ===========

        @param symbols: symbols that must be excluded from the search.
        @returns : self object
        """
        self._excluded_symbols.update([symbol.name for symbol in symbols])
        return self

    def filter_by_memory_access_operation(
        self, operation: _Optional[MemoryAccessOperation] = None
    ) -> "MemoryRangeSymbolFinder":
        r"""
        Update the memory access operation to limit results to accesses performing this
        operation and return the `self` object.

        Information
        ===========

        @param operation: limit results to accesses performing the specified operation.
        @returns : self object
        """
        self._operation = operation
        return self

    def _is_the_same_stack(self, stack1: Stack, stack2: Stack) -> bool:
        # we assume that two stacks are the same if the first contexts of their first frames are the same

        frame1 = next(stack1.frames())
        frame2 = next(stack2.frames())

        return frame1.first_context == frame2.first_context

    def query(self) -> _Iterable[MemoryRangeSymbolResult]:
        r"""
        Iterate over all filtered contexts and yield symbols.

        Note: the same symbol can be yielded several times with different memory accesses.
        """

        # Make a copy of the variables that can modify the generated results
        operation = self._operation
        included_binaries = None if self._included_binaries is None else self._included_binaries.copy()
        excluded_binaries = self._excluded_binaries.copy()
        excluded_symbols = self._excluded_symbols.copy()

        # store last handled stack to use it if we are in the same stack
        last_stack: _Optional[Stack] = None
        # store last result to use it if we are in the same stack
        last_result: _Optional[MemoryRangeSymbolResult] = None

        # Iterate over all context range filtered by ring, processes, from_context and to_context
        for context_range in self._trace.filter(
            processes=self._processes,
            ring_policy=self._ring_policy,
            from_context=self._from_context,
            to_context=self._to_context,
        ):
            from_transition = (
                context_range.begin.transition_before()
                if context_range.begin == self._trace.last_context
                else context_range.begin.transition_after()
            )
            to_transition = (
                context_range.last.transition_before()
                if context_range.last == self._trace.last_context
                else context_range.last.transition_after()
            )

            # iterate over physical memory range
            iterators = [
                self._trace.memory_accesses(
                    address=memory_range.address,
                    size=memory_range.size,
                    from_transition=from_transition,
                    to_transition=to_transition,
                )
                for memory_range in self._physical_memory_ranges
            ]
            # iterate over all memory accesses in the this range
            for memory_access in _collate(iterators, key=lambda x: x.transition.id):
                # apply filter by operation here instead of in the query, because currently
                # operation-constrained queries are not optimized in the backend
                if operation is not None and operation != memory_access.operation:
                    continue
                # get the stack at this transition
                current_context: Context = memory_access.transition.context_before()

                stack = current_context.stack
                if last_result is not None and last_stack is not None and self._is_the_same_stack(last_stack, stack):
                    # update the memory access of the last result and yield it
                    last_result._memory_accesses = [memory_access]
                    yield last_result
                    continue

                last_stack = stack

                # exclude symbols and binary
                handled_binary = None
                handled_symbol = None
                handled_symbol_found = False
                frames = [frame for frame in stack.frames()]
                frames.reverse()
                for frame in frames:
                    loc = frame.first_context.ossi.location()
                    if loc is not None and (
                        loc.binary.name in excluded_binaries
                        or (loc.symbol is not None and loc.symbol.name in excluded_symbols)
                        or ("unknown" in excluded_symbols)
                    ):
                        break

                    if loc is not None:
                        if loc.binary is not None:
                            handled_binary = loc.binary
                        if loc.symbol is not None:
                            handled_symbol = loc.symbol
                    handled_symbol_found = True

                    first_context = frame.first_context
                    handled_process = frame.first_context.ossi.process()
                    handled_thread = frame.first_context.ossi.thread()

                # ignore symbol if it is in excluded symbols or if its binary in the excluded binaries
                if not handled_symbol_found:
                    continue

                # ignore symbol if its binary isn't in the included binaries
                if (
                    included_binaries is not None
                    and handled_binary is not None
                    and handled_binary.name not in included_binaries
                ):
                    continue

                # get the end of symbol
                end_transition = (
                    first_context.transition_after().step_out()
                    if first_context != self._trace.last_context
                    else first_context.transition_before().step_out()
                )
                end_context = None if end_transition is None else end_transition.context_before()

                # get the ring of the symbol
                handled_ring = first_context.read(_arch.x64.cs) & 0x3

                last_result = MemoryRangeSymbolResult(
                    call_symbol=CallSymbol(handled_symbol, first_context, end_context),
                    memory_access=memory_access,
                    ring=handled_ring,
                    process=handled_process,
                    thread=handled_thread,
                    binary=handled_binary,
                )
                yield last_result

    def group_by_symbol_query(self) -> _Iterable[MemoryRangeSymbolResult]:
        r"""
        Iterate over all filtered contexts and yield symbols.

        Note: each symbol will be yielded only once, with a group of all its memory accesses.
        """
        # Add symbols to a stack and pop it when it is finished
        result_stack = []  # type: List[MemoryRangeSymbolResult]
        for result in self.query():
            if len(result_stack) > 0:
                # firstly, we verify if we can pop the last item from the stack
                # Item will be yielded if its end context isn't None and the current result of
                # the query has a memory access such that the before context of its transition
                # >= the context of the last symbol in the stack
                if (
                    result.call_symbol.end_context is not None
                    # len(result.memory_accesses) > 0 because the results of `query`
                    # contain exactly one memory_access by construction.
                    and result.memory_accesses[0].transition.context_before() >= result.call_symbol.end_context
                ):
                    res = result_stack.pop(-1)
                    yield res

            # Here we observe symbols change,
            # if the symbol is changed (result_stack[-1] != result) we add the new symbol to the stack.
            # (len(result_stack) == 0 is only to handle the case of the first result)
            if len(result_stack) == 0 or result_stack[-1] != result:
                # store a deep copy of the result
                result_stack.append(result.copy())
                continue

            # the symbol didn't change, so we add the memory access of the current result
            # to the last item in the stack
            result_stack[-1]._memory_accesses += result.memory_accesses

        # yield all symbols with None end context
        for result in result_stack:
            yield result


# %% [markdown]
#
# ### OutputType


# %%
class OutputFormat(Enum):
    r"""
    Enum describing the various possible output formats of the results
     - RAW: The results will be output using its string representation.
     - TABLE: The results will be output using pandas table format.
     - CSV: The results will be output as csv.
     - HTML: The results will be output as html table.
    """
    RAW = 0
    TABLE = 1
    CSV = 2
    HTML = 3


# %% [markdown]
# ### Main function
#
# This function is called with parameters from the [Parameters](#Parameters) cell in the notebook context,
# or with parameters from the command line in the script context.


# %%
def symbols_access_memory_range(
    server: RevenServer,
    memory_range: MemoryRange,
    context: _Optional[int],
    from_context: _Optional[int] = None,
    to_context: _Optional[int] = None,
    ring_policy: RingPolicy = RingPolicy.All,
    processes: _Optional[_Iterable[str]] = None,
    included_binaries: _Optional[_Iterable[str]] = None,
    excluded_binaries: _Optional[_Iterable[str]] = None,
    excluded_symbols: _Optional[_Iterable[str]] = None,
    operation: _Optional[MemoryAccessOperation] = None,
    grouped_by_symbol: bool = False,
    output_format: OutputFormat = OutputFormat.RAW,
    output_file: _Optional[str] = None,
) -> None:
    # declare symbol finder.
    memory_range_symbols_finder = MemoryRangeSymbolFinder(
        trace=server.trace,
        memory_range=memory_range,
        translation_context=(None if context is None else server.trace.context_before(context)),
        from_context=(None if from_context is None else server.trace.context_before(from_context)),
        to_context=(None if to_context is None else server.trace.context_before(to_context)),
        ring_policy=ring_policy,
        operation=operation,
    )

    # filer by processes
    if processes is not None:
        for process in processes:
            memory_range_symbols_finder.filter_by_processes(server.ossi.executed_processes(process))

    # include binaries
    if included_binaries is not None:
        for binary in included_binaries:
            memory_range_symbols_finder.include_bnaries(server.ossi.executed_binaries(binary))

    # exclude binaries
    if excluded_binaries is not None:
        for binary in excluded_binaries:
            memory_range_symbols_finder.exclude_bnaries(server.ossi.executed_binaries(binary))

    # exclude symbols
    if excluded_symbols is not None:
        for symbol in excluded_symbols:
            memory_range_symbols_finder.exclude_symbols(server.ossi.symbols(symbol))

    query = (
        memory_range_symbols_finder.group_by_symbol_query()
        if grouped_by_symbol
        else memory_range_symbols_finder.query()
    )

    if output_format == OutputFormat.RAW:
        print_func = display if in_notebook() else print
        if output_file is not None:
            file = open(output_file, "w")

            def fprint_func(s: MemoryRangeSymbolResult) -> None:
                file.write(str(s))
                file.write("\n")

            print_func = fprint_func
        for result in query:
            print_func(result)

        if output_file is not None:
            file.close()
    else:
        results = {  # type: ignore
            "Ring": [],
            "Process": [],
            "Thread": [],
            "Binary": [],
            "Symbol": [],
            "Start context": [],
            "Access transition": [],
            "Access operation": [],
            "Access physical": [],
            "Access linear": [],
            "Access size": [],
        }
        for result in query:
            for mem_access in result.memory_accesses:
                results["Ring"].append(result.ring)
                results["Process"].append(str(result.process) if result.process is not None else "unknown")
                results["Thread"].append(str(result.thread) if result.thread is not None else "unknown")
                results["Binary"].append(result.binary.name if result.binary is not None else "unknown")
                results["Symbol"].append(
                    result.call_symbol.symbol.name if result.call_symbol.symbol is not None else "unknown"
                )
                results["Start context"].append(str(result.call_symbol.start_context))
                results["Access transition"].append(mem_access.transition.id)
                results["Access operation"].append(mem_access.operation.name)
                results["Access physical"].append(mem_access.physical_address)
                results["Access linear"].append(mem_access.virtual_address)
                results["Access size"].append(mem_access.size)

        # type stub is installed for pandas module but it is a WIP.
        # It doesn't know the `from_dict`` method of `DataFrame` class.
        # so we ignore the type here.
        df = pandas.DataFrame.from_dict(results)  # type: ignore
        if output_format == OutputFormat.TABLE:
            if output_file is not None:
                with open(output_file, "w") as file:
                    file.write(str(df))
            else:
                print(df)
        elif output_format == OutputFormat.CSV:
            print(df.to_csv()) if output_file is None else df.to_csv(output_file)
        elif output_format == OutputFormat.HTML:
            print(df.to_html()) if output_file is None else df.to_html(output_file)


# %% [markdown]
# ### Argument parsing
#
# Argument parsing function for use in the script context.

# %%
def get_memory_access_operation(operation: str) -> MemoryAccessOperation:
    if operation is None:
        return None
    if operation.lower() == "read":
        return MemoryAccessOperation.Read
    if operation.lower() == "write":
        return MemoryAccessOperation.Write
    raise ValueError(f"'operation' value should be 'read' or 'write'. Received '{operation}'.")


def get_ring_policy(ring: int) -> RingPolicy:
    if ring is None:
        return RingPolicy.All
    if ring == 0:
        return RingPolicy.R0Only
    if ring == 3:
        return RingPolicy.R3Only
    raise ValueError(f"'ring_policy' value should be '0' or '1'. Received '{ring_policy}'.")


def get_output_format(format: str) -> OutputFormat:
    if format.lower() == "raw":
        return OutputFormat.RAW
    if format.lower() == "table":
        return OutputFormat.TABLE
    if format.lower() == "html":
        return OutputFormat.HTML
    if format.lower() == "csv":
        return OutputFormat.CSV
    raise ValueError(f"'output format' value should be 'raw', 'table', 'html', or 'csv'. Received '{format}'.")


def script_main():
    parser = argparse.ArgumentParser(description="Find all symbols that access a memory range")
    parser.add_argument(
        "--host",
        type=str,
        default="localhost",
        required=False,
        help='REVEN host, as a string (default: "localhost")',
    )
    parser.add_argument(
        "-p",
        "--port",
        type=int,
        default="13370",
        required=False,
        help="REVEN port, as an int (default: 13370)",
    )
    parser.add_argument(
        "-m",
        "--memory-range",
        type=str,
        required=True,
        help="The memory range whose accesses to look for in symbols (e.g. [ds:0xfff5000; 2])",
    )
    parser.add_argument(
        "-C",
        "--context",
        type=int,
        required=False,
        help="The context used to translate the memory range if it is virtual",
    )
    parser.add_argument(
        "--from-context",
        type=int,
        required=False,
        help="The context from where the search starts",
    )
    parser.add_argument(
        "--to-context",
        type=int,
        required=False,
        help="The context(not included) at which the search stops",
    )
    parser.add_argument(
        "--ring",
        type=int,
        required=False,
        help="Show symbols in this ring only, can be (0=ring0, 3=ring3)",
    )
    parser.add_argument(
        "--processes",
        required=False,
        nargs="*",
        help="Show symbols in these processes only",
    )
    parser.add_argument(
        "--include-binaries",
        required=False,
        nargs="*",
        help="Show symbols in these binaries only",
    )
    parser.add_argument(
        "--exclude-binaries",
        required=False,
        nargs="*",
        help="Don't show symbols in these binaries, accesses that belong to these symbols will be reported with "
        "the innermost symbol such that it or its binary don't excluded",
    )
    parser.add_argument(
        "--exclude-symbols",
        required=False,
        nargs="*",
        help="Don't show these symbols, accesses that belong to these symbols will be reported with "
        "the innermost non excluded symbol",
    )
    parser.add_argument(
        "--memory-access-operation",
        choices=["read", "write"],
        required=False,
        help="Only show symbols that access the memory range using this operation",
    )
    parser.add_argument(
        "--grouped-by-symbol",
        action="store_true",
        required=False,
        default=False,
        help="Group results by symbol",
    )
    parser.add_argument(
        "-o",
        "--output-file",
        type=str,
        required=False,
        help="The target file of the results. If absent, the results will be printed on the standard output",
    )
    parser.add_argument(
        "--output-format",
        choices=["raw", "table", "csv", "html"],
        required=False,
        default="raw",
        help="Output format of the results",
    )

    args = parser.parse_args()

    try:
        server = RevenServer(args.host, args.port)
    except RuntimeError:
        raise RuntimeError(f"Could not connect to the server on {args.host}:{args.port}.")

    symbols_access_memory_range(
        server=server,
        memory_range=MemoryRange.from_string(args.memory_range),
        context=args.context,
        from_context=args.from_context,
        to_context=args.to_context,
        ring_policy=get_ring_policy(args.ring),
        processes=args.processes,
        included_binaries=args.include_binaries,
        excluded_binaries=args.exclude_binaries,
        excluded_symbols=args.exclude_symbols,
        operation=get_memory_access_operation(args.memory_access_operation),
        grouped_by_symbol=args.grouped_by_symbol,
        output_format=get_output_format(args.output_format),
        output_file=args.output_file,
    )


# %% [markdown]
# ## Parameters
#
# These parameters have to be filled out to use in the notebook context.

# %%
# Server connection
#
host = "localhost"
port = 37103

# Input data

memory_range = MemoryRange(address=_address.LogicalAddress(offset=0xFFFFF8800115E180), size=1)
# Or use the MemoryRange.from_string method
# memory_range = MemoryRange.from_string("[ds:0xFFFFF8800115E180; 1]")


context = 100
# context = None # can be None only when the memory range is defined by a physical address


# Output filter

from_context = None
# from_context = 10


to_context = None
# to_context = 10


ring_policy = RingPolicy.All
# ring_policy = RingPolicy.R0Only
# ring_policy = RingPolicy.R3Only

processes = None  # display result for all processes in the trace
# processes = ["xxx",]

included_binaries = None
# included_binaries = ["xxx",]
excluded_binaries = None
# excluded_binaries = ["xxx",]

excluded_symbols = None
# excluded_symbols = "xxx"

memory_access_operation = None
# memory_access_operation = MemoryAccessOperation.Write
# memory_access_operation = MemoryAccessOperation.Read

# Output target
#
output_file = None  # display results inline
# output_file = "res.csv"  # write results formatted as `csv` to a file named "res.csv" in the current directory


# Output control
#
# group results by symbol
grouped_by_symbol = False
# pandas output type
output_format: OutputFormat = OutputFormat.RAW


# %% [markdown]
# ### Pandas module
#
# This cell verify if pandas module is installed and install it if needed.


# %%
if in_notebook():
    try:
        import pandas  # noqa

        print("pandas already installed")
    except ImportError:
        print("Could not find pandas, attempting to install it from pip")
        import sys
        import subprocess

        command = [f"{sys.executable}", "-m", "pip", "install", "pandas"]
        p = subprocess.run(command)

        if int(p.returncode) != 0:
            raise RuntimeError("Error installing pandas")
        import pandas  # noqa

        print("Successfully installed pandas")
else:
    import pandas  # noqa


# %% [markdown]
# ### Execution cell
#
# This cell executes according to the [parameters](#Parameters) when in notebook context, or according to the
# [parsed arguments](#Argument-parsing) when in script context.
#
# When in notebook context, if the `output` parameter is `None`, then the report will be displayed in the last cell of
# the notebook.

# %%
if __name__ == "__main__":
    if in_notebook():
        try:
            server = RevenServer(host, port)
        except RuntimeError:
            raise RuntimeError(f"Could not connect to the server on {host}:{port}.")

        symbols_access_memory_range(
            server=server,
            memory_range=memory_range,
            context=context,
            from_context=from_context,
            to_context=to_context,
            ring_policy=ring_policy,
            processes=processes,
            included_binaries=included_binaries,
            excluded_binaries=excluded_binaries,
            excluded_symbols=excluded_symbols,
            operation=memory_access_operation,
            grouped_by_symbol=grouped_by_symbol,
            output_format=output_format,
            output_file=output_file,
        )
    else:
        script_main()
# %%

Find all memory accesses that are accessed a given symbol

Purpose

This notebook and script are designed to find all memory accesses that are accessed by a given symbol. This script searches a Reven trace for all memory accesses that are accessed by a given symbol. The script can filter the results by processes, threads, ring, context range and memory access operation. The script can generate two kinds of results:

  • process, binary and symbol call information and all its memory accesses.
  • for this symbol, all its call with all the memory accesses that occurred in that symbol call. Note that this option can take long time to start showing results, Note that:
  • this script allow to include/exclude memory accesses that occurred in children symbol calls of each symbol call.

How to use

Results can be generated from this notebook or from the command line. The script can also be imported as a module for use from your own script or notebook.

From the notebook

  1. Upload the memory_ranges_accessed_by_a_symbol.ipynb file in Jupyter.
  2. Fill out the parameters cell of this notebook according to your scenario and desired output.
  3. Run the full notebook.

From the command line

  1. Make sure that you are in an environment that can run REVEN scripts.
  2. Run python memory_ranges_accessed_by_a_symbol.py --help to get a tour of available arguments.
  3. Run python memory_ranges_accessed_by_a_symbol.py --host <your_host> --port <your_port> [<other_option>] with your arguments of choice.

Imported in your own script or notebook

  1. Make sure that you are in an environment that can run REVEN scripts.
  2. Make sure that memory_ranges_accessed_by_a_symbol.py is in the same directory as your script or notebook.
  3. Add import memory_ranges_accessed_by_a_symbol to your script or notebook. You can access the various functions and classes exposed by the module from the memory_ranges_accessed_by_a_symbol namespace.
  4. Refer to the Argument parsing cell for an example of use in a script, and to the Parameters cell and below for an example of use in a notebook (you just need to preprend memory_ranges_accessed_by_a_symbol in front of the functions and classes from the script).

Known limitations

When using the "table", "csv", "html" output format, this script might require a large quantity of RAM due to the data being retained in memory. If you notice an important RAM usage, you can try the following:

  • Restart with the "raw" format
  • Split the results using the from_context and to_context parameters
  • Use the provided filters (ring, processes, threads) to reduce the number of results

Supported versions

REVEN 2.10+

Supported perimeter

Any REVEN scenario.

Dependencies

The script requires that the target REVEN scenario have:

  • The OSSI feature replayed.
  • The memory history feature replayed.
  • pandas python module

Source

# ---
# jupyter:
#   jupytext:
#     formats: ipynb,py:percent
#     text_representation:
#       extension: .py
#       format_name: percent
#   kernelspec:
#     display_name: reven
#     language: python
#     name: reven-python3
# ---

# %% [markdown]
# # Find all memory accesses that are accessed a given symbol
#
# ## Purpose
#
# This notebook and script are designed to find all memory accesses that are accessed by a given symbol.
#
# This script searches a Reven trace for all memory accesses that are accessed by a given symbol.
# The script can filter the results by processes, threads, ring, context range and memory access operation.
#
# The script can generate two kinds of results:
# - process, binary and symbol call information and all its memory accesses.
# - for this symbol, all its call with all the memory accesses that occurred in that symbol call.
#   Note that this option can take long time to start showing results,
#
# Note that:
# - this script allow to include/exclude memory accesses that occurred in children symbol calls of each symbol call.
#
#
#
# ## How to use
#
# Results can be generated from this notebook or from the command line.
# The script can also be imported as a module for use from your own script or notebook.
#
#
# ### From the notebook
#
# 1. Upload the `memory_ranges_accessed_by_a_symbol.ipynb` file in Jupyter.
# 2. Fill out the [parameters](#Parameters) cell of this notebook according to your scenario and desired output.
# 3. Run the full notebook.
#
#
# ### From the command line
#
# 1. Make sure that you are in an
#    [environment](http://doc.tetrane.com/professional/latest/Python-API/Installation.html#on-the-reven-server)
#    that can run REVEN scripts.
# 2. Run `python memory_ranges_accessed_by_a_symbol.py --help` to get a tour of available arguments.
# 3. Run `python memory_ranges_accessed_by_a_symbol.py --host <your_host> --port <your_port> [<other_option>]`
#    with your arguments of choice.
#
# ### Imported in your own script or notebook
#
# 1. Make sure that you are in an
#    [environment](http://doc.tetrane.com/professional/latest/Python-API/Installation.html#on-the-reven-server)
#    that can run REVEN scripts.
# 2. Make sure that `memory_ranges_accessed_by_a_symbol.py` is in the same directory as your script or notebook.
# 3. Add `import memory_ranges_accessed_by_a_symbol` to your script or notebook. You can access the various functions
#     and classes exposed by the module from the `memory_ranges_accessed_by_a_symbol` namespace.
# 4. Refer to the [Argument parsing](#Argument-parsing) cell for an example of use in a script, and to the
#    [Parameters](#Parameters) cell and below for an example of use in a notebook (you just need to preprend
#    `memory_ranges_accessed_by_a_symbol` in front of the functions and classes from the script).
#
# ## Known limitations
#
# When using the "table", "csv", "html" output format, this script might require a large quantity of RAM due to the
# data being retained in memory. If you notice an important RAM usage, you can try the following:
#
# - Restart with the "raw" format
# - Split the results using the `from_context` and `to_context` parameters
# - Use the provided filters (ring, processes, threads) to reduce the number of results
#
# ## Supported versions
#
# REVEN 2.10+
#
# ## Supported perimeter
#
# Any REVEN scenario.
#
# ## Dependencies
#
# The script requires that the target REVEN scenario have:
#
# * The OSSI feature replayed.
# * The memory history feature replayed.
# * pandas python module

# %% [markdown]
# ### Package imports

# %%
import argparse
import re
import sys
from enum import Enum
from typing import Callable as _Callable, Iterable as _Iterable, Optional as _Optional

from IPython.core.display import display  # type: ignore

import reven2.arch as _arch
from reven2.filter import RingPolicy
from reven2.memhist import MemoryAccess, MemoryAccessOperation
from reven2.ossi import Binary, Process, Symbol
from reven2.ossi.thread import Thread
from reven2.prelude import RevenServer
from reven2.stack import Stack
from reven2.trace import Context, Trace, Transition

# %% [markdown]
# ### Utility functions


# %%
# Detect if we are currently running a Jupyter notebook.
#
# This is used e.g. to display rendered results inline in Jupyter when we are executing in the context of a Jupyter
# notebook, or to display raw results on the standard output when we are executing in the context of a script.
def in_notebook():
    try:
        from IPython import get_ipython  # type: ignore

        if get_ipython() is None or ("IPKernelApp" not in get_ipython().config):
            return False
    except ImportError:
        return False
    return True


# %% [markdown]
# ### Helper classes for results

# %%
class CallSymbol:
    r"""
    CallSymbol is a helper class used to represent a symbol with its start and end context
    """

    def __init__(self, symbol: _Optional[Symbol], start: Context, end: _Optional[Context] = None) -> None:
        self._symbol = symbol
        self._start = start
        self._end = end

    @property
    def symbol(self) -> _Optional[Symbol]:
        r"""
        B{Property:} The symbol of the call symbol. None if the symbol is unknown.
        """
        return self._symbol

    @property
    def start_context(self) -> Context:
        r"""
        B{Property:} The start context of the call symbol.
        """
        return self._start

    @property
    def end_context(self) -> _Optional[Context]:
        r"""
        B{Property:} The end excluded context of the call symbol. None if the end context isn't in the trace.
        """
        return self._end

    def __eq__(self, other: "CallSymbol") -> bool:  # type: ignore
        return self._symbol == other._symbol and self._start == other._start and self._end == other._end

    def __ne__(self, other: "CallSymbol") -> bool:  # type: ignore
        return not (self == other)


class HtmlStr:
    r"""
    Helper class used with notebook special `display` function to
    consider the HTML string as HTML.
    """

    def __init__(self, html: str) -> None:
        self._html = html

    def _repr_html_(self):
        return self._html


class MemoryRangeSymbolResult:
    r"""
    MemoryRangeSymbolResult is a helper class that represents one result of the search.
    """

    def __init__(
        self,
        call_symbol: CallSymbol,
        memory_accesses: _Iterable[MemoryAccess],
        ring: int,
        process: _Optional[Process],
        thread: _Optional[Thread],
        binary: _Optional[Binary],
    ) -> None:
        self._call_symbol = call_symbol
        self._memory_accesses = memory_accesses
        self._ring = ring
        self._process = process
        self._thread = thread
        self._binary = binary

    @property
    def call_symbol(self) -> CallSymbol:
        r"""
        B{Property:} The call symbol of the result.
        """
        return self._call_symbol

    @property
    def memory_accesses(self) -> _Iterable[MemoryAccess]:
        r"""
        B{Property:} The memory accesses of the result.

        Calling this property will consume the generator
        """
        return self._memory_accesses

    @property
    def ring(self) -> int:
        r"""
        B{Property:} The ring of the result.
        """
        return self._ring

    @property
    def process(self) -> _Optional[Process]:
        r"""
        B{Property:} The process of the result.
        """
        return self._process

    @property
    def binary(self) -> _Optional[Binary]:
        r"""
        B{Property:} The binary of the result, None if unknown.
        """
        return self._binary

    @property
    def thread(self) -> _Optional[Thread]:
        r"""
        B{Property:} The thread of the result.
        """
        return self._thread

    def __eq__(self, other: "MemoryRangeSymbolResult") -> bool:  # type: ignore
        return (
            self._ring == other._ring
            and self._process is not None
            and other._process is not None
            and self._process.name == other._process.name
            and self._process.pid == other._process.pid
            and self._process.ppid == other._process.ppid
            and self._thread is not None
            and other._thread is not None
            and self._thread.id == other._thread.id
            and self._thread.owner_process_id == other._thread.owner_process_id
            and (
                (self._binary is None and other._binary is None)
                or (self._binary is not None and other._binary is not None and self._binary.path == other._binary.path)
            )
            and self._call_symbol == other._call_symbol
        )

    def __ne__(self, other: "MemoryRangeSymbolResult") -> bool:  # type: ignore
        return not (self == other)

    def output(self, print_func: _Callable, is_in_notebook: bool):
        r"""
        Output this result using the `print_func` or call `output_in_notebook` if `is_in_notebook`
        argument is true.
        """
        if is_in_notebook:
            self.output_in_notebook()
        else:
            print_func(
                f"ring: {self._ring}, process: {self._process}, "
                f"thread: {self._thread}, binary: {self._binary}, "
                f"symbol: {self._call_symbol.symbol}[{self._call_symbol.start_context}, "
                f"{self._call_symbol.end_context}["
            )
            print_func("\nmemory accesses:")
            for m in self._memory_accesses:
                print_func(f"\n\t{m}, ")
            print_func("\n")

    def output_in_notebook(self) -> None:
        r"""
        Output this result using the special notebook `display` function
        """
        end_context = self._call_symbol.end_context
        display(
            HtmlStr(
                f"<p>ring: {self._ring}, process: {self._process if self._process is not None else 'unknown'}, "
                f"thread: {self._thread if self._thread is not None else 'unknown'}, "
                f"binary: {self._binary if self._binary is not None else 'unknwon'}, "
                f"symbol: {self._call_symbol.symbol}[{self._call_symbol.start_context.format_as_html()}, "
                f"{None if end_context is None else end_context.format_as_html()}[</p>"
            )
        )
        display(HtmlStr("<p>memory accesses:</p>"))
        for m in self._memory_accesses:
            display(HtmlStr(f'<p style="text-indent: 2em;"> &#9673; {m.format_as_html()}</p>'))


class GroupedMemoryRangeSymbolResult:
    r"""
    GroupedMemoryRangeSymbolResult is a helper class that represents results of the search grouped by symbol.
    """

    def __init__(
        self,
        symbol: Symbol,
        memory_range_symbol_result: _Iterable[MemoryRangeSymbolResult],
    ) -> None:
        self._symbol = symbol
        self._memory_range_symbol_result = memory_range_symbol_result

    def output(self, print_func: _Callable, is_in_notebook: bool) -> None:
        r"""
        Output these results using the `print_func` or call `output_in_notebook` if `is_in_notebook`
        argument is true.
        """
        if is_in_notebook:
            self.output_in_notebook()
        else:
            print_func(f"{self._symbol}\n\nCalls")
            for res in self._memory_range_symbol_result:
                print_func(
                    f"\n\t[{res.call_symbol.start_context.format_as_html()}, "
                    f"{None if res.call_symbol.end_context is None else res.call_symbol.end_context.format_as_html()}]"
                    f"\n\tring: {res.ring}, process: {res.process if res.process is not None else 'unknown'}, "
                    f" thread: {res.thread if res.thread is not None else 'unknown'}, binary: {res.binary}, "
                )
                print_func("\n\tmemory accesses:")
                for m in res.memory_accesses:
                    print_func(f"\n\t\t{m}, ")
                print_func("\n")

    def output_in_notebook(self) -> None:
        r"""
        Output these results using the special notebook `display` function
        """
        display(HtmlStr(f'<p>{self._symbol}</p><p style="font-weight: bolder;">Calls</p>'))
        for res in self._memory_range_symbol_result:
            end_context = res.call_symbol.end_context
            display(
                HtmlStr(
                    f'<p style="text-indent: 2em;"> &#9678; [{res.call_symbol.start_context.format_as_html()}, '
                    f"{None if end_context is None else end_context.format_as_html()}]</p>"
                    f'<p style="text-indent: 3em;">ring: {res.ring}, process: '
                    f"{res.process if res.process is not None else 'unknown'}, "
                    f" thread: {res.thread if res.thread is not None else 'unknown'}, binary: {res.binary}</p>"
                )
            )
            display(HtmlStr('<p style="text-indent: 3em; font-weight: bolder;">memory accesses:</p>'))
            for m in res.memory_accesses:
                display(HtmlStr(f'<style="text-indent: 4em;"> &#9673; {m.format_as_html()}</p>'))


# %% [markdown]
# ### SymbolMemoryAccessesFinder
#
# This class represents the main logic of this script

# %%
class SymbolMemoryAccessesFinder(object):
    r"""
        This class is a helper class to search for all memory accesses that are accessed a given symbol.
        Results can be filtered by processes, ring, threads, memory access operation and a context range.

        The memory accesses that are accessed by the given symbol are returned.

        Examples
        ========
        >>> # search all memory accesses that are accessed by the calls of `ExCompareExchangeCallBack`
        >>> # symbol
        >>> import reven2
        >>> server = reven2.RevenServer('localhost', 46445)
        >>> symbol = next(server.ossi.symbols(pattern="ExCompareExchangeCallBack"))
        >>> finder = SymbolMemoryAccessesFinder(server.trace, symbol, with_children_symbols=False)
        >>> for res in finder.query():
        ...     print(res)
        ring: 0, process: System (4), thread: 932, binary: c:/windows/system32/ntoskrnl.exe,
        symbol: ntoskrnl!ExCompareExchangeCallBack[Context before #492474770, Context before #492474830[
    memory accesses:
            [#492474770 mov qword ptr ss:[rsp+0x8], rbx]Write access at
            @phy:0x7297a8c0 (virtual address: lin:0xfffff880045ca8c0) of size 8,
            [#492474771 mov qword ptr ss:[rsp+0x10], rbp]Write access at
            @phy:0x7297a8c8 (virtual address: lin:0xfffff880045ca8c8) of size 8,
            [#492474772 mov qword ptr ss:[rsp+0x18], rsi]Write access at
            @phy:0x7297a8d0 (virtual address: lin:0xfffff880045ca8d0) of size 8,
            ...
    """

    def __init__(
        self,
        trace: Trace,
        symbol: Symbol,
        with_children_symbols: bool = True,
        from_context: _Optional[Context] = None,
        to_context: _Optional[Context] = None,
        ring_policy: RingPolicy = RingPolicy.All,
        processes: _Optional[_Iterable[Process]] = None,
        threads: _Optional[_Iterable[int]] = None,
        memory_access_operation: MemoryAccessOperation = None,
    ) -> None:
        r"""
        Initialize a C{SymbolMemoryAccessesFinder}

        Information
        ===========

        @param trace: the trace where memory accesses will be looked for.
        @param symbol: the symbol that for it the memory accesses will be returned.
        @param from_context: the context where the search will be started.
        @param to_context: the context where the search will be ended.
        @param ring_policy: ring policy to search for.
        @param processes: processes to limit the search in it. If None, all processes will be filtered.
        @param threads: thread ids to limit the search in it. If None, all threads will be filtered.
        @param operation: limit results to accesses performing the specified operation.


        @raises TypeError: if trace is not a C{reven2.trace.Trace}.
        """

        if not isinstance(trace, Trace):
            raise TypeError("You must provide a valid trace")
        self._trace = trace
        self._symbol = symbol
        self._with_children_symbols = with_children_symbols
        self._from_context = from_context
        self._to_context = to_context
        self._ring_policy = ring_policy
        self._memory_access_operation = memory_access_operation
        self._processes = None if processes is None else [process for process in processes]
        self._threads = None if threads is None else [thread for thread in threads]

    def _without_children_accesses(
        self, first_transition: Transition, last_transition: _Optional[Transition]
    ) -> _Iterable[MemoryAccess]:
        if last_transition is None:
            last_transition = self._trace.last_transition
        sub_first_transition = sub_last_transition = next_transition = first_transition
        while sub_last_transition is not None and sub_last_transition < last_transition:
            next_transition_opt = next_transition.step_over()
            if next_transition_opt is None:
                return
            next_transition = next_transition_opt
            if next_transition.id - sub_last_transition.id > 1:
                # request memory accesses for the range [first_transition, sub_last_transition +1[
                for mem_access in self._trace.memory_accesses(
                    from_transition=sub_first_transition, to_transition=sub_last_transition + 1
                ):
                    yield mem_access

                sub_last_transition = sub_first_transition = next_transition
            else:
                sub_last_transition = next_transition

    def _with_children_accesses(
        self, first_transition: Transition, last_transition: _Optional[Transition]
    ) -> _Iterable[MemoryAccess]:
        # get memory access
        for mem_access in self._trace.memory_accesses(from_transition=first_transition, to_transition=last_transition):
            yield mem_access

    def _is_the_same_stack(self, stack1: Stack, stack2: Stack) -> bool:
        # we assume that two stacks are the same if the first contexts of their first frames are the same

        frame1 = next(stack1.frames())
        frame2 = next(stack2.frames())

        return frame1.first_context == frame2.first_context

    def _query_mem_accesses(
        self,
        first_transition: Transition,
        last_transition: _Optional[Transition],
        with_children_symbols: bool,
        memory_access_operation: _Optional[MemoryAccessOperation],
    ) -> _Iterable[MemoryAccess]:
        mem_query = (
            self._with_children_accesses(first_transition, last_transition)
            if with_children_symbols
            else self._without_children_accesses(first_transition, last_transition)
        )

        # store last handled stack to use it if we are in the same stack
        last_stack: _Optional[Stack] = None
        last_stack_is_taken: bool = True

        for mem_access in mem_query:
            if memory_access_operation is not None and mem_access.operation != memory_access_operation:
                continue
            # ignore accesses without linear address
            if mem_access.virtual_address is None:
                continue

            stack = mem_access.transition.context_before().stack
            # if the current stack is the same as the last stack we can yield the access
            if last_stack is not None and self._is_the_same_stack(last_stack, stack):
                if last_stack_is_taken:
                    yield mem_access
                continue

            last_stack = stack

            # we need to be sure that we are always in the current symbol
            # a simple method is to test if the current symbol is in the backtrace of
            # this access's context.
            take_access = False
            for frame in stack.frames():
                loc = frame.first_context.ossi.location()
                if loc is not None and self._symbol == loc.symbol:
                    take_access = True
                    break

            if not take_access:
                last_stack_is_taken = False
                continue
            last_stack_is_taken = True
            yield mem_access

    def filter_by_threads(self, threads: _Iterable[int]) -> "SymbolMemoryAccessesFinder":
        r"""
        Extend the list of threads to limit the search in, and return the self object.

        Information
        ===========

        @param threads: threads to limit the search in.
        @returns : self object
        """
        if self._threads is None:
            self._threads = []
        self._threads += [thread for thread in threads]

        return self

    def filter_by_processes(self, processes: _Iterable[Process]) -> "SymbolMemoryAccessesFinder":
        r"""
        Extend the list of processes to limit the search in, and return the self object.

        Information
        ===========

        @param processes: processes to limit the search in.
        @returns : self object
        """
        if self._processes is None:
            self._processes = []
        self._processes += [process for process in processes]

        return self

    def filter_by_ring(self, ring_policy: RingPolicy) -> "SymbolMemoryAccessesFinder":
        r"""
        Update the ring policy to search for and return the `self` object.

        Information
        ===========

        @param ring_policy: ring policy to search for.
        @returns : self object
        """
        self._ring_policy = ring_policy
        return self

    def from_context(self, context: Context) -> "SymbolMemoryAccessesFinder":
        r"""
        Update the context where the search will be started and return the `self` object.

        Information
        ===========

        @param context: context where the search will be started.
        @returns : self object
        """
        self._from_context = context
        return self

    def to_context(self, context: Context) -> "SymbolMemoryAccessesFinder":
        r"""
        Update the context where the search will be ended and return the `self` object.

        Information
        ===========

        @param context: context where the search will be ended.
        @returns : self object
        """
        self._to_context = context
        return self

    def filter_by_memory_access_operation(
        self, operation: _Optional[MemoryAccessOperation] = None
    ) -> "SymbolMemoryAccessesFinder":
        r"""
        Update the memory access operation to limit results to accesses performing this
        operation and return the `self` object.

        Information
        ===========

        @param operation: limit results to accesses performing the specified operation.
        @returns : self object
        """
        self._operation = operation
        return self

    def query(self) -> _Iterable[MemoryRangeSymbolResult]:
        with_children_symbols = self._with_children_symbols
        memory_access_operation = self._memory_access_operation
        thread_ids = None if self._threads is None else self._threads.copy()
        # filter by process
        for context_range in self._trace.filter(
            processes=self._processes,
            ring_policy=self._ring_policy,
            from_context=self._from_context,
            to_context=self._to_context,
        ):

            last_transition: _Optional[Transition] = None
            # search symbol call
            for context in self._trace.search.symbol(self._symbol, context_range.begin, context_range.end):
                # ignore results that aren't in the list of thread
                if thread_ids is not None:
                    thread = context.ossi.thread()
                    if thread is None or thread.id not in thread_ids:
                        continue
                # we need also to ignore symbols that are recursively called.
                if last_transition is not None and last_transition.context_before() > context:
                    continue
                # if with_children_symbols is true, that means we need to consider
                # the memory accesses in children symbols.
                # use step_out to go out of the symbol
                first_transition = context.transition_after()
                last_transition = first_transition.step_out()
                last_context = None if last_transition is None else last_transition.context_before()
                end_range_transition = (
                    self._trace.last_transition
                    if context_range.end is None or context_range.end == self._trace.last_context
                    else context_range.end.transition_after()
                )
                last_transition = (
                    last_transition
                    if last_transition is None or last_transition < end_range_transition
                    else end_range_transition
                )

                # get the ring of the symbol
                curr_ring = context.read(_arch.x64.cs) & 0x3
                curr_process = context.ossi.process()
                curr_thread = context.ossi.thread()
                curr_location = context.ossi.location()
                curr_binary = None if curr_location is None else curr_location.binary

                yield MemoryRangeSymbolResult(
                    call_symbol=CallSymbol(self._symbol, context, last_context),
                    memory_accesses=self._query_mem_accesses(
                        first_transition,
                        last_transition,
                        with_children_symbols,
                        memory_access_operation,
                    ),
                    ring=curr_ring,
                    process=curr_process,
                    thread=curr_thread,
                    binary=curr_binary,
                )


# %% [markdown]
#
# ### OutputType


# %%
class OutputFormat(Enum):
    r"""
    Enum describing the various possible output formats of the results
     - RAW: The results will be output using its string representation.
     - TABLE: The results will be output using pandas table format.
     - CSV: The results will be output as csv.
     - HTML: The results will be output as html table.
    """
    RAW = 0
    TABLE = 1
    CSV = 2
    HTML = 3


# %% [markdown]
# ### Main function
#
# This function is called with parameters from the [Parameters](#Parameters) cell in the notebook context,
# or with parameters from the command line in the script context.


# %%
def memory_ranges_accessed_by_a_symbol(
    server: RevenServer,
    symbol: str,
    binary_hint: _Optional[str] = None,
    with_children_symbols: bool = True,
    from_context: _Optional[int] = None,
    to_context: _Optional[int] = None,
    ring_policy: RingPolicy = RingPolicy.All,
    processes: _Optional[_Iterable[str]] = None,
    threads: _Optional[_Iterable[int]] = None,
    operation: _Optional[MemoryAccessOperation] = None,
    grouped_by_symbol: bool = False,
    output_format: OutputFormat = OutputFormat.RAW,
    output_file: _Optional[str] = None,
) -> None:
    # get the symbol form the ossi server and raise if it isn't exist
    trace_symbol = None
    symbol_count = 0
    for sym in server.ossi.symbols(pattern=re.escape(symbol), binary_hint=binary_hint):
        if sym.name == symbol:
            symbol_count += 1
            if trace_symbol is None:
                trace_symbol = sym
            if symbol_count == 2:
                print(trace_symbol, file=sys.stderr)
            if symbol_count >= 2:
                print(sym, file=sys.stderr)

    if trace_symbol is None:
        raise ValueError(f"The requested symbol '{symbol}' could not be found")

    if symbol_count > 1:
        sys.exit(
            "Many symbols exist with the same provided symbol name, you may need to provide the symbol's "
            "binary name, please  provide one from the list above"
        )

    # declare memory accesses finder.
    symbol_memory_ranges_finder = SymbolMemoryAccessesFinder(
        trace=server.trace,
        symbol=trace_symbol,
        with_children_symbols=with_children_symbols,
        from_context=(None if from_context is None else server.trace.context_before(from_context)),
        to_context=(None if to_context is None else server.trace.context_before(to_context)),
        threads=threads,
        ring_policy=ring_policy,
        memory_access_operation=operation,
    )

    # filer by processes
    if processes is not None:
        for process in processes:
            symbol_memory_ranges_finder.filter_by_processes(server.ossi.executed_processes(process))

    if output_format == OutputFormat.RAW:
        is_in_notebook = output_file is None and in_notebook()

        def std_print_func(s: str) -> None:
            print(s)

        print_func = std_print_func
        if output_file is not None:
            file = open(output_file, "w")

            def fprint_func(s: str) -> None:
                file.write(s)

            print_func = fprint_func
        if grouped_by_symbol:
            grouped_result = GroupedMemoryRangeSymbolResult(trace_symbol, symbol_memory_ranges_finder.query())

            grouped_result.output(print_func, is_in_notebook)
        else:
            for result in symbol_memory_ranges_finder.query():
                result.output(print_func, is_in_notebook)

        if output_file is not None:
            file.close()
    else:
        column_headers = [
            "Ring",
            "Process",
            "Thread",
            "Binary",
            "Symbol",
            "Start context",
            "Access transition",
            "Access operation",
            "Access physical",
            "Access linear",
            "Access size",
        ]

        def data_generator():
            for result in symbol_memory_ranges_finder.query():
                for mem_access in result.memory_accesses:
                    yield (
                        result.ring,
                        str(result.process) if result.process is not None else "unknown",
                        str(result.thread) if result.thread is not None else "unknown",
                        result.binary.name,
                        result.call_symbol.symbol.name,
                        str(result.call_symbol.start_context),
                        mem_access.transition.id,
                        mem_access.operation.name,
                        mem_access.physical_address,
                        mem_access.virtual_address,
                        mem_access.size,
                    )

        df = pandas.DataFrame(data=data_generator(), columns=column_headers)
        if output_format == OutputFormat.TABLE:
            if output_file is not None:
                with open(output_file, "w") as file:
                    file.write(str(df))
            else:
                print(df)
        elif output_format == OutputFormat.CSV:
            print(df.to_csv()) if output_file is None else df.to_csv(output_file)
        elif output_format == OutputFormat.HTML:
            print(df.to_html()) if output_file is None else df.to_html(output_file)


# %% [markdown]
# ### Argument parsing
#
# Argument parsing function for use in the script context.

# %%
def get_memory_access_operation(operation: str) -> MemoryAccessOperation:
    if operation is None:
        return None
    if operation.lower() == "read":
        return MemoryAccessOperation.Read
    if operation.lower() == "write":
        return MemoryAccessOperation.Write
    raise ValueError(f"'operation' value should be 'read' or 'write'. Received '{operation}'.")


def get_ring_policy(ring: int) -> RingPolicy:
    if ring is None:
        return RingPolicy.All
    if ring == 0:
        return RingPolicy.R0Only
    if ring == 3:
        return RingPolicy.R3Only
    raise ValueError(f"'ring_policy' value should be '0' or '1'. Received '{ring_policy}'.")


def get_output_format(format: str) -> OutputFormat:
    if format.lower() == "raw":
        return OutputFormat.RAW
    if format.lower() == "table":
        return OutputFormat.TABLE
    if format.lower() == "html":
        return OutputFormat.HTML
    if format.lower() == "csv":
        return OutputFormat.CSV
    raise ValueError(f"'output format' value should be 'raw', 'table', 'html', or 'csv'. Received '{format}'.")


def script_main():
    parser = argparse.ArgumentParser(description="Find all memory accesses that are accessed a given symbol")
    parser.add_argument(
        "--host",
        type=str,
        default="localhost",
        required=False,
        help='REVEN host, as a string (default: "localhost")',
    )
    parser.add_argument(
        "-p",
        "--port",
        type=int,
        default="13370",
        required=False,
        help="REVEN port, as an int (default: 13370)",
    )
    parser.add_argument(
        "-s",
        "--symbol",
        type=str,
        required=True,
        help="The symbol whose accesses are looked for (e.g. WriteFile)",
    )
    parser.add_argument(
        "-b",
        "--binary-hint",
        type=str,
        required=False,
        help="The symbol's binary name hint (e.g. ntoskrnl)",
    )
    parser.add_argument(
        "--with-children-symbols",
        action="store_true",
        required=False,
        default=False,
        help="Show accesses from children calls",
    )
    parser.add_argument(
        "--from-context",
        type=int,
        required=False,
        help="The context from where the search starts",
    )
    parser.add_argument(
        "--to-context",
        type=int,
        required=False,
        help="The context(not included) at which the search stops",
    )
    parser.add_argument(
        "--ring",
        type=int,
        required=False,
        help="Show symbol's accesses if it is in this ring only, can be (0=ring0, 3=ring3)",
    )
    parser.add_argument(
        "--processes",
        required=False,
        nargs="*",
        help="Show symbol's accesses if it is in these processes only",
    )
    parser.add_argument(
        "--threads",
        type=int,
        required=False,
        nargs="*",
        help="Show symbol's accesses if it is in these threads only",
    )
    parser.add_argument(
        "--memory-access-operation",
        choices=["read", "write"],
        required=False,
        help="Only show symbols that access the memory range using this operation",
    )
    parser.add_argument(
        "--grouped-by-symbol",
        action="store_true",
        required=False,
        default=False,
        help="Group results by symbol",
    )
    parser.add_argument(
        "-o",
        "--output-file",
        type=str,
        required=False,
        help="The target file of the results. If absent, the results will be printed on the standard output",
    )
    parser.add_argument(
        "--output-format",
        choices=["raw", "table", "csv", "html"],
        required=False,
        default="raw",
        help="Output format of the results",
    )

    args = parser.parse_args()

    try:
        server = RevenServer(args.host, args.port)
    except RuntimeError:
        raise RuntimeError(f"Could not connect to the server on {args.host}:{args.port}.")

    memory_ranges_accessed_by_a_symbol(
        server=server,
        symbol=args.symbol,
        binary_hint=args.binary_hint,
        with_children_symbols=args.with_children_symbols,
        from_context=args.from_context,
        to_context=args.to_context,
        ring_policy=get_ring_policy(args.ring),
        processes=args.processes,
        threads=args.threads,
        operation=get_memory_access_operation(args.memory_access_operation),
        grouped_by_symbol=args.grouped_by_symbol,
        output_format=get_output_format(args.output_format),
        output_file=args.output_file,
    )


# %% [markdown]
# ## Parameters
#
# These parameters have to be filled out to use in the notebook context.

# %%
# Server connection
#
host = "localhost"
port = 37103

# Input data

symbol = "xxx"  # symbol name

binary_hint = None  # symbol's binary name hint

with_children_symbols = True

# Output filter

from_context = None
# from_context = 10


to_context = None
# to_context = 10


ring_policy = RingPolicy.All
# ring_policy = RingPolicy.R0Only
# ring_policy = RingPolicy.R3Only

processes = None  # display result for all processes in the trace
# processes = ["xxx",]

threads = None  # display result for all threads in the trace
# threads = [thread_id,]

memory_access_operation = None
# memory_access_operation = MemoryAccessOperation.Write
# memory_access_operation = MemoryAccessOperation.Read

# Output target
#
output_file = None  # display results inline
# output_file = "res.csv"  # write results formatted as `csv` to a file named "res.csv" in the current directory


# Output control
#
# group results by symbol
grouped_by_symbol = False
# pandas output type
output_format: OutputFormat = OutputFormat.RAW


# %% [markdown]
# ### Pandas module
#
# This cell verify if pandas module is installed and install it if needed.


# %%
if in_notebook():
    try:
        import pandas  # noqa

        print("pandas already installed")
    except ImportError:
        print("Could not find pandas, attempting to install it from pip")
        import subprocess

        command = [f"{sys.executable}", "-m", "pip", "install", "pandas"]
        p = subprocess.run(command)

        if int(p.returncode) != 0:
            raise RuntimeError("Error installing pandas")
        import pandas  # noqa

        print("Successfully installed pandas")
else:
    import pandas  # noqa


# %% [markdown]
# ### Execution cell
#
# This cell executes according to the [parameters](#Parameters) when in notebook context, or according to the
# [parsed arguments](#Argument-parsing) when in script context.
#
# When in notebook context, if the `output` parameter is `None`, then the report will be displayed in the last cell of
# the notebook.

# %%
if __name__ == "__main__":
    if in_notebook():
        try:
            server = RevenServer(host, port)
        except RuntimeError:
            raise RuntimeError(f"Could not connect to the server on {host}:{port}.")
        memory_ranges_accessed_by_a_symbol(
            server=server,
            symbol=symbol,
            binary_hint=binary_hint,
            with_children_symbols=with_children_symbols,
            from_context=from_context,
            to_context=to_context,
            ring_policy=ring_policy,
            processes=processes,
            threads=threads,
            grouped_by_symbol=grouped_by_symbol,
            output_format=output_format,
            output_file=output_file,
        )
    else:
        script_main()
# %%

Migration scripts

Scripts in this directory make it easier to migrate from some version of REVEN to some other.

Migrate bookmarks from 2.5 to 2.6

Purpose

We fixed an issue in REVEN 2.6 leading to some changes in the transition number for QEMU scenarios.

This script is here to help you migrate your bookmarks if they are off after replaying your scenario with REVEN 2.6.

How to use

Launch after updating the trace resource for your scenario to REVEN 2.6+.

usage: migrate_bookmarks_2.5_to_2.6.py [-h] [--host HOST] [-p PORT]

optional arguments:
  -h, --help            show this help message and exit
  --host HOST           Reven host, as a string (default: "localhost")
  -p PORT, --port PORT  Reven port, as an int (default: 13370)

Known limitations

  • It also does not attempt to determine whether the scenario needs to be upgraded or not. Applying the script when bookmarks don't need to be upgraded will actually put them at the wrong position. Apply only if you notice that the bookmarks have been put at the wrong position after updating.

Supported versions

REVEN 2.6+

Supported perimeter

REVEN scenarios recorded with QEMU.

Dependencies

None.

Source

#!/usr/bin/env python3

import argparse
import sys

import reven2

"""
# Migrate bookmarks from 2.5 to 2.6

## Purpose

We fixed an issue in REVEN 2.6 leading to some changes in the transition number for QEMU scenarios.

This script is here to help you migrate your bookmarks if they are off after replaying your scenario with REVEN 2.6.

## How to use

Launch after updating the trace resource for your scenario to REVEN 2.6+.

```bash
usage: migrate_bookmarks_2.5_to_2.6.py [-h] [--host HOST] [-p PORT]

optional arguments:
  -h, --help            show this help message and exit
  --host HOST           Reven host, as a string (default: "localhost")
  -p PORT, --port PORT  Reven port, as an int (default: 13370)
```

## Known limitations

- It also does not attempt to determine whether the scenario needs to be upgraded or not.
Applying the script when bookmarks don't need to be upgraded will actually put them at the wrong position.
Apply only if you notice that the bookmarks have been put at the wrong position after updating.

## Supported versions

REVEN 2.6+

## Supported perimeter

REVEN scenarios recorded with QEMU.

## Dependencies

None.
"""


def migrate_bookmarks(reven_server):
    offset_table = {}

    def get_offset(transition_id):
        lower_bound = None
        for key in offset_table.keys():
            if key <= transition_id and (lower_bound is None or lower_bound < key):
                lower_bound = key
        return offset_table[lower_bound] if lower_bound is not None else 0

    print("Generating offset table...")

    c = reven_server.trace.context_before(0)

    counter = 0
    while c is not None:
        c = c.find_register_change(reven2.arch.x64.cr2)

        if c is None:
            continue

        t_exception = c.transition_before()
        t_before_exception = t_exception - 1

        ctx_before = t_exception.context_before()
        ctx_after = t_exception.context_after()

        # We are looking for code pagefault, so CR2 will contains PC
        cr2 = ctx_after.read(reven2.arch.x64.cr2)

        if ctx_before.is64b():
            if cr2 != ctx_before.read(reven2.arch.x64.rip):
                continue
        else:
            if cr2 != ctx_before.read(reven2.arch.x64.eip):
                continue

        try:
            # The issues occurred when the previous instruction was just doing some reads and not any write
            # Warning: This could not work if the trace is desync with the memory
            next(t_before_exception.memory_accesses(operation=reven2.memhist.MemoryAccessOperation.Write))
        except StopIteration:
            counter += 1
            offset_table[t_exception.id - counter] = counter

    print("Migrating bookmarks...")

    for bookmark in list(reven_server.bookmarks.all()):
        offset = get_offset(bookmark.transition.id)

        print(
            "    id: %d | %60.60s | Transition #%d => #%d (+%d)"
            % (bookmark.id, bookmark.description, bookmark.transition.id, bookmark.transition.id + offset, offset)
        )

        if offset == 0:
            continue

        reven_server.bookmarks.add(bookmark.transition + offset, bookmark.description)
        reven_server.bookmarks.remove(bookmark)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="localhost", help='Reven host, as a string (default: "localhost")')
    parser.add_argument("-p", "--port", type=int, default=13370, help="Reven port, as an int (default: 13370)")
    args = parser.parse_args()

    answer = ""
    while answer not in ["y", "n"]:
        answer = input(
            "This script should be launched on a QEMU scenario with the trace generated with REVEN 2.6 with bookmarks"
            " that were added in REVEN 2.5 or older. Do you want to continue [Y/N]? "
        ).lower()

    if answer == "n":
        print("Aborting")
        sys.exit(0)

    reven_server = reven2.RevenServer(args.host, args.port)
    migrate_bookmarks(reven_server)
    print("Bookmarks migrated!")

Import classic bookmarks

Purpose

Import classic bookmarks (created using up to REVEN 2.4) from ".rbm" files to the "server-side" bookmarks system (from REVEN 2.5+).

How to use

usage: import_bookmarks.py [-h] [--host HOST] [-p PORT] [-f FILENAME]
                           [--prepend-symbol]

optional arguments:
  -h, --help            show this help message and exit
  --host HOST           Reven host, as a string (default: "localhost")
  -p PORT, --port PORT  Reven port, as an int (default: 13370)
  -f FILENAME, --filename FILENAME
                        Path to the classic bookmark file (*.rbm).
  --prepend-symbol      If set, prepend the OSSI symbol as stored in the
                        classic symbol file to the description of the bookmark

Known limitations

N/A

Supported versions

REVEN 2.5+

Supported perimeter

Any REVEN scenario for which a .rbm is available.

Dependencies

None.

Source

import argparse
import json

import reven2


"""
# Import classic bookmarks

## Purpose

Import classic bookmarks (created using up to REVEN 2.4) from ".rbm" files to the "server-side" bookmarks system
(from REVEN 2.5+).

## How to use

```bash
usage: import_bookmarks.py [-h] [--host HOST] [-p PORT] [-f FILENAME]
                           [--prepend-symbol]

optional arguments:
  -h, --help            show this help message and exit
  --host HOST           Reven host, as a string (default: "localhost")
  -p PORT, --port PORT  Reven port, as an int (default: 13370)
  -f FILENAME, --filename FILENAME
                        Path to the classic bookmark file (*.rbm).
  --prepend-symbol      If set, prepend the OSSI symbol as stored in the
                        classic symbol file to the description of the bookmark
```

## Known limitations

N/A

## Supported versions

REVEN 2.5+

## Supported perimeter

Any REVEN scenario for which a .rbm is available.

## Dependencies

None.
"""


def import_bookmarks(reven_server, rbm_path, prepend_symbol=False):
    r"""
    This function is a helper to import classic bookmarks from ".rbm" files to the new "server-side" bookmarks system.

    Examples
    ========

    >>> # Import bookmarks
    >>> f = "Reven2/2.5.0-rc2-1-ga1b971b/Scenarios/bksod_ff34e5e1-dfaa-41fe-88b0-fdad14993fe3/UserData/bookmarks.rbm"
    >>> import_bookmarks(reven_server, f)
    >>> for bookmark in reven_server.bookmarks.all():
    ...     print(bookmark)
    #169672818: 'mst120 deallocated by network'
    #8655429: 'mst120 allocated by system'
    #8627412: 'IcaRawInput looks nice to see decrypted data'
    #1231549571: 'ica find channel on this pointer???'
    #1141851788: 'Same pointer reallocated to something else'
    #1231549773: 'crash'

    >>> # Import bookmarks, prepending the known symbol before the description
    >>> import_bookmarks(reven_server, f, prepend_symbol=True)
    >>> for bookmark in reven_server.bookmarks.all():
    ...     print(bookmark)
    #169672818: 'ExFreePoolWithTag+0x0 - ntoskrnl.exe: mst120 deallocated by network'
    #8655429: 'ExAllocatePoolWithTag+0x1df - ntoskrnl.exe: mst120 allocated by system'
    #8627412: 'IcaRawInput+0x0 - termdd.sys: IcaRawInput looks nice to see decrypted data'
    #1231549571: 'IcaFindChannel+0x3d - termdd.sys: ica find channel on this pointer???'
    #1141851788: 'ExAllocatePoolWithTag+0x1df - ntoskrnl.exe: Same pointer reallocated to something else'
    #1231549773: 'ExpCheckForIoPriorityBoost+0xa7 - ntoskrnl.exe: crash'

    Information
    ===========

    @param reven_server: The C{reven2.RevenServer} instance on which you wish to import the bookmarks.
    @param rbm_path: Path to the classic bookmark file.
    @param prepend_symbol: If C{True}, prepend the OSSI symbol as stored in the classic symbol file to the description
                           of the bookmark.
    """
    with open(rbm_path) as f:
        json_bookmarks = json.load(f)
        for json_bookmark in json_bookmarks.values():
            try:
                transition = reven_server.trace.transition(int(json_bookmark["identifier"]))
                description_prefix = (json_bookmark["symbol"] + ": ") if prepend_symbol else ""
                description = description_prefix + json_bookmark["description"]
                reven_server.bookmarks.add(transition, str(description))
            except IndexError:
                print(
                    "Skipping import of bookmark at transition {} which is out of range".format(
                        json_bookmark["identifier"]
                    )
                )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="localhost", help='Reven host, as a string (default: "localhost")')
    parser.add_argument("-p", "--port", type=int, default="13370", help="Reven port, as an int (default: 13370)")
    parser.add_argument("-f", "--filename", type=str, help="Path to the classic bookmark file (*.rbm).")
    parser.add_argument(
        "--prepend-symbol",
        action="store_true",
        help="If set, prepend the OSSI symbol as stored in the classic symbol file to the "
        "description of the bookmark",
    )
    args = parser.parse_args()

    reven_server = reven2.RevenServer(args.host, args.port)
    import_bookmarks(reven_server, args.filename, args.prepend_symbol)
    print("Bookmarks imported!")

Reporting

Examples in this section gather information to output synthetic reports.

Crash Detection

Purpose

Detect and report crashes and exceptions that occur during the trace. This script detects system crashes that occur inside of a trace, as well as exceptions thrown in user space.

How to use

Usage: crash_detection.py [-h] [--host host] [--port port] [--mode mode]
[--header]
Detect and report crashes and exceptions that appear during a REVEN scenario.
optional arguments:
-h, --help   show this help message and exit
--host host  Reven host, as a string (default: "localhost")
--port port  Reven port, as an int (default: 13370)
--mode mode  Whether to look for "user" crash, "system" crash, or "all"
--header     If present, display a header with the meaning of each column

Known limitations

Because user space processes can catch exceptions, a user exception reported by this script does not necessarily means that the involved user space process crashed after causing the exception.

Supported versions

REVEN 2.7+

Supported perimeter

Any Windows 10 x64 REVEN scenario.

Dependencies

The script requires that the target REVEN scenario have:

  • The Fast Search feature replayed.
  • The OSSI feature replayed.
  • The Backtrace feature replayed.

Source

#!/usr/bin/env python3

import argparse

import reven2

# %% [markdown]
# # Crash Detection
#
# ## Purpose
#
# Detect and report crashes and exceptions that occur during the trace.
#
# This script detects system crashes that occur inside of a trace, as well as exceptions thrown in user space.
#
# ## How to use
#
# ```bash
# Usage: crash_detection.py [-h] [--host host] [--port port] [--mode mode]
#                           [--header]
#
# Detect and report crashes and exceptions that appear during a REVEN scenario.
#
# optional arguments:
#   -h, --help   show this help message and exit
#   --host host  Reven host, as a string (default: "localhost")
#   --port port  Reven port, as an int (default: 13370)
#   --mode mode  Whether to look for "user" crash, "system" crash, or "all"
#   --header     If present, display a header with the meaning of each column
# ```
#
# ## Known limitations
#
# Because user space processes can catch exceptions, a user exception reported by this script does not necessarily
# means that the involved user space process crashed after causing the exception.
#
# ## Supported versions
#
# REVEN 2.7+
#
# ## Supported perimeter
#
# Any Windows 10 x64 REVEN scenario.
#
# ## Dependencies
#
# The script requires that the target REVEN scenario have:
#   - The Fast Search feature replayed.
#   - The OSSI feature replayed.
#   - The Backtrace feature replayed.

HIGH_LEVEL_EXCEPTION_CODES = {
    0x80000003: "breakpoint",
    0x80000004: "single step debug",
    0xC000001D: "illegal instruction",
    0xC0000094: "integer division by zero",
    0xC0000005: "access violation",
    0xC0000409: "stack buffer overrun",
}

# Obtained by reversing the transformations performed on high level exception codes
# Some of the crashes find the low-level exception codes rather than the high-level ones
LOW_LEVEL_EXCEPTION_CODES = {
    0x80000003: "breakpoint",
    0x80000004: "single step debug",
    0x10000002: "illegal instruction",
    0x10000003: "integer division by zero",
    0x10000004: "access violation",
}


class SystemCrash:
    # Code values recovered here:
    # https://docs.microsoft.com/en-us/windows-hardware/drivers/debugger/bug-check-code-reference2
    PF_BUG_CHECK_CODES = [0x50, 0xCC, 0xCD, 0xD5, 0xD6]
    EXCEPTION_BUG_CHECK_CODES = [0x1E, 0x7E, 0x8E, 0x8E, 0x135, 0x1000007E, 0x1000008E]
    SYSTEM_SERVICE_EXCEPTION = 0x3B
    KERNEL_SECURITY_CHECK_FAILURE = 0x139

    def __init__(self, trace, dispatcher_ctx):
        self._trace = trace
        self._dispatch_ctx = dispatcher_ctx
        self._bug_check_code = dispatcher_ctx.read(reven2.arch.x64.ecx)
        self._error_code = None
        self._page_fault_address = None
        self._page_fault_operation = None
        self._process = dispatcher_ctx.ossi.process()
        if self._bug_check_code in SystemCrash.PF_BUG_CHECK_CODES:
            # page fault address is the 2nd parameter of KeBugCheckEx call for PAGE_FAULT bug checks.
            # operation is 3rd parameter of KeBugCheckEx. See for instance:
            # https://docs.microsoft.com/en-us/windows-hardware/drivers/debugger/bug-check-0xcc--page-fault-in-freed-special-pool
            self._page_fault_address = dispatcher_ctx.read(reven2.arch.x64.rdx)
            self._page_fault_operation = dispatcher_ctx.read(reven2.arch.x64.r8)
        elif self._bug_check_code in SystemCrash.EXCEPTION_BUG_CHECK_CODES:
            # error code is the 2nd parameter of KeBugCheckEx call for EXCEPTION bug checks. See for instance:
            # https://docs.microsoft.com/en-us/windows-hardware/drivers/debugger/bug-check-0x1e--kmode-exception-not-handled
            self._error_code = dispatcher_ctx.read(reven2.arch.x64.edx)
        elif self._bug_check_code == SystemCrash.KERNEL_SECURITY_CHECK_FAILURE:
            # error code can be found as the first member of the exception structure that is 4th parameter of
            # KeBugCheckEx call. See:
            # https://docs.microsoft.com/en-us/windows-hardware/drivers/debugger/bug-check-0x139--kernel-security-check-failure
            # https://docs.microsoft.com/en-us/windows/win32/api/winnt/ns-winnt-exception_record
            self._error_code = dispatcher_ctx.deref(reven2.arch.x64.r9, reven2.types.Pointer(reven2.types.U32))
        elif self._bug_check_code == SystemCrash.SYSTEM_SERVICE_EXCEPTION:
            self._error_code = dispatcher_ctx.read(reven2.arch.x64.edx)
        # Look for the exception transition in the backtrace, if any is found
        self._exception_transition = None
        for frame in dispatcher_ctx.transition_before().context_before().stack.frames():
            if frame.creation_transition is not None and frame.creation_transition.exception:
                self._exception_transition = frame.creation_transition
                break

    @property
    def dispatch_ctx(self):
        return self._dispatch_ctx

    @property
    def exception_transition(self):
        return self._exception_transition

    @property
    def error_code(self):
        return self._error_code

    @property
    def bug_check_code(self):
        return self._bug_check_code

    @property
    def page_fault_address(self):
        return self._page_fault_address

    @property
    def page_fault_operation(self):
        return self._page_fault_operation

    @property
    def process(self):
        return self._process


class UserCrash:
    def __init__(self, trace, dispatcher_ctx):
        self._trace = trace
        self._dispatch_ctx = dispatcher_ctx
        self._exception_transition = None
        self._process = dispatcher_ctx.ossi.process()
        try:
            frame = next(dispatcher_ctx.transition_before().context_before().stack.frames())
            self._exception_transition = frame.first_context.transition_before()
        except StopIteration:
            pass

        self._error_code = None
        # heuristic: go back some transitions to get good stack trace
        ctx_before_rsp_changed = dispatcher_ctx - 8
        frames = ctx_before_rsp_changed.stack.frames()
        try:
            ki_exception_dispatch_ctx = next(frames).first_context
            self._error_code = ki_exception_dispatch_ctx.read(reven2.arch.x64.ecx)
        except StopIteration:
            pass

    @property
    def dispatch_ctx(self):
        return self._dispatch_ctx

    @property
    def exception_transition(self):
        return self._exception_transition

    @property
    def error_code(self):
        return self._error_code

    @property
    def process(self):
        return self._process


def detect_system_crashes(server):
    try:
        ntoskrnl = next(server.ossi.executed_binaries("ntoskrnl"))
    except StopIteration:
        raise RuntimeError("Could not find the ntoskrnl binary. " "Is this a Windows 10 trace with OSSI enabled?")

    try:
        ke_bug_check_ex = next(ntoskrnl.symbols("KeBugCheckEx"))
    except StopIteration:
        raise RuntimeError(
            "Could not find the KeBugCheckEx symbol in ntoskrnl. " "Is this a Windows 10 trace with OSSI enabled?"
        )

    for call in server.trace.search.symbol(ke_bug_check_ex):
        yield SystemCrash(server.trace, call)


def detect_user_crashes(server):
    try:
        ntdll = next(server.ossi.executed_binaries("ntdll"))
    except StopIteration:
        raise RuntimeError("Could not find the ntdll binary. " "Is this a Windows 10 trace with OSSI enabled?")
    try:
        ki_user_exception_dispatcher = next(ntdll.symbols("KiUserExceptionDispatch"))
    except StopIteration:
        raise RuntimeError(
            "Could not find the KiUserExceptionDispatch symbol in ntdll. "
            "Is this a Windows 10 trace with OSSI enabled?"
        )

    for call in server.trace.search.symbol(ki_user_exception_dispatcher):
        yield UserCrash(server.trace, call)


def format_exception_code(error_code):
    if error_code is None:
        return None
    if error_code in HIGH_LEVEL_EXCEPTION_CODES:
        return "{} ({:#x})".format(HIGH_LEVEL_EXCEPTION_CODES[error_code], error_code)
    elif error_code in LOW_LEVEL_EXCEPTION_CODES:
        return "{} ({:#x})".format(LOW_LEVEL_EXCEPTION_CODES[error_code], error_code)
    return "unknown or incorrect exception code: {:#x}".format(error_code)


def format_page_fault(page_fault_address, page_fault_operation):
    if page_fault_address is None:
        return None
    # operations changed recently for bug check 0x50. It should work with any version though. See:
    # https://docs.microsoft.com/en-us/windows-hardware/drivers/debugger/bug-check-0x50--page-fault-in-nonpaged-area#page_fault_in_nonpaged_area-parameters
    if page_fault_operation == 0:
        operation = "reading"
    elif page_fault_operation == 1 or page_fault_operation == 2:
        operation = "writing"
    elif page_fault_operation == 10:
        operation = "executing"
    else:
        return "page fault on address {:#x}".format(page_fault_address)

    return "page fault while {} address {:#x}".format(operation, page_fault_address)


def format_cause(error_code=None, page_fault_address=None, page_fault_operation=None):
    exception_fmt = format_exception_code(error_code)
    if exception_fmt is not None:
        return "{}".format(exception_fmt)
    page_fault_fmt = format_page_fault(page_fault_address, page_fault_operation)
    if page_fault_fmt is not None:
        return "{}".format(page_fault_fmt)

    return "Unknown"


def detect_crashes(server, has_system, has_user, has_header=False):
    if has_header:
        print("Mode | Process | Context | BugCheck | Cause | Exception transition")
        print("-----|---------|---------|----------|-------|---------------------")

    if has_system:
        for system_crash in detect_system_crashes(server):
            print(
                "System | {} | {} | {:#x} | {} | {}".format(
                    system_crash.process,
                    system_crash.dispatch_ctx,
                    system_crash.bug_check_code,
                    format_cause(
                        system_crash.error_code, system_crash.page_fault_address, system_crash.page_fault_operation
                    ),
                    system_crash.exception_transition,
                )
            )

    if has_user:
        for user_crash in detect_user_crashes(server):
            print(
                "User | {} | {} | N/A | {} | {}".format(
                    user_crash.process,
                    user_crash.dispatch_ctx,
                    format_exception_code(user_crash.error_code),
                    user_crash.exception_transition,
                )
            )


CRASH_MODE_DICT = {"all": (True, True), "user": (False, True), "system": (True, False)}


def parse_args():
    parser = argparse.ArgumentParser(
        description="Detect and report crashes and exceptions that appear during a " "REVEN scenario.\n",
        formatter_class=argparse.RawTextHelpFormatter,
    )
    parser.add_argument(
        "--host",
        metavar="host",
        dest="host",
        help='Reven host, as a string (default: "localhost")',
        default="localhost",
        type=str,
    )
    parser.add_argument(
        "--port", metavar="port", dest="port", help="Reven port, as an int (default: 13370)", type=int, default=13370
    )
    parser.add_argument(
        "--mode",
        metavar="mode",
        dest="mode",
        help='Whether to look for "user" crash, "system" crash, or "all"',
        type=str,
        default="all",
    )
    parser.add_argument(
        "--header",
        action="store_true",
        dest="header",
        help="If present, display a header with the meaning of each column",
    )
    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_args()

    if args.mode not in CRASH_MODE_DICT:
        raise ValueError(
            'Wrong "mode" value "{}". Mode must be "all",' ' "user" or "system" (defaults to "all").'.format(args.mode)
        )

    (has_system, has_user) = CRASH_MODE_DICT[args.mode]

    # Get a server instance
    reven_server = reven2.RevenServer(args.host, args.port)

    detect_crashes(reven_server, has_system, has_user, args.header)

Thread synchronization

Purpose

Trace calls to Windows Synchronization APIs

  • Critical sections: 'InitializeCriticalSection', 'InitializeCriticalSectionEx', 'EnterCriticalSection', 'LeaveCriticalSection', 'SleepConditionVariableCS', 'SleepConditionVariableCSRW'

  • Slim reader/writer locks: 'RtlInitializeSRWLock', 'RtlAcquireSRWLockShared', 'RtlAcquireSRWLockExclusive', 'RtlReleaseSRWLockShared', 'RtlReleaseSRWLockExclusive'

  • Mutex: 'CreateMutexW', 'OpenMutexW', 'WaitForSingleObject', ReleaseMutex

  • Condition variables: 'InitializeConditionVariable', 'WakeConditionVariable', 'WakeAllConditionVariable', 'RtlWakeConditionVariable', 'RtlWakeAllConditionVariable'

How to use

usage: threadsync.py [-h] [--host HOST] [-p PORT] --pid PID
                     [--from_id FROM_ID] [--to_id TO_ID] [--binary BINARY]
                     [--sync PRIMITIVE]

optional arguments:
  -h, --help            show this help message and exit
  --host HOST           Reven host, as a string (default: "localhost")
  -p PORT, --port PORT  Reven port, as an int (default: 13370)
  --pid PID             Process id
  --from_id FROM_ID     Transition to start searching from
  --to_id TO_ID         Transition to start searching to
  --binary BINARY       Process binary
  --sync PRIMITIVE      Synchronization primitives to check.
                        Each repeated use of `--sync <primitive>` adds the corresponding primitive to check.
                        If no `--sync` option is passed, then all supported primitives are checked.
                        Supported primitives:
                         - cs:  critical section
                         - cv:  condition variable
                         - mx:  mutex
                         - srw: slim read/write lock

Known limitations

N/A

Supported versions

REVEN 2.8+

Supported perimeter

Any Windows 10 on x86-64 scenario.

Dependencies

  • OSSI (with symbols ntdll.dll, kernelbase.dll, ntoskrnl.exe resolved) feature replayed.
  • Fast Search feature replayed.

Source

import argparse

import reven2
from reven2.address import LogicalAddress
from reven2.arch import x64
from reven2.util import collate

"""
# Thread synchronization

## Purpose

Trace calls to Windows Synchronization APIs
  - Critical sections:        'InitializeCriticalSection', 'InitializeCriticalSectionEx',
                              'EnterCriticalSection', 'LeaveCriticalSection',
                              'SleepConditionVariableCS',  'SleepConditionVariableCSRW'

  - Slim reader/writer locks: 'RtlInitializeSRWLock',
                              'RtlAcquireSRWLockShared', 'RtlAcquireSRWLockExclusive',
                              'RtlReleaseSRWLockShared', 'RtlReleaseSRWLockExclusive'

  - Mutex:                    'CreateMutexW', 'OpenMutexW', 'WaitForSingleObject', `ReleaseMutex`

  - Condition variables:      'InitializeConditionVariable',
                              'WakeConditionVariable', 'WakeAllConditionVariable',
                              'RtlWakeConditionVariable', 'RtlWakeAllConditionVariable'

## How to use

```bash
usage: threadsync.py [-h] [--host HOST] [-p PORT] --pid PID
                     [--from_id FROM_ID] [--to_id TO_ID] [--binary BINARY]
                     [--sync PRIMITIVE]

optional arguments:
  -h, --help            show this help message and exit
  --host HOST           Reven host, as a string (default: "localhost")
  -p PORT, --port PORT  Reven port, as an int (default: 13370)
  --pid PID             Process id
  --from_id FROM_ID     Transition to start searching from
  --to_id TO_ID         Transition to start searching to
  --binary BINARY       Process binary
  --sync PRIMITIVE      Synchronization primitives to check.
                        Each repeated use of `--sync <primitive>` adds the corresponding primitive to check.
                        If no `--sync` option is passed, then all supported primitives are checked.
                        Supported primitives:
                         - cs:  critical section
                         - cv:  condition variable
                         - mx:  mutex
                         - srw: slim read/write lock
```

## Known limitations

N/A

## Supported versions

REVEN 2.8+

## Supported perimeter

Any Windows 10 on x86-64 scenario.

## Dependencies

- OSSI (with symbols `ntdll.dll`, `kernelbase.dll`, `ntoskrnl.exe` resolved) feature replayed.
- Fast Search feature replayed.
"""


class SyncOSSI(object):
    def __init__(self, ossi, trace):
        # basic OSSI
        bin_symbol_names = {
            "c:/windows/system32/ntdll.dll": {
                # critical section
                "RtlInitializeCriticalSection",
                "RtlInitializeCriticalSectionEx",
                "RtlEnterCriticalSection",
                "RtlLeaveCriticalSection",
                # slim read/write lock
                "RtlInitializeSRWLock",
                "RtlAcquireSRWLockShared",
                "RtlAcquireSRWLockExclusive",
                "RtlReleaseSRWLockShared",
                "RtlReleaseSRWLockExclusive",
                # condition variable (general)
                "RtlWakeConditionVariable",
                "RtlWakeAllConditionVariable",
            },
            "c:/windows/system32/kernelbase.dll": {
                # mutex
                "CreateMutexW",
                "OpenMutexW",
                "ReleaseMutex",
                "WaitForSingleObject",
                # condition variable (general)
                "InitializeConditionVariable",
                "WakeConditionVariable",
                "WakeAllConditionVariable",
                # condition variable on critical section
                "SleepConditionVariableCS",
                # condition variable on slim read/write lock
                "SleepConditionVariableSRW",
            },
        }

        self.symbols = {}

        for (bin, symbol_names) in bin_symbol_names.items():
            try:
                exec_bin = next(ossi.executed_binaries(f"^{bin}$"))
            except StopIteration:
                raise RuntimeError(f"{bin} not found")

            for name in symbol_names:
                try:
                    sym = next(exec_bin.symbols(f"^{name}$"))
                    self.symbols[name] = sym
                except StopIteration:
                    print(f"Warning: {name} not found in {bin}")

        self.has_debugger = True
        try:
            trace.first_transition.step_over()
        except RuntimeError:
            print(
                "Warning: the debugger interface is not available, so the script cannot determine \
function return values.\nMake sure the stack events and PC range resources are replayed for this scenario."
            )
            self.has_debugger = False

        self.search_symbol = trace.search.symbol
        self.trace = trace


# tool functions
def context_cr3(ctxt):
    return ctxt.read(x64.cr3)


def is_kernel_mode(ctxt):
    return ctxt.read(x64.cs) & 0x3 == 0


def find_process_ranges(rvn, pid, from_id, to_id):
    """
    Traversing over the trace, looking for ranges of the interested process

    Parameters:
     - rvn: RevenServer
     - pid: int (process id)

    Output: yielding ranges
    """
    try:
        ntoskrnl = next(rvn.ossi.executed_binaries("^c:/windows/system32/ntoskrnl.exe$"))
    except StopIteration:
        raise RuntimeError("ntoskrnl.exe not found")

    try:
        ki_swap_context = next(ntoskrnl.symbols("^KiSwapContext$"))
    except StopIteration:
        raise RuntimeError("KiSwapContext not found")

    if from_id is None:
        ctxt_low = rvn.trace.first_context
    else:
        try:
            ctxt_low = rvn.trace.transition(from_id).context_before()
        except IndexError:
            raise RuntimeError(f"Transition of id {from_id} not found")

    if to_id is None:
        ctxt_hi = None
    else:
        try:
            ctxt_hi = rvn.trace.transition(to_id).context_before()
        except IndexError:
            raise RuntimeError(f"Transition of id {to_id} not found")

    if ctxt_hi is not None and ctxt_low >= ctxt_hi:
        return

    for ctxt in rvn.trace.search.symbol(
        ki_swap_context,
        from_context=ctxt_low,
        to_context=None if ctxt_hi is None or ctxt_hi == rvn.trace.last_context else ctxt_hi,
    ):
        if ctxt_low.ossi.process().pid == pid:
            yield (ctxt_low, ctxt)
        ctxt_low = ctxt

    if ctxt_low.ossi.process().pid == pid:
        if ctxt_hi is not None and ctxt_low < ctxt_hi:
            yield (ctxt_low, ctxt_hi)
        else:
            yield (ctxt_low, None)


def find_usermode_ranges(ctxt_low, ctxt_high):
    if ctxt_high is None:
        if not is_kernel_mode(ctxt_low):
            yield (ctxt_low, None)
        return

    ctxt_current = ctxt_low
    while ctxt_current < ctxt_high:
        ctxt_next = ctxt_current.find_register_change(x64.cs, is_forward=True)

        if not is_kernel_mode(ctxt_current):
            if ctxt_next is None or ctxt_next > ctxt_high:
                yield (ctxt_current, ctxt_high)
                break
            else:
                yield (ctxt_current, ctxt_next - 1)

        if ctxt_next is None:
            break

        ctxt_current = ctxt_next


def get_tid(ctxt):
    return ctxt.read(LogicalAddress(0x48, x64.gs), 4)


def caller_context_in_binary(ctxt, binary):
    if binary is None:
        return True

    caller_ctxt = ctxt - 1
    if caller_ctxt is None:
        return False

    caller_binary = caller_ctxt.ossi.location().binary
    if caller_binary is None:
        return False

    return binary in [caller_binary.name, caller_binary.filename, caller_binary.path]


def build_ordered_api_calls(syncOSSI, ctxt_low, ctxt_high, apis):
    def gen(api):
        return (
            (api, ctxt)
            for ctxt in syncOSSI.search_symbol(syncOSSI.symbols[api], from_context=ctxt_low, to_context=ctxt_high)
        )

    api_contexts = (
        # Using directly a generator expression appears to give wrong results,
        # while using a function works as expected.
        gen(api)
        for api in apis
        if api in syncOSSI.symbols
    )

    return collate(api_contexts, key=lambda name_ctxt: name_ctxt[1])


def get_return_value(syncOSSI, ctxt):
    if not syncOSSI.has_debugger:
        return False

    try:
        trans_after = ctxt.transition_after()
    except IndexError:
        return None

    if trans_after is None:
        return None

    trans_ret = trans_after.step_out()
    if trans_ret is None:
        return None

    ctxt_ret = trans_ret.context_after()
    return ctxt_ret.read(x64.rax)


# Values of primitive_handle and return_value fall into one of the following:
#  - None
#  - a numeric value
#  - a string literal
# where None is used usually in cases the script cannot reliably get the actual
# value of the object.
def print_csv(csv_file, ctxt, sync_primitive, primitive_handle, return_value):
    try:
        transition_id = ctxt.transition_before().id + 1
    except IndexError:
        transition_id = 0

    if primitive_handle is None:
        formatted_primitive_handle = "unknown"
    else:
        formatted_primitive_handle = (
            primitive_handle if isinstance(primitive_handle, str) else f"{primitive_handle:#x}"
        )

    if return_value is None:
        formatted_ret_value = "unknown"
    else:
        formatted_ret_value = return_value if isinstance(return_value, str) else f"{return_value:#x}"

    output = f"{transition_id}, {sync_primitive}, {formatted_primitive_handle}, {formatted_ret_value}\n"

    try:
        csv_file.write(output)
    except OSError:
        print(f"Failed to write to {csv_file}")


def check_critical_section(syncOSSI, ctxt_low, ctxt_high, binary, csv_file):
    critical_section_apis = {
        "RtlInitializeCriticalSection",
        "RtlInitializeCriticalSectionEx",
        "RtlEnterCriticalSection",
        "RtlLeaveCriticalSection",
        "SleepConditionVariableCS",
    }

    found = False

    ordered_api_contexts = build_ordered_api_calls(syncOSSI, ctxt_low, ctxt_high, critical_section_apis)
    for (api, ctxt) in ordered_api_contexts:
        if not caller_context_in_binary(ctxt, binary):
            continue

        found = True

        if api in {
            "RtlInitializeCriticalSection",
            "RtlInitializeCriticalSectionEx",
            "RtlEnterCriticalSection",
            "RtlLeaveCriticalSection",
        }:
            # the critical section handle is the first argument
            cs_handle = ctxt.read(x64.rcx)

            if csv_file is None:
                print(f"{ctxt}: {api}, critical section 0x{cs_handle:x}")
            else:
                # any API of this group returns void, then "unused" is passed as the return value
                print_csv(csv_file, ctxt, "cs", cs_handle, "unused")

        elif api in {"SleepConditionVariableCS"}:
            # the condition variable is the first argument
            cv_handle = ctxt.read(x64.rcx)

            # the critical section is the second argument
            cs_handle = ctxt.read(x64.rcx)

            if csv_file is None:
                print(f"{ctxt}: {api}, critical section 0x{cs_handle:x}, condition variable 0x{cv_handle:x}")
            else:
                # go to the return point (if possible) to get the return value
                ret_val = get_return_value(syncOSSI, ctxt)

                print_csv(csv_file, ctxt, "cs", cs_handle, ret_val)
                print_csv(csv_file, ctxt, "cv", cv_handle, ret_val)

    return found


def check_srw_lock(syncOSSI, ctxt_low, ctxt_high, binary, csv_file):
    srw_apis = [
        "RtlInitializeSRWLock",
        "RtlAcquireSRWLockShared",
        "RtlAcquireSRWLockExclusive",
        "RtlReleaseSRWLockShared",
        "RtlReleaseSRWLockExclusive",
    ]

    found = False

    ordered_api_contexts = build_ordered_api_calls(syncOSSI, ctxt_low, ctxt_high, srw_apis)
    for (api, ctxt) in ordered_api_contexts:
        if not caller_context_in_binary(ctxt, binary):
            continue

        found = True

        # the srw lock handle is the first argument
        srw_handle = ctxt.read(x64.rcx)

        if csv_file is None:
            print(f"{ctxt}: {api}, lock 0x{srw_handle:x}")
        else:
            # any API of this group returns void, then "unused" is passed as the return value
            print_csv(csv_file, ctxt, "srw", srw_handle, "unused")

    return found


def check_mutex(syncOSSI, ctxt_low, ctxt_high, binary, used_handles, csv_file):
    mutex_apis = {"CreateMutexW", "OpenMutexW", "WaitForSingleObject", "ReleaseMutex"}

    found = False

    ordered_api_contexts = build_ordered_api_calls(syncOSSI, ctxt_low, ctxt_high, mutex_apis)
    for (api, ctxt) in ordered_api_contexts:
        if not caller_context_in_binary(ctxt, binary):
            continue

        found = True

        if api in {"CreateMutexW", "OpenMutexW"}:
            # go to the return point (if possible) to get the mutex handle
            mx_handle = get_return_value(syncOSSI, ctxt)
            if mx_handle is not None and mx_handle != 0:
                used_handles.add(mx_handle)

            if csv_file is None:
                if mx_handle is None:
                    print(f"{ctxt}: {api}, mutex handle unknown")
                else:
                    if mx_handle == 0:
                        print(f"{ctxt}: {api}, failed")
                    else:
                        print(f"{ctxt}: {api}, mutex handle 0x{mx_handle:x}")
            else:
                print_csv(csv_file, ctxt, "mx", mx_handle, mx_handle)

        elif api in {"ReleaseMutex"}:
            mx_handle = ctxt.read(x64.rcx)
            used_handles.add(mx_handle)

            if csv_file is None:
                print(f"{ctxt}: {api}, mutex handle 0x{mx_handle:x}")
            else:
                # go to the return point (if possible) to get the return value
                ret_val = get_return_value(syncOSSI, ctxt)
                print_csv(csv_file, ctxt, "mx", mx_handle, ret_val)

        elif api in {"WaitForSingleObject"}:
            handle = ctxt.read(x64.rcx)
            if handle in used_handles:
                if csv_file is None:
                    print(f"{ctxt}: {api}, mutex handle 0x{handle:x}")
                else:
                    ret_val = get_return_value(syncOSSI, ctxt)
                    print_csv(csv_file, ctxt, "mx", handle, ret_val)

    return found


def check_condition_variable(syncOSSI, ctxt_low, ctxt_high, binary, csv_file):
    cond_var_apis = {
        "InitializeConditionVariable",
        "WakeConditionVariable",
        "WakeAllConditionVariable",
        "RtlWakeConditionVariable",
        "RtlWakeAllConditionVariable",
    }

    found = False

    ordered_api_contexts = build_ordered_api_calls(syncOSSI, ctxt_low, ctxt_high, cond_var_apis)
    for (api, ctxt) in ordered_api_contexts:
        if not caller_context_in_binary(ctxt, binary):
            continue

        found = True

        # the condition variable is the first argument
        cv_handle = ctxt.read(x64.rcx)

        if csv_file is None:
            print(f"{ctxt}: {api}, condition variable 0x{cv_handle:x}")
        else:
            # any API of this group returns void, then "unused" is passed as the return value
            print_csv(csv_file, ctxt, "cv", cv_handle, "unused")

    return found


def check_lock_unlock(syncOSSI, ctxt_low, ctxt_high, binary, sync_primitives, used_mutex_handles, csv_file):
    transition_low = ctxt_low.transition_after().id
    transition_high = ctxt_high.transition_after().id if ctxt_high is not None else syncOSSI.trace.last_transition.id

    tid = get_tid(ctxt_low)

    if csv_file is None:
        print(
            "\n==== checking the transition range [#{}, #{}] (thread id: {}) ====".format(
                transition_low, transition_high, tid
            )
        )

    found = False

    if "cs" in sync_primitives:
        cs_found = check_critical_section(syncOSSI, ctxt_low, ctxt_high, binary, csv_file)
        found = found or cs_found

    if "srw" in sync_primitives:
        srw_found = check_srw_lock(syncOSSI, ctxt_low, ctxt_high, binary, csv_file)
        found = found or srw_found

    if "cv" in sync_primitives:
        cv_found = check_condition_variable(syncOSSI, ctxt_low, ctxt_high, binary, csv_file)
        found = found or cv_found

    if "mx" in sync_primitives:
        mx_found = check_mutex(syncOSSI, ctxt_low, ctxt_high, binary, used_mutex_handles, csv_file)
        found = found or mx_found

    if not found and csv_file is None:
        print("\tnothing found")


def run(rvn, proc_id, proc_bin, from_id, to_id, sync_primitives, csv_file):
    syncOSSI = SyncOSSI(rvn.ossi, rvn.trace)
    used_mutex_handles = set()
    process_ranges = find_process_ranges(rvn, proc_id, from_id, to_id)
    for (low, high) in process_ranges:
        user_mode_ranges = find_usermode_ranges(low, high)
        for (low_usermode, high_usermode) in user_mode_ranges:
            check_lock_unlock(
                syncOSSI, low_usermode, high_usermode, proc_bin, sync_primitives, used_mutex_handles, csv_file
            )


def parse_args():
    parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument(
        "--host",
        type=str,
        default="localhost",
        help='Reven host, as a string (default: "localhost")',
    )
    parser.add_argument(
        "-p",
        "--port",
        type=int,
        default="13370",
        help="Reven port, as an int (default: 13370)",
    )
    parser.add_argument("--pid", type=int, required=True, help="Process id")
    parser.add_argument("--from_id", type=int, default=None, help="Transition to start searching from")
    parser.add_argument("--to_id", type=int, default=None, help="Transition to start searching to")
    parser.add_argument(
        "--binary",
        type=str,
        default=None,
        help="Process binary",
    )
    parser.add_argument(
        "--sync",
        type=str,
        metavar="PRIMITIVE",
        action="append",
        choices=["cs", "cv", "mx", "srw"],
        default=None,
        help="Synchronization primitives to check.\n"
        + "Each repeated use of `--sync <primitive>` adds the corresponding primitive to check.\n"
        + "If no `--sync` option is passed, then all supported primitives are checked.\n"
        + "Supported primitives:\n"
        + " - cs:  critical section\n"
        + " - cv:  condition variable\n"
        + " - mx:  mutex\n"
        + " - srw: slim read/write lock",
    )
    parser.add_argument("--raw", type=str, default=None, help="CSV file as output")
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    if args.sync is None or not args.sync:
        args.sync = ["cs", "cv", "mx", "srw"]

    if args.raw is None:
        csv_file = None
    else:
        try:
            csv_file = open(args.raw, "w")
        except OSError:
            raise RuntimeError(f"Failed to open {args.raw}")

        try:
            csv_file.write("transition id, primitive, handle, return value\n")
        except OSError:
            raise RuntimeError(f"Failed to write to {csv_file}")

    rvn = reven2.RevenServer(args.host, args.port)
    run(rvn, args.pid, args.binary, args.from_id, args.to_id, args.sync, csv_file)

Strings statistics

Purpose

Display statistics about strings accesses such as:

  • binary that read/write a string.
  • number of read/write accesses a binary does on a string.

How to use

usage: strings_stat.py [-h] [--host HOST] [--port PORT] [-v] [pattern]

Display statistics about strings accesses.

positional arguments:
  pattern        Pattern of the string, looking for "*pattern*", does not
                 support Regular Expression. If no pattern provided, all
                 strings will be used.

optional arguments:
  -h, --help     show this help message and exit
  --host HOST    Reven host, as a string (default: "localhost")
  --port PORT    Reven port, as an int (default: 13370)
  -v, --verbose  Increase output verbosity

Known limitations

N/A

Supported versions

REVEN 2.2+

Supported perimeter

Any REVEN scenario.

Dependencies

The script requires that the target REVEN scenario have: * The Strings feature replayed. * The OSSI feature replayed.

Source

import argparse
import builtins
import logging

import reven2


"""
# Strings statistics

## Purpose

Display statistics about strings accesses such as:
  * binary that read/write a string.
  * number of read/write accesses a binary does on a string.

## How to use

```bash
usage: strings_stat.py [-h] [--host HOST] [--port PORT] [-v] [pattern]

Display statistics about strings accesses.

positional arguments:
  pattern        Pattern of the string, looking for "*pattern*", does not
                 support Regular Expression. If no pattern provided, all
                 strings will be used.

optional arguments:
  -h, --help     show this help message and exit
  --host HOST    Reven host, as a string (default: "localhost")
  --port PORT    Reven port, as an int (default: 13370)
  -v, --verbose  Increase output verbosity
```

## Known limitations

N/A

## Supported versions

REVEN 2.2+

## Supported perimeter

Any REVEN scenario.

## Dependencies

The script requires that the target REVEN scenario have:
    * The Strings feature replayed.
    * The OSSI feature replayed.
"""


class BinaryStringOperations(object):
    def __init__(self, binary):
        self.binary = binary
        self.read_count = 0
        self.write_count = 0


def strings_stat(reven_server, pattern=""):
    for string in reven_server.trace.strings(args.pattern):
        binaries = builtins.dict()
        try:
            # iterates on all accesses to all binaries that access to the string (read/write).
            for memory_access in string.memory_accesses():
                ctx = memory_access.transition.context_before()
                if ctx.ossi.location() is None:
                    continue
                binary = ctx.ossi.location().binary
                try:
                    binary_operations = binaries[binary.path]
                except KeyError:
                    binaries[binary.path] = BinaryStringOperations(binary)
                    binary_operations = binaries[binary.path]
                if memory_access.operation == reven2.memhist.MemoryAccessOperation.Read:
                    binary_operations.read_count += 1
                else:
                    binary_operations.write_count += 1
        except RuntimeError:
            # Limitation of `memory_accesses` method that raise a RuntimeError when
            # the service timeout.
            pass
        yield (string, binaries)


def parse_args():
    parser = argparse.ArgumentParser(description="Display statistics about strings accesses.")
    parser.add_argument(
        "--host", dest="host", help='Reven host, as a string (default: "localhost")', default="localhost", type=str
    )
    parser.add_argument("--port", dest="port", help="Reven port, as an int (default: 13370)", type=int, default=13370)
    parser.add_argument(
        "-v",
        "--verbose",
        dest="log_level",
        help="Increase output verbosity",
        action="store_const",
        const=logging.DEBUG,
        default=logging.INFO,
    )
    parser.add_argument(
        "pattern",
        nargs="?",
        help="Pattern of the string, looking for "
        '"*pattern*", does not support Regular Expression. If no pattern provided, '
        "all strings will be used.",
        default="",
        type=str,
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()

    logging.basicConfig(format="%(message)s", level=args.log_level)

    logging.debug('##### Getting stat for all strings containing "{0}" #####\n'.format(args.pattern))

    # Get a server instance
    reven_server = reven2.RevenServer(args.host, args.port)

    # Print strings
    for string, binaries in strings_stat(reven_server, args.pattern):
        logging.info('"{}":'.format(string.data))
        for binary_operations in binaries.values():
            logging.info(
                "\t- {} (Read: {} - Write: {})".format(
                    binary_operations.binary.filename, binary_operations.read_count, binary_operations.write_count
                )
            )
        logging.info("")
    logging.debug("##### Done #####")

Export bookmarks

Purpose

This notebook and script are designed to export the bookmarks of a scenario, for example for inclusion in a report. The meat of the script uses the ability of the API to iterate on the bookmarks of a REVEN scenario:

for bookmark in self._server.bookmarks.all():
# do something with the bookmark.id, bookmark.transition and bookmark.description

See the Document class and in particular its add_bookmarks function for details.

How to use

Bookmark can be exported from this notebook or from the command line. The script can also be imported as a package for use from your own script or notebook.

From the notebook

  1. Upload the export_bookmarks.ipynb file in Jupyter.
  2. Fill out the parameters cell of this notebook according to your scenario and desired output.
  3. Run the full notebook.

From the command line

  1. Make sure that you are in an environment that can run REVEN scripts.
  2. Run python export_bookmarks.py --help to get a tour of available arguments.
  3. Run python export_bookmarks.py --host <your_host> --port <your_port> [<other_option>] with your arguments of choice.

Imported in your own script or notebook

  1. Make sure that you are in an environment that can run REVEN scripts.
  2. Make sure that export_bookmarks.py is in the same directory as your script or notebook.
  3. Add import export_bookmarks to your script or notebook. You can access the various functions and classes exposed by export_bookmarks.py from the export_bookmarks namespace.
  4. Refer to the Argument parsing cell for an example of use in a script, and to the Parameters cell and below for an example of use in a notebook (you just need to preprend export_bookmarks in front of the functions and classes from the script).

Customizing the notebook/script

To add a new format or change the output, you may want to:

  • Modify the various enumeration types that control the output to add your new format or option.
  • Modify the Formatter class to account for your new format.
  • Modify the Document class to account for your new output control option.

Known limitations

N/A.

Supported versions

REVEN 2.8+

Supported perimeter

Any REVEN scenario.

Dependencies

None.

Source

# -*- coding: utf-8 -*-
# ---
# jupyter:
#   jupytext:
#     formats: ipynb,py:percent
#     text_representation:
#       extension: .py
#       format_name: percent
#   kernelspec:
#     display_name: reven
#     language: python
#     name: reven-python3
# ---

# %% [markdown]
# # Export bookmarks
#
# ## Purpose
#
# This notebook and script are designed to export the bookmarks of a scenario, for example for inclusion in a report.
#
# The meat of the script uses the ability of the API to iterate on the bookmarks of a REVEN scenario:
#
# ```py
# for bookmark in self._server.bookmarks.all():
#     # do something with the bookmark.id, bookmark.transition and bookmark.description
# ```
#
# See the [Document](#Document) class and in particular its `add_bookmarks` function for details.
#
# ## How to use
#
# Bookmark can be exported from this notebook or from the command line.
# The script can also be imported as a package for use from your own script or notebook.
#
# ### From the notebook
#
# 1. Upload the `export_bookmarks.ipynb` file in Jupyter.
# 2. Fill out the [parameters](#Parameters) cell of this notebook according to your scenario and desired output.
# 3. Run the full notebook.
#
#
# ### From the command line
#
# 1. Make sure that you are in an
#    [environment](http://doc.tetrane.com/professional/latest/Python-API/Installation.html#on-the-reven-server)
#    that can run REVEN scripts.
# 2. Run `python export_bookmarks.py --help` to get a tour of available arguments.
# 3. Run `python export_bookmarks.py --host <your_host> --port <your_port> [<other_option>]` with your arguments of
#    choice.
#
# ### Imported in your own script or notebook
#
# 1. Make sure that you are in an
#    [environment](http://doc.tetrane.com/professional/latest/Python-API/Installation.html#on-the-reven-server)
#    that can run REVEN scripts.
# 2. Make sure that `export_bookmarks.py` is in the same directory as your script or notebook.
# 3. Add `import export_bookmarks` to your script or notebook. You can access the various functions and classes
#    exposed by `export_bookmarks.py` from the `export_bookmarks` namespace.
# 4. Refer to the [Argument parsing](#Argument-parsing) cell for an example of use in a script, and to the
#    [Parameters](#Parameters) cell and below for an example of use in a notebook (you just need to preprend
#    `export_bookmarks` in front of the functions and classes from the script).
#
# ## Customizing the notebook/script
#
# To add a new format or change the output, you may want to:
#
# - Modify the various [enumeration types](#Output-option-types) that control the output to add your new format or
#   option.
# - Modify the [Formatter](#Formatter) class to account for your new format.
# - Modify the [Document](#Document) class to account for your new output control option.
#
#
# ## Known limitations
#
# N/A.
#
# ## Supported versions
#
# REVEN 2.8+
#
# ## Supported perimeter
#
# Any REVEN scenario.
#
# ## Dependencies
#
# None.

# %% [markdown]
# ### Package imports

# %%
import argparse  # for argument parsing
import datetime  # Date generation
import sys  # printing to stderr
from enum import Enum
from html import escape as html_escape
from typing import Iterable, Optional

import reven2  # type: ignore

try:
    # Jupyter rendering
    from IPython.display import display, HTML, Markdown  # type: ignore
except ImportError:
    pass


# %% [markdown]
# ### Utility functions

# %%
# Detect if we are currently running a Jupyter notebook.
#
# This is used to display rendered results inline in Jupyter when we are executing in the context of a Jupyter
# notebook, or to display raw results on the standard output when we are executing in the context of a script.
def in_notebook():
    try:
        from IPython import get_ipython  # type: ignore

        if get_ipython() is None or ("IPKernelApp" not in get_ipython().config):
            return False
    except ImportError:
        return False
    return True


# %% [markdown]
# ### Output option types
#
# The enum types below are used to control the output of the script.
#
# Modify these enums to add more options if you want to add e.g. new output formats.

# %%
class HeaderOption(Enum):
    NoHeader = 0
    Simple = 1


class OutputFormat(Enum):
    Raw = 0
    Markdown = 1
    Html = 2


class SortOrder(Enum):
    Transition = 0
    Creation = 1


# %% [markdown]
# ### Formatter
#
# This is the rendering boilerplate.
#
# Modify this if you e.g. need to add new output formats.

# %%
class Formatter:
    def __init__(
        self,
        format: OutputFormat,
    ):
        self._format = format

    def header(self, title: str) -> str:
        if self._format == OutputFormat.Html:
            return f"<h1>{title}</h1>"
        elif self._format == OutputFormat.Markdown:
            return f"# {title}\n\n"
        elif self._format == OutputFormat.Raw:
            return f"{title}\n\n"
        raise NotImplementedError(f"'header' with {self._format}")

    def paragraph(self, paragraph: str) -> str:
        if self._format == OutputFormat.Html:
            return f"<p>{paragraph}</p>"
        elif self._format == OutputFormat.Markdown:
            return f"\n\n{paragraph}\n\n"
        elif self._format == OutputFormat.Raw:
            return f"\n{paragraph}\n"
        raise NotImplementedError(f"'paragraph' with {self._format}")

    def horizontal_ruler(self) -> str:
        if self._format == OutputFormat.Html:
            return "<hr/>"
        elif self._format == OutputFormat.Markdown:
            return "\n---\n"
        elif self._format == OutputFormat.Raw:
            return "\n---\n"
        raise NotImplementedError(f"'horizontal_ruler' with {self._format}")

    def transition(self, transition: reven2.trace.Transition) -> str:
        if transition.instruction is not None:
            tr_desc = str(transition.instruction)
        else:
            tr_desc = str(transition.exception)
        if self._format == OutputFormat.Html:
            if in_notebook():
                tr_id = f"{transition.format_as_html()}"
            else:
                tr_id = f"#{transition.id} "
            return f"{tr_id} <code>{tr_desc}</code>"
        elif self._format == OutputFormat.Markdown:
            return f"`#{transition.id}` `{tr_desc}`"
        elif self._format == OutputFormat.Raw:
            return f"#{transition.id}\t{tr_desc}"
        raise NotImplementedError(f"'transition' with {self._format}")

    def newline(self) -> str:
        if self._format == OutputFormat.Html:
            return "<br/>"
        elif self._format == OutputFormat.Markdown:
            return "  \n"  # EOL spaces to have a newline in markdown
        elif self._format == OutputFormat.Raw:
            return "\n"
        raise NotImplementedError(f"'newline' with {self._format}")

    def paragraph_begin(self) -> str:
        if self._format == OutputFormat.Html:
            return "<p>"
        elif self._format == OutputFormat.Markdown:
            return "\n\n"
        elif self._format == OutputFormat.Raw:
            return "\n"
        raise NotImplementedError(f"'paragraph_begin' with {self._format}")

    def paragraph_end(self) -> str:
        if self._format == OutputFormat.Html:
            return "</p>"
        elif self._format == OutputFormat.Markdown:
            return "\n\n"
        elif self._format == OutputFormat.Raw:
            return "\n"
        raise NotImplementedError(f"'paragraph_end' with {self._format}")

    def important(self, important: str) -> str:
        if self._format == OutputFormat.Html:
            return f"<strong>{important}</strong>"
        elif self._format == OutputFormat.Markdown:
            return f"**{important}**"
        elif self._format == OutputFormat.Raw:
            return f"{important} <- HERE"
        raise NotImplementedError(f"'important' with {self._format}")

    def warning(self, warning: str) -> str:
        if self._format == OutputFormat.Html:
            return f'<div class="alert alert-warning"><strong>Warning:</strong> {warning}</div>'
        elif self._format == OutputFormat.Markdown:
            return f"**Warning: {warning}**"
        elif self._format == OutputFormat.Raw:
            return f"WARNING: {warning}"
        raise NotImplementedError(f"'warning' with {self._format}")

    def code(self, code: str) -> str:
        if self._format == OutputFormat.Html:
            return f"<code>{code}</code>"
        elif self._format == OutputFormat.Markdown:
            return f"`{code}`"
        elif self._format == OutputFormat.Raw:
            return f"{code}"
        raise NotImplementedError(f"'code' with {self._format}")

    def render_error(self, text):
        if text == "":
            return
        if in_notebook():
            if self._format == OutputFormat.Html:
                display(HTML(text))
            elif self._format == OutputFormat.Markdown:
                display(Markdown(text))
            elif self._format == OutputFormat.Raw:
                display(text)
            else:
                raise NotImplementedError(f"inline error rendering with {self._format}")
        else:
            print(text, file=sys.stderr)

    def render(self, text, output):
        if text == "":
            return
        if output is None:
            if in_notebook():
                if self._format == OutputFormat.Html:
                    display(HTML(text))
                elif self._format == OutputFormat.Markdown:
                    display(Markdown(text))
                elif self._format == OutputFormat.Raw:
                    display(text)
                else:
                    raise NotImplementedError(f"inline rendering with {self._format}")
            else:
                print(text)
        else:
            try:
                with open(output, "w") as f:
                    f.write(text)
            except OSError as ose:
                raise ValueError(f"Could not open file {output}: {ose}")


# %% [markdown]
# ### Document
#
# This is the main logic of the script.

# %%
class Document:
    def __init__(
        self,
        server: reven2.RevenServer,
        sort: SortOrder,
        context: Optional[int],
        header: HeaderOption,
        format: OutputFormat,
        output: Optional[str],
        escape_description: bool,
    ):
        self._text = ""
        self._warning = ""
        self._server = server
        if context is None:
            self._context = 0
        else:
            self._context = context
        self._header_opt = header
        self._escape_description = escape_description
        self._output = output
        self._sort = sort
        self._formatter = Formatter(format)

    def add_bookmarks(self):
        if self._sort == SortOrder.Creation:
            for bookmark in sorted(self._server.bookmarks.all(), key=lambda bookmark: bookmark.id):
                self.add_bookmark(bookmark)
        else:
            for bookmark in sorted(self._server.bookmarks.all(), key=lambda bookmark: bookmark.transition):
                self.add_bookmark(bookmark)

    def add_bookmark(self, bookmark: reven2.bookmark.Bookmark):
        self._text += self._formatter.paragraph_begin()
        self.add_bookmark_header(bookmark)
        self.add_location(bookmark.transition)
        if bookmark.transition.id < self._context:
            first_transition = self._server.trace.first_transition
        else:
            first_transition = bookmark.transition - self._context
        self.add_transitions(
            transition for transition in self._server.trace.transitions(first_transition, bookmark.transition)
        )
        self.add_bookmark_transition(bookmark.transition)
        # Catch possible transitions that would out of the trace due to the value of context
        if bookmark.transition != self._server.trace.last_transition:
            if bookmark.transition.id + self._context > self._server.trace.last_transition.id:
                last_transition = self._server.trace.last_transition
            else:
                last_transition = bookmark.transition + 1 + self._context
            self.add_transitions(
                transition for transition in self._server.trace.transitions(bookmark.transition + 1, last_transition)
            )
        self._text += self._formatter.paragraph_end()
        self._text += self._formatter.horizontal_ruler()

    def add_header(self):
        if self._header_opt == HeaderOption.NoHeader:
            return
        elif self._header_opt == HeaderOption.Simple:
            scenario_name = self._server.scenario_name
            self._text += self._formatter.header(f"Bookmarks for scenario {scenario_name}")
            date = datetime.datetime.now()
            self._text += self._formatter.paragraph(f"Generated on {str(date)}")
            self._text += self._formatter.horizontal_ruler()

    def add_transitions(self, transitions: Iterable[reven2.trace.Transition]):
        for transition in transitions:
            self._text += self._formatter.transition(transition)
            self._text += self._formatter.newline()

    def add_bookmark_transition(self, transition: reven2.trace.Transition):
        tr_format = self._formatter.transition(transition)
        alone = self._context == 0
        self._text += self._formatter.important(tr_format) if not alone else tr_format
        self._text += self._formatter.newline()

    def add_bookmark_header(self, bookmark: reven2.bookmark.Bookmark):
        if self._escape_description:
            bookmark_description = html_escape(bookmark.description)
        else:
            bookmark_description = bookmark.description
        self._text += f"{bookmark_description}"
        self._text += self._formatter.newline()

    def add_location(self, transition: reven2.trace.Transition):
        ossi = transition.context_before().ossi
        try:
            if ossi and ossi.location():
                location = self._formatter.code(html_escape(str(ossi.location())))
                self._text += self._formatter.paragraph(f"Location: {location}")
        except RuntimeError:
            pass

    def add_warnings(self):
        ossi = self._server.trace.first_context.ossi
        try:
            if ossi and ossi.location():
                pass
        except RuntimeError:
            self._warning += self._formatter.warning("OSSI not replayed, locations not available in bookmarks.")

    def render(self):
        self._formatter.render_error(self._warning)
        self._formatter.render(self._text, self._output)


# %% [markdown]
# ### Main function
#
# This function is called with parameters from the [Parameters](#Parameters) cell in the notebook context,
# or with parameters from the command line in the script context.

# %%
def export_bookmarks(
    server: reven2.RevenServer,
    sort: SortOrder,
    context: Optional[int],
    header: HeaderOption,
    format: OutputFormat,
    escape_description: bool,
    suppress_warnings: bool,
    output: Optional[str],
):
    document = Document(
        server,
        sort=sort,
        context=context,
        header=header,
        format=format,
        output=output,
        escape_description=escape_description,
    )
    if not suppress_warnings:
        document.add_warnings()
    document.add_header()
    document.add_bookmarks()
    document.render()


# %% [markdown]
# ### Argument parsing
#
# Argument parsing function for use in the script context.

# %%
def get_sort(sort: str) -> SortOrder:
    if sort.lower() == "transition":
        return SortOrder.Transition
    if sort.lower() in ["creation", "id"]:
        return SortOrder.Creation
    raise ValueError(f"'order' value should be 'transition' or 'creation'. Received '{sort}'.")


def get_header(header: str) -> HeaderOption:
    if header.lower() == "no":
        return HeaderOption.NoHeader
    elif header.lower() == "simple":
        return HeaderOption.Simple
    raise ValueError(f"'header' value should be 'no' or 'simple'. Received '{header}'.")


def get_format(format: str) -> OutputFormat:
    if format.lower() == "html":
        return OutputFormat.Html
    elif format.lower() == "md" or format.lower() == "markdown":
        return OutputFormat.Markdown
    elif format.lower() == "raw" or format.lower() == "text":
        return OutputFormat.Raw
    raise ValueError("'format' value should be one of 'html', 'md' or 'raw'. Received '{format}'.")


def script_main():
    parser = argparse.ArgumentParser(description="Export the bookmarks of a scenario to a report.")
    parser.add_argument(
        "--host",
        type=str,
        default="localhost",
        required=False,
        help='REVEN host, as a string (default: "localhost")',
    )
    parser.add_argument(
        "-p",
        "--port",
        type=int,
        default="13370",
        required=False,
        help="REVEN port, as an int (default: 13370)",
    )
    parser.add_argument(
        "-C",
        "--context",
        type=int,
        required=False,
        help="Print CONTEXT lines of surrounding context around the bookmark's instruction",
    )
    parser.add_argument(
        "--header",
        type=str,
        default="no",
        required=False,
        choices=["no", "simple"],
        help="Whether to preprend the output with a header or not (default: no)",
    )
    parser.add_argument(
        "--format",
        type=str,
        default="html",
        required=False,
        choices=["html", "md", "raw"],
        help="The output format (default: html).",
    )
    parser.add_argument(
        "--order",
        type=str,
        default="transition",
        choices=["transition", "creation"],
        required=False,
        help="The sort order of bookmarks in the report (default: transition).",
    )
    parser.add_argument(
        "--no-escape-description",
        action="store_true",
        default=False,
        required=False,
        help="If present, don't escape the HTML in the bookmark descriptions.",
    )
    parser.add_argument(
        "--suppress-warnings",
        action="store_true",
        default=False,
        required=False,
        help="If present, don't print warnings to the standard error output.",
    )
    parser.add_argument(
        "-o",
        "--output-file",
        type=str,
        required=False,
        help="The target file of the report. If absent, the report will be printed on the standard output.",
    )

    args = parser.parse_args()

    try:
        server = reven2.RevenServer(args.host, args.port)
    except RuntimeError:
        raise RuntimeError(f"Could not connect to the server on {args.host}:{args.port}.")

    sort = get_sort(args.order)
    header = get_header(args.header)
    format = get_format(args.format)

    export_bookmarks(
        server,
        sort,
        args.context,
        header,
        format,
        escape_description=(not args.no_escape_description),
        suppress_warnings=args.suppress_warnings,
        output=args.output_file,
    )


# %% [markdown]
# ## Parameters
#
# These parameters have to be filled out to use in the notebook context.

# %%
# Server connection
#
host = "localhost"
port = 37103


# Output target
#
# If set to a path, writes the report file there
output_file = None  # display report inline in the Jupyter Notebook
# output_file = "report.html"  # export report to a file named "report.html" in the current directory


# Output control
#
# Sort order of bookmarks
order = SortOrder.Transition  # Bookmarks will be displayed in increasing transition number.
# order = SortOrder.Creation  # Bookmarks will be displayed in their order of creation.

# Number of transitions to display around the transition of each bookmark
context = 0  # Only display the bookmark transition
# context = 3  # Displays 3 lines above and 3 lines below the bookmark transition

# Whether to prepend a header at the top of the report
header = HeaderOption.Simple  # Display a simple header with the scenario name and generation date
# header = HeaderOption.NoHeader  # Don't display any header

# The format of the report.
# When the output target is set to a file, this specifies the format of that file.
# When the output target is `None` (report rendered inline), the difference between HTML and Markdown
# mostly influences how the description of the bookmarks is interpreted.
format = OutputFormat.Html  # Bookmark description and output file rendered as HTML
# format = export_bookmarks.OutputFormat.Markdown  # Bookmark description and output file rendered as Markdown
# format = export_bookmarks.OutputFormat.Raw  # Everything rendered as raw text

# Whether to escape HTML in the description of bookmarks.
escape_description = False  # HTML will not be escaped in description
# escape_description = True   # HTML will be escaped in description

# Whether or not to suppress the warnings that can be displayed (e.g. in case of missing OSSI)
suppress_warnings = False  # Display warnings at the top of the report
# suppress_warnings = True  # Don't display warnings at the top of the report


# %% [markdown]
# ### Execution cell
#
# This cell executes according to the [parameters](#Parameters) when in notebook context, or according to the
# [parsed arguments](#Argument-parsing) when in script context.
#
# When in notebook context, if the `output` parameter is `None`, then the report will be displayed in the last cell of
# the notebook.

# %%
if __name__ == "__main__":
    if in_notebook():
        try:
            server = reven2.RevenServer(host, port)
        except RuntimeError:
            raise RuntimeError(f"Could not connect to the server on {host}:{port}.")

        export_bookmarks(server, order, context, header, format, escape_description, suppress_warnings, output_file)
    else:
        script_main()

Networking examples

The examples in this section analyze the network activity of a REVEN scenario to produce useful information, such as PCAP files that can then be analyzed in Wireshark.

Dump PCAP

Purpose

Generate a PCAP file containing all network packets that were sent/received in a trace.

The timestamp of packets is replaced by the transition id where the packet was sent/received.

How to use

usage: dump_pcap.py [-h] [--host host] [--port port] [--filename file_name]
                    [--fix-checksum]

Dump a PCAP file from a Windows 10 x64 trace. To get the time as transition ID in wireshark, select:
View->Time display format->Seconds since 1970-01-01

optional arguments:
  -h, --help            show this help message and exit
  --host host           Reven host, as a string (default: "localhost")
  --port port           Reven port, as an int (default: 13370)
  --filename file_name  the output file name (default: "output.pcap"). Will be created if it doesn't exist
  --fix-checksum        If not specified, the packet checksum won't be fixed and you will have the buffer
                        that has been dumped from memory, and a lot of ugly packets in Wireshark,
                        that you can also ignore if needed.

Known limitations

N/A

Supported versions

REVEN 2.2+

Supported perimeter

Any Windows 10 x64 scenario.

Dependencies

  • The script requires the scapy package
  • The network_packet_tools.py file distributed alongside this example must be provided (e.g. in the same directory).
  • The script requires that the target REVEN scenario have:
    • The Fast Search feature replayed.
    • The OSSI feature replayed.
    • An access to the binary e1g6032e.sys and its PDB file.

Source

#!/usr/bin/env python3

import argparse
import itertools
import os
from typing import Iterator as _Iterator, List as _List, Optional as _Optional, Tuple as _Tuple

import network_packet_tools as nw_tools

import reven2

from scapy.all import Ether, TCP, wrpcap

"""
# Dump PCAP

## Purpose

Generate a PCAP file containing all network packets that were sent/received in a trace.

The timestamp of packets is replaced by the transition id where the packet was sent/received.

## How to use

```bash
usage: dump_pcap.py [-h] [--host host] [--port port] [--filename file_name]
                    [--fix-checksum]

Dump a PCAP file from a Windows 10 x64 trace. To get the time as transition ID in wireshark, select:
View->Time display format->Seconds since 1970-01-01

optional arguments:
  -h, --help            show this help message and exit
  --host host           Reven host, as a string (default: "localhost")
  --port port           Reven port, as an int (default: 13370)
  --filename file_name  the output file name (default: "output.pcap"). Will be created if it doesn't exist
  --fix-checksum        If not specified, the packet checksum won't be fixed and you will have the buffer
                        that has been dumped from memory, and a lot of ugly packets in Wireshark,
                        that you can also ignore if needed.
```

## Known limitations

N/A

## Supported versions

REVEN 2.2+

## Supported perimeter

Any Windows 10 x64 scenario.

## Dependencies

- The script requires the scapy package
- The `network_packet_tools.py` file distributed alongside this example must be provided (e.g. in the same directory).
- The script requires that the target REVEN scenario have:
    - The Fast Search feature replayed.
    - The OSSI feature replayed.
    - An access to the binary `e1g6032e.sys` and its PDB file.
"""


def parse_args():
    parser = argparse.ArgumentParser(
        description="Dump a PCAP file from a Windows 10 x64 trace. "
        "To get the time as transition ID in "
        "wireshark, select:\nView->Time display format->Seconds since "
        "1970-01-01\n",
        formatter_class=argparse.RawTextHelpFormatter,
    )
    parser.add_argument(
        "--host",
        metavar="host",
        dest="host",
        help='Reven host, as a string (default: "localhost")',
        default="localhost",
        type=str,
    )
    parser.add_argument(
        "--port", metavar="port", dest="port", help="Reven port, as an int (default: 13370)", type=int, default=13370
    )
    parser.add_argument(
        "--filename",
        metavar="file_name",
        dest="file_name",
        help="the output " 'file name (default: "output.pcap"). Will be created if it doesn\'t exist',
        default="output.pcap",
    )
    parser.add_argument(
        "--fix-checksum",
        dest="fix_checksum",
        action="store_true",
        help="If not specified, the packet checksum won't be fixed and you will have the buffer that \
                        has been dumped from memory, and a lot of ugly packets in Wireshark, that you can also ignore \
                        if needed.",
    )

    args = parser.parse_args()
    return args


def get_network_buffer_recv_RxPacketAssemble(
    ctx: reven2.trace.Context,
) -> _Tuple[_List[reven2.MemoryRange[reven2.address._AbstractAddress]], _Optional[bytearray]]:
    packet_memory_range = nw_tools.get_memory_address_and_size_of_received_network_packet(ctx)
    Buffer: _Optional[bytearray] = None
    sources = []
    if packet_memory_range is not None:
        sources = [packet_memory_range]
        # Get the buffer
        Buffer = ctx.read(packet_memory_range, raw=True)
    return sources, Buffer


def get_network_buffer_send_NdisSendNetBufferLists(
    reven_server: reven2.RevenServer, ctx: reven2.trace.Context
) -> _Tuple[_List[reven2.MemoryRange[reven2.address._AbstractAddress]], _Optional[bytearray]]:
    packet_memory_ranges = nw_tools.get_memory_addresses_and_sizes_of_sent_network_packet(reven_server, ctx)
    Buffer: _Optional[bytearray] = None
    sources = []
    # read buffer and join them
    for memory_range in packet_memory_ranges:
        sources.append(memory_range)
        if Buffer is None:
            Buffer = ctx.read(memory_range, raw=True)
        else:
            Buffer += ctx.read(memory_range, raw=True)

    return sources, Buffer


def get_all_send_recv(reven_server: reven2.RevenServer) -> _Iterator[_Tuple[reven2.trace.Context, str]]:
    print("[+] Get all sent/received packets...")

    send_queries, recv_queries = nw_tools.get_all_send_recv_packet_context(reven_server)

    # `reven2.util.collate` enables to iterate over multiple generators in a sorted way
    send_results = zip(reven2.util.collate(send_queries), itertools.repeat("send"))
    recv_results = zip(reven2.util.collate(recv_queries), itertools.repeat("recv"))

    # Return a sorted generator of both results regarding their context
    return reven2.util.collate([send_results, recv_results], lambda ctx_type: ctx_type[0])


def dump_pcap(reven_server: reven2.RevenServer, output_file: str = "output.pcap", fix_checksum: bool = False) -> None:
    if os.path.isfile(output_file):
        raise RuntimeError(
            '"{}" already exists. Choose an other output file or remove it before running the script.'.format(
                output_file
            )
        )

    print("[+] Creating pcap from trace...")

    # Get all send and recv from the trace
    results = list(get_all_send_recv(reven_server))
    if len(results) == 0:
        print("[+] Finished: no network packets were sent/received in the trace")
        return

    # Get packets buffers and create the pcap file.
    print("[+] Convert packets to pcap format and write to file...")
    for ctx, ty in results:
        # Just detect if send or recv context
        if ty == "send":
            sources, buf = get_network_buffer_send_NdisSendNetBufferLists(reven_server, ctx)
        else:
            sources, buf = get_network_buffer_recv_RxPacketAssemble(ctx)

        if buf is not None:
            packet = Ether(bytes(buf))

            # Here we check wether or not we have to fix checksum.
            if fix_checksum:
                if TCP in packet:
                    del packet[TCP].chksum

            # Replace the time in the packet by the transition ID, so that we get
            # it in Wireshark in a nice way.
            transition = ctx.transition_before().id
            packet.time = transition

            # Write packet to pcap file
            wrpcap(output_file, packet, append=True)

            # Print packet information
            sources_str = ", ".join(
                [
                    "{size} bytes at {address}".format(size=memory_range.size, address=memory_range.address)
                    for memory_range in sources
                ]
            )
            print("#{transition} [{type}] {sources}".format(transition=transition, type=ty, sources=sources_str))

    print("[+] Finished: PCAP file is '{}'.".format(output_file))


if __name__ == "__main__":
    args = parse_args()

    # Get a server instance
    reven_server = reven2.RevenServer(args.host, args.port)

    # Generate the PCAP file
    dump_pcap(reven_server, output_file=args.file_name, fix_checksum=args.fix_checksum)

Taint PCAP

Purpose

Display the list of functions that handle each network packet sent/received during a REVEN scenario.

How to use

usage: taint_pcap.py [-h] [--host host] [--port port] [--symbol symbols]
                     [--recv-only]

Taint every pcap from a trace.

optional arguments:
  -h, --help        show this help message and exit
  --host host       Reven host, as a string (default: "localhost")
  --port port       Reven port, as an int (default: 13370)
  --symbol symbols  Symbol name that maybe the packet goes through, as a
                    string (default ""). This argument can be repeated.
  --recv-only       If specified, handle receivedpackets only

Known limitations

N/A

Supported versions

REVEN 2.2+

Supported perimeter

Any Windows 10 x64 scenario.

Dependencies

  • The script requires the scapy package.
  • The network_packet_tools.py file distributed alongside thisexample must be provided (e.g. in the same directory).
  • The script requires that the target REVEN scenario have:
    • The Fast Search feature replayed.
    • The OSSI feature replayed.
    • An access to the binary 'e1g6032e.sys' and its PDB file.

Source

import argparse
import itertools

import network_packet_tools as nw_tools

import reven2
from reven2.preview.taint import TaintedMemories, Tainter

"""
# Taint PCAP

## Purpose

Display the list of functions that handle each network packet sent/received during a REVEN scenario.

## How to use

```bash
usage: taint_pcap.py [-h] [--host host] [--port port] [--symbol symbols]
                     [--recv-only]

Taint every pcap from a trace.

optional arguments:
  -h, --help        show this help message and exit
  --host host       Reven host, as a string (default: "localhost")
  --port port       Reven port, as an int (default: 13370)
  --symbol symbols  Symbol name that maybe the packet goes through, as a
                    string (default ""). This argument can be repeated.
  --recv-only       If specified, handle receivedpackets only
```

## Known limitations

N/A

## Supported versions

REVEN 2.2+

## Supported perimeter

Any Windows 10 x64 scenario.

## Dependencies

- The script requires the scapy package.
- The `network_packet_tools.py` file distributed alongside thisexample must be provided (e.g. in the same directory).
- The script requires that the target REVEN scenario have:
  - The Fast Search feature replayed.
  - The OSSI feature replayed.
  - An access to the binary 'e1g6032e.sys' and its PDB file.
"""


def parse_args():
    parser = argparse.ArgumentParser(description="Taint every pcap from a trace.")
    parser.add_argument(
        "--host",
        metavar="host",
        dest="host",
        help='Reven host, as a string (default: "localhost")',
        default="localhost",
        type=str,
    )
    parser.add_argument(
        "--port", metavar="port", dest="port", help="Reven port, as an int (default: 13370)", type=int, default=13370
    )
    parser.add_argument(
        "--symbol",
        metavar="symbols",
        action="append",
        dest="symbols",
        help="Symbol name that maybe the packet "
        'goes through, as a string (default ""). This argument can be repeated.',
        type=str,
        default=[],
    )

    parser.add_argument(
        "--recv-only", dest="recv_only", action="store_true", help="If specified, handle received" "packets only "
    )

    args = parser.parse_args()
    return args


def get_memory_range_of_received_network_packet(ctx):
    info = nw_tools.get_memory_address_and_size_of_received_network_packet(ctx)
    if info is None:
        return None
    return TaintedMemories(info[0], info[1])


def get_memory_range_of_sent_network_packet(ctx):
    infos = nw_tools.get_memory_addresses_and_sizes_of_sent_network_packet(ctx)
    tainted_mems = []
    for info in infos:
        tainted_mems.append(TaintedMemories(info[0], info[1]))
    return tainted_mems


def get_all_send_recv(rvn, recv_only=False):
    print("[+] Get all sent/received packets...")

    send_queries, recv_queries = nw_tools.get_all_send_recv_packet_context(rvn)

    # `reven2.util.collate` enables to iterate over multiple generators in a sorted way
    if recv_only:
        return zip(reven2.util.collate(recv_queries), itertools.repeat("recv"))

    send_results = zip(reven2.util.collate(send_queries), itertools.repeat("send"))
    recv_results = zip(reven2.util.collate(recv_queries), itertools.repeat("recv"))

    # Return a sorted generator of both results regarding their context
    return reven2.util.collate([send_results, recv_results], lambda ctx_type: ctx_type[0])


def found_symbol(current_symbol, user_symbols):
    for sym in user_symbols:
        if sym.lower() in current_symbol.name.lower():
            return True
    return False


def taint_pcap(reven_server, recv_only=False, user_symbols=[]):
    # Initialize Tainter
    tainter = Tainter(reven_server.trace)
    # Get all send and recv from the trace
    results = list(get_all_send_recv(reven_server, recv_only))
    if len(results) == 0:
        print("[+] Finished: no network packets were sent/received in the trace")
        return

    # Get packets memory range
    for ctx, ty in results:
        # Just detect if send or recv context
        # Taint packet in forward when it is received and in backward when it is sent
        is_forward = False if ty == "send" else True
        mem_range = (
            get_memory_range_of_sent_network_packet(ctx)
            if ty == "send"
            else get_memory_range_of_received_network_packet(ctx)
        )

        if mem_range is None or isinstance(mem_range, list) and len(mem_range) == 0:
            continue

        taint = tainter.simple_taint(tag0=mem_range, from_context=ctx, is_forward=is_forward)

        print("\n=====================================================================================")
        print(
            "[+]{} - {} packet at address {}".format(
                ctx.transition_before(),
                "Received" if is_forward else "Sent",
                mem_range if is_forward else ["{}".format(mem) for mem in mem_range],
            )
        )

        last_symbol = None
        for change in taint.accesses(changes_only=True).all():
            loc = change.transition.context_before().ossi.location()
            if loc is None:
                continue
            symbol = loc.symbol
            if symbol is None or last_symbol is not None and symbol == last_symbol:
                continue
            if len(user_symbols) == 0 or found_symbol(symbol, user_symbols):
                # if user_symbols is an empty list then no requested symbols, don't filter output
                print("{}: {}".format(change.transition, symbol))
                last_symbol = symbol

        print("=====================================================================================\n")

    print("[+] Finished: tainting all pcap in the trace")


if __name__ == "__main__":

    args = parse_args()

    print("[+] Start tainting pcap from trace...")

    # Get a server instance
    rvn = reven2.RevenServer(args.host, args.port)

    taint_pcap(rvn, args.recv_only, args.symbols)

This module introduces two functions. The first one is used to get the memory address and the size of the buffer used to receive a network packet. While the second one is used to return a list of memory addresses and sizes of buffers used to send a network packet.

Source

"""
This module introduces two functions. The first one is used to get the memory address and the size of the buffer
used to receive a network packet. While the second one is used to return a list of memory addresses and
sizes of buffers used to send a network packet.

"""

from typing import Iterator as _Iterator, List as _List, Optional as _Optional, Tuple as _Tuple, cast as _cast

import reven2


def get_memory_address_and_size_of_received_network_packet(
    ctx: reven2.trace.Context,
) -> _Optional[reven2.MemoryRange[reven2.address._AbstractAddress]]:
    """
    This function returns a pair of memory address, size of the received packet
    'ctx' must be a context resulting from searching "RxPacketAssemble" symbol in the trace

    Information
    ===========

    @param ctx: the context used to retrieve the list of memory address and size of sent packet buffer
    """

    # To get the memory address of received packets, we need to dereference multiple times some pointers in memory.
    # The first one is rcx, as argument. It points to a huge structure.
    # We don't know its type so we can't use the type API.
    # at rcx+0x308 is a pointer to a structure, which contains the size at +8
    # at rcx+0x328 is a pointer index that points the right structure to get for the buffer
    # at rcx+0x328+8 * index is a pointer to the network buffer.

    # at rcx+0x308, then deref +0xc is a byte that is tested at 1 and 2. If 0, then the call to RxPacketAssemble is
    # the last one of a serie and doesn't contain any buffer to fetch.

    # Get a pointer to the huge structure
    pHugeStruct = reven2.address.LogicalAddress(ctx.read(reven2.arch.x64.rcx, reven2.types.USize))

    # Deref to get a pointer to the structure that contains the size
    pSizeStruct: reven2.address.LogicalAddress = reven2.address.LogicalAddress(
        ctx.read(pHugeStruct + 0x308, reven2.types.USize)
    )

    u8Flag: int = ctx.read(pSizeStruct + 0xC, reven2.types.U8)

    # Is last packet part?
    if (u8Flag & 0x3) == 0:
        return None

    u32Size: int = ctx.read(pSizeStruct + 0x8, reven2.types.U32)

    # Next get the index in the structure
    pu32IndexRaw = reven2.address.LogicalAddress(ctx.read(pHugeStruct + 0x328, reven2.types.USize))

    # The index is a dword (eax is used)
    u32IndexRaw: int = ctx.read(pu32IndexRaw, reven2.types.U32)

    # Now, the system perform an operation on this index.
    u32Index = (u32IndexRaw + u32IndexRaw * 2) * 2

    # Now get a pointer to the buffer
    pArray = reven2.address.LogicalAddress(ctx.read(pHugeStruct + 0x328, reven2.types.USize) + 8)
    pBuffer: reven2.address.LogicalAddress = reven2.address.LogicalAddress(
        ctx.read(pArray + 8 * u32Index + 0x20, reven2.types.USize)
    )

    return reven2.MemoryRange(pBuffer, u32Size)


def get_memory_addresses_and_sizes_of_sent_network_packet(
    reven_server: reven2.RevenServer, ctx: reven2.trace.Context
) -> _List[reven2.MemoryRange[reven2.address._AbstractAddress]]:
    """
    This Function returns a list of memory address, size of a sent packet

    'ctx' must be a context resulting from searching "E1000SendNetBufferLists" symbol in the trace

    Information
    ===========

    @param ctx: the context used to retrieve the memory address and size of received packet buffer
    """
    ndis = next(reven_server.ossi.executed_binaries("ndis.sys"))
    net_buffer_list_type = _cast(reven2.types.Struct, ndis.exact_type("_NET_BUFFER_LIST"))
    net_buffer_list: reven2.types.StructInstance = ctx.deref(
        reven2.arch.x64.rdx, reven2.types.Pointer(net_buffer_list_type)
    )
    net_buffer = net_buffer_list.field("FirstNetBuffer").deref_struct()

    mdl = net_buffer.field("CurrentMdl").deref_struct()
    mdlOffset = net_buffer.field("CurrentMdlOffset").read_int()

    packet_memory_addresses = []
    packet_memory_addresses.append(_get_network_packet_address_from_mdl(ctx, mdl, mdlOffset=mdlOffset))

    pNextMdl = mdl.field("Next").read_ptr().assert_struct()
    while pNextMdl.address.offset != 0:
        nextMdl = pNextMdl.deref()
        packet_memory_addresses.append(_get_network_packet_address_from_mdl(ctx, nextMdl))

        pNextMdl = nextMdl.field("Next").read_ptr().assert_struct()

    return packet_memory_addresses


def get_all_send_recv_packet_context(
    reven_server: reven2.RevenServer,
) -> _Tuple[_List[_Iterator[reven2.trace.Context]], _List[_Iterator[reven2.trace.Context]]]:
    """
    This function return a list of all contexts used to send or receive network packets.

    To get these contexts, this function searches the symbol `E1000SendNetBufferLists` to get contexts of
    sent network packets. and searches the symbol `RxPacketAssemble` to get contexts of received network packets.

    This function requires that the trace has the PDB of `e1g6032e.sys` binary otherwise no context will be found.

    'reven_server' is the L{reven2.RevenServer} instance on which to perform the search

    Information
    ===========

    @param reven_server: L{reven2.RevenServer} instance on which to search packets
    """
    # Get generators of search results
    send_queries = [
        reven_server.trace.search.symbol(symbol)
        for symbol in reven_server.ossi.symbols(pattern="E1000SendNetBufferLists", binary_hint="e1g6032e.sys")
    ]
    recv_queries = [
        reven_server.trace.search.symbol(symbol)
        for symbol in reven_server.ossi.symbols(pattern="RxPacketAssemble", binary_hint="e1g6032e.sys")
    ]

    if len(send_queries) == 0 and len(recv_queries) == 0:
        print(
            "No network packets exist in this trace, make sure that this trace is a network trace,"
            " and if it is, make sure that the PDB of `e1g6032e.sys` binary is available in the scenario"
        )

    return send_queries, recv_queries


def _get_network_packet_address_from_mdl(
    ctx: reven2.trace.Context, mdl: reven2.types.StructInstance, mdlOffset: int = 0
) -> reven2.MemoryRange[reven2.address._AbstractAddress]:
    pBufferStartVa = mdl.field("MappedSystemVa").read_ptr()
    u32Size = mdl.field("ByteCount").read_int()

    return reven2.MemoryRange(pBufferStartVa.address + mdlOffset, u32Size - mdlOffset)