Thread id

Purpose

Detect the current thread at a point in the trace and find when the thread is created.

How to use

usage: thread_id.py [-h] [--host HOST] [-p PORT] TRANSITION_ID

positional arguments:
  TRANSITION_ID         Get thread id at transition (before)

optional arguments:
  -h, --help            show this help message and exit
  --host HOST           Reven host, as a string (default: "localhost")
  -p PORT, --port PORT  Reven port, as an int (default: 13370)

Known limitations

  • Current thread is not detected if the given point is in ring0.

Supported versions

REVEN 2.8+

Supported perimeter

Any Windows 10 on x86-64 scenario. Given point must be in a 64 bit process.

Dependencies

The script requires that the target REVEN scenario have: * The Fast Search feature replayed. * The OSSI feature replayed. * An access to the binary 'ntdll.dll' and its PDB file.

Source

import argparse

import reven2
from reven2.address import LogicalAddress
from reven2.arch import x64

"""
# Thread id

## Purpose

Detect the current thread at a point in the trace
and find when the thread is created.

## How to use

```bash
usage: thread_id.py [-h] [--host HOST] [-p PORT] TRANSITION_ID

positional arguments:
  TRANSITION_ID         Get thread id at transition (before)

optional arguments:
  -h, --help            show this help message and exit
  --host HOST           Reven host, as a string (default: "localhost")
  -p PORT, --port PORT  Reven port, as an int (default: 13370)
```

## Known limitations

- Current thread is not detected if the given point is in ring0.

## Supported versions

REVEN 2.8+

## Supported perimeter

Any Windows 10 on x86-64 scenario. Given point must be in a 64 bit process.

## Dependencies

The script requires that the target REVEN scenario have:
    * The Fast Search feature replayed.
    * The OSSI feature replayed.
    * An access to the binary 'ntdll.dll' and its PDB file.
"""


class ThreadInfo(object):
    def __init__(self, ctxt):
        self.cr3 = ctxt.read(x64.cr3)
        self.pid = ctxt.read(LogicalAddress(0x40, x64.gs), 4)
        self.tid = ctxt.read(LogicalAddress(0x48, x64.gs), 4)

    def __eq__(self, other):
        return (self.cr3, self.pid, self.tid) == (
            other.cr3,
            other.pid,
            other.tid,
        )

    def __ne__(self, other):
        return not self == other


def context_ring(ctxt):
    return ctxt.read(x64.cs) & 0x3


def all_start_thread_calls(ossi, trace):
    # look for RtlUserThreadStart
    ntdll_dll = next(ossi.executed_binaries("c:/windows/system32/ntdll.dll"))
    rtl_user_thread_start = next(ntdll_dll.symbols("RtlUserThreadStart"))
    return trace.search.symbol(rtl_user_thread_start)


def thread_search_pc(trace, thread_info, pc, from_context=None, to_context=None):
    matches = trace.search.pc(pc, from_context=from_context, to_context=to_context)
    for ctxt in matches:
        # ensure current match is in requested thread
        if ThreadInfo(ctxt) == thread_info:
            yield ctxt


def find_thread_starting_transition(rvn, thread_info):
    for start_thread_ctxt in all_start_thread_calls(rvn.ossi, rvn.trace):
        if ThreadInfo(start_thread_ctxt) == thread_info:
            # the first argument is the start address of the thread
            thread_start_address = start_thread_ctxt.read(x64.rcx)

            matches = thread_search_pc(
                rvn.trace,
                thread_info,
                pc=thread_start_address,
                from_context=start_thread_ctxt,
            )
            for match in matches:
                return match.transition_after()
    return None


def print_thread_info(rvn, tr_id):
    ctxt = rvn.trace.transition(tr_id).context_before()

    if context_ring(ctxt) == 0:
        print("(User) thread may not count in ring 0")
        return

    # pid, tid at the transition
    thread = ThreadInfo(ctxt)

    start_transition = find_thread_starting_transition(rvn, thread)
    if start_transition is None:
        print("TID: {thread.tid} (PID: {thread.pid}) starting transition not found".format(thread=thread))
        return

    print(
        "TID: {thread.tid} (PID: {thread.pid}), starts at: {transition}".format(
            thread=thread, transition=start_transition
        )
    )


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--host",
        type=str,
        default="localhost",
        help='Reven host, as a string (default: "localhost")',
    )
    parser.add_argument(
        "-p",
        "--port",
        type=int,
        default="13370",
        help="Reven port, as an int (default: 13370)",
    )
    parser.add_argument(
        "transition_id",
        metavar="TRANSITION_ID",
        type=int,
        help="Get thread id at transition (before)",
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()

    rvn = reven2.RevenServer(args.host, args.port)
    tr_id = args.transition_id
    print_thread_info(rvn, tr_id)