Source code for psyclone.psyir.transformations.scalarisation_trans

# -----------------------------------------------------------------------------
# BSD 3-Clause License
#
# Copyright (c) 2024-2026, Science and Technology Facilities Council.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
# -----------------------------------------------------------------------------
# Author: A. B. G. Chalk, STFC Daresbury Lab

'''This module provides the sclarization transformation class.'''

from psyclone.core import VariablesAccessMap, Signature, SymbolicMaths
from psyclone.psyGen import Kern
from psyclone.psyir.nodes import Call, CodeBlock, Literal, \
        IfBlock, Loop, Node, Range, Reference, Routine, StructureReference
from psyclone.psyir.nodes.array_mixin import ArrayMixin
from psyclone.psyir.symbols import DataSymbol, RoutineSymbol, ScalarType
from psyclone.psyir.transformations.loop_trans import LoopTrans
from psyclone.utils import transformation_documentation_wrapper


[docs] @transformation_documentation_wrapper class ScalarisationTrans(LoopTrans): '''This transformation takes a Loop and converts any array accesses to scalar if the results of the loop are unused, and the initial value is unused. For example in the following snippet the value of a(i) is only used inside the loop, so can be turned into a scalar, whereas the values of b(i) are used in the following loop so are kept as an array: >>> from psyclone.psyir.backend.fortran import FortranWriter >>> from psyclone.psyir.frontend.fortran import FortranReader >>> from psyclone.psyir.transformations import ScalarisationTrans >>> from psyclone.psyir.nodes import Loop >>> code = """program test ... integer :: i,j ... real :: a(100), b(100) ... do i = 1,100 ... a(i) = i ... b(i) = a(i) * a(i) ... end do ... do j = 1, 100 ... if(b(i) > 200) then ... print *, b(i) ... end if ... end do ... end program""" >>> psyir = FortranReader().psyir_from_source(code) >>> scalarise = ScalarisationTrans() >>> scalarise.apply(psyir.walk(Loop)[0]) >>> print(FortranWriter()(psyir)) program test integer :: i integer :: j real, dimension(100) :: a real, dimension(100) :: b real :: a_scalar <BLANKLINE> do i = 1, 100, 1 a_scalar = i b(i) = a_scalar * a_scalar enddo do j = 1, 100, 1 if (b(i) > 200) then ! PSyclone CodeBlock (unsupported code) reason: ! - Unsupported statement: Print_Stmt PRINT *, b(i) end if enddo <BLANKLINE> end program test <BLANKLINE> ''' @staticmethod def _is_local_array(signature: Signature, var_accesses: VariablesAccessMap) -> bool: ''' :param signature: The signature to check if it is a local array symbol or not. :param var_accesses: The VariableAccessesInfo object containing signature. :returns: whether the symbol corresponding to signature is a local array symbol or not. ''' if not var_accesses[signature].has_indices(): return False # If any of the accesses are to a CodeBlock then we stop. This can # happen if there is a string access inside a string concatenation, # e.g. NEMO4. for access in var_accesses[signature]: if isinstance(access.node, CodeBlock): return False base_symbol = var_accesses[signature][0].node.symbol if not base_symbol.is_automatic: return False # If its a derived type then we don't scalarise. if isinstance(var_accesses[signature][0].node, StructureReference): return False # Find the containing routine rout = var_accesses[signature][0].node.ancestor(Routine) # If the array is the return symbol then its not a local # array symbol if base_symbol is rout.return_symbol: return False return True @staticmethod def _have_same_unmodified_index( signature: Signature, var_accesses: VariablesAccessMap) -> bool: ''' :param signature: The signature to check. :param var_accesses: The VariableAccessesInfo object containing signature. :returns: whether all the array accesses to signature use the same index, and whether the index is unmodified in the code region. ''' array_indices = None scalarisable = True for access in var_accesses[signature]: if array_indices is None: array_indices = access.component_indices() elif array_indices != access.component_indices(): scalarisable = False break # For each index, we need to check they're not written to in # the loop. for component in array_indices: for index in component: # Index may not be a Reference, so we need to loop over the # References for ref in index.walk(Reference): # This Reference could be the symbol for a Call or # IntrinsicCall, which we don't allow to scalarise if isinstance(ref.symbol, RoutineSymbol): scalarisable = False break sig, _ = ref.get_signature_and_indices() if var_accesses[sig].is_written(): scalarisable = False break return scalarisable @staticmethod def _check_first_access_is_write(signature: Signature, loop: Loop, var_accesses: VariablesAccessMap) \ -> bool: ''' :param signature: The signature to check. :param loop: The Loop object being transformed. :param var_accesses: The VariableAccessesInfo object containing signature. :returns: whether the first access to signature is a write. ''' if not var_accesses[signature].is_written_first(): return False # Need to find the first access and check if its in a conditional. accesses = var_accesses[signature] first_node = accesses[0].node ifblock = first_node.ancestor(IfBlock) # If the depth of the ifblock is larger than loop then the write # is in a conditional if ifblock and ifblock.depth > loop.depth: return False return True @staticmethod def _get_index_values_from_indices( node: ArrayMixin, indices: list[Node]) -> tuple[bool, list[Node]]: ''' Compute a list of index values for a given node. Looks at loop bounds and range declarations to attempt to convert loop variables to an explicit range, i.e. an access like .. code-block:: fortran do i = 1, 100 array(i) = ... end do the returned list would contain a range object for [1:100]. If the computed indexes contains a non-unit stride, or an index is not a Range, Reference or Literal then this function will return True as the first element of the returned tuple, and the list of indices will be incomplete. :param node: The node to compute index values for. :param indices: the list of indexes to have values computed. :returns: a tuple containing a bool value set to True if any of the index values are not computed, and a list of the computed index values. ''' index_values = [] has_complex_index = False for index in indices: # If the index is an array or structure and there are any more # accesses to the signature we're trying to scalarise, then we # should not scalarise. if (type(index) is not Range and type(index) is not Reference and type(index) is not Literal): has_complex_index = True index_values.append(None) one_literal = Literal("1", ScalarType.integer_type()) ancestor_loop = node.ancestor(Loop) # For Range or Literal array indices this is easy. for i, index in enumerate(indices): if isinstance(index, (Range, Literal)): index_values[i] = index while ancestor_loop is not None and not has_complex_index: for i, index in enumerate(indices): # Skip over indices we already set. if index_values[i] is not None: continue if ancestor_loop.variable == index.symbol: start_val = ancestor_loop.start_expr stop_val = ancestor_loop.stop_expr step_val = ancestor_loop.step_expr # If the step value is not exactly 1 then we treat # this as a complex index, as we can't currently # do precise comparisons on non-unit stride accesses. if step_val != one_literal: has_complex_index = True # Create a range for this and add it to the index values. index_range = Range.create( start_val.copy(), stop_val.copy() ) index_values[i] = index_range ancestor_loop = ancestor_loop.ancestor(Loop) # If we couldn't work out any of the index_values, then we treat this # as a complex index for index in index_values: if index is None: has_complex_index = True return has_complex_index, index_values @staticmethod def _value_unused_after_loop(sig: Signature, loop: Loop, var_accesses: VariablesAccessMap) -> bool: ''' :param sig: The signature to check. :param loop: The loop the transformation is operating on. :param var_accesses: The VariableAccessesInfo object containing signature. :returns: whether the value computed in the loop containing sig is read from after the loop. ''' # Find the last access of the signature last_access = var_accesses[sig][-1].node # Compute the indices used in this loop. We know that all of the # indices used in this loop must be the same. indices = last_access.indices # Find the next accesses to this symbol next_accesses = last_access.next_accesses() # Compute the indices ranges. has_complex_index, index_values = \ ScalarisationTrans._get_index_values_from_indices( last_access, indices ) for next_access in next_accesses: # next_accesses looks backwards to the start of the loop, # but we don't care about those accesses here. if next_access.is_descendant_of(loop): continue # If we have a next_access outside of the loop and have a complex # index then we do not scalarise this at the moment. if has_complex_index: return False # If next access is a Call or CodeBlock or Kern then # we have to assume the value is used. These nodes don't # have the is_read property that Reference has, so we need # to be explicit. if isinstance(next_access, (CodeBlock, Call, Kern)): return False # If the access is a read, then return False if next_access.is_read: return False # If the next access is a Reference then we had a full range # access described without any range, which means a full # range access so we can skip the followup checks. if type(next_access) is Reference: continue # We need to ensure that the following write accesses the same # or more of the array. next_indices = next_access.indices next_complex_index, next_values = \ ScalarisationTrans._get_index_values_from_indices( next_access, next_indices ) # If we can't compute the indices of the next access then we # cannot scalarise if next_complex_index: return False # Check the indices of next_access are greater than or equal to # that of the potential scalarisation. for i in range(len(next_values)): # If the next index is a full range we can skip it as it must # cover the previous access if next_access.is_full_range(i): continue # Convert both to ranges if either was a literal next_index = next_values[i] orig_index = index_values[i] if not isinstance(next_index, Range): next_index = Range.create(next_index.copy(), next_index.copy()) if not isinstance(orig_index, Range): orig_index = Range.create(orig_index.copy(), orig_index.copy()) sm = SymbolicMaths.get() # Need to check that next_index stop point is >= orig_index. # If its not then this can't cover the full range so we can # return False to not Scalarise this. if not (sm.greater_than(next_index.stop, orig_index.stop) == SymbolicMaths.Fuzzy.TRUE or sm.equal(next_index.stop, orig_index.stop)): return False # Need to check the next_index start point is <= orig_index if not (sm.less_than(next_index.start, orig_index.start) == SymbolicMaths.Fuzzy.TRUE or sm.equal(next_index.start, orig_index.start)): return False return True def validate(self, node: Loop, **kwargs): ''' Validate the options provided to the ScalarisationTrans. :param node: the supplied loop to apply scalarisation to. ''' self.validate_options(**kwargs)
[docs] def apply(self, node: Loop, **kwargs) -> None: ''' Apply the scalarisation transformation to a loop. All of the array accesses that are identified as being able to be scalarised will be transformed by this transformation. An array access will be scalarised if: 1. All accesses to the array use the same indexing statement. 2. All References contained in the indexing statement are not modified inside of the loop (loop variables are ok). 3. The array symbol is either not accessed again or is written to as its next access. If the next access is inside a conditional that is not an ancestor of the input loop, then PSyclone will assume that we cannot scalarise that value instead of attempting to understand the control flow. 4. The array symbol is a local variable. :param node: the supplied loop to apply scalarisation to. ''' self.validate(node, **kwargs) # For each array reference in the Loop: # Find every access to the same symbol in the loop # They all have to be accessed with the same index statement, and # that index needs to not be written to inside the loop body. # For each symbol that meets this criteria, we then need to check the # first access is a write # Then, for each symbol still meeting this criteria, we need to find # the next access outside of this loop. If its inside an ifblock that # is not an ancestor of this loop then we refuse to scalarise for # simplicity. Otherwise if its a read we can't scalarise safely. # If its a write then this symbol can be scalarised. var_accesses = node.loop_body.reference_accesses() # Find all the arrays that are only accessed by a single index, and # that index is only read inside the loop. potential_targets = filter( lambda sig: ScalarisationTrans._is_local_array(sig, var_accesses), var_accesses) potential_targets = filter( lambda sig: ScalarisationTrans._have_same_unmodified_index(sig, var_accesses), potential_targets) # Now we need to check the first access is a write and remove those # that aren't. potential_targets = filter( lambda sig: ScalarisationTrans._check_first_access_is_write(sig, node, var_accesses), potential_targets) # Check the values written to these arrays are not used after this loop finalised_targets = filter( lambda sig: ScalarisationTrans._value_unused_after_loop(sig, node, var_accesses), potential_targets) routine_table = node.ancestor(Routine).symbol_table # For each finalised target we can replace them with a scalarised # symbol for target in finalised_targets: target_accesses = var_accesses[target] first_access = target_accesses[0].node symbol_type = first_access.symbol.datatype.elemental_type symbol_name = first_access.symbol.name scalar_symbol = routine_table.new_symbol( root_name=f"{symbol_name}_scalar", symbol_type=DataSymbol, datatype=symbol_type) ref_to_copy = Reference(scalar_symbol) for access in target_accesses: node = access.node node.replace_with(ref_to_copy.copy())