Detect data race

This notebook checks for possible data race that may occur when using critical sections as the synchronization primitive.

Prerequisites

  • Supported versions:
  • REVEN 2.9+
  • This notebook should be run in a jupyter notebook server equipped with a REVEN 2 Python kernel. REVEN comes with a Jupyter notebook server accessible with the Open Python button in the Analyze page of any scenario.
  • The following resources are needed for the analyzed scenario:
  • Trace
  • OSSI
  • Memory History
  • Backtrace (Stack Events included)
  • Fast Search

Perimeter:

  • Windows 10 64-bit

Limits

  • Only support the following critical section API(s) as the lock/unlock operations.
  • RtlEnterCriticalSection, RtlTryEnterCriticalSection
  • RtlLeaveCriticalSection

Running

Fill out the parameters in the Parameters cell below, then run all the cells of this notebook.

Source

# --- # jupyter: # jupytext: # formats: ipynb,py:percent # text_representation: # extension: .py # format_name: percent # kernelspec: # display_name: reven # language: python # name: reven-python3 # --- # %% [markdown] # # Detect data race # This notebook checks for possible data race that may occur when using critical sections as the synchronization # primitive. # # ## Prerequisites # - Supported versions: # - REVEN 2.9+ # - This notebook should be run in a jupyter notebook server equipped with a REVEN 2 Python kernel. # REVEN comes with a Jupyter notebook server accessible with the `Open Python` button in the `Analyze` # page of any scenario. # - The following resources are needed for the analyzed scenario: # - Trace # - OSSI # - Memory History # - Backtrace (Stack Events included) # - Fast Search # # ## Perimeter: # - Windows 10 64-bit # # ## Limits # - Only support the following critical section API(s) as the lock/unlock operations. # - `RtlEnterCriticalSection`, `RtlTryEnterCriticalSection` # - `RtlLeaveCriticalSection` # # # ## Running # Fill out the parameters in the [Parameters cell](#Parameters) below, then run all the cells of this notebook. # %% from dataclasses import dataclass from enum import Enum from typing import Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union # Reven specific import reven2 from reven2.address import LinearAddress, LogicalAddress from reven2.arch import x64 from reven2.memhist import MemoryAccess, MemoryAccessOperation from reven2.ossi.ossi import Symbol from reven2.trace import Context, Trace, Transition from reven2.util import collate # %% [markdown] # # Parameters # %% # Host and port of the running the scenario. host = '127.0.0.1' port = 41309 # The PID and the name of the binary of interest (optional): if the binary name is given (i.e. not None), then only # locks and unlocks called directly from the binary are counted. pid = 2460 binary = None begin_trans_id = None end_trans_id = None # Do not show common memory accesses (from different threads) which are synchronized (i.e. mutually excluded) by some # critical section(s). hide_synchronized_accesses = True # Do not show common memory accesses (from different threads) which are dynamically free from critical section(s). suppress_unknown_primitives = False # %% # Helper class which wraps Reven's runtime objects and give methods helping get information about calls to # RtlEnterCriticalSection and RtlLeaveCriticalSection class RuntimeHelper: def __init__(self, host: str, port: int): try: server = reven2.RevenServer(host, port) except RuntimeError: raise RuntimeError(f'Cannot connect to the scenario at {host}:{port}') self.trace = server.trace self.ossi = server.ossi self.search_symbol = self.trace.search.symbol self.search_binary = self.trace.search.binary # basic OSSI bin_symbol_names = { 'c:/windows/system32/ntoskrnl.exe': { 'KiSwapContext' }, 'c:/windows/system32/ntdll.dll': { # critical section 'RtlInitializeCriticalSection', 'RtlInitializeCriticalSectionEx', 'RtlInitializeCriticalSectionAndSpinCount', 'RtlEnterCriticalSection', 'RtlTryEnterCriticalSection', 'RtlLeaveCriticalSection', # initialize/shutdown thread 'LdrpInitializeThread', 'LdrShutdownThread' }, 'c:/windows/system32/basesrv.dll': { 'BaseSrvCreateThread' }, 'c:/windows/system32/csrsrv.dll': { 'CsrCreateThread', 'CsrThreadRefcountZero', 'CsrDereferenceThread' } } self.symbols: Dict[str, Optional[Symbol]] = {} for (bin, symbol_names) in bin_symbol_names.items(): try: exec_bin = next(self.ossi.executed_binaries(f'^{bin}$')) except StopIteration: if bin == 'c:/windows/system32/ntoskrnl.exe': raise RuntimeError(f'{bin} not found') exec_bin = None for name in symbol_names: if exec_bin is None: self.symbols[name] = None continue try: sym = next(exec_bin.symbols(f'^{name}$')) self.symbols[name] = sym except StopIteration: msg = f'{name} not found in {bin}' if name in { 'KiSwapContext', 'RtlEnterCriticalSection', 'RtlTryEnterCriticalSection', 'RtlLeaveCriticalSection' }: raise RuntimeError(msg) else: self.symbols[name] = None print(f'Warning: {msg}') self.has_debugger = True try: self.trace.first_transition.step_over() except RuntimeError: print('Warning: the debugger interface is not available, so the script cannot determine \ function return values.\nMake sure the stack events and PC range resources are replayed for this scenario.') self.has_debugger = False def get_memory_accesses(self, from_context: Context, to_context: Context) -> Iterator[MemoryAccess]: try: from_trans = from_context.transition_after() to_trans = to_context.transition_before() except IndexError: return accesses = self.trace.memory_accesses(from_transition=from_trans, to_transition=to_trans) for access in accesses: # Skip the access without virtual address if access.virtual_address is None: continue # Skip the access without instruction if access.transition.instruction is None: continue # Skip the access of `lock` prefixed instruction ins_bytes = access.transition.instruction.raw if ins_bytes[0] == 0xf0: continue yield access @staticmethod def get_lock_handle(ctxt: Context) -> int: return ctxt.read(x64.rcx) @staticmethod def thread_id(ctxt: Context) -> int: return ctxt.read(LogicalAddress(0x48, x64.gs), 4) @staticmethod def is_kernel_mode(ctxt: Context) -> bool: return ctxt.read(x64.cs) & 0x3 == 0 # %% # Look for a possible first execution context of a binary # Find the lower/upper bound of contexts on which the deadlock detection processes def find_begin_end_context(sco: RuntimeHelper, pid: int, binary: Optional[str], begin_id: Optional[int], end_id: Optional[int]) -> Tuple[Context, Context]: begin_ctxt = None if begin_id is not None: try: begin_trans = sco.trace.transition(begin_id) begin_ctxt = begin_trans.context_after() except IndexError: begin_ctxt = None if begin_ctxt is None: if binary is not None: for name in sco.ossi.executed_binaries(binary): for ctxt in sco.search_binary(name): ctx_process = ctxt.ossi.process() assert ctx_process is not None if ctx_process.pid == pid: begin_ctxt = ctxt break if begin_ctxt is not None: break if begin_ctxt is None: begin_ctxt = sco.trace.first_context end_ctxt = None if end_id is not None: try: end_trans = sco.trace.transition(end_id) end_ctxt = end_trans.context_before() except IndexError: end_ctxt = None if end_ctxt is None: end_ctxt = sco.trace.last_context if (end_ctxt <= begin_ctxt): raise RuntimeError("The begin transition must be smaller than the end.") return (begin_ctxt, end_ctxt) # Get all execution contexts of a given process def find_process_ranges(sco: RuntimeHelper, pid: int, first_context: Context, last_context: Context) \ -> Iterator[Tuple[Context, Optional[Context]]]: if last_context <= first_context: return iter(()) ctxt_low = first_context assert sco.symbols['KiSwapContext'] is not None ki_swap_context = sco.symbols['KiSwapContext'] for ctxt in sco.search_symbol(ki_swap_context, from_context=first_context, to_context=None if last_context == sco.trace.last_context else last_context): ctx_process = ctxt_low.ossi.process() assert ctx_process is not None if ctx_process.pid == pid: yield (ctxt_low, ctxt) ctxt_low = ctxt ctx_process = ctxt_low.ossi.process() assert ctx_process is not None if ctx_process.pid == pid: if ctxt_low < last_context: yield (ctxt_low, last_context) else: # So ctxt_low == last_context, using None for the upper bound. # This happens only when last_context is in the process, and is also the last context of the trace. yield (ctxt_low, None) # Start from a transition, return the first transition that is not a non-instruction, or None if there isn't one. def ignore_non_instructions(trans: Transition, trace: Trace) -> Optional[Transition]: while trans.instruction is None: if trans == trace.last_transition: return None trans = trans + 1 return trans # Extract user mode only context ranges from a context range (which may include also kernel mode ranges) def find_usermode_ranges(sco: RuntimeHelper, ctxt_low: Context, ctxt_high: Optional[Context]) \ -> Iterator[Tuple[Context, Context]]: if ctxt_high is None: return trans = ignore_non_instructions(ctxt_low.transition_after(), sco.trace) if trans is None: return ctxt_current = trans.context_before() while ctxt_current < ctxt_high: ctxt_next = ctxt_current.find_register_change(x64.cs, is_forward=True) if not RuntimeHelper.is_kernel_mode(ctxt_current): if ctxt_next is None or ctxt_next > ctxt_high: yield (ctxt_current, ctxt_high) break else: # It's safe to decrease ctxt_next by 1 because it was obtained from a forward find_register_change yield (ctxt_current, ctxt_next - 1) if ctxt_next is None: break ctxt_current = ctxt_next # Get user mode only execution contexts of a given process def find_process_usermode_ranges(trace: RuntimeHelper, pid: int, first_ctxt: Context, last_ctxt: Context) \ -> Iterator[Tuple[Context, Context]]: for (ctxt_low, ctxt_high) in find_process_ranges(trace, pid, first_ctxt, last_ctxt): usermode_ranges = find_usermode_ranges(trace, ctxt_low, ctxt_high) for usermode_range in usermode_ranges: yield usermode_range def build_ordered_api_calls(sco: RuntimeHelper, binary: Optional[str], ctxt_low: Context, ctxt_high: Context, apis: List[str]) -> Iterator[Tuple[str, Context]]: def gen(api): return ( (api, ctxt) for ctxt in sco.search_symbol(sco.symbols[api], from_context=ctxt_low, to_context=None if ctxt_high == sco.trace.last_context else ctxt_high) ) api_contexts = ( gen(api) for api in apis if api in sco.symbols ) if binary is None: for api_ctxt in collate(api_contexts, key=lambda name_ctxt: name_ctxt[1]): yield api_ctxt else: for (api, ctxt) in collate(api_contexts, key=lambda name_ctxt: name_ctxt[1]): try: caller_ctxt = ctxt - 1 except IndexError: continue caller_location = caller_ctxt.ossi.location() if caller_location is None: continue caller_binary = caller_location.binary if caller_binary is None: continue if binary in [caller_binary.name, caller_binary.filename, caller_binary.path]: yield (api, ctxt) def get_return_value(sco: RuntimeHelper, ctxt: Context) -> Optional[int]: if not sco.has_debugger: return None try: trans_after = ctxt.transition_after() except IndexError: return None if trans_after is None: return None trans_ret = trans_after.step_out() if trans_ret is None: return None ctxt_ret = trans_ret.context_after() return ctxt_ret.read(x64.rax) class SynchronizationAction(Enum): LOCK = 1 UNLOCK = 0 def get_locks_unlocks(sco: RuntimeHelper, binary: Optional[str], ranges: List[Tuple[Context, Context]]) \ -> List[Tuple[SynchronizationAction, Context]]: lock_unlock_apis = { 'RtlEnterCriticalSection': SynchronizationAction.LOCK, 'RtlLeaveCriticalSection': SynchronizationAction.UNLOCK, 'RtlTryEnterCriticalSection': None, } critical_section_actions = [] for ctxt_low, ctxt_high in ranges: for name, ctxt in build_ordered_api_calls(sco, binary, ctxt_low, ctxt_high, list(lock_unlock_apis.keys())): # either lock, unlock, or none action = lock_unlock_apis[name] if action is None: # need to check the return value of the API to get the action api_ret_val = get_return_value(sco, ctxt) if api_ret_val is None: print(f'Warning: failed to get the return value, {name} is omitted') continue else: # RtlTryEnterCriticalSection: the return value is nonzero if the thread is success # to enter the critical section if api_ret_val != 0: action = SynchronizationAction.LOCK else: continue critical_section_actions.append((action, ctxt)) return critical_section_actions # Get locks which are effective at a transition def get_live_locks(sco: RuntimeHelper, locks_unlocks: List[Tuple[SynchronizationAction, Context]], trans: Transition) -> List[Context]: protection_locks: List[Context] = [] trans_ctxt = trans.context_before() for (action, ctxt) in locks_unlocks: if ctxt > trans_ctxt: return protection_locks if action == SynchronizationAction.LOCK: protection_locks.append(ctxt) continue unlock_handle = RuntimeHelper.get_lock_handle(ctxt) # look for the lastest corresponding lock for (idx, lock_ctxt) in reversed(list(enumerate(protection_locks))): lock_handle = RuntimeHelper.get_lock_handle(lock_ctxt) if lock_handle == unlock_handle: del protection_locks[idx] break return protection_locks # %% AccessType = TypeVar("AccessType") @dataclass class MemorySegment: address: Union[LinearAddress, LogicalAddress] size: int @dataclass class MemorySegmentAccess(Generic[AccessType]): segment: MemorySegment accesses: List[AccessType] # Insert a segment into a current list of segments, start at position; if the position is None then start from the # head of the list. # Precondition: `new_segment.segment.size > 0` # Return the position where the segment is inserted. # The complexity of the insertion is about O(n). def insert_memory_segment_access(new_segment: MemorySegmentAccess, segments: List[MemorySegmentAccess], position: Optional[int]) -> int: new_address = new_segment.segment.address new_size = new_segment.segment.size new_accesses = new_segment.accesses if not segments: segments.append(MemorySegmentAccess(MemorySegment(new_address, new_size), new_accesses)) return 0 if position is None: position = 0 index = position first_found_index = None while index < len(segments): # Loop invariant `new_size > 0` # - True at the first iteration # - In the subsequent iterations, either # - the invariant is always kept through cases (2), (5), (6), (7) # - the loop returns directly in cases (1), (3), (4) address = segments[index].segment.address size = segments[index].segment.size accesses = segments[index].accesses if new_address < address: # Case 1 # |--------------| # |--------------| if new_address + new_size <= address: segments.insert(index, MemorySegmentAccess(MemorySegment(new_address, new_size), new_accesses)) return first_found_index if first_found_index is not None else index # Case 2 # |--------------| # |--------------| else: # Insert the different part of the new segment segments.insert(index, MemorySegmentAccess( MemorySegment(new_address, address.offset - new_address.offset), new_accesses)) if first_found_index is None: first_found_index = index # The common part will be handled in the next iteration by either case (3), (4), or (5) # # Since # - `new_address + new_size > address` (else condition of case 2), so # - `new_size > address.offset - new_address.offset` # then invariant `new_size > 0` new_size -= (address.offset - new_address.offset) new_address = address index += 1 elif new_address == address: # case 3 # |--------------| # |--------| if new_address + new_size < address + size: segments.insert(index, MemorySegmentAccess(MemorySegment(new_address, new_size), accesses + new_accesses)) segments[index + 1] = MemorySegmentAccess( MemorySegment(new_address + new_size, size - new_size), accesses) return first_found_index if first_found_index is not None else index # case 4 # |--------------| # |--------------| elif new_address + new_size == address + size: segments[index] = MemorySegmentAccess(MemorySegment(new_address, new_size), accesses + new_accesses) return first_found_index if first_found_index is not None else index # case 5 # |--------------| # |------------------| # new_address + new_size > address + size else: # Current segment's accesses are augmented by new segment's accesses segments[index] = MemorySegmentAccess(MemorySegment(address, size), accesses + new_accesses) if first_found_index is None: first_found_index = index # The different part of the new segment will be handled in the next iteration # # Since: # - `new_address == address` and `new_address + new_size > address + size`, so # - `new_size > size` # then invariant `new_size > 0` new_address = address + size new_size = new_size - size index += 1 # new_address > address else: # case 6 # |--------------| # |-----------| if new_address >= address + size: index += 1 # case 7 # |--------------| # |-----------| else: # Split the current segment into: # - the different part segments[index] = MemorySegmentAccess( MemorySegment(address, new_address.offset - address.offset), accesses) # - the common part segments.insert(index + 1, MemorySegmentAccess( MemorySegment(new_address, address.offset + size - new_address.offset), accesses)) # The `new_segment` will be handled in the next iteration by either case (3), (4) or (5) index += 1 segments.append(MemorySegmentAccess(MemorySegment(new_address, new_size), new_accesses)) index = len(segments) - 1 return first_found_index if first_found_index is not None else index ThreadId = int # Get the common memory accesses of two threads of the same process. # Preconditions on both parameters: # 1. The list of MemorySegmentAccess is sorted by address and size # 2. For each MemorySegmentAccess access, `access.segment.size > 0` # Each thread has a sorted (by address and size) memory segment access list; segments are not empty (i.e `size > 0`). def get_common_memory_accesses( first_thread_segment_accesses: Tuple[ThreadId, List[MemorySegmentAccess[Tuple[Transition, MemoryAccessOperation]]]], second_thread_segment_accesses: Tuple[ThreadId, List[MemorySegmentAccess[Tuple[Transition, MemoryAccessOperation]]]]) \ -> List[Tuple[MemorySegment, Tuple[List[Tuple[Transition, MemoryAccessOperation]], List[Tuple[Transition, MemoryAccessOperation]]]]]: (first_thread_id, first_segment_accesses) = first_thread_segment_accesses (second_thread_id, second_segment_accesses) = second_thread_segment_accesses # Merge these lists into a new list merged_accesses: List[MemorySegmentAccess[Tuple[ThreadId, Transition, MemoryAccessOperation]]] = [] i = j = 0 while i < len(first_segment_accesses) and j < len(second_segment_accesses): (first_segment, first_accesses) = (first_segment_accesses[i].segment, first_segment_accesses[i].accesses) (second_segment, second_accesses) = (second_segment_accesses[j].segment, second_segment_accesses[j].accesses) first_threaded_accesses = [ (first_thread_id, trans, mem_operation) for (trans, mem_operation) in first_accesses] second_threaded_accesses = [ (second_thread_id, trans, mem_operation) for (trans, mem_operation) in second_accesses] (first_address, first_size) = (first_segment.address, first_segment.size) (second_address, second_size) = (second_segment.address, second_segment.size) if (first_address, first_size) < (second_address, second_size): merged_accesses.append(MemorySegmentAccess(first_segment, first_threaded_accesses)) i += 1 elif (first_address, first_size) > (second_address, second_size): merged_accesses.append(MemorySegmentAccess(second_segment, second_threaded_accesses)) j += 1 else: merged_accesses.append(MemorySegmentAccess( first_segment, first_threaded_accesses + second_threaded_accesses)) i += 1 j += 1 while i < len(first_segment_accesses): (first_segment, first_accesses) = (first_segment_accesses[i].segment, first_segment_accesses[i].accesses) first_threaded_accesses = [ (first_thread_id, trans, mem_operation) for (trans, mem_operation) in first_accesses] merged_accesses.append(MemorySegmentAccess(first_segment, first_threaded_accesses)) i += 1 while j < len(second_segment_accesses): (second_segment, second_accesses) = (second_segment_accesses[j].segment, second_segment_accesses[j].accesses) second_threaded_accesses = [ (second_thread_id, trans, mem_operation) for (trans, mem_operation) in second_accesses] merged_accesses.append(MemorySegmentAccess(second_segment, second_threaded_accesses)) j += 1 # The merged list needs to be segmented again to handle the case of overlapping segments coming from different # threads. # We start from an empty list `refined_accesses`, and gradually insert new segment into it. # # The list merged_accesses is sorted already, then the position where a segment is inserted will be always larger # or equal the inserting position of the previous segment. We can use the inserting position of an segment as # the starting position when looking for the position to insert the next segment. # # Though the complexity of `insert_memory_segment_access` is O(n) in average case, the insertion in the loop # below happens mostly at the end of the list (with the complexity O(1)). The complexity of the loop is still O(n). refined_accesses: List[MemorySegmentAccess[Tuple[ThreadId, Transition, MemoryAccessOperation]]] = [] last_inserted_index = None for refined_access in merged_accesses: last_inserted_index = insert_memory_segment_access(refined_access, refined_accesses, last_inserted_index) common_accesses = [] for refined_access in refined_accesses: first_thread_rws = [] second_thread_rws = [] for (thread_id, transition, operation) in refined_access.accesses: if thread_id == first_thread_id: first_thread_rws.append((transition, operation)) else: second_thread_rws.append((transition, operation)) if first_thread_rws and second_thread_rws: common_accesses.append((refined_access.segment, (first_thread_rws, second_thread_rws))) return common_accesses # Return True if the transition should be excluded by the caller for data race detection. def transition_excluded_by_heuristics(sco: RuntimeHelper, trans: Transition, blacklist: Set[int]) -> bool: if trans.id in blacklist: return True # Memory accesses of the first transition are considered non-protected try: trans_ctxt = trans.context_before() except IndexError: return False trans_location = trans_ctxt.ossi.location() if trans_location is None: return False if trans_location.binary is None: return False trans_binary_path = trans_location.binary.path process_thread_core_dlls = [ 'c:/windows/system32/ntdll.dll', 'c:/windows/system32/csrsrv.dll', 'c:/windows/system32/basesrv.dll' ] if trans_binary_path not in process_thread_core_dlls: return False if trans_location.symbol is None: return False symbols = sco.symbols # Accesses of the synchronization APIs themselves are not counted. thread_sync_apis = [ symbols[name] for name in [ 'RtlInitializeCriticalSection', 'RtlInitializeCriticalSectionEx', 'RtlInitializeCriticalSectionAndSpinCount', 'RtlEnterCriticalSection', 'RtlTryEnterCriticalSection', 'RtlLeaveCriticalSection' ] if symbols[name] is not None ] if trans_location.symbol in thread_sync_apis: blacklist.add(trans.id) return True trans_frames = trans_ctxt.stack.frames() # Accesses of thread create/shutdown are not counted thread_create_apis = [ api for api in [ symbols['LdrpInitializeThread'], symbols['BaseSrvCreateThread'], symbols['CsrCreateThread'] ] if api is not None ] thread_shutdown_apis = [ api for api in [ symbols['CsrDereferenceThread'], symbols['CsrThreadRefcountZero'], symbols['LdrShutdownThread'] ] if api is not None ] for frame in trans_frames: frame_ctxt = frame.first_context frame_location = frame_ctxt.ossi.location() if frame_location is None: continue if frame_location.binary is None: continue if frame_location.binary.path not in process_thread_core_dlls: continue if frame_location.symbol is None: continue if ( frame_location.symbol in thread_create_apis or frame_location.symbol in thread_shutdown_apis or frame_location.symbol in thread_sync_apis ): blacklist.add(trans.id) return True return False def get_threads_locks_unlocks(trace: RuntimeHelper, binary: Optional[str], process_ranges: List[Tuple[Context, Context]]) \ -> Dict[ThreadId, List[Tuple[SynchronizationAction, Context]]]: threads_ranges: Dict[int, List[Tuple[Context, Context]]] = {} for (ctxt_lo, ctxt_hi) in process_ranges: tid = RuntimeHelper.thread_id(ctxt_lo) if tid not in threads_ranges: threads_ranges[tid] = [] threads_ranges[tid].append((ctxt_lo, ctxt_hi)) threads_locks_unlocks: Dict[ThreadId, List[Tuple[SynchronizationAction, Context]]] = {} for tid, ranges in threads_ranges.items(): threads_locks_unlocks[tid] = get_locks_unlocks(trace, binary, ranges) return threads_locks_unlocks # Get all segmented memory accesses of each thread given a list of context ranges. # Return a map: thread_id -> list(((mem_addr, mem_size), list((thread_id, transition, mem_operation)))) def get_threads_segmented_memory_accesses(trace: RuntimeHelper, process_ranges: List[Tuple[Context, Context]]) \ -> Dict[ThreadId, List[MemorySegmentAccess[Tuple[Transition, MemoryAccessOperation]]]]: sorted_threads_segmented_memory_accesses: Dict[ ThreadId, List[MemorySegmentAccess[Tuple[Transition, MemoryAccessOperation]]] ] = {} for (ctxt_lo, ctxt_hi) in process_ranges: tid = RuntimeHelper.thread_id(ctxt_lo) if tid not in sorted_threads_segmented_memory_accesses: sorted_threads_segmented_memory_accesses[tid] = [] for access in trace.get_memory_accesses(ctxt_lo, ctxt_hi): access_virtual_address = access.virtual_address if access_virtual_address is None: # DMA access continue sorted_threads_segmented_memory_accesses[tid].append(MemorySegmentAccess( MemorySegment(access_virtual_address, access.size), [(access.transition, access.operation)])) # The memory segment accesses of each thread are sorted by address and size of segments, that will help to improve # the performance of looking for common accesses of two threads. # The complexity of sorting is O(n * logn) # Note that segments can be still overlapped. for tid in sorted_threads_segmented_memory_accesses.keys(): sorted_threads_segmented_memory_accesses[tid].sort(key=lambda x: (x.segment.address, x.segment.size)) threads_segmented_memory_accesses: Dict[ ThreadId, List[MemorySegmentAccess[Tuple[Transition, MemoryAccessOperation]]] ] = {} # The non-overlapped memory segment access list of each thread is built by: # - start from an empty list # - gradually insert segment into the list using `insert_memory_segment_access` # Since the original segment lists are sorted, the complexity of the build is O(n). The total complexity of # constructing the non-overlapped memory segment accesses of each thread is 0(n * logn) then. for tid in sorted_threads_segmented_memory_accesses.keys(): threads_segmented_memory_accesses[tid] = [] last_mem_acc_index = None for seg_mem_acc in sorted_threads_segmented_memory_accesses[tid]: last_mem_acc_index = insert_memory_segment_access( seg_mem_acc, threads_segmented_memory_accesses[tid], last_mem_acc_index ) return threads_segmented_memory_accesses TransitionId = int InstructionPointer = int LockHandle = int def detect_data_race(trace: RuntimeHelper, pid: int, binary: Optional[str], begin_transition_id: Optional[TransitionId], end_transition_id: Optional[TransitionId]): def handlers_to_string(handlers: Iterable[int]): msgs = [f'{handle:#x}' for handle in handlers] message = ', '.join(msgs) return message def get_live_lock_handles( sco: RuntimeHelper, locks_unlocks: List[Tuple[SynchronizationAction, Context]], trans: Transition ) -> Tuple[List[Context], Set[LockHandle]]: locks = get_live_locks(sco, locks_unlocks, trans) return (locks, {RuntimeHelper.get_lock_handle(ctxt) for ctxt in locks}) (first_context, last_context) = find_begin_end_context(trace, pid, binary, begin_transition_id, end_transition_id) process_ranges = list(find_process_usermode_ranges(trace, pid, first_context, last_context)) threads_memory_accesses = get_threads_segmented_memory_accesses(trace, process_ranges) threads_locks_unlocks = get_threads_locks_unlocks(trace, binary, process_ranges) thread_ids: Set[ThreadId] = set() # map from a transition (of a given thread) to a tuple whose # - the first is a list of contexts of calls (e.g. RtlEnterCriticalSection) to lock # - the second is a set of handles used by these calls cached_live_locks_handles: Dict[Tuple[ThreadId, TransitionId], Tuple[List[Context], Set[LockHandle]]] = {} cached_pcs: Dict[TransitionId, InstructionPointer] = {} filtered_trans_ids: Set[TransitionId] = set() for first_tid, first_thread_accesses in threads_memory_accesses.items(): thread_ids.add(first_tid) for second_tid, second_thread_accesses in threads_memory_accesses.items(): if second_tid in thread_ids: continue common_accesses = get_common_memory_accesses( (first_tid, first_thread_accesses), (second_tid, second_thread_accesses) ) for (segment, (first_accesses, second_accesses)) in common_accesses: (segment_address, segment_size) = (segment.address, segment.size) race_pcs: Set[Tuple[InstructionPointer, InstructionPointer]] = set() for (first_transition, first_rw) in first_accesses: if first_transition.id in cached_pcs: first_pc = cached_pcs[first_transition.id] else: first_pc = first_transition.context_before().read(x64.rip) if (first_tid, first_transition.id) not in cached_live_locks_handles: critical_section_locks_unlocks = threads_locks_unlocks[first_tid] cached_live_locks_handles[(first_tid, first_transition.id)] = get_live_lock_handles( trace, critical_section_locks_unlocks, first_transition) for (second_transition, second_rw) in second_accesses: if second_transition.id in cached_pcs: second_pc = cached_pcs[second_transition.id] else: second_pc = second_transition.context_before().read(x64.rip) cached_pcs[second_transition.id] = second_pc # the execution of an instruction is considered atomic if first_pc == second_pc: continue if (second_tid, second_transition.id) not in cached_live_locks_handles: critical_section_locks_unlocks = threads_locks_unlocks[second_tid] cached_live_locks_handles[(second_tid, second_transition.id)] = get_live_lock_handles( trace, critical_section_locks_unlocks, second_transition) # Skip if both are the same operation (i.e. read/read or write/write) if first_rw == second_rw: continue first_is_write = first_rw == MemoryAccessOperation.Write second_is_write = second_rw == MemoryAccessOperation.Write if first_transition < second_transition: before_mem_opr_str = 'write' if first_is_write else 'read' after_mem_opr_str = 'write' if second_is_write else 'read' before_trans = first_transition after_trans = second_transition before_tid = first_tid after_tid = second_tid else: before_mem_opr_str = 'write' if second_is_write else 'read' after_mem_opr_str = 'write' if first_is_write else 'read' before_trans = second_transition after_trans = first_transition before_tid = second_tid after_tid = first_tid # Duplicated if (first_pc, second_pc) in race_pcs: continue race_pcs.add((first_pc, second_pc)) # ===== No data race message = f'No data race during {after_mem_opr_str} at \ [{segment_address.offset:#x}, {segment_size}] (transition #{after_trans.id}) by thread {after_tid}\n\ with a previous {before_mem_opr_str} (transition #{before_trans.id}) by thread {before_tid}' # There is a common critical section synchronizing the accesses live_critical_sections: Dict[ThreadId, Set[LockHandle]] = { first_tid: cached_live_locks_handles[(first_tid, first_transition.id)][1], second_tid: cached_live_locks_handles[(second_tid, second_transition.id)][1] } common_critical_sections = live_critical_sections[first_tid].intersection( live_critical_sections[second_tid] ) if common_critical_sections: if not hide_synchronized_accesses: handlers_str = handlers_to_string(common_critical_sections) print(f'{message}: synchronized by critical section(s) ({handlers_str}).\n') continue # Accesses are excluded since they are in thread create/shutdown are_excluded = transition_excluded_by_heuristics( trace, first_transition, filtered_trans_ids ) and transition_excluded_by_heuristics(trace, second_transition, filtered_trans_ids) if are_excluded: if not hide_synchronized_accesses: print(f'{message}: shared accesses in create/shutdown threads') continue # ===== Data race # If there is no data race, then show the un-synchronized accesses in the following order # 1. one of them are protected by one or several locks # 2. both of them are not protected by any lock message = f'Possible data race during {after_mem_opr_str} at \ [{segment_address.offset:#x}, {segment_size}] (transition #{after_trans.id}) by thread {after_tid}\n\ with a previous {before_mem_opr_str} (transition #{before_trans.id}) by thread {before_tid}' if live_critical_sections[first_tid] or live_critical_sections[second_tid]: print(f'{message}.') for tid in {first_tid, second_tid}: if live_critical_sections[tid]: handlers_str = handlers_to_string(live_critical_sections[tid]) print(f'\tCritical section handles(s) used by {tid}: {handlers_str}\n') elif not suppress_unknown_primitives: print(f'{message}: no critical section used.\n') # %% trace = RuntimeHelper(host, port) detect_data_race(trace, pid, binary, begin_trans_id, end_trans_id)