Huggingface 模型修改
Huggingface 里面的模型封装的很好,想要直接修改代码并非容易的事,但是如果看文档,它有很多参数,能把你想到的大部分结果取出来,下面我就以一次经历来讲讲我如何在 T5 模型上面加一个 feature fusion 层。
查看文档
我使用的是生成自然语言的 T5, 想在encoder 输出加一个 fusion layer。首先查看文档,并把 源代码 复制下来查看,发现可以直接用 T5.encoder 对 input_ids 进行编码,然后把 encoder_outputs 直接输入 T5ForConditionalGeneration,就可以了。
预测时 T5 模型记成了 model.generate,这个是所有继承 PreTrainedModel 的 生成式model 都有的一个 function,可以方便的进行 greedy search, beam search。查看文档,可以直接把模型的输入参数输入 generate 函数里面。
model*kwargs — Additional model specific kwargs will be forwarded to the
forward
function of the model. If the model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder**.
具体代码
class T5MultiCode(nn.Module):
def __init__(self, T5ForConditionalGenerationModel,config,args=None):
super().__init__()
self.t5model = T5ForConditionalGenerationModel
self.fusion_position = args.FusionPosition # encoder, decoder
self.fusion_layer = Attention.Attention_1(
768, with_ave=args.AttentionWithAve, mul=False, common_type=args.AttentionCommonType)
self.loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
self.args = args
pass
def _encoder_fusion(self, source_ids=None, source_masks=None,
target_ids=None, target_masks=None, code_lens=None):
if calculate_loss
code_encoder_output_seq = self.t5model.encoder(
input_ids=code_input, attention_mask=source_mask)
...
outputs = self.t5model(encoder_outputs=code_encoder_output_seq, attention_mask=source_mask, labels=target_id, decoder_attention_mask=target_mask)
losses = outputs.loss if losses is None else losses + outputs.loss
else:
# predict
pred = self.t5model.generate(encoder_outputs=code_encoder_output_seq,
attention_mask=source_mask,
use_cache=True,
num_beams=args.beam_size,
early_stopping=args.task == 'summarize',
max_length=args.max_target_length)[0]
preds.append(pred)
def forward(self, source_ids=None, source_mask=None, target_ids=None, target_mask=None, args=None,
code_lens=None):
return self._encoder_fusion(source_ids, source_mask, target_ids, target_mask, code_lens)
版权声明:本文为qq_44761480原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。