# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Transform permute to reshape wherever applicable.
"""
from typing import List
from aitemplate.compiler.base import IntImm, Operator, Tensor
from aitemplate.compiler.ops import reshape
from aitemplate.compiler.transform import transform_utils
from aitemplate.compiler.transform.toposort import toposort
from aitemplate.utils import graph_utils
def _check_permute_to_reshape(op: Operator) -> bool:
"""Check if applicable to replace permute with reshape.
Args:
op (Operator): reshape op
Returns:
bool: False if operation is not a permute or a permute with memory
layout modification otherwise True.
"""
if not op._attrs["op"].startswith("permute"):
return False
inputs = op._attrs["inputs"]
assert (
len(inputs) == 1
), "Permute operation {} should have 1 input, got {} instead".format(
op._attrs["op"], len(inputs)
)
if "input_accessors" in op._attrs:
input_shape = op._attrs["input_accessors"][0].original_shapes
else:
input_shape = inputs[0].shape()
if op._attrs["op"] == "permute":
permutation = list(op._attrs["dims"])
elif op._attrs["op"] == "permute021":
n_dims = len(input_shape)
permutation = list(range(n_dims - 2)) + [n_dims - 1, n_dims - 2]
elif op._attrs["op"] == "permute102":
permutation = [1, 0, 2]
elif op._attrs["op"] == "permute210":
permutation = [2, 1, 0]
elif op._attrs["op"] == "permute0213":
permutation = [0, 2, 1, 3]
else:
raise NotImplementedError(
f"Not implemented for permute operation: {op._attrs['op']}"
)
if "input_accessors" in op._attrs:
# Can't convert permute to reshape if one of the dimensions included
# in permutation is strided
ta = op._attrs["input_accessors"][0]
if ta.is_from_strided_tensor and ta.stride_dim in permutation:
return False
# Get non-singular dimension indices
permutation = [
dim_idx
for dim_idx in permutation
if not isinstance(input_shape[dim_idx], IntImm)
or input_shape[dim_idx].value() != 1
]
is_reshape = permutation == sorted(permutation)
return is_reshape