当前位置: 首页 > 工具软件 > Fairseq > 使用案例 >

在fairseq中使用新注册的模型、损失函数等

公良弘毅
2023-12-01

在fairseq的目录中创建一个文件夹my_dir

/fairseq/my_dir/
└── __init__.py
└── models
	└── simple_lstm.py
└── criterions
	└── cross_entropy.py

在simple_lstm.py中已经注册好模型和架构(tutorial_simple_lstm)
在fairseq中注册模型、架构以及criterion等,见fairseq 官方文档

import torch
from torch import nn
from fairseq import utils
from fairseq.models import FairseqEncoder, FairseqDecoder, FairseqEncoderDecoderModel
from fairseq.models import register_model, register_model_architecture

class SimpleLSTMEncoder(FairseqEncoder):
    def __init__(self, args, dictionary, embed_dim=128, hidden_dim=128, dropout=0.1):
        super().__init__(dictionary)
        self.args = args

        self.embedding = nn.Embedding(
            num_embeddings=len(dictionary),
            embedding_dim=embed_dim,
            padding_idx=dictionary.pad(),
        )
        self.dropout = nn.Dropout(dropout)

        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=1,
            bidirectional=False,
            batch_first=True,
        )

    def forward(self, src_tokens, src_lengths):
        # 将padding变到右边
        if self.args.left_pad_source:
            src_tokens = utils.convert_padding_direction(
                src_tokens,
                padding_idx=self.dictionary.pad(),
                left_to_right=True,
            )
        
        x = self.embedding(src_tokens)

        x = self.dropout(x)
        # 将序列打包到PackedSequence对象中以提供给LSTM
        x = nn.utils.rnn.pack_padded_sequence(x, src_lengths.cpu(), batch_first=True)

        outputs,(final_hidden, final_cell) = self.lstm(x)

        return {
            'final_hidden': final_hidden.squeeze(0)
        }
    
    def reorder_encoder_out(self, encoder_out, new_order):
        '''
        encoder_out是从forward函数中的返回值
        new_order(LongTensor)是想要的顺序
        '''
        final_hidden = encoder_out['final_hidden']

        return {
            'final_hidden':final_hidden.index_select(0, new_order)
        }
    

class SimpleLSTMDecoder(FairseqDecoder):
    def __init__(self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128, dropout=0.1):
        super().__init__(dictionary)
        self.embedding = nn.Embedding(
            num_embeddings=len(dictionary),
            embedding_dim=embed_dim,
            padding_idx=dictionary.pad(),
        )

        self.dropout = nn.Dropout(dropout)

        self.lstm = nn.LSTM(
            input_size = encoder_hidden_dim + embed_dim,
            hidden_size= hidden_dim,
            num_layers= 1,
            bidirectional=False,
        )

        self.out_project = nn.Linear(hidden_dim, len(dictionary))

    def forward(self, prev_output_tokens, encoder_out):

        bsz, tgt_len = prev_output_tokens.size()

        final_encoder_hidden = encoder_out['final_hidden']

        x = self.embedding(prev_output_tokens)

        x = self.dropout(x)

        x = torch.cat(
            [x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)],
            dim=2,
        )

        initial_state = (
            final_encoder_hidden.unsqueeze(0),  # hidden
            torch.zeros_like(final_encoder_hidden).unsqueeze(0),  # cell
        )

        output, _ = self.lstm(
            x.transpose(0,1),
            initial_state,
        )

        x = output.transpose(0,1)
        x = self.out_project(x)

        return x, None
    
# 注册模型
@register_model('simple_lstm')
class SimpleLSTMModel(FairseqEncoderDecoderModel):
    @staticmethod
    def add_args(parser):
        parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
                            help='dim of the encoder embeddings')
        parser.add_argument('--encoder-hidden-dim', type=int, metavar='N',
                            help='dim of the encoder hidden state')
        parser.add_argument('--encoder-dropout', type=float, default=0.1,
                            help='encoder dropout probability')
        
        parser.add_argument('--decoder-embed-dim', type=int, metavar='N', 
                            help='dim of the decoder embeddings')
        parser.add_argument('--decoder-hidden-dim', type=int, metavar='N', 
                            help='dim of the decoder hidden state')
        parser.add_argument('--decoder-dropout', type=float, default=0.1,
                            help='decoder dropout probability')
        
    @classmethod
    def build_model(cls, args, task):
        encoder = SimpleLSTMEncoder(
            args=args,
            dictionary=task.source_dictionary,
            embed_dim=args.encoder_embed_dim,
            hidden_dim=args.encoder_hidden_dim,
            dropout=args.encoder_dropout,
        )
        decoder = SimpleLSTMDecoder(
            dictionary=task.target_dictionary,
            encoder_hidden_dim=args.encoder_hidden_dim,
            embed_dim=args.decoder_embed_dim,
            hidden_dim=args.decoder_hidden_dim,
            dropout=args.decoder_dropout,
        )

        model = SimpleLSTMModel(encoder, decoder)
        print(model)

        return model
    
    # 默认的forward如下,如果要修改forward函数,可以重写
    # def forward(self, src_tokens, src_lengths, prev_output_tokens):
    #     encoder_out = self.encoder(src_tokens, src_lengths)
    #     decoder_out = self.decoder(prev_output_tokens, encoder_out)
    #     return decoder_out


# 注册架构(模型名称,架构名字)
@register_model_architecture('simple_lstm', 'tutorial_simple_lstm')
def tutorial_simple_lstm(args):
    args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
    args.encoder_hidden_dim = getattr(args, 'encoder_hidden_dim', 256)
    args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256)
    args.decoder_hidden_dim = getattr(args, 'decoder_hidden_dim', 256)

在cross_entropy.py中已经注册好的损失函数(cross_entrop)

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math
from dataclasses import dataclass

import torch.nn.functional as F
from fairseq import utils
from fairseq.logging import metrics
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
from omegaconf import II


@dataclass
class CrossEntropyCriterionConfig(FairseqDataclass):
    sentence_avg: bool = II("optimization.sentence_avg")


@register_criterion("cross_entrop", dataclass=CrossEntropyCriterionConfig)
class CrossEntropyCriterion(FairseqCriterion):
    def __init__(self, task, sentence_avg):
        super().__init__(task)
        self.sentence_avg = sentence_avg

    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(**sample["net_input"])
        loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce)
        sample_size = (
            sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
        )
        logging_output = {
            "loss": loss.data,
            "ntokens": sample["ntokens"],
            "nsentences": sample["target"].size(0),
            "sample_size": sample_size,
        }
        return loss, sample_size, logging_output

    def compute_loss(self, model, net_output, sample, reduce=True):
        lprobs = model.get_normalized_probs(net_output, log_probs=True)
        lprobs = lprobs.view(-1, lprobs.size(-1))
        target = model.get_targets(sample, net_output).view(-1)
        loss = F.nll_loss(
            lprobs,
            target,
            ignore_index=self.padding_idx,
            reduction="sum" if reduce else "none",
        )
        return loss, loss

    @staticmethod
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
        ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
        sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)

        # we divide by log(2) to convert the loss from base e to base 2
        metrics.log_scalar(
            "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
        )
        if sample_size != ntokens:
            metrics.log_scalar(
                "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
            )
            metrics.log_derived(
                "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
            )
        else:
            metrics.log_derived(
                "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
            )

    @staticmethod
    def logging_outputs_can_be_summed() -> bool:
        """
        Whether the logging outputs returned by `forward` can be summed
        across workers prior to calling `reduce_metrics`. Setting this
        to True will improves distributed training speed.
        """
        return True

在my_dir中的__init__.py文件中

from .models import simple_lstm
from .criterions import cross_entropy

训练

fairseq-train data-bin/iwslt14.tokenized.de-en/ --arch tutorial_simple_lstm \
--encoder-dropout 0.2 --decoder-dropout 0.2 --optimizer adam --lr 0.005 \
--lr-shrink 0.5 --max-tokens 12000 --criterion cross_entrop \
--user-dir fairseq/my_dir/
 类似资料: