Source code for psyclone.psyir.nodes.extract_node

# -----------------------------------------------------------------------------
# BSD 3-Clause License
#
# Copyright (c) 2019-2026, Science and Technology Facilities Council
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
# -----------------------------------------------------------------------------
# Author I. Kavcic, Met Office
# Modified A. R. Porter, S. Siso, R. W. Ford and N. Nobre, STFC Daresbury Lab
# Modified J. Henrichs, Bureau of Meteorology
# -----------------------------------------------------------------------------

'''
This module provides support for extraction of code within a specified
invoke. The extracted code may be a single kernel, multiple occurrences of a
kernel in an invoke, nodes in an invoke or the entire invoke (extraction
applied to all Nodes).

There is currently only one class in this module: ExtractNode (see below).

Another class which contains helper functions for code extraction, such as
wrapping up settings for generating driver for the extracted code, will
be added in Issue #298.
'''

from typing import cast, List, Tuple, TYPE_CHECKING

from psyclone.configuration import Config
from psyclone.core import AccessSequence, Signature
from psyclone.errors import InternalError
from psyclone.psyir.nodes.assignment import Assignment
from psyclone.psyir.nodes.call import Call
from psyclone.psyir.nodes.directive import Directive
from psyclone.psyir.nodes.loop import Loop
from psyclone.psyir.nodes.node import Node
from psyclone.psyir.nodes.psy_data_node import PSyDataNode
from psyclone.psyir.nodes.reference import Reference
from psyclone.psyir.nodes.routine import Routine
from psyclone.psyir.nodes.statement import Statement
from psyclone.psyir.nodes.structure_reference import StructureReference
from psyclone.psyir.symbols import (
    ArrayType, ContainerSymbol, DataSymbol, DataType, ImportInterface,
    ScalarType, Symbol, SymbolTable)

if TYPE_CHECKING:
    from psyclone.psyir.tools import ReadWriteInfo


[docs] class ExtractNode(PSyDataNode): ''' This class can be inserted into a Schedule to mark Nodes for code extraction using the ExtractRegionTrans transformation. By applying the transformation the Nodes marked for extraction become children of (the Schedule of) an ExtractNode. :param ast: reference into the fparser2 parse tree corresponding to this node. :type ast: sub-class of :py:class:`fparser.two.Fortran2003.Base` :param children: the PSyIR nodes that are children of this node. :type children: list of :py:class:`psyclone.psyir.nodes.Node` :param parent: the parent of this node in the PSyIR tree. :type parent: :py:class:`psyclone.psyir.nodes.Node` :param options: a dictionary with options provided via transformations. :type options: Optional[Dict[str, Any]] :param str options["prefix"]: a prefix to use for the PSyData module name (``prefix_psy_data_mod``) and the PSyDataType (``prefix_PSyDataType``) - a "_" will be added automatically. It defaults to "extract", which means the module name used will be ``extract_psy_data_mode``, and the data type ``extract_PSyDataType``. :param str options["post_var_postfix"]: a postfix to be used when creating names to store values of output variable. A variable 'a' would store its value as 'a', and its output values as 'a_post' with the default post_var_postfix of '_post'. :param options["read_write_info"]: information about variables that are read and/or written in the instrumented code. :type options["read_write_info"]: py:class:`psyclone.psyir.tools.ReadWriteInfo` ''' # Textual description of the node. _text_name = "Extract" _colour = "green" # The default prefix to add to the PSyData module name and PSyDataType _default_prefix = "extract" # This dictionary keeps track of region+module names that are already # used. For each key (which is module_name+"|"+region_name) it contains # how many regions with that name have been created. This number will # then be added as an index to create unique region identifiers. _used_kernel_names: dict[str, int] = {} def __init__(self, ast=None, children=None, parent=None, options=None): super().__init__(ast=ast, children=children, parent=parent, options=options) # Define a postfix that will be added to variable that are # modified to make sure the names can be distinguished between pre- # and post-variables (i.e. here input and output). A variable # "myvar" will be stored as "myvar" with its input value, and # "myvar_post" with its output value. It is the responsibility # of the transformation that inserts this node to make sure this # name is consistent with the name used when creating the driver # (otherwise the driver will not be able to read in the dumped # valued), and also to handle any potential name clashes (e.g. a # variable 'a' exists, which creates 'a_out' for the output variable, # which would clash with a variable 'a_out' used in the program unit). if options is None: options = {} self._post_name = options.get("post_var_postfix", "_post") # Keep a copy of the argument list: self._read_write_info = options.get("read_write_info") self._driver_creator = None def __eq__(self, other): ''' Checks whether two nodes are equal. Two ExtractNodes are equal if their extract_body members are equal. :param object other: the object to check equality to. :returns: whether other is equal to self. :rtype: bool ''' is_eq = super().__eq__(other) is_eq = is_eq and self.post_name == other.post_name return is_eq @property def extract_body(self) -> Node: ''' :returns: the Schedule associated with this ExtractNode. :rtype: :py:class:`psyclone.psyir.nodes.Schedule` ''' return super().psy_data_body @property def post_name(self) -> str: ''' :returns: the _post_name member of this ExtractNode. :rtype: str ''' return self._post_name
[docs] def get_ignored_variables(self) -> list[tuple[str, Signature]]: ''' This function is used to create a list of variables that should not be written to a kernel data file (or read in the driver). The current implementation removes all loop variables (as long as the variables are only used in loops). The main reason for this is that using OpenMP parallelism means that loop variables are undefined when exiting the loop, so comparing them results in errors. Detect all loop variables that do not need to be added to a kernel data file. This function tests that a loop variable is not used for anything else (a code might 're-use' a loop variable for some other reasons, in which case the variable should still be added), and also takes into account that a loop variable might be used in more than one loop. This is done by collecting all accesses to a loop variable under this ExtractNode, and then collecting all accesses to the same variable under any loop using this variable. If the union of the accesses under all loops is identical to all accesses to this variable, the variable is only used inside loop, and does not need to be stored. ''' ignore_list: list[tuple[str, Signature]] = [] all_accesses = self.reference_accesses() # First collect all accesses to loop variables from all loops # into a dictionary. We update the accesses if a loop variable # is used in more than one loop all_loop_var_accesses: dict[Signature, AccessSequence] = {} for loop in self.walk(Loop): loop_var_sig = Signature(loop.variable.name) accesses_in_loop = loop.reference_accesses() if loop_var_sig in all_loop_var_accesses: all_loop_var_accesses[loop_var_sig].update( accesses_in_loop[loop_var_sig]) else: all_loop_var_accesses[loop_var_sig] = ( accesses_in_loop[loop_var_sig]) # Now check all loop variables, and if all accesses to this variable # are from loops only (i.e. the final value of the loop variable is # not used outside of a loop), the variable does not need to be stored. # As a complication, any variable usage in a directive must be # discarded (e.g. a loop variable might be declared as omp private). # We do this by counting how many directive nodes are in the list, and # then subtracting this number from all accesses: for var_sig, accesses in all_loop_var_accesses.items(): directive_count = 0 for access in all_accesses[var_sig]: statement = access.node.ancestor(Statement, include_self=True) if isinstance(statement, Directive): directive_count += 1 if len(accesses) == len(all_accesses[var_sig]) - directive_count: ignore_list.append(('', var_sig)) return ignore_list
[docs] def lower_to_language_level(self): # pylint: disable=arguments-differ ''' Lowers this node (and all children) to language-level PSyIR. The PSyIR tree is modified in-place. :returns: the lowered version of this node. ''' # Avoid circular dependency # pylint: disable=import-outside-toplevel from psyclone.psyir.tools.call_tree_utils import CallTreeUtils self._populate_region_name() # get_non_local_read_write_info doesn't work with the lowered tree, # so we save a copy of the higher dsl tree copy_dsl_tree = self.copy() for child in self.children: child.lower_to_language_level() self.flatten_references() # Determine the variables to write: ctu = CallTreeUtils() read_write_info = ctu.get_in_out_parameters( self, include_non_data_accesses=False) vars_to_ignore = self.get_ignored_variables() # Use the copy of the dsl_tree to get the external symbols ctu.get_non_local_read_write_info(copy_dsl_tree.children, read_write_info) # TODO #3024: We could be more data efficient by better selecting # which don't need to be copied in (because the extraction region # will only write to them) if self._driver_creator: nodes = self.children region_name_tuple = self.get_unique_region_name(nodes) self.bring_external_symbols(read_write_info, self.ancestor(Routine).symbol_table) # Determine a unique postfix to be used for output variables # that avoid any name clashes postfix = self.determine_postfix(read_write_info, postfix="_post") # Remove the spurious "_" at the end of the prefix or use default prefix = self._prefix[:-1] if self._prefix else "extract" # Create and write the driver code self._driver_creator.write_driver(self.children, read_write_info, postfix=postfix, prefix=prefix, region_name=region_name_tuple, vars_to_ignore=vars_to_ignore) # Remove the variables to be ignored from the read_write_info object, # so they will not be extracted. for var_info in vars_to_ignore: read_write_info.remove(signature=var_info[1], container_name=var_info[0]) options = {'pre_var_list': read_write_info.all_used_vars_list, 'post_var_list': read_write_info.write_list, 'post_var_postfix': self._post_name} return super().lower_to_language_level(options)
# -------------------------------------------------------------------------
[docs] @staticmethod def determine_postfix(read_write_info: "ReadWriteInfo", postfix: str = "_post") -> str: ''' This function prevents any name clashes that can occur when adding the postfix to output variable names. For example, if there is an output variable 'a', the driver (and the output file) will contain two variables: 'a' and 'a_post'. But if there is also another variable called 'a_post', a name clash would occur (two identical keys in the output file, and two identical local variables in the driver). In order to avoid this, the suffix 'post' is changed (to 'post0', 'post1', ...) until any name clashes are avoided. This works for structured and non-structured types. :param read_write_info: information about all input and output parameters. :param postfix: the postfix to append to each output variable. :returns: a postfix that can be added to each output variable without generating a name clash. ''' suffix = "" # Create the a set of all input and output variables (to avoid # checking input+output variables more than once) all_vars = read_write_info.all_used_vars_list # The signatures in the input/output list need to be converted # back to strings to easily append the suffix. all_vars_string = [str(input_var) for _, input_var in all_vars] while any(str(out_sig)+postfix+str(suffix) in all_vars_string for out_sig in read_write_info.signatures_written): suffix = cast(int, suffix) if suffix == "": suffix = 0 else: suffix += 1 return postfix+str(suffix)
[docs] def get_unique_region_name(self, nodes: List[Node]) -> Tuple[str, str]: '''This function returns the region and module name. If they are specified in the user options, these names will just be returned (it is then up to the user to guarantee uniqueness). Otherwise a name based on the module and invoke will be created using indices to make sure the name is unique. :param nodes: a list of nodes. ''' # pylint: disable=import-outside-toplevel from psyclone.psyGen import InvokeSchedule invoke = nodes[0].ancestor(InvokeSchedule) if invoke: module_name = invoke.invoke.invokes.psy.name else: module_name = nodes[0].root.name return (module_name, self._region_name)
# ------------------------------------------------------------------------- @staticmethod def _flatten_signature(signature: Signature) -> str: '''Creates a 'flattened' string for a signature by using ``_`` to separate the parts of a signature. For example, in Fortran a reference to ``a%b`` would be flattened to be ``a_b``. :param signature: the signature to be flattened. :returns: a flattened string (all '%' replaced with '_'.) ''' return str(signature).replace("%", "_") # -------------------------------------------------------------------------
[docs] def flatten_references(self) -> None: '''Replace StructureReferencces with a simple Reference and a flattened name (replacing all % with _). ''' already_flattened: dict[str, Symbol] = {} # dict of name: symbol for structure_ref in self.walk(StructureReference)[:]: if isinstance(structure_ref.parent, Call): if structure_ref.position == 0: return # Method calls are fine signature, _ = structure_ref.get_signature_and_indices() flattened_name = self._flatten_signature(signature) try: symbol = already_flattened[flattened_name] except KeyError: symtab = structure_ref.ancestor(Routine).symbol_table symbol = symtab.new_symbol( flattened_name, symbol_type=DataSymbol, datatype=self._flatten_datatype(structure_ref)) already_flattened[flattened_name] = symbol # We also need two assignments to copy the initial and final # values to/from the flattened temporary self.parent.addchild(Assignment.create(Reference(symbol), structure_ref.copy()), index=self.position) self.parent.addchild(Assignment.create(structure_ref.copy(), Reference(symbol)), index=self.position+1) # Replace the structure access with the flattened reference structure_ref.replace_with(Reference(symbol))
@staticmethod def _flatten_datatype(structure_reference: StructureReference) -> DataType: ''' Ideally this should be replaced by structure_reference.datatype but until it works, this utility method provides hardcoded type information depending on the PSyKAL DSL and names involved. :returns: the datatype of the symbol with the flattened expression. ''' signature, _ = structure_reference.get_signature_and_indices() if Config.get().api == "gocean": api_config = Config.get().api_conf("gocean") grid_properties = api_config.grid_properties for prop_name in grid_properties: gocean_property = grid_properties[prop_name] property_name = gocean_property.fortran.split('%')[-1] # Search for a property with the same name as the signature # inner accessor if signature[-1] == property_name: break else: raise InternalError( f"Could not find type for reference " f"'{structure_reference.debug_string()}' " f"in the config file '{Config.get().filename}'.") if gocean_property.intrinsic_type == 'real': scalar_type = ScalarType.real8_type() else: scalar_type = ScalarType.integer_type() if gocean_property.type == "scalar": return scalar_type # Everything else is a 2D field return ArrayType(scalar_type, [ArrayType.Extent.DEFERRED, ArrayType.Extent.DEFERRED]) # Everything else defaults to integer return ScalarType.integer_type()
[docs] @staticmethod def bring_external_symbols(read_write_info: "ReadWriteInfo", symbol_table: SymbolTable) -> None: ''' Use the ModuleManager to explore external dependencies and bring symbols used in other modules into scope (with ImportInterface). The symbols will be tagged with a 'signature@module_name' tag. :param read_write_info: information about the symbols usage in the scope. :param symbol_table: the associated symbol table. ''' # Cyclic import # pylint: disable=import-outside-toplevel from psyclone.parse import ModuleManager mod_man = ModuleManager.get() for module_name, signature in read_write_info.all_used_vars_list: if not module_name: # Ignore local symbols, which will have been added above continue container = symbol_table.find_or_create( module_name, symbol_type=ContainerSymbol) # Any symbols imported from this ContainerSymbol must be added # to the same scope (table) in which it resides. actual_table = container.find_symbol_table(symbol_table.node) # Now look up the original symbol. While the variable could # be declared Unresolved here (i.e. just imported), we need the # type information for the output variables (VAR_post), which # are created later and which will query the original symbol for # its type. And since they are not imported, they need to be # explicitly declared. mod_info = mod_man.get_module_info(module_name) container_symbol = mod_info.get_symbol(signature[0]) if not container_symbol: # TODO #2120: This typically indicates a problem with parsing # a module: the psyir does not have the full tree structure. continue # It is possible that external symbol name (signature[0]) already # exist in the symbol table (the same name is used in the local # subroutine and in a module). In this case, the imported symbol # must be renamed: if signature[0] in symbol_table: interface = ImportInterface(container, orig_name=signature[0]) else: interface = ImportInterface(container) actual_table.find_or_create_tag( tag=f"{signature[0]}@{module_name}", root_name=signature[0], symbol_type=DataSymbol, interface=interface, datatype=container_symbol.datatype)
# For AutoAPI documentation generation __all__ = ['ExtractNode']