Detect data race

This notebook checks for possible data race that may occur when using critical sections as the synchronization primitive.

Prerequisites

  • Supported versions:
  • REVEN 2.9+
  • This notebook should be run in a jupyter notebook server equipped with a REVEN 2 Python kernel. REVEN comes with a Jupyter notebook server accessible with the Open Python button in the Analyze page of any scenario.
  • The following resources are needed for the analyzed scenario:
  • Trace
  • OSSI
  • Memory History
  • Backtrace (Stack Events included)
  • Fast Search

Perimeter:

  • Windows 10 64-bit

Limits

  • Only support the following critical section API(s) as the lock/unlock operations.
  • RtlEnterCriticalSection, RtlTryEnterCriticalSection
  • RtlLeaveCriticalSection

Running

Fill out the parameters in the Parameters cell below, then run all the cells of this notebook.

Source

# ---
# jupyter:
#   jupytext:
#     formats: ipynb,py:percent
#     text_representation:
#       extension: .py
#       format_name: percent
#   kernelspec:
#     display_name: reven
#     language: python
#     name: reven-python3
# ---

# %% [markdown]
# # Detect data race
# This notebook checks for possible data race that may occur when using critical sections as the synchronization
# primitive.
#
# ## Prerequisites
#  - Supported versions:
#     - REVEN 2.9+
#  - This notebook should be run in a jupyter notebook server equipped with a REVEN 2 Python kernel.
# REVEN comes with a Jupyter notebook server accessible with the `Open Python` button in the `Analyze`
# page of any scenario.
#  - The following resources are needed for the analyzed scenario:
#     - Trace
#     - OSSI
#     - Memory History
#     - Backtrace (Stack Events included)
#     - Fast Search
#
# ## Perimeter:
#  - Windows 10 64-bit
#
# ## Limits
#  - Only support the following critical section API(s) as the lock/unlock operations.
#     - `RtlEnterCriticalSection`, `RtlTryEnterCriticalSection`
#     - `RtlLeaveCriticalSection`
#
#
# ## Running
# Fill out the parameters in the [Parameters cell](#Parameters) below, then run all the cells of this notebook.


# %%
from dataclasses import dataclass
from enum import Enum
from typing import Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union


# Reven specific
import reven2
from reven2.address import LinearAddress, LogicalAddress
from reven2.arch import x64
from reven2.memhist import MemoryAccess, MemoryAccessOperation
from reven2.ossi.ossi import Symbol
from reven2.trace import Context, Trace, Transition
from reven2.util import collate

# %% [markdown]
# # Parameters

# %%
# Host and port of the running the scenario.
host = '127.0.0.1'
port = 41309

# The PID and the name of the binary of interest (optional): if the binary name is given (i.e. not None), then only
# locks and unlocks called directly from the binary are counted.
pid = 2460
binary = None

begin_trans_id = None
end_trans_id = None

# Do not show common memory accesses (from different threads) which are synchronized (i.e. mutually excluded) by some
# critical section(s).
hide_synchronized_accesses = True

# Do not show common memory accesses (from different threads) which are dynamically free from critical section(s).
suppress_unknown_primitives = False


# %%
# Helper class which wraps Reven's runtime objects and give methods helping get information about calls to
# RtlEnterCriticalSection and RtlLeaveCriticalSection
class RuntimeHelper:
    def __init__(self, host: str, port: int):
        try:
            server = reven2.RevenServer(host, port)
        except RuntimeError:
            raise RuntimeError(f'Cannot connect to the scenario at {host}:{port}')

        self.trace = server.trace
        self.ossi = server.ossi
        self.search_symbol = self.trace.search.symbol
        self.search_binary = self.trace.search.binary

        # basic OSSI
        bin_symbol_names = {
            'c:/windows/system32/ntoskrnl.exe': {
                'KiSwapContext'
            },

            'c:/windows/system32/ntdll.dll': {
                # critical section
                'RtlInitializeCriticalSection',
                'RtlInitializeCriticalSectionEx',
                'RtlInitializeCriticalSectionAndSpinCount',
                'RtlEnterCriticalSection',
                'RtlTryEnterCriticalSection',
                'RtlLeaveCriticalSection',

                # initialize/shutdown thread
                'LdrpInitializeThread',
                'LdrShutdownThread'
            },

            'c:/windows/system32/basesrv.dll': {
                'BaseSrvCreateThread'
            },

            'c:/windows/system32/csrsrv.dll': {
                'CsrCreateThread',
                'CsrThreadRefcountZero',
                'CsrDereferenceThread'
            }
        }

        self.symbols: Dict[str, Optional[Symbol]] = {}
        for (bin, symbol_names) in bin_symbol_names.items():
            try:
                exec_bin = next(self.ossi.executed_binaries(f'^{bin}$'))
            except StopIteration:
                if bin == 'c:/windows/system32/ntoskrnl.exe':
                    raise RuntimeError(f'{bin} not found')
                exec_bin = None

            for name in symbol_names:
                if exec_bin is None:
                    self.symbols[name] = None
                    continue

                try:
                    sym = next(exec_bin.symbols(f'^{name}$'))
                    self.symbols[name] = sym
                except StopIteration:
                    msg = f'{name} not found in {bin}'

                    if name in {
                        'KiSwapContext',
                        'RtlEnterCriticalSection', 'RtlTryEnterCriticalSection', 'RtlLeaveCriticalSection'
                    }:
                        raise RuntimeError(msg)
                    else:
                        self.symbols[name] = None
                        print(f'Warning: {msg}')

        self.has_debugger = True
        try:
            self.trace.first_transition.step_over()
        except RuntimeError:
            print('Warning: the debugger interface is not available, so the script cannot determine \
function return values.\nMake sure the stack events and PC range resources are replayed for this scenario.')
            self.has_debugger = False

    def get_memory_accesses(self, from_context: Context, to_context: Context) -> Iterator[MemoryAccess]:
        try:
            from_trans = from_context.transition_after()
            to_trans = to_context.transition_before()
        except IndexError:
            return

        accesses = self.trace.memory_accesses(from_transition=from_trans, to_transition=to_trans)
        for access in accesses:
            # Skip the access without virtual address
            if access.virtual_address is None:
                continue

            # Skip the access without instruction
            if access.transition.instruction is None:
                continue

            # Skip the access of `lock` prefixed instruction
            ins_bytes = access.transition.instruction.raw
            if ins_bytes[0] == 0xf0:
                continue

            yield access

    @staticmethod
    def get_lock_handle(ctxt: Context) -> int:
        return ctxt.read(x64.rcx)

    @staticmethod
    def thread_id(ctxt: Context) -> int:
        return ctxt.read(LogicalAddress(0x48, x64.gs), 4)

    @staticmethod
    def is_kernel_mode(ctxt: Context) -> bool:
        return ctxt.read(x64.cs) & 0x3 == 0


# %%
# Look for a possible first execution context of a binary
# Find the lower/upper bound of contexts on which the deadlock detection processes
def find_begin_end_context(sco: RuntimeHelper, pid: int, binary: Optional[str],
                           begin_id: Optional[int], end_id: Optional[int]) -> Tuple[Context, Context]:
    begin_ctxt = None
    if begin_id is not None:
        try:
            begin_trans = sco.trace.transition(begin_id)
            begin_ctxt = begin_trans.context_after()
        except IndexError:
            begin_ctxt = None

    if begin_ctxt is None:
        if binary is not None:
            for name in sco.ossi.executed_binaries(binary):
                for ctxt in sco.search_binary(name):
                    ctx_process = ctxt.ossi.process()
                    assert ctx_process is not None
                    if ctx_process.pid == pid:
                        begin_ctxt = ctxt
                        break
                if begin_ctxt is not None:
                    break

    if begin_ctxt is None:
        begin_ctxt = sco.trace.first_context

    end_ctxt = None
    if end_id is not None:
        try:
            end_trans = sco.trace.transition(end_id)
            end_ctxt = end_trans.context_before()
        except IndexError:
            end_ctxt = None

    if end_ctxt is None:
        end_ctxt = sco.trace.last_context

    if (end_ctxt <= begin_ctxt):
        raise RuntimeError("The begin transition must be smaller than the end.")

    return (begin_ctxt, end_ctxt)


# Get all execution contexts of a given process
def find_process_ranges(sco: RuntimeHelper, pid: int, first_context: Context, last_context: Context) \
        -> Iterator[Tuple[Context, Optional[Context]]]:
    if last_context <= first_context:
        return iter(())

    ctxt_low = first_context
    assert sco.symbols['KiSwapContext'] is not None
    ki_swap_context = sco.symbols['KiSwapContext']
    for ctxt in sco.search_symbol(ki_swap_context, from_context=first_context,
                                  to_context=None if last_context == sco.trace.last_context else last_context):
        ctx_process = ctxt_low.ossi.process()
        assert ctx_process is not None

        if ctx_process.pid == pid:
            yield (ctxt_low, ctxt)
        ctxt_low = ctxt

    ctx_process = ctxt_low.ossi.process()
    assert ctx_process is not None

    if ctx_process.pid == pid:
        if ctxt_low < last_context:
            yield (ctxt_low, last_context)
        else:
            # So ctxt_low == last_context, using None for the upper bound.
            # This happens only when last_context is in the process, and is also the last context of the trace.
            yield (ctxt_low, None)


# Start from a transition, return the first transition that is not a non-instruction, or None if there isn't one.
def ignore_non_instructions(trans: Transition, trace: Trace) -> Optional[Transition]:
    while trans.instruction is None:
        if trans == trace.last_transition:
            return None
        trans = trans + 1

    return trans


# Extract user mode only context ranges from a context range (which may include also kernel mode ranges)
def find_usermode_ranges(sco: RuntimeHelper, ctxt_low: Context, ctxt_high: Optional[Context]) \
        -> Iterator[Tuple[Context, Context]]:
    if ctxt_high is None:
        return

    trans = ignore_non_instructions(ctxt_low.transition_after(), sco.trace)
    if trans is None:
        return
    ctxt_current = trans.context_before()

    while ctxt_current < ctxt_high:
        ctxt_next = ctxt_current.find_register_change(x64.cs, is_forward=True)

        if not RuntimeHelper.is_kernel_mode(ctxt_current):
            if ctxt_next is None or ctxt_next > ctxt_high:
                yield (ctxt_current, ctxt_high)
                break
            else:
                # It's safe to decrease ctxt_next by 1 because it was obtained from a forward find_register_change
                yield (ctxt_current, ctxt_next - 1)

        if ctxt_next is None:
            break

        ctxt_current = ctxt_next


# Get user mode only execution contexts of a given process
def find_process_usermode_ranges(trace: RuntimeHelper, pid: int, first_ctxt: Context, last_ctxt: Context) \
        -> Iterator[Tuple[Context, Context]]:
    for (ctxt_low, ctxt_high) in find_process_ranges(trace, pid, first_ctxt, last_ctxt):
        usermode_ranges = find_usermode_ranges(trace, ctxt_low, ctxt_high)
        for usermode_range in usermode_ranges:
            yield usermode_range


def build_ordered_api_calls(sco: RuntimeHelper, binary: Optional[str],
                            ctxt_low: Context, ctxt_high: Context, apis: List[str]) -> Iterator[Tuple[str, Context]]:
    def gen(api):
        return (
            (api, ctxt)
            for ctxt in sco.search_symbol(sco.symbols[api], from_context=ctxt_low,
                                          to_context=None if ctxt_high == sco.trace.last_context else ctxt_high)
        )

    api_contexts = (
        gen(api)
        for api in apis if api in sco.symbols
    )

    if binary is None:
        for api_ctxt in collate(api_contexts, key=lambda name_ctxt: name_ctxt[1]):
            yield api_ctxt
    else:
        for (api, ctxt) in collate(api_contexts, key=lambda name_ctxt: name_ctxt[1]):
            try:
                caller_ctxt = ctxt - 1
            except IndexError:
                continue

            caller_location = caller_ctxt.ossi.location()
            if caller_location is None:
                continue

            caller_binary = caller_location.binary
            if caller_binary is None:
                continue

            if binary in [caller_binary.name, caller_binary.filename, caller_binary.path]:
                yield (api, ctxt)


def get_return_value(sco: RuntimeHelper, ctxt: Context) -> Optional[int]:
    if not sco.has_debugger:
        return None

    try:
        trans_after = ctxt.transition_after()
    except IndexError:
        return None

    if trans_after is None:
        return None

    trans_ret = trans_after.step_out()
    if trans_ret is None:
        return None

    ctxt_ret = trans_ret.context_after()
    return ctxt_ret.read(x64.rax)


class SynchronizationAction(Enum):
    LOCK = 1
    UNLOCK = 0


def get_locks_unlocks(sco: RuntimeHelper, binary: Optional[str], ranges: List[Tuple[Context, Context]]) \
        -> List[Tuple[SynchronizationAction, Context]]:

    lock_unlock_apis = {
        'RtlEnterCriticalSection': SynchronizationAction.LOCK,
        'RtlLeaveCriticalSection': SynchronizationAction.UNLOCK,
        'RtlTryEnterCriticalSection': None,
    }

    critical_section_actions = []

    for ctxt_low, ctxt_high in ranges:
        for name, ctxt in build_ordered_api_calls(sco, binary, ctxt_low, ctxt_high, list(lock_unlock_apis.keys())):
            # either lock, unlock, or none
            action = lock_unlock_apis[name]

            if action is None:
                # need to check the return value of the API to get the action
                api_ret_val = get_return_value(sco, ctxt)
                if api_ret_val is None:
                    print(f'Warning: failed to get the return value, {name} is omitted')
                    continue
                else:
                    # RtlTryEnterCriticalSection: the return value is nonzero if the thread is success
                    # to enter the critical section
                    if api_ret_val != 0:
                        action = SynchronizationAction.LOCK
                    else:
                        continue

            critical_section_actions.append((action, ctxt))

    return critical_section_actions


# Get locks which are effective at a transition
def get_live_locks(sco: RuntimeHelper,
                   locks_unlocks: List[Tuple[SynchronizationAction, Context]], trans: Transition) -> List[Context]:
    protection_locks: List[Context] = []
    trans_ctxt = trans.context_before()

    for (action, ctxt) in locks_unlocks:
        if ctxt > trans_ctxt:
            return protection_locks

        if action == SynchronizationAction.LOCK:
            protection_locks.append(ctxt)
            continue

        unlock_handle = RuntimeHelper.get_lock_handle(ctxt)
        # look for the lastest corresponding lock
        for (idx, lock_ctxt) in reversed(list(enumerate(protection_locks))):
            lock_handle = RuntimeHelper.get_lock_handle(lock_ctxt)
            if lock_handle == unlock_handle:
                del protection_locks[idx]
                break

    return protection_locks


# %%
AccessType = TypeVar("AccessType")


@dataclass
class MemorySegment:
    address: Union[LinearAddress, LogicalAddress]
    size: int


@dataclass
class MemorySegmentAccess(Generic[AccessType]):
    segment: MemorySegment
    accesses: List[AccessType]


# Insert a segment into a current list of segments, start at position; if the position is None then start from the
# head of the list.
# Precondition: `new_segment.segment.size > 0`
# Return the position where the segment is inserted.
# The complexity of the insertion is about O(n).
def insert_memory_segment_access(new_segment: MemorySegmentAccess, segments: List[MemorySegmentAccess],
                                 position: Optional[int]) -> int:
    new_address = new_segment.segment.address
    new_size = new_segment.segment.size
    new_accesses = new_segment.accesses

    if not segments:
        segments.append(MemorySegmentAccess(MemorySegment(new_address, new_size), new_accesses))
        return 0

    if position is None:
        position = 0

    index = position
    first_found_index = None
    while index < len(segments):
        # Loop invariant `new_size > 0`
        #  - True at the first iteration
        #  - In the subsequent iterations, either
        #     - the invariant is always kept through cases (2), (5), (6), (7)
        #     - the loop returns directly in cases (1), (3), (4)

        address = segments[index].segment.address
        size = segments[index].segment.size
        accesses = segments[index].accesses

        if new_address < address:
            # Case 1
            #                   |--------------|
            # |--------------|
            if new_address + new_size <= address:
                segments.insert(index,
                                MemorySegmentAccess(MemorySegment(new_address, new_size), new_accesses))

                return first_found_index if first_found_index is not None else index
            # Case 2
            #              |--------------|
            # |--------------|
            else:
                # Insert the different part of the new segment
                segments.insert(index,
                                MemorySegmentAccess(
                                    MemorySegment(new_address, address.offset - new_address.offset), new_accesses))

                if first_found_index is None:
                    first_found_index = index

                # The common part will be handled in the next iteration by either case (3), (4), or (5)
                #
                # Since
                #  - `new_address + new_size > address` (else condition of case 2), so
                #  - `new_size  > address.offset - new_address.offset`
                # then invariant `new_size > 0`
                new_size -= (address.offset - new_address.offset)
                new_address = address
                index += 1
        elif new_address == address:
            # case 3
            # |--------------|
            # |--------|
            if new_address + new_size < address + size:
                segments.insert(index,
                                MemorySegmentAccess(MemorySegment(new_address, new_size), accesses + new_accesses))
                segments[index + 1] = MemorySegmentAccess(
                    MemorySegment(new_address + new_size, size - new_size), accesses)

                return first_found_index if first_found_index is not None else index
            # case 4
            # |--------------|
            # |--------------|
            elif new_address + new_size == address + size:
                segments[index] = MemorySegmentAccess(MemorySegment(new_address, new_size), accesses + new_accesses)

                return first_found_index if first_found_index is not None else index
            # case 5
            # |--------------|
            # |------------------|
            # new_address + new_size > address + size
            else:
                # Current segment's accesses are augmented by new segment's accesses
                segments[index] = MemorySegmentAccess(MemorySegment(address, size), accesses + new_accesses)

                if first_found_index is None:
                    first_found_index = index

                # The different part of the new segment will be handled in the next iteration
                #
                # Since:
                #  - `new_address == address` and `new_address + new_size > address + size`, so
                #  - `new_size > size`
                # then invariant `new_size > 0`
                new_address = address + size
                new_size = new_size - size
                index += 1
        # new_address > address
        else:
            # case 6
            # |--------------|
            #                    |-----------|
            if new_address >= address + size:
                index += 1
            # case 7
            # |--------------|
            #            |-----------|
            else:
                # Split the current segment into:
                #  - the different part
                segments[index] = MemorySegmentAccess(
                    MemorySegment(address, new_address.offset - address.offset), accesses)
                #  - the common part
                segments.insert(index + 1, MemorySegmentAccess(
                    MemorySegment(new_address, address.offset + size - new_address.offset), accesses))

                # The `new_segment` will be handled in the next iteration by either case (3), (4) or (5)
                index += 1

    segments.append(MemorySegmentAccess(MemorySegment(new_address, new_size), new_accesses))
    index = len(segments) - 1

    return first_found_index if first_found_index is not None else index


ThreadId = int


# Get the common memory accesses of two threads of the same process.
# Preconditions on both parameters:
#  1. The list of MemorySegmentAccess is sorted by address and size
#  2. For each MemorySegmentAccess access, `access.segment.size > 0`
# Each thread has a sorted (by address and size) memory segment access list; segments are not empty (i.e `size > 0`).
def get_common_memory_accesses(
    first_thread_segment_accesses: Tuple[ThreadId,
                                         List[MemorySegmentAccess[Tuple[Transition, MemoryAccessOperation]]]],
    second_thread_segment_accesses: Tuple[ThreadId,
                                          List[MemorySegmentAccess[Tuple[Transition, MemoryAccessOperation]]]]) \
        -> List[Tuple[MemorySegment,
                Tuple[List[Tuple[Transition, MemoryAccessOperation]],
                      List[Tuple[Transition, MemoryAccessOperation]]]]]:

    (first_thread_id, first_segment_accesses) = first_thread_segment_accesses
    (second_thread_id, second_segment_accesses) = second_thread_segment_accesses

    # Merge these lists into a new list
    merged_accesses: List[MemorySegmentAccess[Tuple[ThreadId, Transition, MemoryAccessOperation]]] = []
    i = j = 0
    while i < len(first_segment_accesses) and j < len(second_segment_accesses):
        (first_segment, first_accesses) = (first_segment_accesses[i].segment, first_segment_accesses[i].accesses)
        (second_segment, second_accesses) = (second_segment_accesses[j].segment, second_segment_accesses[j].accesses)

        first_threaded_accesses = [
            (first_thread_id, trans, mem_operation) for (trans, mem_operation) in first_accesses]
        second_threaded_accesses = [
            (second_thread_id, trans, mem_operation) for (trans, mem_operation) in second_accesses]

        (first_address, first_size) = (first_segment.address, first_segment.size)
        (second_address, second_size) = (second_segment.address, second_segment.size)

        if (first_address, first_size) < (second_address, second_size):
            merged_accesses.append(MemorySegmentAccess(first_segment, first_threaded_accesses))
            i += 1
        elif (first_address, first_size) > (second_address, second_size):
            merged_accesses.append(MemorySegmentAccess(second_segment, second_threaded_accesses))
            j += 1
        else:
            merged_accesses.append(MemorySegmentAccess(
                first_segment, first_threaded_accesses + second_threaded_accesses))
            i += 1
            j += 1

    while i < len(first_segment_accesses):
        (first_segment, first_accesses) = (first_segment_accesses[i].segment, first_segment_accesses[i].accesses)
        first_threaded_accesses = [
            (first_thread_id, trans, mem_operation) for (trans, mem_operation) in first_accesses]
        merged_accesses.append(MemorySegmentAccess(first_segment, first_threaded_accesses))
        i += 1

    while j < len(second_segment_accesses):
        (second_segment, second_accesses) = (second_segment_accesses[j].segment, second_segment_accesses[j].accesses)
        second_threaded_accesses = [
            (second_thread_id, trans, mem_operation) for (trans, mem_operation) in second_accesses]
        merged_accesses.append(MemorySegmentAccess(second_segment, second_threaded_accesses))
        j += 1

    # The merged list needs to be segmented again to handle the case of overlapping segments coming from different
    # threads.
    # We start from an empty list `refined_accesses`, and gradually insert new segment into it.
    #
    # The list merged_accesses is sorted already, then the position where a segment is inserted will be always larger
    # or equal the inserting position of the previous segment. We can use the inserting position of an segment as
    # the starting position when looking for the position to insert the next segment.
    #
    # Though the complexity of `insert_memory_segment_access` is O(n) in average case, the insertion in the loop
    # below happens mostly at the end of the list (with the complexity O(1)). The complexity of the loop is still O(n).
    refined_accesses: List[MemorySegmentAccess[Tuple[ThreadId, Transition, MemoryAccessOperation]]] = []
    last_inserted_index = None
    for refined_access in merged_accesses:
        last_inserted_index = insert_memory_segment_access(refined_access, refined_accesses, last_inserted_index)

    common_accesses = []
    for refined_access in refined_accesses:
        first_thread_rws = []
        second_thread_rws = []

        for (thread_id, transition, operation) in refined_access.accesses:
            if thread_id == first_thread_id:
                first_thread_rws.append((transition, operation))
            else:
                second_thread_rws.append((transition, operation))

        if first_thread_rws and second_thread_rws:
            common_accesses.append((refined_access.segment, (first_thread_rws, second_thread_rws)))

    return common_accesses


# Return True if the transition should be excluded by the caller for data race detection.
def transition_excluded_by_heuristics(sco: RuntimeHelper, trans: Transition, blacklist: Set[int]) -> bool:
    if trans.id in blacklist:
        return True

    # Memory accesses of the first transition are considered non-protected
    try:
        trans_ctxt = trans.context_before()
    except IndexError:
        return False

    trans_location = trans_ctxt.ossi.location()
    if trans_location is None:
        return False

    if trans_location.binary is None:
        return False

    trans_binary_path = trans_location.binary.path
    process_thread_core_dlls = [
        'c:/windows/system32/ntdll.dll', 'c:/windows/system32/csrsrv.dll', 'c:/windows/system32/basesrv.dll'
    ]
    if trans_binary_path not in process_thread_core_dlls:
        return False

    if trans_location.symbol is None:
        return False

    symbols = sco.symbols

    # Accesses of the synchronization APIs themselves are not counted.
    thread_sync_apis = [
        symbols[name] for name in [
            'RtlInitializeCriticalSection', 'RtlInitializeCriticalSectionEx',
            'RtlInitializeCriticalSectionAndSpinCount',
            'RtlEnterCriticalSection', 'RtlTryEnterCriticalSection', 'RtlLeaveCriticalSection'
        ]
        if symbols[name] is not None
    ]
    if trans_location.symbol in thread_sync_apis:
        blacklist.add(trans.id)
        return True

    trans_frames = trans_ctxt.stack.frames()

    # Accesses of thread create/shutdown are not counted
    thread_create_apis = [
        api for api in [
            symbols['LdrpInitializeThread'], symbols['BaseSrvCreateThread'], symbols['CsrCreateThread']
        ]
        if api is not None
    ]
    thread_shutdown_apis = [
        api for api in [
            symbols['CsrDereferenceThread'], symbols['CsrThreadRefcountZero'], symbols['LdrShutdownThread']
        ]
        if api is not None
    ]
    for frame in trans_frames:
        frame_ctxt = frame.first_context
        frame_location = frame_ctxt.ossi.location()

        if frame_location is None:
            continue

        if frame_location.binary is None:
            continue

        if frame_location.binary.path not in process_thread_core_dlls:
            continue

        if frame_location.symbol is None:
            continue

        if (
            frame_location.symbol in thread_create_apis
            or frame_location.symbol in thread_shutdown_apis
            or frame_location.symbol in thread_sync_apis
        ):
            blacklist.add(trans.id)
            return True

    return False


def get_threads_locks_unlocks(trace: RuntimeHelper, binary: Optional[str],
                              process_ranges: List[Tuple[Context, Context]]) \
        -> Dict[ThreadId, List[Tuple[SynchronizationAction, Context]]]:

    threads_ranges: Dict[int, List[Tuple[Context, Context]]] = {}
    for (ctxt_lo, ctxt_hi) in process_ranges:
        tid = RuntimeHelper.thread_id(ctxt_lo)
        if tid not in threads_ranges:
            threads_ranges[tid] = []

        threads_ranges[tid].append((ctxt_lo, ctxt_hi))

    threads_locks_unlocks: Dict[ThreadId, List[Tuple[SynchronizationAction, Context]]] = {}
    for tid, ranges in threads_ranges.items():
        threads_locks_unlocks[tid] = get_locks_unlocks(trace, binary, ranges)

    return threads_locks_unlocks


# Get all segmented memory accesses of each thread given a list of context ranges.
# Return a map: thread_id -> list(((mem_addr, mem_size), list((thread_id, transition, mem_operation))))
def get_threads_segmented_memory_accesses(trace: RuntimeHelper, process_ranges: List[Tuple[Context, Context]]) \
        -> Dict[ThreadId, List[MemorySegmentAccess[Tuple[Transition, MemoryAccessOperation]]]]:
    sorted_threads_segmented_memory_accesses: Dict[
        ThreadId,
        List[MemorySegmentAccess[Tuple[Transition, MemoryAccessOperation]]]
    ] = {}
    for (ctxt_lo, ctxt_hi) in process_ranges:
        tid = RuntimeHelper.thread_id(ctxt_lo)
        if tid not in sorted_threads_segmented_memory_accesses:
            sorted_threads_segmented_memory_accesses[tid] = []

        for access in trace.get_memory_accesses(ctxt_lo, ctxt_hi):
            access_virtual_address = access.virtual_address
            if access_virtual_address is None:
                # DMA access
                continue

            sorted_threads_segmented_memory_accesses[tid].append(MemorySegmentAccess(
                MemorySegment(access_virtual_address, access.size),
                [(access.transition, access.operation)]))

    # The memory segment accesses of each thread are sorted by address and size of segments, that will help to improve
    # the performance of looking for common accesses of two threads.
    # The complexity of sorting is O(n * logn)
    # Note that segments can be still overlapped.
    for tid in sorted_threads_segmented_memory_accesses.keys():
        sorted_threads_segmented_memory_accesses[tid].sort(key=lambda x: (x.segment.address, x.segment.size))

    threads_segmented_memory_accesses: Dict[
        ThreadId,
        List[MemorySegmentAccess[Tuple[Transition, MemoryAccessOperation]]]
    ] = {}
    # The non-overlapped memory segment access list of each thread is built by:
    #   - start from an empty list
    #   - gradually insert segment into the list using `insert_memory_segment_access`
    # Since the original segment lists are sorted, the complexity of the build is O(n). The total complexity of
    # constructing the non-overlapped memory segment accesses of each thread is 0(n * logn) then.
    for tid in sorted_threads_segmented_memory_accesses.keys():
        threads_segmented_memory_accesses[tid] = []
        last_mem_acc_index = None
        for seg_mem_acc in sorted_threads_segmented_memory_accesses[tid]:
            last_mem_acc_index = insert_memory_segment_access(
                seg_mem_acc, threads_segmented_memory_accesses[tid], last_mem_acc_index
            )

    return threads_segmented_memory_accesses


TransitionId = int
InstructionPointer = int
LockHandle = int


def detect_data_race(trace: RuntimeHelper, pid: int, binary: Optional[str],
                     begin_transition_id: Optional[TransitionId], end_transition_id: Optional[TransitionId]):

    def handlers_to_string(handlers: Iterable[int]):
        msgs = [f'{handle:#x}' for handle in handlers]
        message = ', '.join(msgs)
        return message

    def get_live_lock_handles(
        sco: RuntimeHelper, locks_unlocks: List[Tuple[SynchronizationAction, Context]], trans: Transition
    ) -> Tuple[List[Context], Set[LockHandle]]:
        locks = get_live_locks(sco, locks_unlocks, trans)
        return (locks, {RuntimeHelper.get_lock_handle(ctxt) for ctxt in locks})

    (first_context, last_context) = find_begin_end_context(trace, pid, binary, begin_transition_id, end_transition_id)
    process_ranges = list(find_process_usermode_ranges(trace, pid, first_context, last_context))

    threads_memory_accesses = get_threads_segmented_memory_accesses(trace, process_ranges)
    threads_locks_unlocks = get_threads_locks_unlocks(trace, binary, process_ranges)

    thread_ids: Set[ThreadId] = set()

    # map from a transition (of a given thread) to a tuple whose
    #  - the first is a list of contexts of calls (e.g. RtlEnterCriticalSection) to lock
    #  - the second is a set of handles used by these calls
    cached_live_locks_handles: Dict[Tuple[ThreadId, TransitionId], Tuple[List[Context], Set[LockHandle]]] = {}

    cached_pcs: Dict[TransitionId, InstructionPointer] = {}

    filtered_trans_ids: Set[TransitionId] = set()

    for first_tid, first_thread_accesses in threads_memory_accesses.items():
        thread_ids.add(first_tid)
        for second_tid, second_thread_accesses in threads_memory_accesses.items():
            if second_tid in thread_ids:
                continue

            common_accesses = get_common_memory_accesses(
                (first_tid, first_thread_accesses), (second_tid, second_thread_accesses)
            )

            for (segment, (first_accesses, second_accesses)) in common_accesses:
                (segment_address, segment_size) = (segment.address, segment.size)

                race_pcs: Set[Tuple[InstructionPointer, InstructionPointer]] = set()

                for (first_transition, first_rw) in first_accesses:
                    if first_transition.id in cached_pcs:
                        first_pc = cached_pcs[first_transition.id]
                    else:
                        first_pc = first_transition.context_before().read(x64.rip)

                    if (first_tid, first_transition.id) not in cached_live_locks_handles:
                        critical_section_locks_unlocks = threads_locks_unlocks[first_tid]

                        cached_live_locks_handles[(first_tid, first_transition.id)] = get_live_lock_handles(
                            trace, critical_section_locks_unlocks, first_transition)

                    for (second_transition, second_rw) in second_accesses:
                        if second_transition.id in cached_pcs:
                            second_pc = cached_pcs[second_transition.id]
                        else:
                            second_pc = second_transition.context_before().read(x64.rip)
                            cached_pcs[second_transition.id] = second_pc

                        # the execution of an instruction is considered atomic
                        if first_pc == second_pc:
                            continue

                        if (second_tid, second_transition.id) not in cached_live_locks_handles:
                            critical_section_locks_unlocks = threads_locks_unlocks[second_tid]

                            cached_live_locks_handles[(second_tid, second_transition.id)] = get_live_lock_handles(
                                trace, critical_section_locks_unlocks, second_transition)

                        # Skip if both are the same operation (i.e. read/read or write/write)
                        if first_rw == second_rw:
                            continue

                        first_is_write = first_rw == MemoryAccessOperation.Write
                        second_is_write = second_rw == MemoryAccessOperation.Write

                        if first_transition < second_transition:
                            before_mem_opr_str = 'write' if first_is_write else 'read'
                            after_mem_opr_str = 'write' if second_is_write else 'read'

                            before_trans = first_transition
                            after_trans = second_transition

                            before_tid = first_tid
                            after_tid = second_tid

                        else:
                            before_mem_opr_str = 'write' if second_is_write else 'read'
                            after_mem_opr_str = 'write' if first_is_write else 'read'

                            before_trans = second_transition
                            after_trans = first_transition

                            before_tid = second_tid
                            after_tid = first_tid

                        # Duplicated
                        if (first_pc, second_pc) in race_pcs:
                            continue

                        race_pcs.add((first_pc, second_pc))

                        # ===== No data race

                        message = f'No data race during {after_mem_opr_str} at \
[{segment_address.offset:#x}, {segment_size}] (transition #{after_trans.id}) by thread {after_tid}\n\
with a previous {before_mem_opr_str} (transition #{before_trans.id}) by thread {before_tid}'

                        # There is a common critical section synchronizing the accesses
                        live_critical_sections: Dict[ThreadId, Set[LockHandle]] = {
                            first_tid: cached_live_locks_handles[(first_tid, first_transition.id)][1],
                            second_tid: cached_live_locks_handles[(second_tid, second_transition.id)][1]
                        }

                        common_critical_sections = live_critical_sections[first_tid].intersection(
                            live_critical_sections[second_tid]
                        )
                        if common_critical_sections:
                            if not hide_synchronized_accesses:
                                handlers_str = handlers_to_string(common_critical_sections)
                                print(f'{message}: synchronized by critical section(s) ({handlers_str}).\n')
                            continue

                        # Accesses are excluded since they are in thread create/shutdown
                        are_excluded = transition_excluded_by_heuristics(
                            trace, first_transition, filtered_trans_ids
                        ) and transition_excluded_by_heuristics(trace, second_transition, filtered_trans_ids)
                        if are_excluded:
                            if not hide_synchronized_accesses:
                                print(f'{message}: shared accesses in create/shutdown threads')
                            continue

                        # ===== Data race
                        # If there is no data race, then show the un-synchronized accesses in the following order
                        #   1. one of them are protected by one or several locks
                        #   2. both of them are not protected by any lock

                        message = f'Possible data race during {after_mem_opr_str} at \
[{segment_address.offset:#x}, {segment_size}] (transition #{after_trans.id}) by thread {after_tid}\n\
with a previous {before_mem_opr_str} (transition #{before_trans.id}) by thread {before_tid}'

                        if live_critical_sections[first_tid] or live_critical_sections[second_tid]:
                            print(f'{message}.')
                            for tid in {first_tid, second_tid}:
                                if live_critical_sections[tid]:
                                    handlers_str = handlers_to_string(live_critical_sections[tid])
                                    print(f'\tCritical section handles(s) used by {tid}: {handlers_str}\n')
                        elif not suppress_unknown_primitives:
                            print(f'{message}: no critical section used.\n')


# %%
trace = RuntimeHelper(host, port)
detect_data_race(trace, pid, binary, begin_trans_id, end_trans_id)