Thread synchronization

Purpose

Trace calls to Windows Synchronization APIs

  • Critical sections: 'InitializeCriticalSection', 'InitializeCriticalSectionEx', 'EnterCriticalSection', 'LeaveCriticalSection', 'SleepConditionVariableCS', 'SleepConditionVariableCSRW'

  • Slim reader/writer locks: 'RtlInitializeSRWLock', 'RtlAcquireSRWLockShared', 'RtlAcquireSRWLockExclusive', 'RtlReleaseSRWLockShared', 'RtlReleaseSRWLockExclusive'

  • Mutex: 'CreateMutexW', 'OpenMutexW', 'WaitForSingleObject', ReleaseMutex

  • Condition variables: 'InitializeConditionVariable', 'WakeConditionVariable', 'WakeAllConditionVariable', 'RtlWakeConditionVariable', 'RtlWakeAllConditionVariable'

How to use

usage: threadsync.py [-h] [--host HOST] [-p PORT] --pid PID
                     [--from_id FROM_ID] [--to_id TO_ID] [--binary BINARY]
                     [--sync PRIMITIVE]

optional arguments:
  -h, --help            show this help message and exit
  --host HOST           Reven host, as a string (default: "localhost")
  -p PORT, --port PORT  Reven port, as an int (default: 13370)
  --pid PID             Process id
  --from_id FROM_ID     Transition to start searching from
  --to_id TO_ID         Transition to start searching to
  --binary BINARY       Process binary
  --sync PRIMITIVE      Synchronization primitives to check.
                        Each repeated use of `--sync <primitive>` adds the corresponding primitive to check.
                        If no `--sync` option is passed, then all supported primitives are checked.
                        Supported primitives:
                         - cs:  critical section
                         - cv:  condition variable
                         - mx:  mutex
                         - srw: slim read/write lock

Known limitations

N/A

Supported versions

REVEN 2.8+

Supported perimeter

Any Windows 10 on x86-64 scenario.

Dependencies

  • OSSI (with symbols ntdll.dll, kernelbase.dll, ntoskrnl.exe resolved) feature replayed.
  • Fast Search feature replayed.

Source

import argparse

import reven2
from reven2.address import LogicalAddress
from reven2.arch import x64
from reven2.util import collate

"""
# Thread synchronization

## Purpose

Trace calls to Windows Synchronization APIs
  - Critical sections:        'InitializeCriticalSection', 'InitializeCriticalSectionEx',
                              'EnterCriticalSection', 'LeaveCriticalSection',
                              'SleepConditionVariableCS',  'SleepConditionVariableCSRW'

  - Slim reader/writer locks: 'RtlInitializeSRWLock',
                              'RtlAcquireSRWLockShared', 'RtlAcquireSRWLockExclusive',
                              'RtlReleaseSRWLockShared', 'RtlReleaseSRWLockExclusive'

  - Mutex:                    'CreateMutexW', 'OpenMutexW', 'WaitForSingleObject', `ReleaseMutex`

  - Condition variables:      'InitializeConditionVariable',
                              'WakeConditionVariable', 'WakeAllConditionVariable',
                              'RtlWakeConditionVariable', 'RtlWakeAllConditionVariable'

## How to use

```bash
usage: threadsync.py [-h] [--host HOST] [-p PORT] --pid PID
                     [--from_id FROM_ID] [--to_id TO_ID] [--binary BINARY]
                     [--sync PRIMITIVE]

optional arguments:
  -h, --help            show this help message and exit
  --host HOST           Reven host, as a string (default: "localhost")
  -p PORT, --port PORT  Reven port, as an int (default: 13370)
  --pid PID             Process id
  --from_id FROM_ID     Transition to start searching from
  --to_id TO_ID         Transition to start searching to
  --binary BINARY       Process binary
  --sync PRIMITIVE      Synchronization primitives to check.
                        Each repeated use of `--sync <primitive>` adds the corresponding primitive to check.
                        If no `--sync` option is passed, then all supported primitives are checked.
                        Supported primitives:
                         - cs:  critical section
                         - cv:  condition variable
                         - mx:  mutex
                         - srw: slim read/write lock
```

## Known limitations

N/A

## Supported versions

REVEN 2.8+

## Supported perimeter

Any Windows 10 on x86-64 scenario.

## Dependencies

- OSSI (with symbols `ntdll.dll`, `kernelbase.dll`, `ntoskrnl.exe` resolved) feature replayed.
- Fast Search feature replayed.
"""


class SyncOSSI(object):
    def __init__(self, ossi, trace):
        # basic OSSI
        bin_symbol_names = {
            "c:/windows/system32/ntdll.dll": {
                # critical section
                "RtlInitializeCriticalSection",
                "RtlInitializeCriticalSectionEx",
                "RtlEnterCriticalSection",
                "RtlLeaveCriticalSection",
                # slim read/write lock
                "RtlInitializeSRWLock",
                "RtlAcquireSRWLockShared",
                "RtlAcquireSRWLockExclusive",
                "RtlReleaseSRWLockShared",
                "RtlReleaseSRWLockExclusive",
                # condition variable (general)
                "RtlWakeConditionVariable",
                "RtlWakeAllConditionVariable",
            },
            "c:/windows/system32/kernelbase.dll": {
                # mutex
                "CreateMutexW",
                "OpenMutexW",
                "ReleaseMutex",
                "WaitForSingleObject",
                # condition variable (general)
                "InitializeConditionVariable",
                "WakeConditionVariable",
                "WakeAllConditionVariable",
                # condition variable on critical section
                "SleepConditionVariableCS",
                # condition variable on slim read/write lock
                "SleepConditionVariableSRW",
            },
        }

        self.symbols = {}

        for (bin, symbol_names) in bin_symbol_names.items():
            try:
                exec_bin = next(ossi.executed_binaries(f"^{bin}$"))
            except StopIteration:
                raise RuntimeError(f"{bin} not found")

            for name in symbol_names:
                try:
                    sym = next(exec_bin.symbols(f"^{name}$"))
                    self.symbols[name] = sym
                except StopIteration:
                    print(f"Warning: {name} not found in {bin}")

        self.has_debugger = True
        try:
            trace.first_transition.step_over()
        except RuntimeError:
            print(
                "Warning: the debugger interface is not available, so the script cannot determine \
function return values.\nMake sure the stack events and PC range resources are replayed for this scenario."
            )
            self.has_debugger = False

        self.search_symbol = trace.search.symbol
        self.trace = trace


# tool functions
def context_cr3(ctxt):
    return ctxt.read(x64.cr3)


def is_kernel_mode(ctxt):
    return ctxt.read(x64.cs) & 0x3 == 0


def find_process_ranges(rvn, pid, from_id, to_id):
    """
    Traversing over the trace, looking for ranges of the interested process

    Parameters:
     - rvn: RevenServer
     - pid: int (process id)

    Output: yielding ranges
    """
    try:
        ntoskrnl = next(rvn.ossi.executed_binaries("^c:/windows/system32/ntoskrnl.exe$"))
    except StopIteration:
        raise RuntimeError("ntoskrnl.exe not found")

    try:
        ki_swap_context = next(ntoskrnl.symbols("^KiSwapContext$"))
    except StopIteration:
        raise RuntimeError("KiSwapContext not found")

    if from_id is None:
        ctxt_low = rvn.trace.first_context
    else:
        try:
            ctxt_low = rvn.trace.transition(from_id).context_before()
        except IndexError:
            raise RuntimeError(f"Transition of id {from_id} not found")

    if to_id is None:
        ctxt_hi = None
    else:
        try:
            ctxt_hi = rvn.trace.transition(to_id).context_before()
        except IndexError:
            raise RuntimeError(f"Transition of id {to_id} not found")

    if ctxt_hi is not None and ctxt_low >= ctxt_hi:
        return

    for ctxt in rvn.trace.search.symbol(
        ki_swap_context,
        from_context=ctxt_low,
        to_context=None if ctxt_hi is None or ctxt_hi == rvn.trace.last_context else ctxt_hi,
    ):
        if ctxt_low.ossi.process().pid == pid:
            yield (ctxt_low, ctxt)
        ctxt_low = ctxt

    if ctxt_low.ossi.process().pid == pid:
        if ctxt_hi is not None and ctxt_low < ctxt_hi:
            yield (ctxt_low, ctxt_hi)
        else:
            yield (ctxt_low, None)


def find_usermode_ranges(ctxt_low, ctxt_high):
    if ctxt_high is None:
        if not is_kernel_mode(ctxt_low):
            yield (ctxt_low, None)
        return

    ctxt_current = ctxt_low
    while ctxt_current < ctxt_high:
        ctxt_next = ctxt_current.find_register_change(x64.cs, is_forward=True)

        if not is_kernel_mode(ctxt_current):
            if ctxt_next is None or ctxt_next > ctxt_high:
                yield (ctxt_current, ctxt_high)
                break
            else:
                yield (ctxt_current, ctxt_next - 1)

        if ctxt_next is None:
            break

        ctxt_current = ctxt_next


def get_tid(ctxt):
    return ctxt.read(LogicalAddress(0x48, x64.gs), 4)


def caller_context_in_binary(ctxt, binary):
    if binary is None:
        return True

    caller_ctxt = ctxt - 1
    if caller_ctxt is None:
        return False

    caller_binary = caller_ctxt.ossi.location().binary
    if caller_binary is None:
        return False

    return binary in [caller_binary.name, caller_binary.filename, caller_binary.path]


def build_ordered_api_calls(syncOSSI, ctxt_low, ctxt_high, apis):
    def gen(api):
        return (
            (api, ctxt)
            for ctxt in syncOSSI.search_symbol(syncOSSI.symbols[api], from_context=ctxt_low, to_context=ctxt_high)
        )

    api_contexts = (
        # Using directly a generator expression appears to give wrong results,
        # while using a function works as expected.
        gen(api)
        for api in apis
        if api in syncOSSI.symbols
    )

    return collate(api_contexts, key=lambda name_ctxt: name_ctxt[1])


def get_return_value(syncOSSI, ctxt):
    if not syncOSSI.has_debugger:
        return False

    try:
        trans_after = ctxt.transition_after()
    except IndexError:
        return None

    if trans_after is None:
        return None

    trans_ret = trans_after.step_out()
    if trans_ret is None:
        return None

    ctxt_ret = trans_ret.context_after()
    return ctxt_ret.read(x64.rax)


# Values of primitive_handle and return_value fall into one of the following:
#  - None
#  - a numeric value
#  - a string literal
# where None is used usually in cases the script cannot reliably get the actual
# value of the object.
def print_csv(csv_file, ctxt, sync_primitive, primitive_handle, return_value):
    try:
        transition_id = ctxt.transition_before().id + 1
    except IndexError:
        transition_id = 0

    if primitive_handle is None:
        formatted_primitive_handle = "unknown"
    else:
        formatted_primitive_handle = (
            primitive_handle if isinstance(primitive_handle, str) else f"{primitive_handle:#x}"
        )

    if return_value is None:
        formatted_ret_value = "unknown"
    else:
        formatted_ret_value = return_value if isinstance(return_value, str) else f"{return_value:#x}"

    output = f"{transition_id}, {sync_primitive}, {formatted_primitive_handle}, {formatted_ret_value}\n"

    try:
        csv_file.write(output)
    except OSError:
        print(f"Failed to write to {csv_file}")


def check_critical_section(syncOSSI, ctxt_low, ctxt_high, binary, csv_file):
    critical_section_apis = {
        "RtlInitializeCriticalSection",
        "RtlInitializeCriticalSectionEx",
        "RtlEnterCriticalSection",
        "RtlLeaveCriticalSection",
        "SleepConditionVariableCS",
    }

    found = False

    ordered_api_contexts = build_ordered_api_calls(syncOSSI, ctxt_low, ctxt_high, critical_section_apis)
    for (api, ctxt) in ordered_api_contexts:
        if not caller_context_in_binary(ctxt, binary):
            continue

        found = True

        if api in {
            "RtlInitializeCriticalSection",
            "RtlInitializeCriticalSectionEx",
            "RtlEnterCriticalSection",
            "RtlLeaveCriticalSection",
        }:
            # the critical section handle is the first argument
            cs_handle = ctxt.read(x64.rcx)

            if csv_file is None:
                print(f"{ctxt}: {api}, critical section 0x{cs_handle:x}")
            else:
                # any API of this group returns void, then "unused" is passed as the return value
                print_csv(csv_file, ctxt, "cs", cs_handle, "unused")

        elif api in {"SleepConditionVariableCS"}:
            # the condition variable is the first argument
            cv_handle = ctxt.read(x64.rcx)

            # the critical section is the second argument
            cs_handle = ctxt.read(x64.rcx)

            if csv_file is None:
                print(f"{ctxt}: {api}, critical section 0x{cs_handle:x}, condition variable 0x{cv_handle:x}")
            else:
                # go to the return point (if possible) to get the return value
                ret_val = get_return_value(syncOSSI, ctxt)

                print_csv(csv_file, ctxt, "cs", cs_handle, ret_val)
                print_csv(csv_file, ctxt, "cv", cv_handle, ret_val)

    return found


def check_srw_lock(syncOSSI, ctxt_low, ctxt_high, binary, csv_file):
    srw_apis = [
        "RtlInitializeSRWLock",
        "RtlAcquireSRWLockShared",
        "RtlAcquireSRWLockExclusive",
        "RtlReleaseSRWLockShared",
        "RtlReleaseSRWLockExclusive",
    ]

    found = False

    ordered_api_contexts = build_ordered_api_calls(syncOSSI, ctxt_low, ctxt_high, srw_apis)
    for (api, ctxt) in ordered_api_contexts:
        if not caller_context_in_binary(ctxt, binary):
            continue

        found = True

        # the srw lock handle is the first argument
        srw_handle = ctxt.read(x64.rcx)

        if csv_file is None:
            print(f"{ctxt}: {api}, lock 0x{srw_handle:x}")
        else:
            # any API of this group returns void, then "unused" is passed as the return value
            print_csv(csv_file, ctxt, "srw", srw_handle, "unused")

    return found


def check_mutex(syncOSSI, ctxt_low, ctxt_high, binary, used_handles, csv_file):
    mutex_apis = {"CreateMutexW", "OpenMutexW", "WaitForSingleObject", "ReleaseMutex"}

    found = False

    ordered_api_contexts = build_ordered_api_calls(syncOSSI, ctxt_low, ctxt_high, mutex_apis)
    for (api, ctxt) in ordered_api_contexts:
        if not caller_context_in_binary(ctxt, binary):
            continue

        found = True

        if api in {"CreateMutexW", "OpenMutexW"}:
            # go to the return point (if possible) to get the mutex handle
            mx_handle = get_return_value(syncOSSI, ctxt)
            if mx_handle is not None and mx_handle != 0:
                used_handles.add(mx_handle)

            if csv_file is None:
                if mx_handle is None:
                    print(f"{ctxt}: {api}, mutex handle unknown")
                else:
                    if mx_handle == 0:
                        print(f"{ctxt}: {api}, failed")
                    else:
                        print(f"{ctxt}: {api}, mutex handle 0x{mx_handle:x}")
            else:
                print_csv(csv_file, ctxt, "mx", mx_handle, mx_handle)

        elif api in {"ReleaseMutex"}:
            mx_handle = ctxt.read(x64.rcx)
            used_handles.add(mx_handle)

            if csv_file is None:
                print(f"{ctxt}: {api}, mutex handle 0x{mx_handle:x}")
            else:
                # go to the return point (if possible) to get the return value
                ret_val = get_return_value(syncOSSI, ctxt)
                print_csv(csv_file, ctxt, "mx", mx_handle, ret_val)

        elif api in {"WaitForSingleObject"}:
            handle = ctxt.read(x64.rcx)
            if handle in used_handles:
                if csv_file is None:
                    print(f"{ctxt}: {api}, mutex handle 0x{handle:x}")
                else:
                    ret_val = get_return_value(syncOSSI, ctxt)
                    print_csv(csv_file, ctxt, "mx", handle, ret_val)

    return found


def check_condition_variable(syncOSSI, ctxt_low, ctxt_high, binary, csv_file):
    cond_var_apis = {
        "InitializeConditionVariable",
        "WakeConditionVariable",
        "WakeAllConditionVariable",
        "RtlWakeConditionVariable",
        "RtlWakeAllConditionVariable",
    }

    found = False

    ordered_api_contexts = build_ordered_api_calls(syncOSSI, ctxt_low, ctxt_high, cond_var_apis)
    for (api, ctxt) in ordered_api_contexts:
        if not caller_context_in_binary(ctxt, binary):
            continue

        found = True

        # the condition variable is the first argument
        cv_handle = ctxt.read(x64.rcx)

        if csv_file is None:
            print(f"{ctxt}: {api}, condition variable 0x{cv_handle:x}")
        else:
            # any API of this group returns void, then "unused" is passed as the return value
            print_csv(csv_file, ctxt, "cv", cv_handle, "unused")

    return found


def check_lock_unlock(syncOSSI, ctxt_low, ctxt_high, binary, sync_primitives, used_mutex_handles, csv_file):
    transition_low = ctxt_low.transition_after().id
    transition_high = ctxt_high.transition_after().id if ctxt_high is not None else syncOSSI.trace.last_transition.id

    tid = get_tid(ctxt_low)

    if csv_file is None:
        print(
            "\n==== checking the transition range [#{}, #{}] (thread id: {}) ====".format(
                transition_low, transition_high, tid
            )
        )

    found = False

    if "cs" in sync_primitives:
        cs_found = check_critical_section(syncOSSI, ctxt_low, ctxt_high, binary, csv_file)
        found = found or cs_found

    if "srw" in sync_primitives:
        srw_found = check_srw_lock(syncOSSI, ctxt_low, ctxt_high, binary, csv_file)
        found = found or srw_found

    if "cv" in sync_primitives:
        cv_found = check_condition_variable(syncOSSI, ctxt_low, ctxt_high, binary, csv_file)
        found = found or cv_found

    if "mx" in sync_primitives:
        mx_found = check_mutex(syncOSSI, ctxt_low, ctxt_high, binary, used_mutex_handles, csv_file)
        found = found or mx_found

    if not found and csv_file is None:
        print("\tnothing found")


def run(rvn, proc_id, proc_bin, from_id, to_id, sync_primitives, csv_file):
    syncOSSI = SyncOSSI(rvn.ossi, rvn.trace)
    used_mutex_handles = set()
    process_ranges = find_process_ranges(rvn, proc_id, from_id, to_id)
    for (low, high) in process_ranges:
        user_mode_ranges = find_usermode_ranges(low, high)
        for (low_usermode, high_usermode) in user_mode_ranges:
            check_lock_unlock(
                syncOSSI, low_usermode, high_usermode, proc_bin, sync_primitives, used_mutex_handles, csv_file
            )


def parse_args():
    parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument(
        "--host",
        type=str,
        default="localhost",
        help='Reven host, as a string (default: "localhost")',
    )
    parser.add_argument(
        "-p",
        "--port",
        type=int,
        default="13370",
        help="Reven port, as an int (default: 13370)",
    )
    parser.add_argument("--pid", type=int, required=True, help="Process id")
    parser.add_argument("--from_id", type=int, default=None, help="Transition to start searching from")
    parser.add_argument("--to_id", type=int, default=None, help="Transition to start searching to")
    parser.add_argument(
        "--binary",
        type=str,
        default=None,
        help="Process binary",
    )
    parser.add_argument(
        "--sync",
        type=str,
        metavar="PRIMITIVE",
        action="append",
        choices=["cs", "cv", "mx", "srw"],
        default=None,
        help="Synchronization primitives to check.\n"
        + "Each repeated use of `--sync <primitive>` adds the corresponding primitive to check.\n"
        + "If no `--sync` option is passed, then all supported primitives are checked.\n"
        + "Supported primitives:\n"
        + " - cs:  critical section\n"
        + " - cv:  condition variable\n"
        + " - mx:  mutex\n"
        + " - srw: slim read/write lock",
    )
    parser.add_argument("--raw", type=str, default=None, help="CSV file as output")
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    if args.sync is None or not args.sync:
        args.sync = ["cs", "cv", "mx", "srw"]

    if args.raw is None:
        csv_file = None
    else:
        try:
            csv_file = open(args.raw, "w")
        except OSError:
            raise RuntimeError(f"Failed to open {args.raw}")

        try:
            csv_file.write("transition id, primitive, handle, return value\n")
        except OSError:
            raise RuntimeError(f"Failed to write to {csv_file}")

    rvn = reven2.RevenServer(args.host, args.port)
    run(rvn, args.pid, args.binary, args.from_id, args.to_id, args.sync, csv_file)