# -----------------------------------------------------------------------------
# BSD 3-Clause License
#
# Copyright (c) 2026, Science and Technology Facilities Council.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
# -----------------------------------------------------------------------------
# Authors A. B. G. Chalk STFC Daresbury Lab
'''This module contains the DataNodeToTempTrans class.'''
from psyclone.psyGen import Transformation
from psyclone.psyir.transformations import TransformationError
from psyclone.psyir.nodes import (
ArrayReference,
Assignment,
Call,
DataNode,
IfBlock,
IntrinsicCall,
Literal,
Loop,
Range,
Reference,
Routine,
Statement,
Schedule,
UnaryOperation,
)
from psyclone.psyir.symbols.datatypes import (
ArrayType,
UnresolvedType,
UnsupportedFortranType,
)
from psyclone.psyir.symbols.interfaces import (
UnresolvedInterface,
UnknownInterface
)
from psyclone.psyir.symbols import (
DataSymbol, ImportInterface, ContainerSymbol, Symbol)
from psyclone.utils import transformation_documentation_wrapper
[docs]
@transformation_documentation_wrapper
class DataNodeToTempTrans(Transformation):
"""Provides a generic transformation for moving a datanode from a
statement into a new standalone statement. For example:
>>> from psyclone.psyir.frontend.fortran import FortranReader
>>> from psyclone.psyir.backend.fortran import FortranWriter
>>> from psyclone.psyir.nodes import Assignment
>>> from psyclone.psyir.transformations import DataNodeToTempTrans
>>>
>>> psyir = FortranReader().psyir_from_source('''
... subroutine my_subroutine()
... integer :: i
... integer :: j
... i = j * 2
... end subroutine
... ''')
>>> assign = psyir.walk(Assignment)[0]
>>> DataNodeToTempTrans().apply(assign.rhs, storage_name="temp")
>>> print(FortranWriter()(psyir))
subroutine my_subroutine()
integer, dimension(10,10) :: a
integer :: i
integer :: j
integer :: temp
<BLANKLINE>
temp = j * 2
i = temp
<BLANKLINE>
end subroutine my_subroutine
<BLANKLINE>
"""
[docs]
def validate(self, node: DataNode, **kwargs):
"""Validity checks for input arguments
:param node: The DataNode to be extracted.
:raises TypeError: if the input arguments are the wrong types.
:raises TransformationError: if the input node's datatype can't be
resolved.
:raises TransformationError: if the input node's datatype is an array
but any of the array's dimensions are unknown.
:raises TransformationError: if the input node doesn't have an
ancestor statement.
:raises TransformationError: if the input node contains a call
that isn't guaranteed to be pure.
"""
# Validate the input options and types.
self.validate_options(**kwargs)
verbose = self.get_option("verbose", **kwargs)
if not isinstance(node, DataNode):
raise TypeError(
f"Input node to {self.name} should be a "
f"DataNode but got '{type(node).__name__}'."
)
dtype = node.datatype
calls = node.walk(Call)
for call in calls:
if not call.is_pure:
message = (
f"Input node to {self.name} contains a call "
f"'{call.debug_string().strip()}' that is not guaranteed "
f"to be pure. Input node is "
f"'{node.debug_string().strip()}'."
)
if verbose:
node.ancestor(Statement).append_preceding_comment(
f"PSyclone Warning: {message}"
)
raise TransformationError(message)
if isinstance(dtype, ArrayType):
for element in dtype.shape:
if element in [ArrayType.Extent.DEFERRED,
ArrayType.Extent.ATTRIBUTE]:
message = (
f"Input node's datatype is an array of unknown size, "
f"so the {self.name} cannot be applied. "
f"Input node was '{node.debug_string().strip()}'."
)
if verbose:
node.ancestor(Statement).append_preceding_comment(
f"PSyclone Warning: {message}"
)
raise TransformationError(message)
# The shape must now be set by ArrayBounds, we need to
# examine the symbols used to define those bounds.
symbols = set()
if isinstance(element.lower, DataNode):
symbols.update(element.lower.get_all_accessed_symbols())
if isinstance(element.upper, DataNode):
symbols.update(element.upper.get_all_accessed_symbols())
# Compare the symbols in the array bounds with the symbols
# already in the scope.
scope_table = node.scope.symbol_table
for sym in symbols:
scoped_name_sym = scope_table.lookup(
sym.name, otherwise=None
)
# If sym is not scoped_name_sym, then there is a
# symbol collision from an imported symbol.
if scoped_name_sym and sym is not scoped_name_sym:
# If the symbol in scope is imported from the same
# container then we can skip this.
if scoped_name_sym.interface == sym.interface:
continue
message = (
f"The type of the node supplied to {self.name} "
f"depends upon an imported symbol '{sym.name}' "
f"which has a name clash with a symbol in the "
f"current scope."
)
if verbose:
node.ancestor(Statement).append_preceding_comment(
f"PSyclone Warning: {message}"
)
raise TransformationError(message)
# If its not in the current scope, and its visibility is
# private then we can't import it.
if (not scoped_name_sym and sym.visibility ==
Symbol.Visibility.PRIVATE):
message = (
f"The datatype of the node suppled to "
f"{self.name} depends upon an imported symbol "
f"'{sym.name}' that is declared as private in "
f"its containing module, so cannot be imported."
)
if verbose:
node.ancestor(Statement).append_preceding_comment(
f"PSyclone Warning: {message}"
)
raise TransformationError(message)
# If its an imported symbol we need to check if its
# the same import interface.
if isinstance(sym.interface, ImportInterface):
scoped_name_sym = scope_table.lookup(
sym.interface.container_symbol.name,
otherwise=None
)
if scoped_name_sym and not isinstance(
scoped_name_sym, ContainerSymbol):
message = (
f"Input node contains an imported symbol "
f"'{sym.name}' whose containing module "
f"collides with an existing symbol. Colliding "
f"name is "
f"'{sym.interface.container_symbol.name}'."
)
if verbose:
node.ancestor(Statement).\
append_preceding_comment(
f"PSyclone Warning: {message}"
)
raise TransformationError(message)
if node.ancestor(Statement) is None:
raise TransformationError(
f"Input node to {self.name} has no ancestor "
f"Statement node which is not supported."
)
if (isinstance(dtype, (UnresolvedType, UnsupportedFortranType))
or (isinstance(dtype, ArrayType) and
isinstance(dtype.elemental_type,
(UnresolvedType, UnsupportedFortranType)))):
failing_symbols = []
symbols = node.get_all_accessed_symbols()
for sym in symbols:
if isinstance(sym.interface, (UnresolvedInterface,
UnknownInterface)):
failing_symbols.append(sym.name)
# Sort the order of the list to get consistant results for tests.
failing_symbols.sort()
message = (
f"The datatype of the supplied node cannot be "
f"computed, so the {self.name} cannot be applied. Input node "
f"was '{node.debug_string().strip()}'."
)
if failing_symbols:
message += (
f" The following symbols in the input node have not been "
f"resolved by PSyclone: '{failing_symbols}'. Setting "
f"RESOLVE_IMPORTS in the transformation script may "
f"enable resolution of these symbols."
)
if verbose:
node.ancestor(Statement).append_preceding_comment(
f"PSyclone Warning: {message}"
)
raise TransformationError(message)
[docs]
def apply(self, node: DataNode, storage_name: str = "",
verbose: bool = False, **kwargs):
"""Applies the DataNodeToTempTrans to the input arguments.
:param node: The datanode to extract.
:param storage_name: The base name of the temporary variable to store
the result of the input node in. The default is tmp(_...)
based on the rules defined in the SymbolTable class.
:param verbose: Whether to add comments to the input node if
the transformation fails.
"""
# Call validate to check inputs are valid.
self.validate(node, storage_name=storage_name, verbose=verbose,
**kwargs)
# Find the datatype
datatype = node.datatype
# Make sure the shape is all in the symbol table. We know that
# all symbols we find can be safely added as otherwise validate will
# fail.
# Symbols occuring within the shape definition that are from imported
# modules but that aren't currently in the symbol table will be placed
# into the symbol table with a corresponding ImportInterface so the
# resultant symbol will reference the original definition of the shape
# in the containing module.
if isinstance(datatype, ArrayType):
for element in datatype.shape:
symbols = set()
if isinstance(element.lower, DataNode):
symbols.update(element.lower.get_all_accessed_symbols())
if isinstance(element.upper, DataNode):
symbols.update(element.upper.get_all_accessed_symbols())
scope_table = node.scope.symbol_table
for sym in symbols:
scoped_name_sym = scope_table.lookup(
sym.name, otherwise=None
)
# If no symbol with the name exists then create one.
if not scoped_name_sym:
sym_copy = sym.copy()
if isinstance(sym_copy.interface, ImportInterface):
# Check if the ContainerSymbol is already in the
# interface
container = scope_table.lookup(
sym_copy.interface.container_symbol.name,
otherwise=None
)
if container is None:
# Add the container symbol to the symbol table
node.scope.symbol_table.add(
sym_copy.interface.container_symbol
)
# If we find the container then we need to update
# the interface to use the container listed.
else:
sym_copy.interface.container_symbol = \
container
node.scope.symbol_table.add(sym_copy)
# Now we've created the relevant symbols, we need to update
# the datatype to use the in-scope symbols
datatype.replace_symbols_using(node.scope.symbol_table)
# If any of the bound information aren't static then we need
# to create an allocatable array.
has_static_bounds = True
for element in datatype.shape:
if not isinstance(element.lower, Literal):
has_static_bounds = False
break
if not isinstance(element.upper, Literal):
has_static_bounds = False
break
if has_static_bounds:
datatype = ArrayType(datatype.elemental_type,
[x.copy() for x in datatype.shape])
else:
# We want to create an allocatable symbol for Array entities,
# so create a new datatype for the symbol and keep the
# datatype around for the ALLOCATE statement later.
allocatable_datatype = datatype
datatype = ArrayType(allocatable_datatype.elemental_type,
[ArrayType.Extent.DEFERRED for x in
allocatable_datatype.shape])
# Create a symbol of the relevant type.
containing_routine = node.ancestor(Routine)
if containing_routine:
sym_tab = containing_routine.symbol_table
else:
sym_tab = node.scope.symbol_table
if not storage_name:
symbol = sym_tab.new_symbol(
root_name="tmp",
symbol_type=DataSymbol,
datatype=datatype
)
else:
symbol = sym_tab.new_symbol(
root_name=storage_name,
symbol_type=DataSymbol,
datatype=datatype
)
# Create a Reference to the new symbol
new_ref = Reference(symbol)
# Find the containing schedule and position of the statement
# containing the DataNode.
schedule = node.ancestor(Schedule)
path = node.path_from(schedule)
pos = path[0]
# Replace the datanode with the new reference
node.replace_with(new_ref)
# Create an assignment to set the value of the new symbol
assign = Assignment.create(new_ref.copy(), node)
# Add the assignment into the tree.
schedule.addchild(assign, pos)
# If the datatype is an array, we need to allocate the array
# before the statement too if its not already allocated.
if isinstance(datatype, ArrayType) and not has_static_bounds:
# Create an array reference to the symbol with the dimensions
# returned by the datatype call earlier.
ref = ArrayReference.create(
symbol,
[Range.create(x.lower.copy(), x.upper.copy()) for x in
allocatable_datatype.shape]
)
# Create the IntrinsicCall to ALLOCATE.
intrinsic = IntrinsicCall.create(
IntrinsicCall.Intrinsic.ALLOCATE,
(ref,)
)
allocated = IntrinsicCall.create(
IntrinsicCall.Intrinsic.ALLOCATED,
(Reference(symbol),)
)
ifblock = IfBlock.create(
UnaryOperation.create(
UnaryOperation.Operator.NOT,
allocated),
[intrinsic]
)
# If the shape doesn't contain array references then we can hoist
# the allocate statement outside of any ancestor loops.
hoistable = True
for shape in ref.indices:
for ref2 in shape.walk(Reference):
if isinstance(ref2, ArrayReference):
hoistable = False
# If we can hoist the allocate, find the highest level Loop
# ancestor and set the schedule and position to place the
# allocate before this loop.
# TODO #1445: Use HositTrans to do this if its extended to support
# more node types.
if hoistable:
loop_anc = schedule.ancestor(Loop)
cursor = loop_anc
while cursor:
loop_anc = cursor
cursor = cursor.ancestor(Loop)
if loop_anc:
pos = loop_anc.position
schedule = loop_anc.ancestor(Schedule)
# Add the allocate statement and the containing ifblock into the
# tree immediately before its use.
schedule.addchild(ifblock, pos)
__all__ = ["DataNodeToTempTrans"]