2023-12-05

torchなどのライブラリを使いpre_trainingを行い文章生成させる際、文章生成の計算は基本的に自分で実装する必要がある。 huggingface用のmodelに変換しても良いが、おれおれアーキテクチャにした場合、変換も面倒…

そこで、huggingfaceにあるtop_kやtop_pなどの便利な実装を使えるようにするメモ。

文章生成についての手法は以下がわかりやすくておすすめ
https://huggingface.co/blog/how-to-generate

logits_processの実装

logits_processの処理は、以下にまとまっている。
https://github.com/huggingface/transformers/blob/v4.35.2/src/transformers/generation/logits_process.py

例えば、top_kは以下のような実装となっている。

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
    top_k = min(self.top_k, scores.size(-1))  # Safety check
    # Remove all tokens with a probability less than the last token of the top-k
    indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
    scores = scores.masked_fill(indices_to_remove, self.filter_value)
    return scores

よく使うのは、top_pやtop_k、NoRepeatNGramLogitsProcessor、RepetitionPenaltyLogitsProcessorあたり

NoRepeatNGramLogitsProcessor
https://github.com/huggingface/transformers/blob/v4.35.2/src/transformers/generation/logits_process.py#L718

RepetitionPenaltyLogitsProcessor
https://github.com/huggingface/transformers/blob/v4.35.2/src/transformers/generation/logits_process.py#L270

huggingfaceでの実装

よく使うのは、model.generate だが、generateからsampleが選択される部分はここ
https://github.com/huggingface/transformers/blob/v4.35.2/src/transformers/generation/utils.py#L1706

sample関数の実装
https://github.com/huggingface/transformers/blob/v4.35.2/src/transformers/generation/utils.py#L2612

huggingfaceのsampleでは、logits_processorは次のように使われている

next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)

...

probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

greedy searchの実装はここ
https://github.com/huggingface/transformers/blob/v4.35.2/src/transformers/generation/utils.py#L2353

argmaxを使っているのがわかる

torchのmodelの出力に対して使う

modelの作り方にも依るが、基本的にはlogitが出力されるように作るはず。

使い方は、LogitsProcessorListを定義して、ここまでの入力xとlogitを渡すだけ。 xはtokenizerされた単語のlist、例えば[20, 15, 1, 1120, ....]のようなtensorとなる。

用途によってtensorの次元を合わせてください。

from transformers.generation.utils import (
    LogitsProcessorList, 
    TopKLogitsWarper,
    TopPLogitsWarper,
    TemperatureLogitsWarper,
    RepetitionPenaltyLogitsProcessor
)

logits_processor = LogitsProcessorList([
    RepetitionPenaltyLogitsProcessor(repetition_penalty),
])

logits_wraper = LogitsProcessorList([
        TemperatureLogitsWarper(temperature),
        TopPLogitsWarper(top_p),
        TopKLogitsWarper(top_k),
])

## modelからlogitsを得る
logits = model(x, input_pos)

logits = logits[:, -1, :] ## sequenceの最後のlogit
next_token_scores = logits_processor(x, logits)
next_token_scores = logits_wraper(x, next_token_scores)
next_token_scores = next_token_scores.squeeze(0) ## 次元の調整

## next_token_scoresに対してsoftmaxで確率にして、次のtokenのindexを出力する
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
next_idx = torch.multinomial(probs, num_samples=1)

所感

top_k, temperatureくらいだと自分で実装しても良いが、RepetitionPenaltyはここまでのindexも見る必要がありめんどうだったので、huggingfaceの実装を使った。
他にも便利な処理もあるので、一度使えるようにしておくと便利。

See Also