Thread synchronization

Purpose

Trace calls to Windows Synchronization APIs

  • Critical sections: EnterCriticalSection, LeaveCriticalSection
  • Condition variables: WakeConditionVariable, SleepConditionVariableCS
  • WaitForSingleObject, ReleaseMutex

How to use

usage: threadsync.py [-h] [--host HOST] [-p PORT] --cr3 PROC_CR3 --image
                     PROC_IMAGE

optional arguments:
  -h, --help            show this help message and exit
  --host HOST           Reven host, as a string (default: "localhost")
  -p PORT, --port PORT  Reven port, as an int (default: 13370)
  --cr3 PROC_CR3        Process cr3
  --image PROC_IMAGE    Process image

Known limitations

N/A

Supported versions

REVEN 2.8+

Supported perimeter

Any Windows 10 on x86-64 scenario.

Dependencies

  • OSSI (with symbols ntdll.dll, kernelbase.dll, ntoskrnl.exe resolved) feature replayed.
  • Fast Search feature replayed.

Source

import argparse

import reven2
from reven2.address import LogicalAddress
from reven2.arch import x64


'''
# Thread synchronization

## Purpose

Trace calls to Windows Synchronization APIs
  - Critical sections: `EnterCriticalSection`, `LeaveCriticalSection`
  - Condition variables: `WakeConditionVariable`, `SleepConditionVariableCS`
  - `WaitForSingleObject`, `ReleaseMutex`

## How to use

```bash
usage: threadsync.py [-h] [--host HOST] [-p PORT] --cr3 PROC_CR3 --image
                     PROC_IMAGE

optional arguments:
  -h, --help            show this help message and exit
  --host HOST           Reven host, as a string (default: "localhost")
  -p PORT, --port PORT  Reven port, as an int (default: 13370)
  --cr3 PROC_CR3        Process cr3
  --image PROC_IMAGE    Process image
```

## Known limitations

N/A

## Supported versions

REVEN 2.8+

## Supported perimeter

Any Windows 10 on x86-64 scenario.

## Dependencies

- OSSI (with symbols `ntdll.dll`, `kernelbase.dll`, `ntoskrnl.exe` resolved) feature replayed.
- Fast Search feature replayed.
'''


class SyncOSSI(object):
    def __init__(self, ossi, trace):
        # basic OSSI
        ntoskrnl = next(ossi.executed_binaries("ntoskrnl.exe"))
        self.ki_swap_thread_symbol = next(ntoskrnl.symbols("KiSwapThread"))

        ntdll = next(ossi.executed_binaries("c:/windows/system32/ntdll.dll"))
        print(ntdll)
        self.rtl_enter_critical_section = next(
            ntdll.symbols("RtlEnterCriticalSection")
        )
        print(self.rtl_enter_critical_section)
        self.rtl_leave_critical_section = next(
            ntdll.symbols("RtlLeaveCriticalSection")
        )
        print(self.rtl_leave_critical_section)
        self.rtl_wake_condition_variable = next(
            ntdll.symbols("RtlWakeConditionVariable")
        )
        print(self.rtl_wake_condition_variable)

        kernelbase = next(
            ossi.executed_binaries("c:/windows/system32/kernelbase.dll")
        )
        print(kernelbase)
        self.wait_for_single_object = next(
            kernelbase.symbols("WaitForSingleObject")
        )
        print(self.wait_for_single_object)
        self.release_mutex = next(kernelbase.symbols("ReleaseMutex"))
        print(self.release_mutex)
        self.sleep_condition_variable_cs = next(
            kernelbase.symbols("SleepConditionVariableCS")
        )
        print(self.sleep_condition_variable_cs)

        self.search_symbol = trace.search.symbol


# tool functions
def context_cr3(ctxt):
    return ctxt.read(x64.cr3)


def is_kernel_mode(ctxt):
    return ctxt.read(x64.cs) & 0x3 == 0


def find_process_ranges(rvn, cr3, image):
    """
    Traversing over the trace, looking for ranges of the interested process

    Parameters:
     - rvn: RevenServer
     - cr3: int (value of the cr3 of the process)
     - image: str (name of the image)

    Output: yielding ranges
    """
    try:
        binary = next(rvn.ossi.executed_binaries(image))
    except StopIteration:
        print('Image {} not found'.format(image))

    try:
        ctxt = next(rvn.trace.search.binary(binary))
    except StopIteration:
        print('Binary {} has no execution contexts'.format(binary))

    ctxt_low = ctxt
    ctxt_high = None

    while True:
        ctxt_high = ctxt_low.find_register_change(
            x64.cr3, is_forward=True
        )
        if ctxt_high is None:
            break

        if context_cr3(ctxt_low) == cr3:
            yield (ctxt_low, ctxt_high)

        ctxt_low = ctxt_high + 1


def find_usermode_ranges(ctxt_low, ctxt_high):
    ctxt_current = ctxt_low

    while ctxt_current < ctxt_high:
        ctxt_next = ctxt_current.find_register_change(x64.cs, is_forward=True)

        if not is_kernel_mode(ctxt_current):
            if ctxt_next is None or ctxt_next > ctxt_high:
                yield (ctxt_current, ctxt_high)
                break
            else:
                yield (ctxt_current, ctxt_next - 1)

        ctxt_current = ctxt_next


def get_tid(ctxt):
    return ctxt.read(LogicalAddress(0x48, x64.gs), 4)


def caller_context_in_image(ctxt, image):
    if image is None:
        return True

    caller_ctxt = ctxt - 1
    if caller_ctxt is None:
        return False

    caller_image = caller_ctxt.ossi.location().binary
    if caller_image is None:
        return False

    return (caller_image.filename == image)


def check_critical_section_lock(syncOSSI, ctxt_low, ctxt_high, image):
    lock_ctxts = list(
        syncOSSI.search_symbol(
            syncOSSI.rtl_enter_critical_section,
            from_context=ctxt_low,
            to_context=ctxt_high
        )
    )

    if not lock_ctxts:
        return None

    locks = []
    for ctxt in lock_ctxts:
        if not caller_context_in_image(ctxt, image):
            continue

        # the critical section handle is the first argument of EnterCriticalSection, so passed in rcx
        critical_section_addr = ctxt.read(x64.rcx)
        # the lock flag is at the offset 0x8 from the critical section handle
        lock_flag = ctxt.read(
            LogicalAddress(critical_section_addr + 0x8), 4
        )
        lock_flag = lock_flag & 0x1
        locks.append((critical_section_addr, lock_flag, ctxt))
    return locks


def check_critical_section_unlock(syncOSSI, ctxt_low, ctxt_high, image):
    unlock_ctxts = list(
        syncOSSI.search_symbol(
            syncOSSI.rtl_leave_critical_section,
            from_context=ctxt_low,
            to_context=ctxt_high,
        )
    )

    if not unlock_ctxts:
        return None

    unlocks = []
    for ctxt in unlock_ctxts:
        if not caller_context_in_image(ctxt, image):
            continue

        critical_section_addr = ctxt.read(x64.rcx)
        unlocks.append((critical_section_addr, ctxt))
    return unlocks


def check_wait_for_single_object(syncOSSI, ctxt_low, ctxt_high, image):
    wait_ctxts = list(
        syncOSSI.search_symbol(
            syncOSSI.wait_for_single_object,
            from_context=ctxt_low,
            to_context=ctxt_high,
        )
    )

    if not wait_ctxts:
        return None

    waits = []
    for ctxt in wait_ctxts:
        if not caller_context_in_image(ctxt, image):
            continue

        object_handle = ctxt.read(x64.rcx)
        waits.append((object_handle, ctxt))
    return waits


def check_mutex_release(syncOSSI, ctxt_low, ctxt_high, image):
    mutex_release_ctxts = list(
        syncOSSI.search_symbol(
            syncOSSI.release_mutex,
            from_context=ctxt_low,
            to_context=ctxt_high,
        )
    )

    if not mutex_release_ctxts:
        return None

    mutexes = []
    for ctxt in mutex_release_ctxts:
        if not caller_context_in_image(ctxt, image):
            continue

        mutex_handle = ctxt.read(x64.rcx)
        mutexes.append((mutex_handle, ctxt))
    return mutexes


def check_sleep_condition_variable(syncOSSI, ctxt_low, ctxt_high, image):
    sleep_cv_cs_ctxts = list(
        syncOSSI.search_symbol(
            syncOSSI.sleep_condition_variable_cs,
            from_context=ctxt_low,
            to_context=ctxt_high,
        )
    )

    if not sleep_cv_cs_ctxts:
        return None

    sleep_cv_css = []
    for ctxt in sleep_cv_cs_ctxts:
        if not caller_context_in_image(ctxt, image):
            continue

        cond_var = ctxt.read(x64.rcx)
        cs_sec = ctxt.read(x64.rdx)
        sleep_cv_css.append((cond_var, cs_sec, ctxt))
    return sleep_cv_css


def check_rtl_wake_condition_variable(syncOSSI, ctxt_low, ctxt_high, image):
    rtl_wake_cond_ctxts = list(
        syncOSSI.search_symbol(
            syncOSSI.rtl_wake_condition_variable,
            from_context=ctxt_low,
            to_context=ctxt_high,
        )
    )

    if not rtl_wake_cond_ctxts:
        return None

    rtl_wake_cond = []
    for ctxt in rtl_wake_cond_ctxts:
        if not caller_context_in_image(ctxt, image):
            continue

        cond_var = ctxt.read(x64.rcx)
        rtl_wake_cond.append((cond_var, ctxt))


def check_lock_unlock(syncOSSI, ctxt_low, ctxt_high, image):
    transition_low = ctxt_low.transition_after().id
    transition_high = ctxt_high.transition_after().id

    tid = get_tid(ctxt_low)

    found = False
    print(
        "\n==== checking the transition range [{}, {}] (TID: {}) ====".format(
            transition_low, transition_high, tid
        )
    )

    locks = check_critical_section_lock(syncOSSI, ctxt_low, ctxt_high, image)
    if locks is not None:
        found = True
        for (cs, flag, lock_ctxt) in locks:
            transition_lock = lock_ctxt.transition_after().id
            # the flag is zero, the critical section is already locked, so the thread would be blocked
            if flag == 0:
                print(
                    "\tcritial section 0x{:x} trylock at {}".format(
                        cs, transition_lock
                    )
                )
            # otherwise, the critical section is not locked, so the thread will lock the section
            else:
                print(
                    "\tcritial section 0x{:x} locked at {}".format(
                        cs, transition_lock
                    )
                )

    unlocks = check_critical_section_unlock(syncOSSI, ctxt_low, ctxt_high, image)
    if unlocks is not None:
        found = True
        for (cs, unlock_ctxt) in unlocks:
            transition_unlock = unlock_ctxt.transition_after().id
            print(
                "\tcritial section 0x{:x} unlocked at {}".format(
                    cs, transition_unlock
                )
            )

    wait_result = check_wait_for_single_object(syncOSSI, ctxt_low, ctxt_high, image)
    if wait_result is not None:
        found = True
        for (object_handle, wait_ctxt) in wait_result:
            transition_wait = wait_ctxt.transition_after().id
            print(
                "\twait for object 0x{:x} at {}".format(
                    object_handle, transition_wait
                )
            )

    mutex_release_result = check_mutex_release(syncOSSI, ctxt_low, ctxt_high, image)
    if mutex_release_result is not None:
        found = True
        for (mutex_handle, release_ctxt) in mutex_release_result:
            transition_release = release_ctxt.transition_after().id
            print(
                "\trelease mutex 0x{:x} at {}".format(
                    mutex_handle, transition_release
                )
            )

    cond_var_sleep_result = check_sleep_condition_variable(syncOSSI, ctxt_low, ctxt_high, image)
    if cond_var_sleep_result is not None:
        found = True
        for (cond_var, cs_sec, sleep_ctxt) in cond_var_sleep_result:
            transition_sleep = sleep_ctxt.transition_after().id
            print(
                "\tsleep on condition variable 0x{:x} with critical section 0x{:x} at {}".format(
                    cond_var, cs_sec, transition_sleep
                )
            )

    cond_var_wake_result = check_rtl_wake_condition_variable(syncOSSI, ctxt_low, ctxt_high, image)
    if cond_var_wake_result is not None:
        found = True
        for (cond_var, wake_ctxt) in cond_var_wake_result:
            transition_wake = wake_ctxt.transition_after().id
            print(
                "\twake condition variable 0x{:x} at {}".format(
                    cond_var, transition_wake
                )
            )

    if not found:
        print(
            "\tnothing found"
        )


def run(rvn, proc_cr3, proc_image):
    syncOSSI = SyncOSSI(rvn.ossi, rvn.trace)
    process_ranges = find_process_ranges(rvn, proc_cr3, proc_image)
    for (low, high) in process_ranges:
        user_mode_ranges = find_usermode_ranges(low, high)
        for (low_usermode, high_usermode) in user_mode_ranges:
            check_lock_unlock(syncOSSI, low_usermode, high_usermode, None)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--host",
        type=str,
        default="localhost",
        help='Reven host, as a string (default: "localhost")',
    )
    parser.add_argument(
        "-p",
        "--port",
        type=int,
        default="13370",
        help="Reven port, as an int (default: 13370)",
    )
    parser.add_argument(
        "--cr3",
        metavar="PROC_CR3",
        type=int,
        required=True,
        help="Process cr3",
    )
    parser.add_argument(
        "--image",
        metavar="PROC_IMAGE",
        type=str,
        required=True,
        help="Process image",
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()

    rvn = reven2.RevenServer(args.host, args.port)
    run(rvn, args.cr3, args.image)