Detect data race
This notebook checks for possible data race that may occur when using critical sections as the synchronization primitive.
Prerequisites
- Supported versions:
- REVEN 2.9+
- This notebook should be run in a jupyter notebook server equipped with a REVEN 2 Python kernel.
REVEN comes with a Jupyter notebook server accessible with the
Open Python
button in theAnalyze
page of any scenario. - The following resources are needed for the analyzed scenario:
- Trace
- OSSI
- Memory History
- Backtrace (Stack Events included)
- Fast Search
Perimeter:
- Windows 10 64-bit
Limits
- Only support the following critical section API(s) as the lock/unlock operations.
RtlEnterCriticalSection
,RtlTryEnterCriticalSection
RtlLeaveCriticalSection
Running
Fill out the parameters in the Parameters cell below, then run all the cells of this notebook.
Source
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: reven
# language: python
# name: reven-python3
# ---
# %% [markdown]
# # Detect data race
# This notebook checks for possible data race that may occur when using critical sections as the synchronization
# primitive.
#
# ## Prerequisites
# - Supported versions:
# - REVEN 2.9+
# - This notebook should be run in a jupyter notebook server equipped with a REVEN 2 Python kernel.
# REVEN comes with a Jupyter notebook server accessible with the `Open Python` button in the `Analyze`
# page of any scenario.
# - The following resources are needed for the analyzed scenario:
# - Trace
# - OSSI
# - Memory History
# - Backtrace (Stack Events included)
# - Fast Search
#
# ## Perimeter:
# - Windows 10 64-bit
#
# ## Limits
# - Only support the following critical section API(s) as the lock/unlock operations.
# - `RtlEnterCriticalSection`, `RtlTryEnterCriticalSection`
# - `RtlLeaveCriticalSection`
#
#
# ## Running
# Fill out the parameters in the [Parameters cell](#Parameters) below, then run all the cells of this notebook.
# %%
from dataclasses import dataclass
from enum import Enum
from typing import Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
# Reven specific
import reven2
from reven2.address import LogicalAddress
from reven2.arch import x64
from reven2.memhist import MemoryAccess, MemoryAccessOperation
from reven2.ossi.ossi import Symbol
from reven2.trace import Context, Trace, Transition
from reven2.util import collate
# %% [markdown]
# # Parameters
# %%
# Host and port of the running the scenario.
host = '127.0.0.1'
port = 41309
# The PID and the name of the binary of interest (optional): if the binary name is given (i.e. not None), then only
# locks and unlocks called directly from the binary are counted.
pid = 2460
binary = None
begin_trans_id = None
end_trans_id = None
# Do not show common memory accesses (from different threads) which are synchronized (i.e. mutually excluded) by some
# critical section(s).
hide_synchronized_accesses = True
# Do not show common memory accesses (from different threads) which are dynamically free from critical section(s).
suppress_unknown_primitives = False
# %%
# Helper class which wraps Reven's runtime objects and give methods helping get information about calls to
# RtlEnterCriticalSection and RtlLeaveCriticalSection
class RuntimeHelper:
def __init__(self, host: str, port: int):
try:
server = reven2.RevenServer(host, port)
except RuntimeError:
raise RuntimeError(f'Cannot connect to the scenario at {host}:{port}')
self.trace = server.trace
self.ossi = server.ossi
self.search_symbol = self.trace.search.symbol
self.search_binary = self.trace.search.binary
# basic OSSI
bin_symbol_names = {
'c:/windows/system32/ntoskrnl.exe': {
'KiSwapContext'
},
'c:/windows/system32/ntdll.dll': {
# critical section
'RtlInitializeCriticalSection',
'RtlInitializeCriticalSectionEx',
'RtlInitializeCriticalSectionAndSpinCount',
'RtlEnterCriticalSection',
'RtlTryEnterCriticalSection',
'RtlLeaveCriticalSection',
# initialize/shutdown thread
'LdrpInitializeThread',
'LdrShutdownThread'
},
'c:/windows/system32/basesrv.dll': {
'BaseSrvCreateThread'
},
'c:/windows/system32/csrsrv.dll': {
'CsrCreateThread',
'CsrThreadRefcountZero',
'CsrDereferenceThread'
}
}
self.symbols: Dict[str, Symbol] = {}
for (bin, symbol_names) in bin_symbol_names.items():
try:
exec_bin = next(self.ossi.executed_binaries(f'^{bin}$'))
except StopIteration:
if bin == 'c:/windows/system32/ntoskrnl.exe':
raise RuntimeError(f'{bin} not found')
exec_bin = None
for name in symbol_names:
if exec_bin is None:
self.symbols[name] = None
continue
try:
sym = next(exec_bin.symbols(f'^{name}$'))
self.symbols[name] = sym
except StopIteration:
msg = f'{name} not found in {bin}'
if name in {
'KiSwapContext',
'RtlEnterCriticalSection', 'RtlTryEnterCriticalSection', 'RtlLeaveCriticalSection'
}:
raise RuntimeError(msg)
else:
self.symbols[name] = None
print(f'Warning: {msg}')
self.has_debugger = True
try:
self.trace.first_transition.step_over()
except RuntimeError:
print('Warning: the debugger interface is not available, so the script cannot determine \
function return values.\nMake sure the stack events and PC range resources are replayed for this scenario.')
self.has_debugger = False
def get_memory_accesses(self, from_context: Context, to_context: Context) -> Iterator[MemoryAccess]:
try:
from_trans = from_context.transition_after()
to_trans = to_context.transition_before()
except IndexError:
return
accesses = self.trace.memory_accesses(from_transition=from_trans, to_transition=to_trans)
for access in accesses:
# Skip the access without virtual address
if access.virtual_address is None:
continue
# Skip the access without instruction
if access.transition.instruction is None:
continue
# Skip the access of `lock` prefixed instruction
ins_bytes = access.transition.instruction.raw
if ins_bytes[0] == 0xf0:
continue
yield access
@staticmethod
def get_lock_handle(ctxt: Context) -> int:
return ctxt.read(x64.rcx)
@staticmethod
def thread_id(ctxt: Context) -> int:
return ctxt.read(LogicalAddress(0x48, x64.gs), 4)
@staticmethod
def is_kernel_mode(ctxt: Context) -> bool:
return ctxt.read(x64.cs) & 0x3 == 0
# %%
# Look for a possible first execution context of a binary
# Find the lower/upper bound of contexts on which the deadlock detection processes
def find_begin_end_context(sco: RuntimeHelper, pid: int, binary: Optional[str],
begin_id: Optional[int], end_id: Optional[int]) -> Tuple[Context, Context]:
begin_ctxt = None
if begin_id is not None:
try:
begin_trans = sco.trace.transition(begin_id)
begin_ctxt = begin_trans.context_after()
except IndexError:
begin_ctxt = None
if begin_ctxt is None:
if binary is not None:
for name in sco.ossi.executed_binaries(binary):
for ctxt in sco.search_binary(name):
if ctxt.ossi.process().pid == pid:
begin_ctxt = ctxt
break
if begin_ctxt is not None:
break
if begin_ctxt is None:
begin_ctxt = sco.trace.first_context
end_ctxt = None
if end_id is not None:
try:
end_trans = sco.trace.transition(end_id)
end_ctxt = end_trans.context_before()
except IndexError:
end_ctxt = None
if end_ctxt is None:
end_ctxt = sco.trace.last_context
if (end_ctxt <= begin_ctxt):
raise RuntimeError("The begin transition must be smaller than the end.")
return (begin_ctxt, end_ctxt)
# Get all execution contexts of a given process
def find_process_ranges(sco: RuntimeHelper, pid: int, first_context: Context, last_context: Context) \
-> Iterator[Tuple[Context, Optional[Context]]]:
if last_context <= first_context:
return iter(())
ctxt_low = first_context
ki_swap_context = sco.symbols['KiSwapContext']
for ctxt in sco.search_symbol(ki_swap_context, from_context=first_context,
to_context=None if last_context == sco.trace.last_context else last_context):
if ctxt_low.ossi.process().pid == pid:
yield (ctxt_low, ctxt)
ctxt_low = ctxt
if ctxt_low.ossi.process().pid == pid:
if ctxt_low < last_context:
yield (ctxt_low, last_context)
else:
# So ctxt_low == last_context, using None for the upper bound.
# This happens only when last_context is in the process, and is also the last context of the trace.
yield (ctxt_low, None)
# Start from a transition, return the first transition that is not a non-instruction, or None if there isn't one.
def ignore_non_instructions(trans: Transition, trace: Trace) -> Optional[Transition]:
while trans.instruction is None:
if trans == trace.last_transition:
return None
trans = trans + 1
return trans
# Extract user mode only context ranges from a context range (which may include also kernel mode ranges)
def find_usermode_ranges(sco: RuntimeHelper, ctxt_low: Context, ctxt_high: Optional[Context]) \
-> Iterator[Tuple[Context, Context]]:
if ctxt_high is None:
return
trans = ignore_non_instructions(ctxt_low.transition_after(), sco.trace)
if trans is None:
return
ctxt_current = trans.context_before()
while ctxt_current < ctxt_high:
ctxt_next = ctxt_current.find_register_change(x64.cs, is_forward=True)
if not RuntimeHelper.is_kernel_mode(ctxt_current):
if ctxt_next is None or ctxt_next > ctxt_high:
yield (ctxt_current, ctxt_high)
break
else:
# It's safe to decrease ctxt_next by 1 because it was obtained from a forward find_register_change
yield (ctxt_current, ctxt_next - 1)
if ctxt_next is None:
break
ctxt_current = ctxt_next
# Get user mode only execution contexts of a given process
def find_process_usermode_ranges(trace: RuntimeHelper, pid: int, first_ctxt: Context, last_ctxt: Context) \
-> Iterator[Tuple[Context, Context]]:
for (ctxt_low, ctxt_high) in find_process_ranges(trace, pid, first_ctxt, last_ctxt):
usermode_ranges = find_usermode_ranges(trace, ctxt_low, ctxt_high)
for usermode_range in usermode_ranges:
yield usermode_range
def build_ordered_api_calls(sco: RuntimeHelper, binary: Optional[str],
ctxt_low: Context, ctxt_high: Context, apis: List[str]) -> Iterator[Tuple[str, Context]]:
def gen(api):
return (
(api, ctxt)
for ctxt in sco.search_symbol(sco.symbols[api], from_context=ctxt_low,
to_context=None if ctxt_high == sco.trace.last_context else ctxt_high)
)
api_contexts = (
gen(api)
for api in apis if api in sco.symbols
)
if binary is None:
for api_ctxt in collate(api_contexts, key=lambda name_ctxt: name_ctxt[1]):
yield api_ctxt
else:
for (api, ctxt) in collate(api_contexts, key=lambda name_ctxt: name_ctxt[1]):
try:
caller_ctxt = ctxt - 1
except IndexError:
continue
caller_location = caller_ctxt.ossi.location()
if caller_location is None:
continue
caller_binary = caller_location.binary
if caller_binary is None:
continue
if binary in [caller_binary.name, caller_binary.filename, caller_binary.path]:
yield (api, ctxt)
def get_return_value(sco: RuntimeHelper, ctxt: Context) -> Optional[int]:
if not sco.has_debugger:
return None
try:
trans_after = ctxt.transition_after()
except IndexError:
return None
if trans_after is None:
return None
trans_ret = trans_after.step_out()
if trans_ret is None:
return None
ctxt_ret = trans_ret.context_after()
return ctxt_ret.read(x64.rax)
class SynchronizationAction(Enum):
LOCK = 1
UNLOCK = 0
def get_locks_unlocks(sco: RuntimeHelper, binary: Optional[str], ranges: List[Tuple[Context, Context]]) \
-> List[Tuple[SynchronizationAction, Context]]:
lock_unlock_apis = {
'RtlEnterCriticalSection': SynchronizationAction.LOCK,
'RtlLeaveCriticalSection': SynchronizationAction.UNLOCK,
'RtlTryEnterCriticalSection': None,
}
critical_section_actions = []
for ctxt_low, ctxt_high in ranges:
for name, ctxt in build_ordered_api_calls(sco, binary, ctxt_low, ctxt_high, list(lock_unlock_apis.keys())):
# either lock, unlock, or none
action = lock_unlock_apis[name]
if action is None:
# need to check the return value of the API to get the action
api_ret_val = get_return_value(sco, ctxt)
if api_ret_val is None:
print(f'Warning: failed to get the return value, {name} is omitted')
continue
else:
# RtlTryEnterCriticalSection: the return value is nonzero if the thread is success
# to enter the critical section
if api_ret_val != 0:
action = SynchronizationAction.LOCK
else:
continue
critical_section_actions.append((action, ctxt))
return critical_section_actions
# Get locks which are effective at a transition
def get_live_locks(sco: RuntimeHelper,
locks_unlocks: List[Tuple[SynchronizationAction, Context]], trans: Transition) -> List[Context]:
protection_locks: List[Context] = []
trans_ctxt = trans.context_before()
for (action, ctxt) in locks_unlocks:
if ctxt > trans_ctxt:
return protection_locks
if action == SynchronizationAction.LOCK:
protection_locks.append(ctxt)
continue
unlock_handle = RuntimeHelper.get_lock_handle(ctxt)
# look for the lastest corresponding lock
for (idx, lock_ctxt) in reversed(list(enumerate(protection_locks))):
lock_handle = RuntimeHelper.get_lock_handle(lock_ctxt)
if lock_handle == unlock_handle:
del protection_locks[idx]
break
return protection_locks
# %%
AccessType = TypeVar("AccessType")
@dataclass
class MemorySegment:
address: LogicalAddress
size: int
@dataclass
class MemorySegmentAccess(Generic[AccessType]):
segment: MemorySegment
accesses: List[AccessType]
# Insert a segment into a current list of segments, start at position; if the position is None then start from the
# head of the list.
# Precondition: `new_segment.segment.size > 0`
# Return the position where the segment is inserted.
# The complexity of the insertion is about O(n).
def insert_memory_segment_access(new_segment: MemorySegmentAccess, segments: List[MemorySegmentAccess],
position: Optional[int]) -> int:
new_address = new_segment.segment.address
new_size = new_segment.segment.size
new_accesses = new_segment.accesses
if not segments:
segments.append(MemorySegmentAccess(MemorySegment(new_address, new_size), new_accesses))
return 0
if position is None:
position = 0
index = position
first_found_index = None
while index < len(segments):
# Loop invariant `new_size > 0`
# - True at the first iteration
# - In the subsequent iterations, either
# - the invariant is always kept through cases (2), (5), (6), (7)
# - the loop returns directly in cases (1), (3), (4)
address = segments[index].segment.address
size = segments[index].segment.size
accesses = segments[index].accesses
if new_address < address:
# Case 1
# |--------------|
# |--------------|
if new_address + new_size <= address:
segments.insert(index,
MemorySegmentAccess(MemorySegment(new_address, new_size), new_accesses))
return first_found_index if first_found_index is not None else index
# Case 2
# |--------------|
# |--------------|
else:
# Insert the different part of the new segment
segments.insert(index,
MemorySegmentAccess(
MemorySegment(new_address, address.offset - new_address.offset), new_accesses))
if first_found_index is None:
first_found_index = index
# The common part will be handled in the next iteration by either case (3), (4), or (5)
#
# Since
# - `new_address + new_size > address` (else condition of case 2), so
# - `new_size > address.offset - new_address.offset`
# then invariant `new_size > 0`
new_size -= (address.offset - new_address.offset)
new_address = address
index += 1
elif new_address == address:
# case 3
# |--------------|
# |--------|
if new_address + new_size < address + size:
segments.insert(index,
MemorySegmentAccess(MemorySegment(new_address, new_size), accesses + new_accesses))
segments[index + 1] = MemorySegmentAccess(
MemorySegment(new_address + new_size, size - new_size), accesses)
return first_found_index if first_found_index is not None else index
# case 4
# |--------------|
# |--------------|
elif new_address + new_size == address + size:
segments[index] = MemorySegmentAccess(MemorySegment(new_address, new_size), accesses + new_accesses)
return first_found_index if first_found_index is not None else index
# case 5
# |--------------|
# |------------------|
# new_address + new_size > address + size
else:
# Current segment's accesses are augmented by new segment's accesses
segments[index] = MemorySegmentAccess(MemorySegment(address, size), accesses + new_accesses)
if first_found_index is None:
first_found_index = index
# The different part of the new segment will be handled in the next iteration
#
# Since:
# - `new_address == address` and `new_address + new_size > address + size`, so
# - `new_size > size`
# then invariant `new_size > 0`
new_address = address + size
new_size = new_size - size
index += 1
# new_address > address
else:
# case 6
# |--------------|
# |-----------|
if new_address >= address + size:
index += 1
# case 7
# |--------------|
# |-----------|
else:
# Split the current segment into:
# - the different part
segments[index] = MemorySegmentAccess(
MemorySegment(address, new_address.offset - address.offset), accesses)
# - the common part
segments.insert(index + 1, MemorySegmentAccess(
MemorySegment(new_address, address.offset + size - new_address.offset), accesses))
# The `new_segment` will be handled in the next iteration by either case (3), (4) or (5)
index += 1
segments.append(MemorySegmentAccess(MemorySegment(new_address, new_size), new_accesses))
index = len(segments) - 1
return first_found_index if first_found_index is not None else index
ThreadId = int
# Get the common memory accesses of two threads of the same process.
# Preconditions on both parameters:
# 1. The list of MemorySegmentAccess is sorted by address and size
# 2. For each MemorySegmentAccess access, `access.segment.size > 0`
# Each thread has a sorted (by address and size) memory segment access list; segments are not empty (i.e `size > 0`).
def get_common_memory_accesses(
first_thread_segment_accesses: Tuple[ThreadId,
List[MemorySegmentAccess[Tuple[Transition, MemoryAccessOperation]]]],
second_thread_segment_accesses: Tuple[ThreadId,
List[MemorySegmentAccess[Tuple[Transition, MemoryAccessOperation]]]]) \
-> List[Tuple[MemorySegment,
Tuple[List[Tuple[Transition, MemoryAccessOperation]],
List[Tuple[Transition, MemoryAccessOperation]]]]]:
(first_thread_id, first_segment_accesses) = first_thread_segment_accesses
(second_thread_id, second_segment_accesses) = second_thread_segment_accesses
# Merge these lists into a new list
merged_accesses: List[MemorySegmentAccess[Tuple[ThreadId, Transition, MemoryAccessOperation]]] = []
i = j = 0
while i < len(first_segment_accesses) and j < len(second_segment_accesses):
(first_segment, first_accesses) = (first_segment_accesses[i].segment, first_segment_accesses[i].accesses)
(second_segment, second_accesses) = (second_segment_accesses[j].segment, second_segment_accesses[j].accesses)
first_threaded_accesses = [
(first_thread_id, trans, mem_operation) for (trans, mem_operation) in first_accesses]
second_threaded_accesses = [
(second_thread_id, trans, mem_operation) for (trans, mem_operation) in second_accesses]
(first_address, first_size) = (first_segment.address, first_segment.size)
(second_address, second_size) = (second_segment.address, second_segment.size)
if (first_address, first_size) < (second_address, second_size):
merged_accesses.append(MemorySegmentAccess(first_segment, first_threaded_accesses))
i += 1
elif (first_address, first_size) > (second_address, second_size):
merged_accesses.append(MemorySegmentAccess(second_segment, second_threaded_accesses))
j += 1
else:
merged_accesses.append(MemorySegmentAccess(
first_segment, first_threaded_accesses + second_threaded_accesses))
i += 1
j += 1
while i < len(first_segment_accesses):
(first_segment, first_accesses) = (first_segment_accesses[i].segment, first_segment_accesses[i].accesses)
first_threaded_accesses = [
(first_thread_id, trans, mem_operation) for (trans, mem_operation) in first_accesses]
merged_accesses.append(MemorySegmentAccess(first_segment, first_threaded_accesses))
i += 1
while j < len(second_segment_accesses):
(second_segment, second_accesses) = (second_segment_accesses[j].segment, second_segment_accesses[j].accesses)
second_threaded_accesses = [
(second_thread_id, trans, mem_operation) for (trans, mem_operation) in second_accesses]
merged_accesses.append(MemorySegmentAccess(second_segment, second_threaded_accesses))
j += 1
# The merged list needs to be segmented again to handle the case of overlapping segments coming from different
# threads.
# We start from an empty list `refined_accesses`, and gradually insert new segment into it.
#
# The list merged_accesses is sorted already, then the position where a segment is inserted will be always larger
# or equal the inserting position of the previous segment. We can use the inserting position of an segment as
# the starting position when looking for the position to insert the next segment.
#
# Though the complexity of `insert_memory_segment_access` is O(n) in average case, the insertion in the loop
# below happens mostly at the end of the list (with the complexity O(1)). The complexity of the loop is still O(n).
refined_accesses: List[MemorySegmentAccess[Tuple[ThreadId, Transition, MemoryAccessOperation]]] = []
last_inserted_index = None
for refined_access in merged_accesses:
last_inserted_index = insert_memory_segment_access(refined_access, refined_accesses, last_inserted_index)
common_accesses = []
for refined_access in refined_accesses:
first_thread_rws = []
second_thread_rws = []
for (thread_id, transition, operation) in refined_access.accesses:
if thread_id == first_thread_id:
first_thread_rws.append((transition, operation))
else:
second_thread_rws.append((transition, operation))
if first_thread_rws and second_thread_rws:
common_accesses.append((refined_access.segment, (first_thread_rws, second_thread_rws)))
return common_accesses
# Return True if the transition should be excluded by the caller for data race detection.
def transition_excluded_by_heuristics(sco: RuntimeHelper, trans: Transition, blacklist: Set[int]) -> bool:
if trans.id in blacklist:
return True
# Memory accesses of the first transition are considered non-protected
try:
trans_ctxt = trans.context_before()
except IndexError:
return False
trans_location = trans_ctxt.ossi.location()
if trans_location is None:
return False
if trans_location.binary is None:
return False
trans_binary_path = trans_location.binary.path
process_thread_core_dlls = [
'c:/windows/system32/ntdll.dll', 'c:/windows/system32/csrsrv.dll', 'c:/windows/system32/basesrv.dll'
]
if trans_binary_path not in process_thread_core_dlls:
return False
if trans_location.symbol is None:
return False
symbols = sco.symbols
# Accesses of the synchronization APIs themselves are not counted.
thread_sync_apis = [
symbols[name] for name in [
'RtlInitializeCriticalSection', 'RtlInitializeCriticalSectionEx',
'RtlInitializeCriticalSectionAndSpinCount',
'RtlEnterCriticalSection', 'RtlTryEnterCriticalSection', 'RtlLeaveCriticalSection'
]
if symbols[name] is not None
]
if trans_location.symbol in thread_sync_apis:
blacklist.add(trans.id)
return True
trans_frames = trans_ctxt.stack.frames()
# Accesses of thread create/shutdown are not counted
thread_create_apis = [
api for api in [
symbols['LdrpInitializeThread'], symbols['BaseSrvCreateThread'], symbols['CsrCreateThread']
]
if api is not None
]
thread_shutdown_apis = [
api for api in [
symbols['CsrDereferenceThread'], symbols['CsrThreadRefcountZero'], symbols['LdrShutdownThread']
]
if api is not None
]
for frame in trans_frames:
frame_ctxt = frame.first_context
frame_location = frame_ctxt.ossi.location()
if frame_location is None:
continue
if frame_location.binary is None:
continue
if frame_location.binary.path not in process_thread_core_dlls:
continue
if frame_location.symbol is None:
continue
if any(
[frame_location.symbol in apis for apis in [thread_create_apis, thread_shutdown_apis, thread_sync_apis]]
):
blacklist.add(trans.id)
return True
return False
def get_threads_locks_unlocks(trace: RuntimeHelper, binary: Optional[str],
process_ranges: List[Tuple[Context, Context]]) \
-> Dict[ThreadId, List[Tuple[SynchronizationAction, Context]]]:
threads_ranges: Dict[int, List[Tuple[Context, Context]]] = {}
for (ctxt_lo, ctxt_hi) in process_ranges:
tid = RuntimeHelper.thread_id(ctxt_lo)
if tid not in threads_ranges:
threads_ranges[tid] = []
threads_ranges[tid].append((ctxt_lo, ctxt_hi))
threads_locks_unlocks: Dict[ThreadId, List[Tuple[SynchronizationAction, Context]]] = {}
for tid, ranges in threads_ranges.items():
threads_locks_unlocks[tid] = get_locks_unlocks(trace, binary, ranges)
return threads_locks_unlocks
# Get all segmented memory accesses of each thread given a list of context ranges.
# Return a map: thread_id -> list(((mem_addr, mem_size), list((thread_id, transition, mem_operation))))
def get_threads_segmented_memory_accesses(trace: RuntimeHelper, process_ranges: List[Tuple[Context, Context]]) \
-> Dict[ThreadId, List[MemorySegmentAccess[Tuple[Transition, MemoryAccessOperation]]]]:
sorted_threads_segmented_memory_accesses: Dict[
ThreadId,
List[MemorySegmentAccess[Tuple[Transition, MemoryAccessOperation]]]
] = {}
for (ctxt_lo, ctxt_hi) in process_ranges:
tid = RuntimeHelper.thread_id(ctxt_lo)
if tid not in sorted_threads_segmented_memory_accesses:
sorted_threads_segmented_memory_accesses[tid] = []
for access in trace.get_memory_accesses(ctxt_lo, ctxt_hi):
sorted_threads_segmented_memory_accesses[tid].append(MemorySegmentAccess(
MemorySegment(access.virtual_address, access.size),
[(access.transition, access.operation)]))
# The memory segment accesses of each thread are sorted by address and size of segments, that will help to improve
# the performance of looking for common accesses of two threads.
# The complexity of sorting is O(n * logn)
# Note that segments can be still overlapped.
for tid in sorted_threads_segmented_memory_accesses.keys():
sorted_threads_segmented_memory_accesses[tid].sort(key=lambda x: (x.segment.address, x.segment.size))
threads_segmented_memory_accesses: Dict[
ThreadId,
List[MemorySegmentAccess[Tuple[Transition, MemoryAccessOperation]]]
] = {}
# The non-overlapped memory segment access list of each thread is built by:
# - start from an empty list
# - gradually insert segment into the list using `insert_memory_segment_access`
# Since the original segment lists are sorted, the complexity of the build is O(n). The total complexity of
# constructing the non-overlapped memory segment accesses of each thread is 0(n * logn) then.
for tid in sorted_threads_segmented_memory_accesses.keys():
threads_segmented_memory_accesses[tid] = []
last_mem_acc_index = None
for seg_mem_acc in sorted_threads_segmented_memory_accesses[tid]:
last_mem_acc_index = insert_memory_segment_access(
seg_mem_acc, threads_segmented_memory_accesses[tid], last_mem_acc_index
)
return threads_segmented_memory_accesses
TransitionId = int
InstructionPointer = int
LockHandle = int
def detect_data_race(trace: RuntimeHelper, pid: int, binary: Optional[str],
begin_transition_id: Optional[TransitionId], end_transition_id: Optional[TransitionId]):
def handlers_to_string(handlers: Iterable[int]):
msgs = [f'{handle:#x}' for handle in handlers]
message = ', '.join(msgs)
return message
def get_live_lock_handles(sco: RuntimeHelper, locks_unlocks: List[Context], trans: Transition) \
-> Tuple[List[Context], Set[LockHandle]]:
locks = get_live_locks(sco, locks_unlocks, trans)
return (locks, {RuntimeHelper.get_lock_handle(ctxt) for ctxt in locks})
(first_context, last_context) = find_begin_end_context(trace, pid, binary, begin_transition_id, end_transition_id)
process_ranges = list(find_process_usermode_ranges(trace, pid, first_context, last_context))
threads_memory_accesses = get_threads_segmented_memory_accesses(trace, process_ranges)
threads_locks_unlocks = get_threads_locks_unlocks(trace, binary, process_ranges)
thread_ids: Set[ThreadId] = set()
# map from a transition (of a given thread) to a tuple whose
# - the first is a list of contexts of calls (e.g. RtlEnterCriticalSection) to lock
# - the second is a set of handles used by these calls
cached_live_locks_handles: Dict[Tuple[ThreadId, TransitionId], Tuple[List[Context], Set[LockHandle]]] = {}
cached_pcs: Dict[TransitionId, InstructionPointer] = {}
filtered_trans_ids: Set[TransitionId] = set()
for first_tid, first_thread_accesses in threads_memory_accesses.items():
thread_ids.add(first_tid)
for second_tid, second_thread_accesses in threads_memory_accesses.items():
if second_tid in thread_ids:
continue
common_accesses = get_common_memory_accesses(
(first_tid, first_thread_accesses), (second_tid, second_thread_accesses)
)
for (segment, (first_accesses, second_accesses)) in common_accesses:
(segment_address, segment_size) = (segment.address, segment.size)
race_pcs: Set[Tuple[InstructionPointer, InstructionPointer]] = set()
for (first_transition, first_rw) in first_accesses:
if first_transition.id in cached_pcs:
first_pc = cached_pcs[first_transition.id]
else:
first_pc = first_transition.context_before().read(x64.rip)
if (first_tid, first_transition.id) not in cached_live_locks_handles:
critical_section_locks_unlocks = threads_locks_unlocks[first_tid]
cached_live_locks_handles[(first_tid, first_transition.id)] = get_live_lock_handles(
trace, critical_section_locks_unlocks, first_transition)
for (second_transition, second_rw) in second_accesses:
if second_transition.id in cached_pcs:
second_pc = cached_pcs[second_transition.id]
else:
second_pc = second_transition.context_before().read(x64.rip)
cached_pcs[second_transition.id] = second_pc
# the execution of an instruction is considered atomic
if first_pc == second_pc:
continue
if (second_tid, second_transition.id) not in cached_live_locks_handles:
critical_section_locks_unlocks = threads_locks_unlocks[second_tid]
cached_live_locks_handles[(second_tid, second_transition.id)] = get_live_lock_handles(
trace, critical_section_locks_unlocks, second_transition)
# Skip if both are the same operation (i.e. read/read or write/write)
if first_rw == second_rw:
continue
first_is_write = first_rw == MemoryAccessOperation.Write
second_is_write = second_rw == MemoryAccessOperation.Write
if first_transition < second_transition:
before_mem_opr_str = 'write' if first_is_write else 'read'
after_mem_opr_str = 'write' if second_is_write else 'read'
before_trans = first_transition
after_trans = second_transition
before_tid = first_tid
after_tid = second_tid
else:
before_mem_opr_str = 'write' if second_is_write else 'read'
after_mem_opr_str = 'write' if first_is_write else 'read'
before_trans = second_transition
after_trans = first_transition
before_tid = second_tid
after_tid = first_tid
# Duplicated
if (first_pc, second_pc) in race_pcs:
continue
race_pcs.add((first_pc, second_pc))
# ===== No data race
message = f'No data race during {after_mem_opr_str} at \
[{segment_address.offset:#x}, {segment_size}] (transition #{after_trans.id}) by thread {after_tid}\n\
with a previous {before_mem_opr_str} (transition #{before_trans.id}) by thread {before_tid}'
# There is a common critical section synchronizing the accesses
live_critical_sections: Dict[ThreadId, Set[LockHandle]] = {
first_tid: cached_live_locks_handles[(first_tid, first_transition.id)][1],
second_tid: cached_live_locks_handles[(second_tid, second_transition.id)][1]
}
common_critical_sections = live_critical_sections[first_tid].intersection(
live_critical_sections[second_tid]
)
if common_critical_sections:
if not hide_synchronized_accesses:
handlers_str = handlers_to_string(common_critical_sections)
print(f'{message}: synchronized by critical section(s) ({handlers_str}).\n')
continue
# Accesses are excluded since they are in thread create/shutdown
are_excluded = transition_excluded_by_heuristics(
trace, first_transition, filtered_trans_ids
) and transition_excluded_by_heuristics(trace, second_transition, filtered_trans_ids)
if are_excluded:
if not hide_synchronized_accesses:
print(f'{message}: shared accesses in create/shutdown threads')
continue
# ===== Data race
# If there is no data race, then show the un-synchronized accesses in the following order
# 1. one of them are protected by one or several locks
# 2. both of them are not protected by any lock
message = f'Possible data race during {after_mem_opr_str} at \
[{segment_address.offset:#x}, {segment_size}] (transition #{after_trans.id}) by thread {after_tid}\n\
with a previous {before_mem_opr_str} (transition #{before_trans.id}) by thread {before_tid}'
if live_critical_sections[first_tid] or live_critical_sections[second_tid]:
print(f'{message}.')
for tid in {first_tid, second_tid}:
if live_critical_sections[tid]:
handlers_str = handlers_to_string(live_critical_sections[tid])
print(f'\tCritical section handles(s) used by {tid}: {handlers_str}\n')
elif not suppress_unknown_primitives:
print(f'{message}: no critical section used.\n')
# %%
trace = RuntimeHelper(host, port)
detect_data_race(trace, pid, binary, begin_trans_id, end_trans_id)