Detect data race

This notebook checks for possible data race that may occur when using critical sections as the synchronization primitive.


  • Supported versions:
  • REVEN 2.9+
  • This notebook should be run in a jupyter notebook server equipped with a REVEN 2 Python kernel. REVEN comes with a Jupyter notebook server accessible with the Open Python button in the Analyze page of any scenario.
  • The following resources are needed for the analyzed scenario:
  • Trace
  • OSSI
  • Memory History
  • Backtrace (Stack Events included)
  • Fast Search


  • Windows 10 64-bit


  • Only support the following critical section API(s) as the lock/unlock operations.
  • RtlEnterCriticalSection, RtlTryEnterCriticalSection
  • RtlLeaveCriticalSection


Fill out the parameters in the Parameters cell below, then run all the cells of this notebook.


# ---
# jupyter:
#   jupytext:
#     formats: ipynb,py:percent
#     text_representation:
#       extension: .py
#       format_name: percent
#   kernelspec:
#     display_name: reven
#     language: python
#     name: reven-python3
# ---

# %% [markdown]
# # Detect data race
# This notebook checks for possible data race that may occur when using critical sections as the synchronization
# primitive.
# ## Prerequisites
#  - Supported versions:
#     - REVEN 2.9+
#  - This notebook should be run in a jupyter notebook server equipped with a REVEN 2 Python kernel.
# REVEN comes with a Jupyter notebook server accessible with the `Open Python` button in the `Analyze`
# page of any scenario.
#  - The following resources are needed for the analyzed scenario:
#     - Trace
#     - OSSI
#     - Memory History
#     - Backtrace (Stack Events included)
#     - Fast Search
# ## Perimeter:
#  - Windows 10 64-bit
# ## Limits
#  - Only support the following critical section API(s) as the lock/unlock operations.
#     - `RtlEnterCriticalSection`, `RtlTryEnterCriticalSection`
#     - `RtlLeaveCriticalSection`
# ## Running
# Fill out the parameters in the [Parameters cell](#Parameters) below, then run all the cells of this notebook.

# %%
from dataclasses import dataclass
from enum import Enum
from typing import Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union

# Reven specific
import reven2
from reven2.address import LinearAddress, LogicalAddress
from reven2.arch import x64
from reven2.memhist import MemoryAccess, MemoryAccessOperation
from reven2.ossi.ossi import Symbol
from reven2.trace import Context, Trace, Transition
from reven2.util import collate

# %% [markdown]
# # Parameters

# %%
# Host and port of the running the scenario.
host = ''
port = 41309

# The PID and the name of the binary of interest (optional): if the binary name is given (i.e. not None), then only
# locks and unlocks called directly from the binary are counted.
pid = 2460
binary = None

begin_trans_id = None
end_trans_id = None

# Do not show common memory accesses (from different threads) which are synchronized (i.e. mutually excluded) by some
# critical section(s).
hide_synchronized_accesses = True

# Do not show common memory accesses (from different threads) which are dynamically free from critical section(s).
suppress_unknown_primitives = False

# %%
# Helper class which wraps Reven's runtime objects and give methods helping get information about calls to
# RtlEnterCriticalSection and RtlLeaveCriticalSection
class RuntimeHelper:
    def __init__(self, host: str, port: int):
            server = reven2.RevenServer(host, port)
        except RuntimeError:
            raise RuntimeError(f'Cannot connect to the scenario at {host}:{port}')

        self.trace = server.trace
        self.ossi = server.ossi
        self.search_symbol =
        self.search_binary =

        # basic OSSI
        bin_symbol_names = {
            'c:/windows/system32/ntoskrnl.exe': {

            'c:/windows/system32/ntdll.dll': {
                # critical section

                # initialize/shutdown thread

            'c:/windows/system32/basesrv.dll': {

            'c:/windows/system32/csrsrv.dll': {

        self.symbols: Dict[str, Optional[Symbol]] = {}
        for (bin, symbol_names) in bin_symbol_names.items():
                exec_bin = next(self.ossi.executed_binaries(f'^{bin}$'))
            except StopIteration:
                if bin == 'c:/windows/system32/ntoskrnl.exe':
                    raise RuntimeError(f'{bin} not found')
                exec_bin = None

            for name in symbol_names:
                if exec_bin is None:
                    self.symbols[name] = None

                    sym = next(exec_bin.symbols(f'^{name}$'))
                    self.symbols[name] = sym
                except StopIteration:
                    msg = f'{name} not found in {bin}'

                    if name in {
                        'RtlEnterCriticalSection', 'RtlTryEnterCriticalSection', 'RtlLeaveCriticalSection'
                        raise RuntimeError(msg)
                        self.symbols[name] = None
                        print(f'Warning: {msg}')

        self.has_debugger = True
        except RuntimeError:
            print('Warning: the debugger interface is not available, so the script cannot determine \
function return values.\nMake sure the stack events and PC range resources are replayed for this scenario.')
            self.has_debugger = False

    def get_memory_accesses(self, from_context: Context, to_context: Context) -> Iterator[MemoryAccess]:
            from_trans = from_context.transition_after()
            to_trans = to_context.transition_before()
        except IndexError:

        accesses = self.trace.memory_accesses(from_transition=from_trans, to_transition=to_trans)
        for access in accesses:
            # Skip the access without virtual address
            if access.virtual_address is None:

            # Skip the access without instruction
            if access.transition.instruction is None:

            # Skip the access of `lock` prefixed instruction
            ins_bytes = access.transition.instruction.raw
            if ins_bytes[0] == 0xf0:

            yield access

    def get_lock_handle(ctxt: Context) -> int:

    def thread_id(ctxt: Context) -> int:
        return,, 4)

    def is_kernel_mode(ctxt: Context) -> bool:
        return & 0x3 == 0

# %%
# Look for a possible first execution context of a binary
# Find the lower/upper bound of contexts on which the deadlock detection processes
def find_begin_end_context(sco: RuntimeHelper, pid: int, binary: Optional[str],
                           begin_id: Optional[int], end_id: Optional[int]) -> Tuple[Context, Context]:
    begin_ctxt = None
    if begin_id is not None:
            begin_trans = sco.trace.transition(begin_id)
            begin_ctxt = begin_trans.context_after()
        except IndexError:
            begin_ctxt = None

    if begin_ctxt is None:
        if binary is not None:
            for name in sco.ossi.executed_binaries(binary):
                for ctxt in sco.search_binary(name):
                    ctx_process = ctxt.ossi.process()
                    assert ctx_process is not None
                    if == pid:
                        begin_ctxt = ctxt
                if begin_ctxt is not None:

    if begin_ctxt is None:
        begin_ctxt = sco.trace.first_context

    end_ctxt = None
    if end_id is not None:
            end_trans = sco.trace.transition(end_id)
            end_ctxt = end_trans.context_before()
        except IndexError:
            end_ctxt = None

    if end_ctxt is None:
        end_ctxt = sco.trace.last_context

    if (end_ctxt <= begin_ctxt):
        raise RuntimeError("The begin transition must be smaller than the end.")

    return (begin_ctxt, end_ctxt)

# Get all execution contexts of a given process
def find_process_ranges(sco: RuntimeHelper, pid: int, first_context: Context, last_context: Context) \
        -> Iterator[Tuple[Context, Optional[Context]]]:
    if last_context <= first_context:
        return iter(())

    ctxt_low = first_context
    assert sco.symbols['KiSwapContext'] is not None
    ki_swap_context = sco.symbols['KiSwapContext']
    for ctxt in sco.search_symbol(ki_swap_context, from_context=first_context,
                                  to_context=None if last_context == sco.trace.last_context else last_context):
        ctx_process = ctxt_low.ossi.process()
        assert ctx_process is not None

        if == pid:
            yield (ctxt_low, ctxt)
        ctxt_low = ctxt

    ctx_process = ctxt_low.ossi.process()
    assert ctx_process is not None

    if == pid:
        if ctxt_low < last_context:
            yield (ctxt_low, last_context)
            # So ctxt_low == last_context, using None for the upper bound.
            # This happens only when last_context is in the process, and is also the last context of the trace.
            yield (ctxt_low, None)

# Start from a transition, return the first transition that is not a non-instruction, or None if there isn't one.
def ignore_non_instructions(trans: Transition, trace: Trace) -> Optional[Transition]:
    while trans.instruction is None:
        if trans == trace.last_transition:
            return None
        trans = trans + 1

    return trans

# Extract user mode only context ranges from a context range (which may include also kernel mode ranges)
def find_usermode_ranges(sco: RuntimeHelper, ctxt_low: Context, ctxt_high: Optional[Context]) \
        -> Iterator[Tuple[Context, Context]]:
    if ctxt_high is None:

    trans = ignore_non_instructions(ctxt_low.transition_after(), sco.trace)
    if trans is None:
    ctxt_current = trans.context_before()

    while ctxt_current < ctxt_high:
        ctxt_next = ctxt_current.find_register_change(x64.cs, is_forward=True)

        if not RuntimeHelper.is_kernel_mode(ctxt_current):
            if ctxt_next is None or ctxt_next > ctxt_high:
                yield (ctxt_current, ctxt_high)
                # It's safe to decrease ctxt_next by 1 because it was obtained from a forward find_register_change
                yield (ctxt_current, ctxt_next - 1)

        if ctxt_next is None:

        ctxt_current = ctxt_next

# Get user mode only execution contexts of a given process
def find_process_usermode_ranges(trace: RuntimeHelper, pid: int, first_ctxt: Context, last_ctxt: Context) \
        -> Iterator[Tuple[Context, Context]]:
    for (ctxt_low, ctxt_high) in find_process_ranges(trace, pid, first_ctxt, last_ctxt):
        usermode_ranges = find_usermode_ranges(trace, ctxt_low, ctxt_high)
        for usermode_range in usermode_ranges:
            yield usermode_range

def build_ordered_api_calls(sco: RuntimeHelper, binary: Optional[str],
                            ctxt_low: Context, ctxt_high: Context, apis: List[str]) -> Iterator[Tuple[str, Context]]:
    def gen(api):
        return (
            (api, ctxt)
            for ctxt in sco.search_symbol(sco.symbols[api], from_context=ctxt_low,
                                          to_context=None if ctxt_high == sco.trace.last_context else ctxt_high)

    api_contexts = (
        for api in apis if api in sco.symbols

    if binary is None:
        for api_ctxt in collate(api_contexts, key=lambda name_ctxt: name_ctxt[1]):
            yield api_ctxt
        for (api, ctxt) in collate(api_contexts, key=lambda name_ctxt: name_ctxt[1]):
                caller_ctxt = ctxt - 1
            except IndexError:

            caller_location = caller_ctxt.ossi.location()
            if caller_location is None:

            caller_binary = caller_location.binary
            if caller_binary is None:

            if binary in [, caller_binary.filename, caller_binary.path]:
                yield (api, ctxt)

def get_return_value(sco: RuntimeHelper, ctxt: Context) -> Optional[int]:
    if not sco.has_debugger:
        return None

        trans_after = ctxt.transition_after()
    except IndexError:
        return None

    if trans_after is None:
        return None

    trans_ret = trans_after.step_out()
    if trans_ret is None:
        return None

    ctxt_ret = trans_ret.context_after()

class SynchronizationAction(Enum):
    LOCK = 1
    UNLOCK = 0

def get_locks_unlocks(sco: RuntimeHelper, binary: Optional[str], ranges: List[Tuple[Context, Context]]) \
        -> List[Tuple[SynchronizationAction, Context]]:

    lock_unlock_apis = {
        'RtlEnterCriticalSection': SynchronizationAction.LOCK,
        'RtlLeaveCriticalSection': SynchronizationAction.UNLOCK,
        'RtlTryEnterCriticalSection': None,

    critical_section_actions = []

    for ctxt_low, ctxt_high in ranges:
        for name, ctxt in build_ordered_api_calls(sco, binary, ctxt_low, ctxt_high, list(lock_unlock_apis.keys())):
            # either lock, unlock, or none
            action = lock_unlock_apis[name]

            if action is None:
                # need to check the return value of the API to get the action
                api_ret_val = get_return_value(sco, ctxt)
                if api_ret_val is None:
                    print(f'Warning: failed to get the return value, {name} is omitted')
                    # RtlTryEnterCriticalSection: the return value is nonzero if the thread is success
                    # to enter the critical section
                    if api_ret_val != 0:
                        action = SynchronizationAction.LOCK

            critical_section_actions.append((action, ctxt))

    return critical_section_actions

# Get locks which are effective at a transition
def get_live_locks(sco: RuntimeHelper,
                   locks_unlocks: List[Tuple[SynchronizationAction, Context]], trans: Transition) -> List[Context]:
    protection_locks: List[Context] = []
    trans_ctxt = trans.context_before()

    for (action, ctxt) in locks_unlocks:
        if ctxt > trans_ctxt:
            return protection_locks

        if action == SynchronizationAction.LOCK:

        unlock_handle = RuntimeHelper.get_lock_handle(ctxt)
        # look for the lastest corresponding lock
        for (idx, lock_ctxt) in reversed(list(enumerate(protection_locks))):
            lock_handle = RuntimeHelper.get_lock_handle(lock_ctxt)
            if lock_handle == unlock_handle:
                del protection_locks[idx]

    return protection_locks

# %%
AccessType = TypeVar("AccessType")

class MemorySegment:
    address: Union[LinearAddress, LogicalAddress]
    size: int

class MemorySegmentAccess(Generic[AccessType]):
    segment: MemorySegment
    accesses: List[AccessType]

# Insert a segment into a current list of segments, start at position; if the position is None then start from the
# head of the list.
# Precondition: `new_segment.segment.size > 0`
# Return the position where the segment is inserted.
# The complexity of the insertion is about O(n).
def insert_memory_segment_access(new_segment: MemorySegmentAccess, segments: List[MemorySegmentAccess],
                                 position: Optional[int]) -> int:
    new_address = new_segment.segment.address
    new_size = new_segment.segment.size
    new_accesses = new_segment.accesses

    if not segments:
        segments.append(MemorySegmentAccess(MemorySegment(new_address, new_size), new_accesses))
        return 0

    if position is None:
        position = 0

    index = position
    first_found_index = None
    while index < len(segments):
        # Loop invariant `new_size > 0`
        #  - True at the first iteration
        #  - In the subsequent iterations, either
        #     - the invariant is always kept through cases (2), (5), (6), (7)
        #     - the loop returns directly in cases (1), (3), (4)

        address = segments[index].segment.address
        size = segments[index].segment.size
        accesses = segments[index].accesses

        if new_address < address:
            # Case 1
            #                   |--------------|
            # |--------------|
            if new_address + new_size <= address:
                                MemorySegmentAccess(MemorySegment(new_address, new_size), new_accesses))

                return first_found_index if first_found_index is not None else index
            # Case 2
            #              |--------------|
            # |--------------|
                # Insert the different part of the new segment
                                    MemorySegment(new_address, address.offset - new_address.offset), new_accesses))

                if first_found_index is None:
                    first_found_index = index

                # The common part will be handled in the next iteration by either case (3), (4), or (5)
                # Since
                #  - `new_address + new_size > address` (else condition of case 2), so
                #  - `new_size  > address.offset - new_address.offset`
                # then invariant `new_size > 0`
                new_size -= (address.offset - new_address.offset)
                new_address = address
                index += 1
        elif new_address == address:
            # case 3
            # |--------------|
            # |--------|
            if new_address + new_size < address + size:
                                MemorySegmentAccess(MemorySegment(new_address, new_size), accesses + new_accesses))
                segments[index + 1] = MemorySegmentAccess(
                    MemorySegment(new_address + new_size, size - new_size), accesses)

                return first_found_index if first_found_index is not None else index
            # case 4
            # |--------------|
            # |--------------|
            elif new_address + new_size == address + size:
                segments[index] = MemorySegmentAccess(MemorySegment(new_address, new_size), accesses + new_accesses)

                return first_found_index if first_found_index is not None else index
            # case 5
            # |--------------|
            # |------------------|
            # new_address + new_size > address + size
                # Current segment's accesses are augmented by new segment's accesses
                segments[index] = MemorySegmentAccess(MemorySegment(address, size), accesses + new_accesses)

                if first_found_index is None:
                    first_found_index = index

                # The different part of the new segment will be handled in the next iteration
                # Since:
                #  - `new_address == address` and `new_address + new_size > address + size`, so
                #  - `new_size > size`
                # then invariant `new_size > 0`
                new_address = address + size
                new_size = new_size - size
                index += 1
        # new_address > address
            # case 6
            # |--------------|
            #                    |-----------|
            if new_address >= address + size:
                index += 1
            # case 7
            # |--------------|
            #            |-----------|
                # Split the current segment into:
                #  - the different part
                segments[index] = MemorySegmentAccess(
                    MemorySegment(address, new_address.offset - address.offset), accesses)
                #  - the common part
                segments.insert(index + 1, MemorySegmentAccess(
                    MemorySegment(new_address, address.offset + size - new_address.offset), accesses))

                # The `new_segment` will be handled in the next iteration by either case (3), (4) or (5)
                index += 1

    segments.append(MemorySegmentAccess(MemorySegment(new_address, new_size), new_accesses))
    index = len(segments) - 1

    return first_found_index if first_found_index is not None else index

ThreadId = int

# Get the common memory accesses of two threads of the same process.
# Preconditions on both parameters:
#  1. The list of MemorySegmentAccess is sorted by address and size
#  2. For each MemorySegmentAccess access, `access.segment.size > 0`
# Each thread has a sorted (by address and size) memory segment access list; segments are not empty (i.e `size > 0`).
def get_common_memory_accesses(
    first_thread_segment_accesses: Tuple[ThreadId,
                                         List[MemorySegmentAccess[Tuple[Transition, MemoryAccessOperation]]]],
    second_thread_segment_accesses: Tuple[ThreadId,
                                          List[MemorySegmentAccess[Tuple[Transition, MemoryAccessOperation]]]]) \
        -> List[Tuple[MemorySegment,
                Tuple[List[Tuple[Transition, MemoryAccessOperation]],
                      List[Tuple[Transition, MemoryAccessOperation]]]]]:

    (first_thread_id, first_segment_accesses) = first_thread_segment_accesses
    (second_thread_id, second_segment_accesses) = second_thread_segment_accesses

    # Merge these lists into a new list
    merged_accesses: List[MemorySegmentAccess[Tuple[ThreadId, Transition, MemoryAccessOperation]]] = []
    i = j = 0
    while i < len(first_segment_accesses) and j < len(second_segment_accesses):
        (first_segment, first_accesses) = (first_segment_accesses[i].segment, first_segment_accesses[i].accesses)
        (second_segment, second_accesses) = (second_segment_accesses[j].segment, second_segment_accesses[j].accesses)

        first_threaded_accesses = [
            (first_thread_id, trans, mem_operation) for (trans, mem_operation) in first_accesses]
        second_threaded_accesses = [
            (second_thread_id, trans, mem_operation) for (trans, mem_operation) in second_accesses]

        (first_address, first_size) = (first_segment.address, first_segment.size)
        (second_address, second_size) = (second_segment.address, second_segment.size)

        if (first_address, first_size) < (second_address, second_size):
            merged_accesses.append(MemorySegmentAccess(first_segment, first_threaded_accesses))
            i += 1
        elif (first_address, first_size) > (second_address, second_size):
            merged_accesses.append(MemorySegmentAccess(second_segment, second_threaded_accesses))
            j += 1
                first_segment, first_threaded_accesses + second_threaded_accesses))
            i += 1
            j += 1

    while i < len(first_segment_accesses):
        (first_segment, first_accesses) = (first_segment_accesses[i].segment, first_segment_accesses[i].accesses)
        first_threaded_accesses = [
            (first_thread_id, trans, mem_operation) for (trans, mem_operation) in first_accesses]
        merged_accesses.append(MemorySegmentAccess(first_segment, first_threaded_accesses))
        i += 1

    while j < len(second_segment_accesses):
        (second_segment, second_accesses) = (second_segment_accesses[j].segment, second_segment_accesses[j].accesses)
        second_threaded_accesses = [
            (second_thread_id, trans, mem_operation) for (trans, mem_operation) in second_accesses]
        merged_accesses.append(MemorySegmentAccess(second_segment, second_threaded_accesses))
        j += 1

    # The merged list needs to be segmented again to handle the case of overlapping segments coming from different
    # threads.
    # We start from an empty list `refined_accesses`, and gradually insert new segment into it.
    # The list merged_accesses is sorted already, then the position where a segment is inserted will be always larger
    # or equal the inserting position of the previous segment. We can use the inserting position of an segment as
    # the starting position when looking for the position to insert the next segment.
    # Though the complexity of `insert_memory_segment_access` is O(n) in average case, the insertion in the loop
    # below happens mostly at the end of the list (with the complexity O(1)). The complexity of the loop is still O(n).
    refined_accesses: List[MemorySegmentAccess[Tuple[ThreadId, Transition, MemoryAccessOperation]]] = []
    last_inserted_index = None
    for refined_access in merged_accesses:
        last_inserted_index = insert_memory_segment_access(refined_access, refined_accesses, last_inserted_index)

    common_accesses = []
    for refined_access in refined_accesses:
        first_thread_rws = []
        second_thread_rws = []

        for (thread_id, transition, operation) in refined_access.accesses:
            if thread_id == first_thread_id:
                first_thread_rws.append((transition, operation))
                second_thread_rws.append((transition, operation))

        if first_thread_rws and second_thread_rws:
            common_accesses.append((refined_access.segment, (first_thread_rws, second_thread_rws)))

    return common_accesses

# Return True if the transition should be excluded by the caller for data race detection.
def transition_excluded_by_heuristics(sco: RuntimeHelper, trans: Transition, blacklist: Set[int]) -> bool:
    if in blacklist:
        return True

    # Memory accesses of the first transition are considered non-protected
        trans_ctxt = trans.context_before()
    except IndexError:
        return False

    trans_location = trans_ctxt.ossi.location()
    if trans_location is None:
        return False

    if trans_location.binary is None:
        return False

    trans_binary_path = trans_location.binary.path
    process_thread_core_dlls = [
        'c:/windows/system32/ntdll.dll', 'c:/windows/system32/csrsrv.dll', 'c:/windows/system32/basesrv.dll'
    if trans_binary_path not in process_thread_core_dlls:
        return False

    if trans_location.symbol is None:
        return False

    symbols = sco.symbols

    # Accesses of the synchronization APIs themselves are not counted.
    thread_sync_apis = [
        symbols[name] for name in [
            'RtlInitializeCriticalSection', 'RtlInitializeCriticalSectionEx',
            'RtlEnterCriticalSection', 'RtlTryEnterCriticalSection', 'RtlLeaveCriticalSection'
        if symbols[name] is not None
    if trans_location.symbol in thread_sync_apis:
        return True

    trans_frames = trans_ctxt.stack.frames()

    # Accesses of thread create/shutdown are not counted
    thread_create_apis = [
        api for api in [
            symbols['LdrpInitializeThread'], symbols['BaseSrvCreateThread'], symbols['CsrCreateThread']
        if api is not None
    thread_shutdown_apis = [
        api for api in [
            symbols['CsrDereferenceThread'], symbols['CsrThreadRefcountZero'], symbols['LdrShutdownThread']
        if api is not None
    for frame in trans_frames:
        frame_ctxt = frame.first_context
        frame_location = frame_ctxt.ossi.location()

        if frame_location is None:

        if frame_location.binary is None:

        if frame_location.binary.path not in process_thread_core_dlls:

        if frame_location.symbol is None:

        if (
            frame_location.symbol in thread_create_apis
            or frame_location.symbol in thread_shutdown_apis
            or frame_location.symbol in thread_sync_apis
            return True

    return False

def get_threads_locks_unlocks(trace: RuntimeHelper, binary: Optional[str],
                              process_ranges: List[Tuple[Context, Context]]) \
        -> Dict[ThreadId, List[Tuple[SynchronizationAction, Context]]]:

    threads_ranges: Dict[int, List[Tuple[Context, Context]]] = {}
    for (ctxt_lo, ctxt_hi) in process_ranges:
        tid = RuntimeHelper.thread_id(ctxt_lo)
        if tid not in threads_ranges:
            threads_ranges[tid] = []

        threads_ranges[tid].append((ctxt_lo, ctxt_hi))

    threads_locks_unlocks: Dict[ThreadId, List[Tuple[SynchronizationAction, Context]]] = {}
    for tid, ranges in threads_ranges.items():
        threads_locks_unlocks[tid] = get_locks_unlocks(trace, binary, ranges)

    return threads_locks_unlocks

# Get all segmented memory accesses of each thread given a list of context ranges.
# Return a map: thread_id -> list(((mem_addr, mem_size), list((thread_id, transition, mem_operation))))
def get_threads_segmented_memory_accesses(trace: RuntimeHelper, process_ranges: List[Tuple[Context, Context]]) \
        -> Dict[ThreadId, List[MemorySegmentAccess[Tuple[Transition, MemoryAccessOperation]]]]:
    sorted_threads_segmented_memory_accesses: Dict[
        List[MemorySegmentAccess[Tuple[Transition, MemoryAccessOperation]]]
    ] = {}
    for (ctxt_lo, ctxt_hi) in process_ranges:
        tid = RuntimeHelper.thread_id(ctxt_lo)
        if tid not in sorted_threads_segmented_memory_accesses:
            sorted_threads_segmented_memory_accesses[tid] = []

        for access in trace.get_memory_accesses(ctxt_lo, ctxt_hi):
            access_virtual_address = access.virtual_address
            if access_virtual_address is None:
                # DMA access

                MemorySegment(access_virtual_address, access.size),
                [(access.transition, access.operation)]))

    # The memory segment accesses of each thread are sorted by address and size of segments, that will help to improve
    # the performance of looking for common accesses of two threads.
    # The complexity of sorting is O(n * logn)
    # Note that segments can be still overlapped.
    for tid in sorted_threads_segmented_memory_accesses.keys():
        sorted_threads_segmented_memory_accesses[tid].sort(key=lambda x: (x.segment.address, x.segment.size))

    threads_segmented_memory_accesses: Dict[
        List[MemorySegmentAccess[Tuple[Transition, MemoryAccessOperation]]]
    ] = {}
    # The non-overlapped memory segment access list of each thread is built by:
    #   - start from an empty list
    #   - gradually insert segment into the list using `insert_memory_segment_access`
    # Since the original segment lists are sorted, the complexity of the build is O(n). The total complexity of
    # constructing the non-overlapped memory segment accesses of each thread is 0(n * logn) then.
    for tid in sorted_threads_segmented_memory_accesses.keys():
        threads_segmented_memory_accesses[tid] = []
        last_mem_acc_index = None
        for seg_mem_acc in sorted_threads_segmented_memory_accesses[tid]:
            last_mem_acc_index = insert_memory_segment_access(
                seg_mem_acc, threads_segmented_memory_accesses[tid], last_mem_acc_index

    return threads_segmented_memory_accesses

TransitionId = int
InstructionPointer = int
LockHandle = int

def detect_data_race(trace: RuntimeHelper, pid: int, binary: Optional[str],
                     begin_transition_id: Optional[TransitionId], end_transition_id: Optional[TransitionId]):

    def handlers_to_string(handlers: Iterable[int]):
        msgs = [f'{handle:#x}' for handle in handlers]
        message = ', '.join(msgs)
        return message

    def get_live_lock_handles(
        sco: RuntimeHelper, locks_unlocks: List[Tuple[SynchronizationAction, Context]], trans: Transition
    ) -> Tuple[List[Context], Set[LockHandle]]:
        locks = get_live_locks(sco, locks_unlocks, trans)
        return (locks, {RuntimeHelper.get_lock_handle(ctxt) for ctxt in locks})

    (first_context, last_context) = find_begin_end_context(trace, pid, binary, begin_transition_id, end_transition_id)
    process_ranges = list(find_process_usermode_ranges(trace, pid, first_context, last_context))

    threads_memory_accesses = get_threads_segmented_memory_accesses(trace, process_ranges)
    threads_locks_unlocks = get_threads_locks_unlocks(trace, binary, process_ranges)

    thread_ids: Set[ThreadId] = set()

    # map from a transition (of a given thread) to a tuple whose
    #  - the first is a list of contexts of calls (e.g. RtlEnterCriticalSection) to lock
    #  - the second is a set of handles used by these calls
    cached_live_locks_handles: Dict[Tuple[ThreadId, TransitionId], Tuple[List[Context], Set[LockHandle]]] = {}

    cached_pcs: Dict[TransitionId, InstructionPointer] = {}

    filtered_trans_ids: Set[TransitionId] = set()

    for first_tid, first_thread_accesses in threads_memory_accesses.items():
        for second_tid, second_thread_accesses in threads_memory_accesses.items():
            if second_tid in thread_ids:

            common_accesses = get_common_memory_accesses(
                (first_tid, first_thread_accesses), (second_tid, second_thread_accesses)

            for (segment, (first_accesses, second_accesses)) in common_accesses:
                (segment_address, segment_size) = (segment.address, segment.size)

                race_pcs: Set[Tuple[InstructionPointer, InstructionPointer]] = set()

                for (first_transition, first_rw) in first_accesses:
                    if in cached_pcs:
                        first_pc = cached_pcs[]
                        first_pc = first_transition.context_before().read(

                    if (first_tid, not in cached_live_locks_handles:
                        critical_section_locks_unlocks = threads_locks_unlocks[first_tid]

                        cached_live_locks_handles[(first_tid,] = get_live_lock_handles(
                            trace, critical_section_locks_unlocks, first_transition)

                    for (second_transition, second_rw) in second_accesses:
                        if in cached_pcs:
                            second_pc = cached_pcs[]
                            second_pc = second_transition.context_before().read(
                            cached_pcs[] = second_pc

                        # the execution of an instruction is considered atomic
                        if first_pc == second_pc:

                        if (second_tid, not in cached_live_locks_handles:
                            critical_section_locks_unlocks = threads_locks_unlocks[second_tid]

                            cached_live_locks_handles[(second_tid,] = get_live_lock_handles(
                                trace, critical_section_locks_unlocks, second_transition)

                        # Skip if both are the same operation (i.e. read/read or write/write)
                        if first_rw == second_rw:

                        first_is_write = first_rw == MemoryAccessOperation.Write
                        second_is_write = second_rw == MemoryAccessOperation.Write

                        if first_transition < second_transition:
                            before_mem_opr_str = 'write' if first_is_write else 'read'
                            after_mem_opr_str = 'write' if second_is_write else 'read'

                            before_trans = first_transition
                            after_trans = second_transition

                            before_tid = first_tid
                            after_tid = second_tid

                            before_mem_opr_str = 'write' if second_is_write else 'read'
                            after_mem_opr_str = 'write' if first_is_write else 'read'

                            before_trans = second_transition
                            after_trans = first_transition

                            before_tid = second_tid
                            after_tid = first_tid

                        # Duplicated
                        if (first_pc, second_pc) in race_pcs:

                        race_pcs.add((first_pc, second_pc))

                        # ===== No data race

                        message = f'No data race during {after_mem_opr_str} at \
[{segment_address.offset:#x}, {segment_size}] (transition #{}) by thread {after_tid}\n\
with a previous {before_mem_opr_str} (transition #{}) by thread {before_tid}'

                        # There is a common critical section synchronizing the accesses
                        live_critical_sections: Dict[ThreadId, Set[LockHandle]] = {
                            first_tid: cached_live_locks_handles[(first_tid,][1],
                            second_tid: cached_live_locks_handles[(second_tid,][1]

                        common_critical_sections = live_critical_sections[first_tid].intersection(
                        if common_critical_sections:
                            if not hide_synchronized_accesses:
                                handlers_str = handlers_to_string(common_critical_sections)
                                print(f'{message}: synchronized by critical section(s) ({handlers_str}).\n')

                        # Accesses are excluded since they are in thread create/shutdown
                        are_excluded = transition_excluded_by_heuristics(
                            trace, first_transition, filtered_trans_ids
                        ) and transition_excluded_by_heuristics(trace, second_transition, filtered_trans_ids)
                        if are_excluded:
                            if not hide_synchronized_accesses:
                                print(f'{message}: shared accesses in create/shutdown threads')

                        # ===== Data race
                        # If there is no data race, then show the un-synchronized accesses in the following order
                        #   1. one of them are protected by one or several locks
                        #   2. both of them are not protected by any lock

                        message = f'Possible data race during {after_mem_opr_str} at \
[{segment_address.offset:#x}, {segment_size}] (transition #{}) by thread {after_tid}\n\
with a previous {before_mem_opr_str} (transition #{}) by thread {before_tid}'

                        if live_critical_sections[first_tid] or live_critical_sections[second_tid]:
                            for tid in {first_tid, second_tid}:
                                if live_critical_sections[tid]:
                                    handlers_str = handlers_to_string(live_critical_sections[tid])
                                    print(f'\tCritical section handles(s) used by {tid}: {handlers_str}\n')
                        elif not suppress_unknown_primitives:
                            print(f'{message}: no critical section used.\n')

# %%
trace = RuntimeHelper(host, port)
detect_data_race(trace, pid, binary, begin_trans_id, end_trans_id)