# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Deduplicate make_jagged ops in the graph.
"""
import logging
from dataclasses import dataclass
from typing import Dict, List, Set
from aitemplate.compiler.base import IntVar, JaggedIntVar, Operator, Tensor
from aitemplate.compiler.ops.common.view_ops import make_jagged
from aitemplate.compiler.transform.toposort import toposort
from aitemplate.compiler.transform.transform_utils import (
remove_dst_op_from_tensor,
replace_tensor,
replace_tensor_for_op,
sanitize_sorted_graph,
)
from aitemplate.utils.graph_utils import get_sorted_ops
_LOGGER = logging.getLogger(__name__)
@dataclass
class MakeJaggedMetaData:
op: Operator
sources_list: List[Tensor]
offsets_list: List[Tensor]
outputs: List[Tensor]
jagged_int_var: JaggedIntVar
def _get_make_jagged_metadata(
sorted_graph: List[Tensor],
) -> Dict[IntVar, List[MakeJaggedMetaData]]:
"""Collect metadata about the existing make_jagged ops in the graph.
The MakeJaggedMetaData instances, one per make_jagged op, are grouped
by the total_length dimension in the source input Tensors of the ops.
In case of multiple inputs, total_length dimension is the same in
every input. The metadata is used further to inform the transformation.
"""
metadata = {}
for op in get_sorted_ops(sorted_graph):
if op._attrs["op"] == "make_jagged":
outputs = op._attrs["outputs"]
jagged_int_var = outputs[0]._attrs["shape"][0]
total_length = jagged_int_var.total_length()
num_sources = op._attrs["num_sources"]
if total_length not in metadata:
metadata[total_length] = []
metadata[total_length].append(
MakeJaggedMetaData(
op=op,
sources_list=op._attrs["inputs"][:num_sources],
offsets_list=op._attrs["inputs"][num_sources:],
outputs=outputs,
jagged_int_var=jagged_int_var,
)
)
return metadata
def _remove_make_jagged_ops(
make_jagged_metadata: Dict[IntVar, List[MakeJaggedMetaData]],
graph_inputs: Set[Tensor],
graph_outputs: Set[Tensor],
):
"""Remove the make_jagged ops from the graph where possible.
The individual make_jagged ops scattered over the graph are removed,
to be further replaced by a single make_jagged instance, per total_length
dimension, applied to all inputs with the total_length dimension at once.
The ops are considered group by group, where group is formed from
the ops with the same total_length dimension in the source Tensors.
The make_jagged ops in the group are not removed (and the respective
total_length key is popped from the make_jagged_metadata) if:
1. There is only one make_jagged op in the group.
2. There is a make_jagged op in the group connecting a
graph input to a graph output: can't be eliminated.
3. The total_length dimension representing the group is
not present in any of the graph inputs' shape.
In other cases, all make_jagged ops in the grpup are removed from the graph
(and the respective total_length key is kept in the make_jagged_metadata).
"""
for total_length in list(make_jagged_metadata.keys()):
make_jagged_group = make_jagged_metadata[total_length]
assert len({d.jagged_int_var for d in make_jagged_group}) == 1, (
"All make_jagged ops applied to the sources with the "
"same total_length must produce the same jagged_int_var."
) # this includes offsets identity check internally
if len(make_jagged_group) == 1:
_LOGGER.debug(
"There is only one make_jagged op in the group "
f"with {total_length=}: skipping the group."
)
make_jagged_metadata.pop(total_length)
continue
has_input_to_output_op = False
for data in make_jagged_group:
if any(s in graph_inputs for s in data.sources_list) and any(
o in graph_outputs for o in data.outputs
):
has_input_to_output_op = True
break
if has_input_to_output_op:
_LOGGER.debug(
"There is a make_jagged op in the group with "
f"{total_length=} that maps a graph input to "
"a graph output: skipping the group."
)
make_jagged_metadata.pop(total_length)
continue
graph_input_with_total_length = False
for inp in graph_inputs:
shape = inp._attrs["shape"]
if shape and shape[0] == total_length:
graph_input_with_total_length = True
break
if not graph_input_with_total_length:
_LOGGER.debug(
"None of the graph inputs has the first dimension "
f"equal to {total_length=}: skipping the group."
)
make_jagged_metadata.pop(total_length)
continue
_LOGGER.debug(
f"Removing {len(make_jagged_group)} make_jagged ops "
f"in the group with {total_length=} from the graph."
)
for data in make_jagged_group:
for source, output in zip(data.sources_list, data.outputs):
replace_tensor(output, source)
remove_dst_op_from_tensor(source, data.op)
def _apply_make_jagged_to_inputs(
make_jagged_metadata: Dict[IntVar, List[MakeJaggedMetaData]],
sorted_graph: List[Tensor],
graph_inputs: Set[Tensor],
) -> Dict[IntVar, JaggedIntVar]:
"""Apply new make_jagged ops to the (bundled) input source Tensors.
For each group of make_jagged ops that removed from the graph,
a new make_jagged op is applied to all graph inputs with the
corresponding total_length dimension. This way, the source Tensors
are converted to jagged Tensors right from the "beginning" of the
graph and can be used as jagged Tensors downstream.
Two points are worth mentioning:
1. Due to the fact that the new make_jagged op is applied to
*all* source inputs with the total_length dimension, it is
guaranteed that the offsets validation performed by the
make_jagged op's back-end will run before any of the
resulting jagged Tensors can be used downstream.
2. Because a single make_jagged op is applied to multiple
graph inputs, the make_jagged op's back-end kernel will
be launched only once to validate the offsets (the latter
are the same for every source input). This optimizes out
redundant validation of the same offsets.
The mapping of each total_length to the new JaggedIntVar (produced
by the corresponding new make_jagged op) is returned.
"""
new_jagged_int_vars = {}
for total_length, make_jagged_group in make_jagged_metadata.items():
sources_list = []
for inp in graph_inputs:
shape = inp._attrs["shape"]
if shape and shape[0] == total_length:
sources_list.append(inp)
_LOGGER.debug(
"Adding a single make_jagged op for the source inputs "
f"{[source._attrs['name'] for source in sources_list]}."
)
data = make_jagged_group[0]
new_make_jagged_op = make_jagged(
batch_dim=data.jagged_int_var.batch_dim(),
jagged_dims=data.jagged_int_var.jagged_dims(),
check_sequence_lengths=all(
d.op._attrs["check_sequence_lengths"] for d in make_jagged_group
),
)
jagged_tensors = new_make_jagged_op(
source=sources_list,
offsets_list=data.offsets_list,
)
jagged_int_var = jagged_tensors[0]._attrs["shape"][0]
new_jagged_int_vars[total_length] = jagged_int_var
for source, jagged in zip(sources_list, jagged_tensors):
for op in source._attrs["dst_ops"]:
if op is not new_make_jagged_op:
replace_tensor_for_op(op, source, jagged)
sorted_graph.extend(jagged_tensors)
return new_jagged_int_vars
def _replace_total_length_with_jagged_int_var(
new_jagged_int_vars: Dict[IntVar, JaggedIntVar],
sorted_graph: List[Tensor],
graph_inputs: Set[Tensor],
):
"""Replace total_length dimensions by the new JaggedIntVars.
As we've removed the internal make_jagged ops from the graph and
replaced their output jagged Tensors by the input source Tensors,
the latter have lost their JaggedIntVars. Here we replace the
total_length dimension in *every* non-input Tensor in the graph
by the corresponding new JaggedIntVar (produced by the new
make_jagged op applied to the bundled source inputs). This includes,
but is not limited to, the source inputs of the make_jagged ops
removed from within the graph in the beginning of the pass.
"""
for total_length, new_jagged_int_var in new_jagged_int_vars.items():
for tensor in sorted_graph:
if tensor not in graph_inputs:
shape = tensor._attrs["shape"]
if shape and shape[0] == total_length:
shape[0] = new_jagged_int_var
[docs]def dedup_make_jagged_ops(
sorted_graph: List[Tensor],
workdir: str = None,
) -> List[Tensor]:
"""Deduplicate make_jagged ops in the graph.
The rationale is to eliminate redundant offset validation as
well as make the implicit jagged Tensors (sources) in the graph
explicit, by replacing their total_length dimension with the
corresponding JaggedIntVar.
The pass is performed in the following steps:
1. Collect the metadata of the existing make_jagged ops.
2. Remove make_jagged ops from the graph where possible.
3. Apply new make_jagged ops to the (bundled) source inputs.
4. Replace total_length dimensions with new JaggedIntVars.
See the docstrings of the individual steps' helper functions
above for more details.
"""
make_jagged_metadata = _get_make_jagged_metadata(sorted_graph)
if not make_jagged_metadata:
_LOGGER.debug("No make_jagged ops in the graph: skipping.")
return sorted_graph
graph_inputs = {t for t in sorted_graph if t._attrs["is_input"]}
graph_outputs = {t for t in sorted_graph if t._attrs["is_output"]}
_remove_make_jagged_ops(
make_jagged_metadata,
graph_inputs,
graph_outputs,
)
if not make_jagged_metadata:
_LOGGER.debug(
"There are make_jagged ops in the graph, "
"but nothing to deduplicate: skipping."
)
return sorted_graph
# drop the removed make_jagged outputs
sorted_graph = sanitize_sorted_graph(sorted_graph)
new_jagged_int_vars = _apply_make_jagged_to_inputs(
make_jagged_metadata,
sorted_graph,
graph_inputs,
)
_replace_total_length_with_jagged_int_var(
new_jagged_int_vars,
sorted_graph,
graph_inputs,
)
# sort the new make_jagged outputs
sorted_graph = toposort(sorted_graph)
# name the new tensors + do sanity check
sorted_graph = sanitize_sorted_graph(sorted_graph)
return sorted_graph