Find all memory accesses that are accessed a given symbol

Purpose

This notebook and script are designed to find all memory accesses that are accessed by a given symbol. This script searches a Reven trace for all memory accesses that are accessed by a given symbol. The script can filter the results by processes, threads, ring, context range and memory access operation. The script can generate two kinds of results:

  • process, binary and symbol call information and all its memory accesses.
  • for this symbol, all its call with all the memory accesses that occurred in that symbol call. Note that this option can take long time to start showing results, Note that:
  • this script allow to include/exclude memory accesses that occurred in children symbol calls of each symbol call.

How to use

Results can be generated from this notebook or from the command line. The script can also be imported as a module for use from your own script or notebook.

From the notebook

  1. Upload the memory_ranges_accessed_by_a_symbol.ipynb file in Jupyter.
  2. Fill out the parameters cell of this notebook according to your scenario and desired output.
  3. Run the full notebook.

From the command line

  1. Make sure that you are in an environment that can run REVEN scripts.
  2. Run python memory_ranges_accessed_by_a_symbol.py --help to get a tour of available arguments.
  3. Run python memory_ranges_accessed_by_a_symbol.py --host <your_host> --port <your_port> [<other_option>] with your arguments of choice.

Imported in your own script or notebook

  1. Make sure that you are in an environment that can run REVEN scripts.
  2. Make sure that memory_ranges_accessed_by_a_symbol.py is in the same directory as your script or notebook.
  3. Add import memory_ranges_accessed_by_a_symbol to your script or notebook. You can access the various functions and classes exposed by the module from the memory_ranges_accessed_by_a_symbol namespace.
  4. Refer to the Argument parsing cell for an example of use in a script, and to the Parameters cell and below for an example of use in a notebook (you just need to preprend memory_ranges_accessed_by_a_symbol in front of the functions and classes from the script).

Known limitations

When using the "table", "csv", "html" output format, this script might require a large quantity of RAM due to the data being retained in memory. If you notice an important RAM usage, you can try the following:

  • Restart with the "raw" format
  • Split the results using the from_context and to_context parameters
  • Use the provided filters (ring, processes, threads) to reduce the number of results

Supported versions

REVEN 2.10+

Supported perimeter

Any REVEN scenario.

Dependencies

The script requires that the target REVEN scenario have:

  • The OSSI feature replayed.
  • The memory history feature replayed.
  • pandas python module

Source

# ---
# jupyter:
#   jupytext:
#     formats: ipynb,py:percent
#     text_representation:
#       extension: .py
#       format_name: percent
#   kernelspec:
#     display_name: reven
#     language: python
#     name: reven-python3
# ---

# %% [markdown]
# # Find all memory accesses that are accessed a given symbol
#
# ## Purpose
#
# This notebook and script are designed to find all memory accesses that are accessed by a given symbol.
#
# This script searches a Reven trace for all memory accesses that are accessed by a given symbol.
# The script can filter the results by processes, threads, ring, context range and memory access operation.
#
# The script can generate two kinds of results:
# - process, binary and symbol call information and all its memory accesses.
# - for this symbol, all its call with all the memory accesses that occurred in that symbol call.
#   Note that this option can take long time to start showing results,
#
# Note that:
# - this script allow to include/exclude memory accesses that occurred in children symbol calls of each symbol call.
#
#
#
# ## How to use
#
# Results can be generated from this notebook or from the command line.
# The script can also be imported as a module for use from your own script or notebook.
#
#
# ### From the notebook
#
# 1. Upload the `memory_ranges_accessed_by_a_symbol.ipynb` file in Jupyter.
# 2. Fill out the [parameters](#Parameters) cell of this notebook according to your scenario and desired output.
# 3. Run the full notebook.
#
#
# ### From the command line
#
# 1. Make sure that you are in an
#    [environment](http://doc.tetrane.com/professional/latest/Python-API/Installation.html#on-the-reven-server)
#    that can run REVEN scripts.
# 2. Run `python memory_ranges_accessed_by_a_symbol.py --help` to get a tour of available arguments.
# 3. Run `python memory_ranges_accessed_by_a_symbol.py --host <your_host> --port <your_port> [<other_option>]`
#    with your arguments of choice.
#
# ### Imported in your own script or notebook
#
# 1. Make sure that you are in an
#    [environment](http://doc.tetrane.com/professional/latest/Python-API/Installation.html#on-the-reven-server)
#    that can run REVEN scripts.
# 2. Make sure that `memory_ranges_accessed_by_a_symbol.py` is in the same directory as your script or notebook.
# 3. Add `import memory_ranges_accessed_by_a_symbol` to your script or notebook. You can access the various functions
#     and classes exposed by the module from the `memory_ranges_accessed_by_a_symbol` namespace.
# 4. Refer to the [Argument parsing](#Argument-parsing) cell for an example of use in a script, and to the
#    [Parameters](#Parameters) cell and below for an example of use in a notebook (you just need to preprend
#    `memory_ranges_accessed_by_a_symbol` in front of the functions and classes from the script).
#
# ## Known limitations
#
# When using the "table", "csv", "html" output format, this script might require a large quantity of RAM due to the
# data being retained in memory. If you notice an important RAM usage, you can try the following:
#
# - Restart with the "raw" format
# - Split the results using the `from_context` and `to_context` parameters
# - Use the provided filters (ring, processes, threads) to reduce the number of results
#
# ## Supported versions
#
# REVEN 2.10+
#
# ## Supported perimeter
#
# Any REVEN scenario.
#
# ## Dependencies
#
# The script requires that the target REVEN scenario have:
#
# * The OSSI feature replayed.
# * The memory history feature replayed.
# * pandas python module

# %% [markdown]
# ### Package imports

# %%
import argparse
import re
import sys
from enum import Enum
from typing import Callable as _Callable, Iterable as _Iterable, Optional as _Optional

from IPython.core.display import display  # type: ignore

import reven2.arch as _arch
from reven2.filter import RingPolicy
from reven2.memhist import MemoryAccess, MemoryAccessOperation
from reven2.ossi import Binary, Process, Symbol
from reven2.ossi.thread import Thread
from reven2.prelude import RevenServer
from reven2.stack import Stack
from reven2.trace import Context, Trace, Transition

# %% [markdown]
# ### Utility functions


# %%
# Detect if we are currently running a Jupyter notebook.
#
# This is used e.g. to display rendered results inline in Jupyter when we are executing in the context of a Jupyter
# notebook, or to display raw results on the standard output when we are executing in the context of a script.
def in_notebook():
    try:
        from IPython import get_ipython  # type: ignore

        if get_ipython() is None or ("IPKernelApp" not in get_ipython().config):
            return False
    except ImportError:
        return False
    return True


# %% [markdown]
# ### Helper classes for results

# %%
class CallSymbol:
    r"""
    CallSymbol is a helper class used to represent a symbol with its start and end context
    """

    def __init__(self, symbol: _Optional[Symbol], start: Context, end: _Optional[Context] = None) -> None:
        self._symbol = symbol
        self._start = start
        self._end = end

    @property
    def symbol(self) -> _Optional[Symbol]:
        r"""
        B{Property:} The symbol of the call symbol. None if the symbol is unknown.
        """
        return self._symbol

    @property
    def start_context(self) -> Context:
        r"""
        B{Property:} The start context of the call symbol.
        """
        return self._start

    @property
    def end_context(self) -> _Optional[Context]:
        r"""
        B{Property:} The end excluded context of the call symbol. None if the end context isn't in the trace.
        """
        return self._end

    def __eq__(self, other: "CallSymbol") -> bool:  # type: ignore
        return self._symbol == other._symbol and self._start == other._start and self._end == other._end

    def __ne__(self, other: "CallSymbol") -> bool:  # type: ignore
        return not (self == other)


class HtmlStr:
    r"""
    Helper class used with notebook special `display` function to
    consider the HTML string as HTML.
    """

    def __init__(self, html: str) -> None:
        self._html = html

    def _repr_html_(self):
        return self._html


class MemoryRangeSymbolResult:
    r"""
    MemoryRangeSymbolResult is a helper class that represents one result of the search.
    """

    def __init__(
        self,
        call_symbol: CallSymbol,
        memory_accesses: _Iterable[MemoryAccess],
        ring: int,
        process: _Optional[Process],
        thread: _Optional[Thread],
        binary: _Optional[Binary],
    ) -> None:
        self._call_symbol = call_symbol
        self._memory_accesses = memory_accesses
        self._ring = ring
        self._process = process
        self._thread = thread
        self._binary = binary

    @property
    def call_symbol(self) -> CallSymbol:
        r"""
        B{Property:} The call symbol of the result.
        """
        return self._call_symbol

    @property
    def memory_accesses(self) -> _Iterable[MemoryAccess]:
        r"""
        B{Property:} The memory accesses of the result.

        Calling this property will consume the generator
        """
        return self._memory_accesses

    @property
    def ring(self) -> int:
        r"""
        B{Property:} The ring of the result.
        """
        return self._ring

    @property
    def process(self) -> _Optional[Process]:
        r"""
        B{Property:} The process of the result.
        """
        return self._process

    @property
    def binary(self) -> _Optional[Binary]:
        r"""
        B{Property:} The binary of the result, None if unknown.
        """
        return self._binary

    @property
    def thread(self) -> _Optional[Thread]:
        r"""
        B{Property:} The thread of the result.
        """
        return self._thread

    def __eq__(self, other: "MemoryRangeSymbolResult") -> bool:  # type: ignore
        return (
            self._ring == other._ring
            and self._process is not None
            and other._process is not None
            and self._process.name == other._process.name
            and self._process.pid == other._process.pid
            and self._process.ppid == other._process.ppid
            and self._thread is not None
            and other._thread is not None
            and self._thread.id == other._thread.id
            and self._thread.owner_process_id == other._thread.owner_process_id
            and (
                (self._binary is None and other._binary is None)
                or (self._binary is not None and other._binary is not None and self._binary.path == other._binary.path)
            )
            and self._call_symbol == other._call_symbol
        )

    def __ne__(self, other: "MemoryRangeSymbolResult") -> bool:  # type: ignore
        return not (self == other)

    def output(self, print_func: _Callable, is_in_notebook: bool):
        r"""
        Output this result using the `print_func` or call `output_in_notebook` if `is_in_notebook`
        argument is true.
        """
        if is_in_notebook:
            self.output_in_notebook()
        else:
            print_func(
                f"ring: {self._ring}, process: {self._process}, "
                f"thread: {self._thread}, binary: {self._binary}, "
                f"symbol: {self._call_symbol.symbol}[{self._call_symbol.start_context}, "
                f"{self._call_symbol.end_context}["
            )
            print_func("\nmemory accesses:")
            for m in self._memory_accesses:
                print_func(f"\n\t{m}, ")
            print_func("\n")

    def output_in_notebook(self) -> None:
        r"""
        Output this result using the special notebook `display` function
        """
        end_context = self._call_symbol.end_context
        display(
            HtmlStr(
                f"<p>ring: {self._ring}, process: {self._process if self._process is not None else 'unknown'}, "
                f"thread: {self._thread if self._thread is not None else 'unknown'}, "
                f"binary: {self._binary if self._binary is not None else 'unknwon'}, "
                f"symbol: {self._call_symbol.symbol}[{self._call_symbol.start_context.format_as_html()}, "
                f"{None if end_context is None else end_context.format_as_html()}[</p>"
            )
        )
        display(HtmlStr("<p>memory accesses:</p>"))
        for m in self._memory_accesses:
            display(HtmlStr(f'<p style="text-indent: 2em;"> &#9673; {m.format_as_html()}</p>'))


class GroupedMemoryRangeSymbolResult:
    r"""
    GroupedMemoryRangeSymbolResult is a helper class that represents results of the search grouped by symbol.
    """

    def __init__(
        self,
        symbol: Symbol,
        memory_range_symbol_result: _Iterable[MemoryRangeSymbolResult],
    ) -> None:
        self._symbol = symbol
        self._memory_range_symbol_result = memory_range_symbol_result

    def output(self, print_func: _Callable, is_in_notebook: bool) -> None:
        r"""
        Output these results using the `print_func` or call `output_in_notebook` if `is_in_notebook`
        argument is true.
        """
        if is_in_notebook:
            self.output_in_notebook()
        else:
            print_func(f"{self._symbol}\n\nCalls")
            for res in self._memory_range_symbol_result:
                print_func(
                    f"\n\t[{res.call_symbol.start_context.format_as_html()}, "
                    f"{None if res.call_symbol.end_context is None else res.call_symbol.end_context.format_as_html()}]"
                    f"\n\tring: {res.ring}, process: {res.process if res.process is not None else 'unknown'}, "
                    f" thread: {res.thread if res.thread is not None else 'unknown'}, binary: {res.binary}, "
                )
                print_func("\n\tmemory accesses:")
                for m in res.memory_accesses:
                    print_func(f"\n\t\t{m}, ")
                print_func("\n")

    def output_in_notebook(self) -> None:
        r"""
        Output these results using the special notebook `display` function
        """
        display(HtmlStr(f'<p>{self._symbol}</p><p style="font-weight: bolder;">Calls</p>'))
        for res in self._memory_range_symbol_result:
            end_context = res.call_symbol.end_context
            display(
                HtmlStr(
                    f'<p style="text-indent: 2em;"> &#9678; [{res.call_symbol.start_context.format_as_html()}, '
                    f"{None if end_context is None else end_context.format_as_html()}]</p>"
                    f'<p style="text-indent: 3em;">ring: {res.ring}, process: '
                    f"{res.process if res.process is not None else 'unknown'}, "
                    f" thread: {res.thread if res.thread is not None else 'unknown'}, binary: {res.binary}</p>"
                )
            )
            display(HtmlStr('<p style="text-indent: 3em; font-weight: bolder;">memory accesses:</p>'))
            for m in res.memory_accesses:
                display(HtmlStr(f'<style="text-indent: 4em;"> &#9673; {m.format_as_html()}</p>'))


# %% [markdown]
# ### SymbolMemoryAccessesFinder
#
# This class represents the main logic of this script

# %%
class SymbolMemoryAccessesFinder(object):
    r"""
        This class is a helper class to search for all memory accesses that are accessed a given symbol.
        Results can be filtered by processes, ring, threads, memory access operation and a context range.

        The memory accesses that are accessed by the given symbol are returned.

        Examples
        ========
        >>> # search all memory accesses that are accessed by the calls of `ExCompareExchangeCallBack`
        >>> # symbol
        >>> import reven2
        >>> server = reven2.RevenServer('localhost', 46445)
        >>> symbol = next(server.ossi.symbols(pattern="ExCompareExchangeCallBack"))
        >>> finder = SymbolMemoryAccessesFinder(server.trace, symbol, with_children_symbols=False)
        >>> for res in finder.query():
        ...     print(res)
        ring: 0, process: System (4), thread: 932, binary: c:/windows/system32/ntoskrnl.exe,
        symbol: ntoskrnl!ExCompareExchangeCallBack[Context before #492474770, Context before #492474830[
    memory accesses:
            [#492474770 mov qword ptr ss:[rsp+0x8], rbx]Write access at
            @phy:0x7297a8c0 (virtual address: lin:0xfffff880045ca8c0) of size 8,
            [#492474771 mov qword ptr ss:[rsp+0x10], rbp]Write access at
            @phy:0x7297a8c8 (virtual address: lin:0xfffff880045ca8c8) of size 8,
            [#492474772 mov qword ptr ss:[rsp+0x18], rsi]Write access at
            @phy:0x7297a8d0 (virtual address: lin:0xfffff880045ca8d0) of size 8,
            ...
    """

    def __init__(
        self,
        trace: Trace,
        symbol: Symbol,
        with_children_symbols: bool = True,
        from_context: _Optional[Context] = None,
        to_context: _Optional[Context] = None,
        ring_policy: RingPolicy = RingPolicy.All,
        processes: _Optional[_Iterable[Process]] = None,
        threads: _Optional[_Iterable[int]] = None,
        memory_access_operation: MemoryAccessOperation = None,
    ) -> None:
        r"""
        Initialize a C{SymbolMemoryAccessesFinder}

        Information
        ===========

        @param trace: the trace where memory accesses will be looked for.
        @param symbol: the symbol that for it the memory accesses will be returned.
        @param from_context: the context where the search will be started.
        @param to_context: the context where the search will be ended.
        @param ring_policy: ring policy to search for.
        @param processes: processes to limit the search in it. If None, all processes will be filtered.
        @param threads: thread ids to limit the search in it. If None, all threads will be filtered.
        @param operation: limit results to accesses performing the specified operation.


        @raises TypeError: if trace is not a C{reven2.trace.Trace}.
        """

        if not isinstance(trace, Trace):
            raise TypeError("You must provide a valid trace")
        self._trace = trace
        self._symbol = symbol
        self._with_children_symbols = with_children_symbols
        self._from_context = from_context
        self._to_context = to_context
        self._ring_policy = ring_policy
        self._memory_access_operation = memory_access_operation
        self._processes = None if processes is None else [process for process in processes]
        self._threads = None if threads is None else [thread for thread in threads]

    def _without_children_accesses(
        self, first_transition: Transition, last_transition: _Optional[Transition]
    ) -> _Iterable[MemoryAccess]:
        if last_transition is None:
            last_transition = self._trace.last_transition
        sub_first_transition = sub_last_transition = next_transition = first_transition
        while sub_last_transition is not None and sub_last_transition < last_transition:
            next_transition_opt = next_transition.step_over()
            if next_transition_opt is None:
                return
            next_transition = next_transition_opt
            if next_transition.id - sub_last_transition.id > 1:
                # request memory accesses for the range [first_transition, sub_last_transition +1[
                for mem_access in self._trace.memory_accesses(
                    from_transition=sub_first_transition, to_transition=sub_last_transition + 1
                ):
                    yield mem_access

                sub_last_transition = sub_first_transition = next_transition
            else:
                sub_last_transition = next_transition

    def _with_children_accesses(
        self, first_transition: Transition, last_transition: _Optional[Transition]
    ) -> _Iterable[MemoryAccess]:
        # get memory access
        for mem_access in self._trace.memory_accesses(from_transition=first_transition, to_transition=last_transition):
            yield mem_access

    def _is_the_same_stack(self, stack1: Stack, stack2: Stack) -> bool:
        # we assume that two stacks are the same if the first contexts of their first frames are the same

        frame1 = next(stack1.frames())
        frame2 = next(stack2.frames())

        return frame1.first_context == frame2.first_context

    def _query_mem_accesses(
        self,
        first_transition: Transition,
        last_transition: _Optional[Transition],
        with_children_symbols: bool,
        memory_access_operation: _Optional[MemoryAccessOperation],
    ) -> _Iterable[MemoryAccess]:
        mem_query = (
            self._with_children_accesses(first_transition, last_transition)
            if with_children_symbols
            else self._without_children_accesses(first_transition, last_transition)
        )

        # store last handled stack to use it if we are in the same stack
        last_stack: _Optional[Stack] = None
        last_stack_is_taken: bool = True

        for mem_access in mem_query:
            if memory_access_operation is not None and mem_access.operation != memory_access_operation:
                continue
            # ignore accesses without linear address
            if mem_access.virtual_address is None:
                continue

            stack = mem_access.transition.context_before().stack
            # if the current stack is the same as the last stack we can yield the access
            if last_stack is not None and self._is_the_same_stack(last_stack, stack):
                if last_stack_is_taken:
                    yield mem_access
                continue

            last_stack = stack

            # we need to be sure that we are always in the current symbol
            # a simple method is to test if the current symbol is in the backtrace of
            # this access's context.
            take_access = False
            for frame in stack.frames():
                loc = frame.first_context.ossi.location()
                if loc is not None and self._symbol == loc.symbol:
                    take_access = True
                    break

            if not take_access:
                last_stack_is_taken = False
                continue
            last_stack_is_taken = True
            yield mem_access

    def filter_by_threads(self, threads: _Iterable[int]) -> "SymbolMemoryAccessesFinder":
        r"""
        Extend the list of threads to limit the search in, and return the self object.

        Information
        ===========

        @param threads: threads to limit the search in.
        @returns : self object
        """
        if self._threads is None:
            self._threads = []
        self._threads += [thread for thread in threads]

        return self

    def filter_by_processes(self, processes: _Iterable[Process]) -> "SymbolMemoryAccessesFinder":
        r"""
        Extend the list of processes to limit the search in, and return the self object.

        Information
        ===========

        @param processes: processes to limit the search in.
        @returns : self object
        """
        if self._processes is None:
            self._processes = []
        self._processes += [process for process in processes]

        return self

    def filter_by_ring(self, ring_policy: RingPolicy) -> "SymbolMemoryAccessesFinder":
        r"""
        Update the ring policy to search for and return the `self` object.

        Information
        ===========

        @param ring_policy: ring policy to search for.
        @returns : self object
        """
        self._ring_policy = ring_policy
        return self

    def from_context(self, context: Context) -> "SymbolMemoryAccessesFinder":
        r"""
        Update the context where the search will be started and return the `self` object.

        Information
        ===========

        @param context: context where the search will be started.
        @returns : self object
        """
        self._from_context = context
        return self

    def to_context(self, context: Context) -> "SymbolMemoryAccessesFinder":
        r"""
        Update the context where the search will be ended and return the `self` object.

        Information
        ===========

        @param context: context where the search will be ended.
        @returns : self object
        """
        self._to_context = context
        return self

    def filter_by_memory_access_operation(
        self, operation: _Optional[MemoryAccessOperation] = None
    ) -> "SymbolMemoryAccessesFinder":
        r"""
        Update the memory access operation to limit results to accesses performing this
        operation and return the `self` object.

        Information
        ===========

        @param operation: limit results to accesses performing the specified operation.
        @returns : self object
        """
        self._operation = operation
        return self

    def query(self) -> _Iterable[MemoryRangeSymbolResult]:
        with_children_symbols = self._with_children_symbols
        memory_access_operation = self._memory_access_operation
        thread_ids = None if self._threads is None else self._threads.copy()
        # filter by process
        for context_range in self._trace.filter(
            processes=self._processes,
            ring_policy=self._ring_policy,
            from_context=self._from_context,
            to_context=self._to_context,
        ):

            last_transition: _Optional[Transition] = None
            # search symbol call
            for context in self._trace.search.symbol(self._symbol, context_range.begin, context_range.end):
                # ignore results that aren't in the list of thread
                if thread_ids is not None:
                    thread = context.ossi.thread()
                    if thread is None or thread.id not in thread_ids:
                        continue
                # we need also to ignore symbols that are recursively called.
                if last_transition is not None and last_transition.context_before() > context:
                    continue
                # if with_children_symbols is true, that means we need to consider
                # the memory accesses in children symbols.
                # use step_out to go out of the symbol
                first_transition = context.transition_after()
                last_transition = first_transition.step_out()
                last_context = None if last_transition is None else last_transition.context_before()
                end_range_transition = (
                    self._trace.last_transition
                    if context_range.end is None or context_range.end == self._trace.last_context
                    else context_range.end.transition_after()
                )
                last_transition = (
                    last_transition
                    if last_transition is None or last_transition < end_range_transition
                    else end_range_transition
                )

                # get the ring of the symbol
                curr_ring = context.read(_arch.x64.cs) & 0x3
                curr_process = context.ossi.process()
                curr_thread = context.ossi.thread()
                curr_location = context.ossi.location()
                curr_binary = None if curr_location is None else curr_location.binary

                yield MemoryRangeSymbolResult(
                    call_symbol=CallSymbol(self._symbol, context, last_context),
                    memory_accesses=self._query_mem_accesses(
                        first_transition,
                        last_transition,
                        with_children_symbols,
                        memory_access_operation,
                    ),
                    ring=curr_ring,
                    process=curr_process,
                    thread=curr_thread,
                    binary=curr_binary,
                )


# %% [markdown]
#
# ### OutputType


# %%
class OutputFormat(Enum):
    r"""
    Enum describing the various possible output formats of the results
     - RAW: The results will be output using its string representation.
     - TABLE: The results will be output using pandas table format.
     - CSV: The results will be output as csv.
     - HTML: The results will be output as html table.
    """
    RAW = 0
    TABLE = 1
    CSV = 2
    HTML = 3


# %% [markdown]
# ### Main function
#
# This function is called with parameters from the [Parameters](#Parameters) cell in the notebook context,
# or with parameters from the command line in the script context.


# %%
def memory_ranges_accessed_by_a_symbol(
    server: RevenServer,
    symbol: str,
    binary_hint: _Optional[str] = None,
    with_children_symbols: bool = True,
    from_context: _Optional[int] = None,
    to_context: _Optional[int] = None,
    ring_policy: RingPolicy = RingPolicy.All,
    processes: _Optional[_Iterable[str]] = None,
    threads: _Optional[_Iterable[int]] = None,
    operation: _Optional[MemoryAccessOperation] = None,
    grouped_by_symbol: bool = False,
    output_format: OutputFormat = OutputFormat.RAW,
    output_file: _Optional[str] = None,
) -> None:
    # get the symbol form the ossi server and raise if it isn't exist
    trace_symbol = None
    symbol_count = 0
    for sym in server.ossi.symbols(pattern=re.escape(symbol), binary_hint=binary_hint):
        if sym.name == symbol:
            symbol_count += 1
            if trace_symbol is None:
                trace_symbol = sym
            if symbol_count == 2:
                print(trace_symbol, file=sys.stderr)
            if symbol_count >= 2:
                print(sym, file=sys.stderr)

    if trace_symbol is None:
        raise ValueError(f"The requested symbol '{symbol}' could not be found")

    if symbol_count > 1:
        sys.exit(
            "Many symbols exist with the same provided symbol name, you may need to provide the symbol's "
            "binary name, please  provide one from the list above"
        )

    # declare memory accesses finder.
    symbol_memory_ranges_finder = SymbolMemoryAccessesFinder(
        trace=server.trace,
        symbol=trace_symbol,
        with_children_symbols=with_children_symbols,
        from_context=(None if from_context is None else server.trace.context_before(from_context)),
        to_context=(None if to_context is None else server.trace.context_before(to_context)),
        threads=threads,
        ring_policy=ring_policy,
        memory_access_operation=operation,
    )

    # filer by processes
    if processes is not None:
        for process in processes:
            symbol_memory_ranges_finder.filter_by_processes(server.ossi.executed_processes(process))

    if output_format == OutputFormat.RAW:
        is_in_notebook = output_file is None and in_notebook()

        def std_print_func(s: str) -> None:
            print(s)

        print_func = std_print_func
        if output_file is not None:
            file = open(output_file, "w")

            def fprint_func(s: str) -> None:
                file.write(s)

            print_func = fprint_func
        if grouped_by_symbol:
            grouped_result = GroupedMemoryRangeSymbolResult(trace_symbol, symbol_memory_ranges_finder.query())

            grouped_result.output(print_func, is_in_notebook)
        else:
            for result in symbol_memory_ranges_finder.query():
                result.output(print_func, is_in_notebook)

        if output_file is not None:
            file.close()
    else:
        column_headers = [
            "Ring",
            "Process",
            "Thread",
            "Binary",
            "Symbol",
            "Start context",
            "Access transition",
            "Access operation",
            "Access physical",
            "Access linear",
            "Access size",
        ]

        def data_generator():
            for result in symbol_memory_ranges_finder.query():
                for mem_access in result.memory_accesses:
                    yield (
                        result.ring,
                        str(result.process) if result.process is not None else "unknown",
                        str(result.thread) if result.thread is not None else "unknown",
                        result.binary.name,
                        result.call_symbol.symbol.name,
                        str(result.call_symbol.start_context),
                        mem_access.transition.id,
                        mem_access.operation.name,
                        mem_access.physical_address,
                        mem_access.virtual_address,
                        mem_access.size,
                    )

        df = pandas.DataFrame(data=data_generator(), columns=column_headers)
        if output_format == OutputFormat.TABLE:
            if output_file is not None:
                with open(output_file, "w") as file:
                    file.write(str(df))
            else:
                print(df)
        elif output_format == OutputFormat.CSV:
            print(df.to_csv()) if output_file is None else df.to_csv(output_file)
        elif output_format == OutputFormat.HTML:
            print(df.to_html()) if output_file is None else df.to_html(output_file)


# %% [markdown]
# ### Argument parsing
#
# Argument parsing function for use in the script context.

# %%
def get_memory_access_operation(operation: str) -> MemoryAccessOperation:
    if operation is None:
        return None
    if operation.lower() == "read":
        return MemoryAccessOperation.Read
    if operation.lower() == "write":
        return MemoryAccessOperation.Write
    raise ValueError(f"'operation' value should be 'read' or 'write'. Received '{operation}'.")


def get_ring_policy(ring: int) -> RingPolicy:
    if ring is None:
        return RingPolicy.All
    if ring == 0:
        return RingPolicy.R0Only
    if ring == 3:
        return RingPolicy.R3Only
    raise ValueError(f"'ring_policy' value should be '0' or '1'. Received '{ring_policy}'.")


def get_output_format(format: str) -> OutputFormat:
    if format.lower() == "raw":
        return OutputFormat.RAW
    if format.lower() == "table":
        return OutputFormat.TABLE
    if format.lower() == "html":
        return OutputFormat.HTML
    if format.lower() == "csv":
        return OutputFormat.CSV
    raise ValueError(f"'output format' value should be 'raw', 'table', 'html', or 'csv'. Received '{format}'.")


def script_main():
    parser = argparse.ArgumentParser(description="Find all memory accesses that are accessed a given symbol")
    parser.add_argument(
        "--host",
        type=str,
        default="localhost",
        required=False,
        help='REVEN host, as a string (default: "localhost")',
    )
    parser.add_argument(
        "-p",
        "--port",
        type=int,
        default="13370",
        required=False,
        help="REVEN port, as an int (default: 13370)",
    )
    parser.add_argument(
        "-s",
        "--symbol",
        type=str,
        required=True,
        help="The symbol whose accesses are looked for (e.g. WriteFile)",
    )
    parser.add_argument(
        "-b",
        "--binary-hint",
        type=str,
        required=False,
        help="The symbol's binary name hint (e.g. ntoskrnl)",
    )
    parser.add_argument(
        "--with-children-symbols",
        action="store_true",
        required=False,
        default=False,
        help="Show accesses from children calls",
    )
    parser.add_argument(
        "--from-context",
        type=int,
        required=False,
        help="The context from where the search starts",
    )
    parser.add_argument(
        "--to-context",
        type=int,
        required=False,
        help="The context(not included) at which the search stops",
    )
    parser.add_argument(
        "--ring",
        type=int,
        required=False,
        help="Show symbol's accesses if it is in this ring only, can be (0=ring0, 3=ring3)",
    )
    parser.add_argument(
        "--processes",
        required=False,
        nargs="*",
        help="Show symbol's accesses if it is in these processes only",
    )
    parser.add_argument(
        "--threads",
        type=int,
        required=False,
        nargs="*",
        help="Show symbol's accesses if it is in these threads only",
    )
    parser.add_argument(
        "--memory-access-operation",
        choices=["read", "write"],
        required=False,
        help="Only show symbols that access the memory range using this operation",
    )
    parser.add_argument(
        "--grouped-by-symbol",
        action="store_true",
        required=False,
        default=False,
        help="Group results by symbol",
    )
    parser.add_argument(
        "-o",
        "--output-file",
        type=str,
        required=False,
        help="The target file of the results. If absent, the results will be printed on the standard output",
    )
    parser.add_argument(
        "--output-format",
        choices=["raw", "table", "csv", "html"],
        required=False,
        default="raw",
        help="Output format of the results",
    )

    args = parser.parse_args()

    try:
        server = RevenServer(args.host, args.port)
    except RuntimeError:
        raise RuntimeError(f"Could not connect to the server on {args.host}:{args.port}.")

    memory_ranges_accessed_by_a_symbol(
        server=server,
        symbol=args.symbol,
        binary_hint=args.binary_hint,
        with_children_symbols=args.with_children_symbols,
        from_context=args.from_context,
        to_context=args.to_context,
        ring_policy=get_ring_policy(args.ring),
        processes=args.processes,
        threads=args.threads,
        operation=get_memory_access_operation(args.memory_access_operation),
        grouped_by_symbol=args.grouped_by_symbol,
        output_format=get_output_format(args.output_format),
        output_file=args.output_file,
    )


# %% [markdown]
# ## Parameters
#
# These parameters have to be filled out to use in the notebook context.

# %%
# Server connection
#
host = "localhost"
port = 37103

# Input data

symbol = "xxx"  # symbol name

binary_hint = None  # symbol's binary name hint

with_children_symbols = True

# Output filter

from_context = None
# from_context = 10


to_context = None
# to_context = 10


ring_policy = RingPolicy.All
# ring_policy = RingPolicy.R0Only
# ring_policy = RingPolicy.R3Only

processes = None  # display result for all processes in the trace
# processes = ["xxx",]

threads = None  # display result for all threads in the trace
# threads = [thread_id,]

memory_access_operation = None
# memory_access_operation = MemoryAccessOperation.Write
# memory_access_operation = MemoryAccessOperation.Read

# Output target
#
output_file = None  # display results inline
# output_file = "res.csv"  # write results formatted as `csv` to a file named "res.csv" in the current directory


# Output control
#
# group results by symbol
grouped_by_symbol = False
# pandas output type
output_format: OutputFormat = OutputFormat.RAW


# %% [markdown]
# ### Pandas module
#
# This cell verify if pandas module is installed and install it if needed.


# %%
if in_notebook():
    try:
        import pandas  # noqa

        print("pandas already installed")
    except ImportError:
        print("Could not find pandas, attempting to install it from pip")
        import subprocess

        command = [f"{sys.executable}", "-m", "pip", "install", "pandas"]
        p = subprocess.run(command)

        if int(p.returncode) != 0:
            raise RuntimeError("Error installing pandas")
        import pandas  # noqa

        print("Successfully installed pandas")
else:
    import pandas  # noqa


# %% [markdown]
# ### Execution cell
#
# This cell executes according to the [parameters](#Parameters) when in notebook context, or according to the
# [parsed arguments](#Argument-parsing) when in script context.
#
# When in notebook context, if the `output` parameter is `None`, then the report will be displayed in the last cell of
# the notebook.

# %%
if __name__ == "__main__":
    if in_notebook():
        try:
            server = RevenServer(host, port)
        except RuntimeError:
            raise RuntimeError(f"Could not connect to the server on {host}:{port}.")
        memory_ranges_accessed_by_a_symbol(
            server=server,
            symbol=symbol,
            binary_hint=binary_hint,
            with_children_symbols=with_children_symbols,
            from_context=from_context,
            to_context=to_context,
            ring_policy=ring_policy,
            processes=processes,
            threads=threads,
            grouped_by_symbol=grouped_by_symbol,
            output_format=output_format,
            output_file=output_file,
        )
    else:
        script_main()
# %%