Detecting critical section deadlocks

This notebook checks for potential deadlocks that may occur when using critical sections as the synchronization primitive.

Prerequisites

  • This notebook should be run in a jupyter notebook server equipped with a REVEN 2 python kernel. REVEN comes with a jupyter notebook server accessible with the Open Python button in the Analyze page of any scenario.
  • The following resources must be replayed for the analyzed scenario:
  • Trace
  • OSSI
  • Fast Search

Limits

  • Only support RtlEnterCriticalSection and RtlLeaveCriticalSection as the lock (resp. unlock) primitive.
  • Locks and unlocks must be not nested: a critical section can be locked and (or) unlocked then, but it must not be relocked multiple times, e.g.
  • lock A, lock B, unlock A, lock A => OK
  • lock A, lock B, lock A => not OK (since A is locked again)

Running

Fill out the parameters in the Parameters cell below, then run all the cells of this notebook.

Source

# ---
# jupyter:
#   jupytext:
#     formats: ipynb,py:percent
#     text_representation:
#       extension: .py
#       format_name: percent
#       format_version: '1.3'
#       jupytext_version: 1.11.2
#   kernelspec:
#     display_name: reven
#     language: python
#     name: reven-python3
# ---

# %% [markdown]
# # Detecting critical section deadlocks
# This notebook checks for potential deadlocks that may occur when using critical sections as the synchronization
# primitive.
#
# ## Prerequisites
#  - This notebook should be run in a jupyter notebook server equipped with a REVEN 2 python kernel.
# REVEN comes with a jupyter notebook server accessible with the `Open Python` button
# in the `Analyze` page of any scenario.
#  - The following resources must be replayed for the analyzed scenario:
#     - Trace
#     - OSSI
#     - Fast Search
#
# ## Limits
#  - Only support `RtlEnterCriticalSection` and `RtlLeaveCriticalSection` as the lock (resp. unlock) primitive.
#  - Locks and unlocks must be not nested: a critical section can be locked and (or) unlocked then, but it must not be
# relocked multiple times, e.g.
#     - lock A, lock B, unlock A, lock A => OK
#     - lock A, lock B, lock A => not OK (since A is locked again)
#
# ## Running
# Fill out the parameters in the [Parameters cell](#Parameters) below, then run all the cells of this notebook.

# %%
# For Python's type annotation
from typing import Dict, Iterator, List, Optional, Set, Tuple

# Reven specific
import reven2
from reven2.address import LogicalAddress
from reven2.arch import x64
from reven2.trace import Context, Trace, Transition
from reven2.util import collate

# %% [markdown]
# # Parameters

# %%
# Host and port of the running the scenario
host = '127.0.0.1'
port = 35083


# The PID and (or) the name of the binary of interest: if the binary name is given (not None),
# only locks and unlocks called directly from the binary are counted.
pid = 2044
binary = None

# The begin and end transition numbers between them the deadlock detection processes.
begin_trans_id = None  # if None, start from the first transition of the trace
end_trans_id = None    # if None, stop at the last transition


# %%
# Helper class which wraps Reven's runtime objects and give methods helping get information about calls to
# RtlEnterCriticalSection and RtlLeaveCriticalSection
class RuntimeHelper:
    def __init__(self, host: str, port: int):
        try:
            server = reven2.RevenServer(host, port)
        except RuntimeError:
            raise RuntimeError(f'Cannot connect to the scenario at {host}:{port}')

        self.trace = server.trace
        self.ossi = server.ossi
        self.search_symbol = self.trace.search.symbol
        self.search_binary = self.trace.search.binary

        try:
            ntdll = next(self.ossi.executed_binaries('^c:/windows/system32/ntdll.dll$'))
        except StopIteration:
            raise RuntimeError('ntdll.dll not found')

        try:
            self.__rtl_enter_critical_section = next(ntdll.symbols("^RtlEnterCriticalSection$"))
            self.__rtl_leave_critical_section = next(ntdll.symbols("^RtlLeaveCriticalSection$"))
        except StopIteration:
            raise RuntimeError('Rtl(Enter|Leave)CriticalSection symbol not found')

    def get_critical_section_locks(self, from_context: Context, to_context: Context) -> Iterator[Context]:
        # For ease, if the from_context is the first context of the trace and also the beginning of a lock,
        # then omit it since the caller is unknown
        if from_context == self.trace.first_context:
            if from_context == self.trace.last_context:
                return
            else:
                from_context = from_context + 1

        for ctxt in self.search_symbol(self.__rtl_enter_critical_section, from_context, to_context):
            # Any context correspond to the entry of RtlEnterCriticalSection, since the first context
            # does not count, decrease by 1 is safe.
            yield ctxt - 1

    def get_critical_section_unlocks(self, from_context: Context, to_context: Context) -> Iterator[Context]:
        # For ease, if the from_context is the first context of the trace and also the beginning of an unlock,
        # then omit it since the caller is unknown
        if from_context == self.trace.first_context:
            if from_context == self.trace.last_context:
                return
            else:
                from_context = from_context + 1

        for ctxt in self.search_symbol(self.__rtl_leave_critical_section, from_context, to_context):
            # Any context correspond to the entry of RtlLeaveCriticalSection, since the first context
            # does not count, decrease by 1 is safe.
            yield ctxt - 1

    def get_critical_section_handle(ctxt: Context) -> int:
        return ctxt.read(x64.rcx)

    def thread_id(ctxt: Context) -> int:
        return ctxt.read(LogicalAddress(0x48, x64.gs), 4)

    def is_kernel_mode(ctxt: Context) -> bool:
        return ctxt.read(x64.cs) & 0x3 == 0


# %%
# Find the lower/upper bound of contexts on which the deadlock detection processes
def find_begin_end_context(sco: RuntimeHelper, pid: int, binary: Optional[str],
                           begin_id: Optional[int], end_id: Optional[int]) -> Tuple[Context, Context]:
    begin_ctxt = None
    if begin_id is not None:
        try:
            begin_trans = sco.trace.transition(begin_id)
            begin_ctxt = begin_trans.context_after()
        except IndexError:
            begin_ctxt = None

    if begin_ctxt is None:
        if binary is not None:
            for name in sco.ossi.executed_binaries(binary):
                for ctxt in sco.search_binary(name):
                    if ctxt.ossi.process().pid == pid:
                        begin_ctxt = ctxt
                        break
                if begin_ctxt is not None:
                    break

    if begin_ctxt is None:
        begin_ctxt = sco.trace.first_context

    end_ctxt = None
    if end_id is not None:
        try:
            end_trans = sco.trace.transition(end_id)
            end_ctxt = end_trans.context_before()
        except IndexError:
            end_ctxt = None

    if end_ctxt is None:
        end_ctxt = sco.trace.last_context

    if (end_ctxt <= begin_ctxt):
        raise RuntimeError("The begin transition must be smaller than the end.")

    return (begin_ctxt, end_ctxt)


# Get all execution contexts of a given process
def find_process_ranges(sco: RuntimeHelper, pid: int, first_ctxt: Context, last_context: Context) \
        -> Iterator[Tuple[Context, Context]]:
    ctxt_low = first_ctxt
    ctxt_high: Optional[Context] = None

    while True:
        current_pid = ctxt_low.ossi.process().pid
        ctxt_high = ctxt_low.find_register_change(x64.cr3, is_forward=True)

        if ctxt_high is None:
            if current_pid == pid:
                if ctxt_low < last_context - 1:
                    yield (ctxt_low, last_context - 1)
            break

        if ctxt_high >= last_context:
            break

        if current_pid == pid:
            yield (ctxt_low, ctxt_high)

        ctxt_low = ctxt_high + 1


# Start from a transition, return the first transition that is not a non-instruction, or None if there isn't one.
def ignore_non_instructions(trans: Transition, trace: Trace) -> Optional[Transition]:
    while trans.instruction is None:
        if trans == trace.last_transition:
            return None
        trans = trans + 1

    return trans


# Extract user mode only context ranges from a context range (which may include also kernel mode ranges)
def find_usermode_ranges(sco: RuntimeHelper, ctxt_low: Context, ctxt_high: Context) \
        -> Iterator[Tuple[Context, Context]]:
    trans = ignore_non_instructions(ctxt_low.transition_after(), sco.trace)
    if trans is None:
        return
    ctxt_current = trans.context_before()

    while ctxt_current < ctxt_high:
        ctxt_next = ctxt_current.find_register_change(x64.cs, is_forward=True)

        if not RuntimeHelper.is_kernel_mode(ctxt_current):
            if ctxt_next is None or ctxt_next > ctxt_high:
                yield (ctxt_current, ctxt_high)
                break
            else:
                # It's safe to decrease ctxt_next by 1 because it was obtained from a forward find_register_change
                yield (ctxt_current, ctxt_next - 1)

        if ctxt_next is None:
            break

        ctxt_current = ctxt_next


# Get user mode only execution contexts of a given process
def find_process_usermode_ranges(trace: RuntimeHelper, pid: int, first_ctxt: Context, last_ctxt: Context) \
        -> Iterator[Tuple[Context, Context]]:
    for (ctxt_low, ctxt_high) in find_process_ranges(trace, pid, first_ctxt, last_ctxt):
        usermode_ranges = find_usermode_ranges(trace, ctxt_low, ctxt_high)
        for usermode_range in usermode_ranges:
            yield usermode_range


def context_is_in_binary(ctxt: Optional[Context], binary: Optional[str]) -> bool:
    if ctxt is None:
        return False

    if binary is None:
        return True

    ctxt_loc = ctxt.ossi.location()
    if ctxt_loc is None:
        return False

    ctxt_binary = ctxt_loc.binary
    if ctxt_binary is not None:
        return (binary in [ctxt_binary.name, ctxt_binary.filename, ctxt_binary.path])

    return False


# Get locks (i.e. RtlEnterCriticalSection) called by the binary in a range of context (defined by ctxt_low and
# ctxt_high).
# Note that for a process (of a binary), there are also locks called by libraries loaded by PE loader, such calls
# are considered "uninteresting" if the binary name is given.
def get_in_binary_locks(sco: RuntimeHelper, ctxt_low: Context, ctxt_high: Context, binary: Optional[str]) \
        -> Iterator[Context]:
    for ctxt in sco.get_critical_section_locks(ctxt_low, ctxt_high):
        if context_is_in_binary(ctxt, binary):
            yield ctxt


# Get unlocks (i.e. RtlLeaveCriticalSection) called by the binary in a range of context (defined by ctxt_low and
# ctxt_high).
# Note that for a process (of a binary), there are also unlocks called by libraries loaded by PE loader, such calls
# are considered "uninteresting" if the binary name is given.
def get_in_binary_unlocks(sco: RuntimeHelper, ctxt_low: Context, ctxt_high: Context, binary: Optional[str]) \
        -> Iterator[Context]:
    for ctxt in sco.get_critical_section_unlocks(ctxt_low, ctxt_high):
        if context_is_in_binary(ctxt, binary):
            yield ctxt


# Sort lock and unlock contexts (called in a range of contexts) in a correct order.
# Return a generator of pairs (bool, Context): True for lock, False for unlock.
def get_in_binary_locks_unlocks(sco: RuntimeHelper, ctxt_low: Context, ctxt_high: Context, binary: Optional[str]) \
        -> Iterator[Tuple[bool, Context]]:
    def generate_locks():
        for ctxt in get_in_binary_locks(sco, ctxt_low, ctxt_high, binary):
            yield (True, ctxt)

    def generate_unlocks():
        for ctxt in get_in_binary_unlocks(sco, ctxt_low, ctxt_high, binary):
            yield (False, ctxt)

    return collate([generate_locks(), generate_unlocks()], key=lambda bool_context: bool_context[1])


# Generate all locks and unlocks called by the binary
def get_thread_usermode_in_binary_locks_unlocks(sco: RuntimeHelper, ranges: List[Tuple[Context, Context]],
                                                binary: Optional[str]) -> Iterator[Tuple[bool, Context]]:
    for (ctxt_low, ctxt_high) in ranges:
        for lock_unlock in get_in_binary_locks_unlocks(sco, ctxt_low, ctxt_high, binary):
            yield lock_unlock


# %%
# Check for RAII in locks/unlocks (i.e. critical sections are unlocked in reverse order of lock), warn if RAII is
# violated as it is good practice.
#
# In synchronization with critical sections, RAII is an idiomatic technique used to make the lock order consistent,
# then avoid deadlock in case of there is a total order of locks between threads. For example: the following threads
# where the lock/unlock of critical sections A and B follows RAII
#  - thread 0: lock A, lock B, unlock B, unlock A, lock A, lock B
#  - thread 1: lock A, lock B, unlock B, unlock A, lock A, lock B
# are deadlock free. But
#  - thread 0: lock A, lock B, unlock A, lock A, unlock B
#  - thread 1: lock A, lock B, unlock A, lock A, unlock B
# have deadlock. Let's consider the interleaving:
#  - lock A, lock B, unlock A (thread 0), lock A (thread 1)
# now thread 1 try to lock B (but it cannot since B is already locked by thread 0), thread 0 cannot unlock B neither
# since it needs to lock A first (but it cannot since A is already locked by thread 1).
#
# Note that RAII cannot guarantee the deadlock free synchronization if the condition about single total order of
# critical sections is not satisfied.
def check_intrathread_lock_unlock_matching(trace: RuntimeHelper, thread_ranges: List[Tuple[Context, Context]],
                                           binary: Optional[str]):
    locks_unlocks = get_thread_usermode_in_binary_locks_unlocks(trace, thread_ranges, binary)

    # lock/unlock should (though not obliged) follow RAII
    corresponding_lock_stack: List[Context] = []
    mismatch_lock_unlock_pcs: Set[Tuple[int, int]] = set()
    mismatch_unlocks_pcs: Set[int] = set()
    ok = True
    for (is_lock, ctxt) in locks_unlocks:
        if is_lock:
            # push lock context
            corresponding_lock_stack.append(ctxt)
        else:
            if corresponding_lock_stack:
                last_lock_ctxt = corresponding_lock_stack[-1]

                last_lock_handle = RuntimeHelper.get_critical_section_handle(last_lock_ctxt)
                current_unlock_handle = RuntimeHelper.get_critical_section_handle(ctxt)

                if last_lock_handle == current_unlock_handle:
                    # lock and unlock on the same critical section
                    corresponding_lock_stack.pop()
                else:
                    # It's safe to decrease by 1 since the first context of the trace is never counted as a lock
                    # nor unlock (c.f. RuntimeHelper::get_critical_section_(locks|unlocks)).
                    in_binary_lock_pc = last_lock_ctxt.read(x64.rip)
                    in_binary_unlock_pc = ctxt.read(x64.rip)

                    if (in_binary_lock_pc, in_binary_unlock_pc) in mismatch_lock_unlock_pcs:
                        continue

                    mismatch_lock_unlock_pcs.add((in_binary_lock_pc, in_binary_unlock_pc))

                    print(f'Warning:\n\t#{last_lock_ctxt.transition_after().id}: lock at 0x{in_binary_lock_pc:x} \
(on critical section handle 0x{last_lock_handle:x}) followed by\n\t#{ctxt.transition_after().id}: \
unlock at 0x{in_binary_unlock_pc:x} (on different critical section handle 0x{current_unlock_handle:x})')

                    ok = False
            else:
                in_binary_unlock_pc = ctxt.read(x64.rip)
                if in_binary_unlock_pc in mismatch_unlocks_pcs:
                    continue

                mismatch_unlocks_pcs.add(in_binary_unlock_pc)

                print(f'Warning:\n\t#{ctxt.transition_after().id}: unlock at  \
(on 0x{current_unlock_handle:x}) without any lock')
                ok = False
    if ok:
        print('OK')


# Build a dependency graph of locks: a lock A is followed by a lock B if A is still locked when B is locked.
# For example:
#  - lock A, lock B => A followed by B
#  - lock A, unlock A, lock B => A is not followed by B (since A is already unlocked when B is locked)
def build_locks_unlocks_order_graph_next(locks_unlocks: Iterator[Tuple[bool, Context]]) \
        -> Tuple[Dict[int, List[int]], Dict[Tuple[int, int], List[Tuple[Context, Context]]]]:

    order_graph: Dict[int, List[int]] = {}
    order_graph_label: Dict[Tuple[int, int], List[Tuple[Context, Context]]] = {}

    lock_stack: List[Context] = []
    for (is_lock, ctxt) in locks_unlocks:
        if not lock_stack:
            if is_lock:
                lock_stack.append(ctxt)
            continue

        current_lock_unlock_handle = RuntimeHelper.get_critical_section_handle(ctxt)
        current_lock_unlock_threadid = RuntimeHelper.thread_id(ctxt)
        if not is_lock:
            # looking for the last lock in the stack
            i = len(lock_stack) - 1
            while True:
                lock_handle_i = RuntimeHelper.get_critical_section_handle(lock_stack[i])
                lock_threadid_i = RuntimeHelper.thread_id(lock_stack[i])
                if lock_handle_i == current_lock_unlock_handle and lock_threadid_i == current_lock_unlock_threadid:
                    del lock_stack[i]
                    break
                if i == 0:
                    break
                i -= 1
            continue

        last_lock_ctxt = lock_stack[-1]
        # check of the last lock and the current lock are in the same thread
        if RuntimeHelper.thread_id(last_lock_ctxt) == RuntimeHelper.thread_id(ctxt):
            # create the edge: last_lock -> current_lock
            last_lock_handle = RuntimeHelper.get_critical_section_handle(last_lock_ctxt)
            if last_lock_handle not in order_graph:
                order_graph[last_lock_handle] = []
            order_graph[last_lock_handle].append(current_lock_unlock_handle)

            # create (or update) the label of the edge
            if (last_lock_handle, current_lock_unlock_handle) not in order_graph_label:
                order_graph_label[(last_lock_handle, current_lock_unlock_handle)] = []
            order_graph_label[(last_lock_handle, current_lock_unlock_handle)].append((last_lock_ctxt, ctxt))

        lock_stack.append(ctxt)

    return (order_graph, order_graph_label)


# Check if there are cycles in the lock dependency graph, such a cycle is considered a potential deadlock.
def check_order_graph_cycle(graph: Dict[int, List[int]], labels: Dict[Tuple[int, int], List[Tuple[Context, Context]]]):
    def dfs(path: List[Tuple[int, int, int]], starting_node: int, visited_nodes: Set[int]) \
            -> Optional[List[List[Tuple[int, int, int]]]]:
        if starting_node not in graph:
            return None

        next_nodes_to_visit = set(graph[starting_node]) - visited_nodes
        if not next_nodes_to_visit:
            return None

        nodes_on_path = set()
        tids_on_path = set()
        for (hd, tl, tid) in path:
            nodes_on_path.add(hd)
            nodes_on_path.add(tl)
            tids_on_path.add(tid)

        # check if we can build a cycle of locks by trying visiting a node
        back_nodes = next_nodes_to_visit & nodes_on_path
        for node in back_nodes:
            back_node = node
            tids_starting_back = set()
            for (ctxt_hd, _) in labels[(starting_node, back_node)]:
                tids_starting_back.add(RuntimeHelper.thread_id(ctxt_hd))

            # sub-path starting from the back-node to the starting-node
            sub_path = []
            sub_path_tids = set()
            for (hd, tl, tid) in path:
                if hd == node:
                    sub_path.append((hd, tl, tid))
                    sub_path_tids.add(tid)
                    node = tl

                if tl == starting_node:
                    diff_tids = tids_starting_back - sub_path_tids
                    # there is an edge whose TID is not on the sub-path yet
                    cycles = []
                    if diff_tids:
                        for tid in diff_tids:
                            for (ctxt_hd, _) in labels[(starting_node, back_node)]:
                                if RuntimeHelper.thread_id(ctxt_hd) == tid:
                                    sub_path.append((starting_node, back_node, tid))
                                    cycles.append(sub_path)
                                    break
                        return cycles
                    else:
                        return None

        for next_node in next_nodes_to_visit:
            tids = set()
            for (ctxt_hd, _,) in labels[(starting_node, next_node)]:
                tid = RuntimeHelper.thread_id(ctxt_hd)
                tids.add(tid)

            tids_to_visit = tids - tids_on_path
            if tids_to_visit:
                for tid_to_visit in tids_to_visit:
                    for (ctxt_hd, _) in labels[(starting_node, next_node)]:
                        if RuntimeHelper.thread_id(ctxt_hd) == tid_to_visit:
                            next_path = path
                            next_path.append((starting_node, next_node, tid))
                            visited_nodes.add(next_node)
                            some_cycles = dfs(next_path, next_node, visited_nodes)
                            if some_cycles is not None:
                                return some_cycles

        return None

    def compare_cycles(c0: List[Tuple[int, int, int]], c1: List[Tuple[int, int, int]]) -> bool:
        if len(c0) != len(c1):
            return False

        if len(c0) == 0:
            return True

        def circular_generator(c: List[Tuple[int, int, int]], e: Tuple[int, int, int]) \
                -> Optional[Iterator[Tuple[int, int, int]]]:
            clen = len(c)
            i = 0
            for elem in c:
                if elem == e:
                    while True:
                        yield c[i]
                        i = i + 1
                        if i == clen:
                            i = 0
                i = i + 1

            return None

        c0_gen = circular_generator(c0, c0[0])
        c1_gen = circular_generator(c1, c0[0])

        if c1_gen is None or c0_gen is None:
            return False

        i = 0
        while i < len(c0):
            e0 = next(c0_gen)
            e1 = next(c1_gen)
            if e0 != e1:
                return False
            i = i + 1

        return True

    ok = True
    distinct_cycles: List[List[Tuple[int, int, int]]] = []
    for node in graph:
        cycles = dfs([], node, set())
        if cycles is None or not cycles:
            continue

        for cycle in cycles:
            duplicated = False
            if not distinct_cycles:
                duplicated = False
            else:
                for ec in distinct_cycles:
                    if compare_cycles(ec, cycle):
                        duplicated = True
                        break
            if duplicated:
                continue

            distinct_cycles.append(cycle)

            print('Potential deadlock(s):')
            for (node, next_node, tid) in cycle:
                label = labels[(node, next_node)]
                distinct_labels: Set[Tuple[int, int]] = set()
                for (node_ctxt, next_node_ctxt) in label:
                    in_binary_lock_pc = (node_ctxt - 1).read(x64.rip)
                    in_binary_next_lock_pc = (next_node_ctxt - 1).read(x64.rip)

                    if (in_binary_lock_pc, in_binary_next_lock_pc) in distinct_labels:
                        continue

                    distinct_labels.add((in_binary_lock_pc, in_binary_next_lock_pc))

                    print(f'\t#{node_ctxt.transition_after().id}: lock at 0x{in_binary_lock_pc:x} \
(on thread {RuntimeHelper.thread_id(node_ctxt)}, critical section handle \
0x{RuntimeHelper.get_critical_section_handle(node_ctxt):x}) followed by\n\t\
#{next_node_ctxt.transition_after().id}: lock at 0x{in_binary_next_lock_pc:x} \
(on thread {RuntimeHelper.thread_id(next_node_ctxt)}, \
critical section handle 0x{RuntimeHelper.get_critical_section_handle(next_node_ctxt):x})')
                    ok = False
                print(
                    '\t=============================================================='
                )
    if ok:
        print('Not found.')
    return None


# Get user mode locks (and unlocks) of called by the binary, build the dependency graph, then check the cycles.
def check_lock_cycle(trace: RuntimeHelper, threads_ranges: List[Tuple[Context, Context]], binary: Optional[str]):
    locks_unlocks = get_thread_usermode_in_binary_locks_unlocks(trace, threads_ranges, binary)
    (order_graph, order_graph_labels) = build_locks_unlocks_order_graph_next(locks_unlocks)
    check_order_graph_cycle(order_graph, order_graph_labels)


# Combination of checking lock/unlock RAII and deadlock
def detect_deadlocks(trace: RuntimeHelper, pid: int, binary: Optional[str],
                     first_id: Optional[int], last_id: Optional[int]):
    (first_context, last_context) = find_begin_end_context(trace, pid, binary, first_id, last_id)
    process_ranges = list(find_process_usermode_ranges(trace, pid, first_context, last_context))
    thread_ranges: Dict[int, List[Tuple[Context, Context]]] = {}

    for (ctxt_low, ctxt_high) in process_ranges:
        tid = RuntimeHelper.thread_id(ctxt_low)
        if tid not in thread_ranges:
            thread_ranges[tid] = []
        thread_ranges[tid].append((ctxt_low, ctxt_high))

    for (tid, ranges) in thread_ranges.items():
        print('\n============ checking lock/unlock matching on thread {} ============'.format(tid))
        check_intrathread_lock_unlock_matching(trace, ranges, binary)

    print('\n\n============ checking potential deadlocks on process ============')
    check_lock_cycle(trace, process_ranges, binary)


# %%
trace = RuntimeHelper(host, port)
detect_deadlocks(trace, pid, binary, begin_trans_id, end_trans_id)