Source code for aitemplate.frontend.nn.embedding

#  Copyright (c) Meta Platforms, Inc. and affiliates.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
from aitemplate.compiler import ops
from aitemplate.frontend.nn.dropout import Dropout
from aitemplate.frontend.nn.layer_norm import LayerNorm
from aitemplate.frontend.nn.module import Module
from aitemplate.frontend.nn.parameter import Parameter


[docs]class Embedding(Module): r"""A simple lookup table that stores embeddings of a fixed dictionary and size. This module is often used to store word embeddings and retrieve them using indices. The input to the module is a list of indices, and the output is the corresponding word embeddings. Args: shape (List[int]): denotes the shape of the embeddings which is typically `[num_embeddings, embedding_dim]` where `num_embeddings` is the size of the dictionary of embeddings, and `embedding_dim` is the size of each embedding vector. dtype (string): denotes the data type """ def __init__( self, shape, dtype, ): super().__init__() self.weight = Parameter(shape=shape, dtype=dtype) def tensor(self): return self.weight.tensor()
[docs]class BertEmbeddings(Module): """Construct the embeddings from word, position and token_type embeddings.""" USE_CUDA = None def __init__( self, hidden_size, vocab_size, max_position_embeddings, type_vocab_size, layer_norm_eps, hidden_dropout_prob, dtype="float16", ): super().__init__() assert ( hidden_dropout_prob == 0.0 ), "Dropout rate larger than 0 is not supported yet." self.word_embeddings = Embedding(shape=[vocab_size, hidden_size], dtype=dtype) self.position_embeddings = Embedding( shape=[max_position_embeddings, hidden_size], dtype=dtype, ) self.token_type_embeddings = Embedding( shape=[type_vocab_size, hidden_size], dtype=dtype ) self.LayerNorm = LayerNorm([hidden_size], layer_norm_eps, dtype) self.dropout = Dropout(hidden_dropout_prob)
[docs] def forward( self, input_ids, # [B, S] token_type_ids, # [B, S] position_ids, # [B, S] ): embeddings = ops.bert_embeddings()( input_ids, token_type_ids, position_ids, self.word_embeddings.weight.tensor(), self.token_type_embeddings.weight.tensor(), self.position_embeddings.weight.tensor(), self.LayerNorm.weight.tensor(), self.LayerNorm.bias.tensor(), self.LayerNorm.eps, ) embeddings = self.dropout(embeddings) return embeddings