๐Ÿค— ํ—ˆ๊น… ํŽ˜์ด์Šค Trainer ๋ชจ๋“ˆ ์‚ฌ์šฉํ•˜๊ธฐ

May 15, 2024, 3:01 p.m. ยท 7 min read ยท ๐ŸŒ๏ธŽ ko

deep learning NLP implementation

ํ—ˆ๊น… ํŽ˜์ด์Šค๐Ÿค— transformers ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋Š” ๋‹ค์–‘ํ•œ ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ๊ณผ ๋ฐ์ดํ„ฐ์…‹์„ ๊ฐ„ํŽธํ•˜๊ฒŒ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์–ด ๋„๋ฆฌ ์‚ฌ์šฉ๋˜๋Š” ํŒจํ‚ค์ง€์ด๋‹ค. transformers ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์˜ ๋ฉ”์ธ์€ ๋ฐ”๋กœ ๋ชจ๋ธ ํ›ˆ๋ จ์„ ์œ„ํ•œ Trainer ํ•จ์ˆ˜๋ผ ํ•  ์ˆ˜ ์žˆ๋Š”๋ฐ, ๋ชจ๋ธ ํ›ˆ๋ จ์„ ์œ„ํ•œ ์ •๋ง ๋งŽ์€ ๊ธฐ๋Šฅ๋“ค์„ ์ œ๊ณตํ•˜๊ณ  ์žˆ๋‹ค. ๊ทธ๋Ÿฐ๋ฐ ์ด๋ ‡๊ฒŒ ๋งŽ์€ ๊ธฐ๋Šฅ๋“ค์„ ์ „๋ถ€ ์„ค์ •ํ•˜๋ ค๋ฉด ์—„์ฒญ๋‚˜๊ฒŒ ๋งŽ์€ ํŒŒ๋ผ๋ฏธํ„ฐ๋“ค์„ ์ธ์ž๋กœ ๋ฐ›์•„์•ผ ํ•œ๋‹ค. (ํ•„์ž๋ฅผ ํฌํ•จํ•ด) ๋งŽ์€ ์‚ฌ๋žŒ๋“ค์ด ํ—ˆ๊น…ํŽ˜์ด์Šค ์‚ฌ์šฉ์„ ์–ด๋ ค์›Œํ•˜๋Š” ์ด์œ ์ด๋‹ค.

์ด ๊ธ€์—์„œ๋Š” ๋จผ์ € ํ—ˆ๊น…ํŽ˜์ด์Šค Trainer ํด๋ž˜์Šค์˜ ์‚ฌ์šฉ๋ฒ•์„ ์•Œ์•„๋ณธ๋‹ค. ํŠนํžˆ ์ž์ฃผ ์‚ฌ์šฉํ•˜๋Š” ์ธ์ž๋“ค๋กœ ์–ด๋–ค ๊ฒƒ์ด ์žˆ๋Š”์ง€, ์–ด๋–ค ๋ฐฉ์‹์œผ๋กœ Trainer์— ๋„ฃ์–ด์ค˜์•ผ ํ•˜๋Š”์ง€๋ฅผ ์•Œ์•„๋ณธ๋‹ค. ๋˜, GPT-2 ๋ชจ๋ธ์„ IMDB ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ fine-tuning์‹œ์ผœ์„œ ์˜ํ™” ๋ฆฌ๋ทฐ ๋ถ„๋ฅ˜์— ์‹ค์ œ๋กœ ์ ์šฉํ•ด๋ณธ๋‹ค.

ํ—ˆ๊น…ํŽ˜์ด์Šค๋ž€?

ํ—ˆ๊น…ํŽ˜์ด์Šค(HuggingFace) ๐Ÿค—๋Š” ๋จธ์‹ ๋Ÿฌ๋‹/๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์„ ํ›ˆ๋ จ, ๊ณต์œ ํ•˜๊ณ  ๋ฐฐํฌํ•˜๊ธฐ ์œ„ํ•œ ์„œ๋น„์Šค๋“ค์„ ์ œ๊ณตํ•˜๋Š” ํšŒ์‚ฌ์ด์ž ์ปค๋ฎค๋‹ˆํ‹ฐ์ด๋‹ค. ํ—ˆ๊น…ํŽ˜์ด์Šค์—์„œ ์ œ๊ณตํ•˜๋Š” ์„œ๋น„์Šค๋Š” ํฌ๊ฒŒ ๋‘ ๊ฐœ๋กœ ๋‚˜๋‰  ์ˆ˜ ์žˆ๋‹ค.

์ด ๊ธ€์—์„œ ๋‹ค๋ฃจ๋Š” ๊ฒƒ์€ ์ด ์ค‘ ์ „์ž๋กœ, ํ—ˆ๊น…ํŽ˜์ด์Šค์—์„œ ์ œ๊ณตํ•˜๋Š” transformers ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์˜ ์‚ฌ์šฉ๋ฒ•์„ ์•Œ์•„๋ณด๊ฒ ๋‹ค.

Trainer์™€ TrainingArguments์˜ ๋งค๊ฐœ๋ณ€์ˆ˜๋“ค

ํ—ˆ๊น…ํŽ˜์ด์Šค Trainer API๋ฅผ ์ด์šฉํ•ด์„œ ๋ชจ๋ธ์„ ํ›ˆ๋ จํ•  ๋•Œ, ๋งค๊ฐœ๋ณ€์ˆ˜๋“ค ์ค‘ ์ผ๋ถ€๋Š” TraininingArguments๋กœ, ์ผ๋ถ€๋Š” Trainer์— ๋„ฃ์–ด์ฃผ์–ด์•ผ ํ•œ๋‹ค. ์˜ˆ์‹œ๋ฅผ ๋“ค์ž๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.

from transformers import Trainer, TrainingArguments

training_arguments = TrainingArguments(
ย  ย ย output_dir='./results',
ย  ย ย evaluation_strategy="epoch",
ย  ย ย num_train_epochs=3,
ย  ย ย per_device_train_batch_size=16,
ย  ย ย per_device_eval_batch_size=32,
ย  ย ย learning_rate=3e-5,
ย  ย ย logging_strategy="epoch",
ย  ย ย load_best_model_at_end=True,
ย  ย ย save_strategy="epoch",
ย  ย ย metric_for_best_model="accuracy",
)

trainer = Trainer(
ย  ย ย model=AutoModelForSequenceClassification.from_pretrained(โ€bert-base-uncasedโ€),
ย  ย ย train_dataset=ds_train,
ย  ย ย eval_dataset=ds_test,
ย  ย ย args=training_arguments,
ย  ย ย compute_metrics=compute_metrics,
)

trainer.train()

์ด๋ ‡๊ฒŒ TrainingArguments ๊ฐ์ฒด๋ฅผ ํ•˜๋‚˜ ๋งŒ๋“  ํ›„, ์ด๋ฅผ ๋‹ค์‹œ Trainer์— args ๋งค๊ฐœ๋ณ€์ˆ˜๋กœ ๋„ฃ์–ด์ฃผ๋Š” ์‹์œผ๋กœ training์— ํ•„์š”ํ•œ ์ •๋ณด๋“ค์„ ์•Œ๋ ค์ฃผ์–ด์•ผ ํ•œ๋‹ค. ๋จผ์ € TrainingArguments๊ฐ€ ๋ฐ›๋Š” ๋งค๊ฐœ๋ณ€์ˆ˜๋“ค๋ถ€ํ„ฐ ์‚ดํŽด๋ณด์ž.

TrainingArguments

๊ณต์‹ API ๋ฌธ์„œ์—์„œ TrainingArguments๋ฅผ ์ฐพ์•„๋ณด๋ฉด argument๋“ค์˜ ๋งค์šฐ ๊ธด ๋ชฉ๋ก์ด ๋‚˜์˜จ๋‹ค. ์ด ์ค‘ ์ž์ฃผ ์“ฐ์ด๋Š” ๋ช‡ ๊ฐœ๋ฅผ ์ •๋ฆฌํ•ด๋ณด์•˜๋‹ค.

  1. output_dir: str: ์œ ์ผํ•œ required argument๋กœ, ํ›ˆ๋ จ๋œ ๋ชจ๋ธ๊ณผ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์–ด๋””์— ์ €์žฅํ• ์ง€๋ฅผ ์˜๋ฏธํ•œ๋‹ค. ํŒŒ์ผ์‹œ์Šคํ…œ ์ƒ์˜ ๊ฒฝ๋กœ๋กœ ์ง€์ •ํ•˜๊ฑฐ๋‚˜(e.g. โ€˜./resultsโ€™) ํ—ˆ๊น…ํŽ˜์ด์Šค ํ—ˆ๋ธŒ์— ์ €์žฅํ•  repository ์ด๋ฆ„์œผ๋กœ ์„ค์ •ํ• ์ˆ˜๋„ ์žˆ๋‹ค (e.g. โ€™bert-uncased-imdb-finetunedโ€™). ํ›„์ž์˜ ๊ฒฝ์šฐ ํŠธ๋ ˆ์ด๋‹ ํ›„ trainer.push_to_hub()๋ฅผ ํ•˜๋ฉด ์ž๋™์œผ๋กœ ํ—ˆ๋ธŒ์— ์—…๋กœ๋“œ๋œ๋‹ค.

    • overwrite_output_dir: bool: True๋กœ ์„ค์ •์‹œ, output_dir์— ์ด๋ฏธ ํŒŒ์ผ์ด ์กด์žฌํ•˜๋Š” ๊ฒฝ์šฐ์—๋„ ๋ฎ์–ด์“ฐ๊ธฐ๋ฅผ ํ•œ๋‹ค.
  2. num_train_epochs: ํ›ˆ๋ จํ•  ์—ํฌํฌ(epoch)์˜ ์ˆ˜์ด๋‹ค.

  3. per_device_train_batch_size: train ์‹œ์˜ batch size๋ฅผ ์ง€์ •ํ•ด์ค€๋‹ค.
  4. per_device_eval_batch_size: evaluation ์‹œ์˜ batch size๋ฅผ ์ง€์ •ํ•ด์ค€๋‹ค.
  5. learning_rate: learning rate(ํ•™์Šต๋ฅ )์„ ์ง€์ •ํ•ด์ค€๋‹ค.

    • lr_scheduler_type: learning rate scheduler๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ์‹ถ์€ ๊ฒฝ์šฐ ์ง€์ •ํ•ด์ค„ ์ˆ˜ ์žˆ๋‹ค. default๋Š” ๋‹น์—ฐํžˆ constant์ด๊ณ , linear, cosine, cosine_with_restarts, polynomial, constant_with_warmup ๋“ฑ์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค.
    • constant๊ฐ€ ์•„๋‹Œ LR scheduler๋ฅผ ์‚ฌ์šฉํ•  ์‹œ ์ถ”๊ฐ€์ ์œผ๋กœ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ๋„ฃ์–ด์ฃผ์–ด์•ผ ํ•œ๋‹ค.
  6. weight_decay: L2 weight decay๋ฅผ ์„ค์ •ํ•œ๋‹ค.

  7. logging_strategy: ๋กœ๊ทธ๋ฅผ ์–ด๋–ป๊ฒŒ ๋‚จ๊ธธ์ง€ ์„ค์ •ํ•œ๋‹ค. ๊ธฐ๋ณธ๊ฐ’์€ steps๋กœ, no๋กœ ์„ค์ •ํ•˜๋ฉด ๋กœ๊ทธ๋ฅผ ๋‚จ๊ธฐ์ง€ ์•Š์œผ๋ฉฐ epoch๋กœ ์„ค์ •ํ•˜๋ฉด ํ•œ ์—ํฌํฌ๊ฐ€ ๋๋‚  ๋•Œ๋งˆ๋‹ค, steps๋กœ ์„ค์ •ํ•˜๋ฉด ๋งค logging_steps๋งˆ๋‹ค ๋กœ๊ทธ๋ฅผ ๋‚จ๊ธฐ๊ฒŒ ๋œ๋‹ค.

    • ์ฆ‰, logging_strategy=โ€˜stepsโ€™๋กœ ์„ค์ •์‹œ logging_steps๋ฅผ ๊ฐ™์ด ์„ค์ •ํ•ด์ฃผ์–ด์•ผ ํ•œ๋‹ค. logging_steps๋Š” ์ •์ˆ˜๋ฅผ ๋„ฃ์–ด์ค„ ์ˆ˜๋„ ์žˆ์ง€๋งŒ 0์—์„œ 1 ์‚ฌ์ด์˜ float ๊ฐ’์œผ๋กœ ์ง€์ •ํ•  ์ˆ˜๋„ ์žˆ๋Š”๋ฐ, ์ด ๊ฒฝ์šฐ ์ „์ฒด training step์— ์ด๋ฅผ ๊ณฑํ•œ ๊ฐ’์„ ์‚ฌ์šฉํ•œ๋‹ค.
  8. save_strategy: ๋ชจ๋ธ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์–ธ์ œ ์ €์žฅํ• ์ง€ ์„ค์ •ํ•œ๋‹ค. ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ no, steps, epoch ์ค‘ ํ•˜๋‚˜๋ฅผ ๊ณ ๋ฅผ ์ˆ˜ ์žˆ์œผ๋ฉฐ steps๋กœ ์„ค์ •๋˜๋ฉด save_steps๋ฅผ ๊ฐ™์ด ์„ค์ •ํ•ด์ฃผ์–ด์•ผ ํ•œ๋‹ค.

    • save_total_limit: ์ตœ๋Œ€๋กœ ์ €์žฅํ•  ์ˆ˜ ์žˆ๋Š” ๋ชจ๋ธ ์ฒดํฌํฌ์ธํŠธ์˜ ์ˆ˜๋ฅผ ์ง€์ •ํ•ด์ค€๋‹ค.
  9. evaluation_strategy: evaluation์„ ์–ธ์ œ ์ˆ˜ํ–‰ํ• ์ง€๋ฅผ ๊ฒฐ์ •ํ•œ๋‹ค. ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ no, steps, epoch ์ค‘ ํ•˜๋‚˜๋ฅผ ๊ณ ๋ฅผ ์ˆ˜ ์žˆ์œผ๋ฉฐ steps๋กœ ์„ค์ •๋˜๋ฉด evaluation_steps๋ฅผ ๊ฐ™์ด ์„ค์ •ํ•ด์ฃผ์–ด์•ผ ํ•œ๋‹ค.

  10. use_cpu: True๋กœ ์„ค์ •ํ•  ์‹œ ์‚ฌ์šฉ๊ฐ€๋Šฅํ•œ GPU๊ฐ€ ์žˆ์–ด๋„ CPU์—์„œ ์‹คํ–‰ํ•œ๋‹ค.
  11. seed: ๋žœ๋ค ์‹œ๋“œ๋ฅผ ์„ค์ •ํ•ด์ค€๋‹ค. ๋งŽ์€ ๊ฒฝ์šฐ 42๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค.
  12. fp16: fp16 mixed-precision์„ ์‚ฌ์šฉํ• ์ง€ ์—ฌ๋ถ€๋ฅผ True/False๋กœ ์„ค์ •ํ•ด์ค€๋‹ค.
  13. disable_tqdm: False๋กœ ์„ค์ •์‹œ, ํ‘œ ๋Œ€์‹  tqdm์„ ์‚ฌ์šฉํ•ด์„œ progress bar๋ฅผ ํ‘œ์‹œํ•ด์ค€๋‹ค.
  14. load_best_model_at_end: True๋กœ ์„ค์ •ํ•˜๋ฉด ํ›ˆ๋ จ์ด ๋๋‚ฌ์„ ๋•Œ ๊ฐ€์žฅ ์„ฑ๋Šฅ์ด ์ข‹์€ ๋ชจ๋ธ์„ ๋กœ๋“œํ•ด์ค€๋‹ค.

    • metric_for_best_model: ์ด๋•Œ โ€˜์„ฑ๋Šฅ์ด ์ข‹๋‹คโ€™๋Š” ๊ฒƒ์˜ ๊ธฐ์ค€์„ ๋ฌด์—‡์œผ๋กœ ์‚ผ์„์ง€ ์ง€์ •ํ•ด์ค€๋‹ค. ํ›„์ˆ ํ•  compute_metrics ํ•จ์ˆ˜์—์„œ ๋ฐ˜ํ™˜ํ•˜๋Š” metric ์ค‘ ํ•˜๋‚˜์˜ ์ด๋ฆ„์„ string์œผ๋กœ ๋„ฃ์–ด์ฃผ๋ฉด ๋œ๋‹ค (e. g. โ€œaccuracyโ€)
    • greater_is_better: ํ•ด๋‹น metric์ด ๋†’์„ ์ˆ˜๋ก ์ข‹์€ ๊ฒƒ์ธ์ง€, ๋‚ฎ์„์ˆ˜๋ก ์ข‹์€ ๊ฒƒ์ธ์ง€๋ฅผ ์•Œ๋ ค์ค€๋‹ค.

Trainer

TrainingArguments์—์„œ ํ•™์Šต์„ ์‹œํ‚ฌ ๋•Œ ์•Œ๋ ค์ค˜์•ผ ํ•  ์„ธ๋ถ€์‚ฌํ•ญ๋“ค์„ ๋„ฃ์–ด์ฃผ์—ˆ๋‹ค๋ฉด, Trainer์—์„œ๋Š” ์ข€ ๋” ๊ธฐ๋ณธ์ ์ธ ๊ตต์งํ•œ ์ •๋ณด๋“ค์„ ์•Œ๋ ค์ฃผ์–ด์•ผ ํ•œ๋‹ค.

  1. model: transformer ๋ชจ๋ธ ๊ฐ์ฒด๋‚˜ PyTorch์˜ nn.Module ๊ฐ์ฒด๋ฅผ ๋„ฃ์–ด์ฃผ๋ฉด ๋œ๋‹ค.

    • model ๋Œ€์‹ ์—, ์ƒˆ๋กœ์šด ๋ชจ๋ธ ๊ฐ์ฒด๋ฅผ ํ•˜๋‚˜ ๋งŒ๋“ค์–ด ๋ฐ˜ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜์ธ model_init์„ ์ œ๊ณตํ•ด์ค„ ์ˆ˜๋„ ์žˆ๋‹ค.
  2. args: ์•ž์„œ ์†Œ๊ฐœํ•œ TrainingArguments ๊ฐ์ฒด์ด๋‹ค.

  3. data_collator: train/evaluation dataset์— ์žˆ๋Š” ์›์†Œ๋“ค์˜ list๋ฅผ ๋ฌถ์–ด์„œ batch๋กœ ๋งŒ๋“ค์–ด์ฃผ๊ธฐ ์œ„ํ•ด ์‚ฌ์šฉ๋˜๋Š” ํ•จ์ˆ˜์ด๋‹ค. tokenizer๋ฅผ ์ง€์ •ํ•ด์ฃผ์ง€ ์•Š์€ ๊ฒฝ์šฐ default_data_collator()๊ฐ€ ์‚ฌ์šฉ๋˜๊ณ , ์ง€์ •ํ•ด์ค€ ๊ฒฝ์šฐ๋Š” DataCollatorWithPadding์˜ ์ธ์Šคํ„ด์Šค๊ฐ€ ์‚ฌ์šฉ๋œ๋‹ค.
  4. train_dataset: ๊ฐ€์žฅ ์ค‘์š”ํ•˜๋‹ค๊ณ  ํ•  ์ˆ˜ ์žˆ๋Š” ํ›ˆ๋ จ์šฉ ๋ฐ์ดํ„ฐ์…‹์„ ์ง€์ •ํ•ด์ค€๋‹ค. transformer์˜ dataset.Dataset์ผ์ˆ˜๋„, PyTorch์˜ Dataset์ผ ์ˆ˜๋„ ์žˆ๋‹ค.
  5. eval_dataset: Evaluation์šฉ ๋ฐ์ดํ„ฐ์…‹์„ ์ง€์ •ํ•ด์ค€๋‹ค. ํ˜•์‹์€ train_dataset๊ณผ ๊ฐ™๋‹ค.
  6. tokenizer: ๋ฐ์ดํ„ฐ๋ฅผ ์ „์ฒ˜๋ฆฌํ•˜๊ธฐ ์œ„ํ•œ ํ† ํฌ๋‚˜์ด์ €๋ฅผ ์ง€์ •ํ•ด์ค€๋‹ค.
  7. compute_metrics: Evaluation์‹œ์— metric๋“ค์„ ๊ณ„์‚ฐํ•ด์ฃผ๊ธฐ ์œ„ํ•œ ํ•จ์ˆ˜์ด๋‹ค. ํ•จ์ˆ˜์˜ ์ž…๋ ฅ๊ณผ ์ถœ๋ ฅ์€ ํŠน์ • ํ˜•์‹์„ ๋”ฐ๋ผ์•ผ๋งŒ ํ•˜๋Š”๋ฐ, ์ด๋Š” ๋’ค์—์„œ ๋งˆ์ € ์„ค๋ช…ํ•˜๊ฒ ๋‹ค.
  8. optimizers: ํ›ˆ๋ จ์— ์‚ฌ์šฉํ•  optimizer์™€ LR scheduler๋ฅผ ์ง€์ •ํ•ด์ค€๋‹ค. torch.optim.Optimizer ๊ฐ์ฒด์™€ torch.optim.lr_scheduler.LambdaLR ๊ฐ์ฒด์˜ tuple์„ ์š”๊ตฌํ•œ๋‹ค. ์•„๋ฌด๊ฒƒ๋„ ์ž…๋ ฅํ•˜์ง€ ์•Š์œผ๋ฉด AdamW๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค.

๋ณต์žกํ•œ ํ˜•์‹์„ ๋”ฐ๋ผ์•ผ ํ•˜๋Š” ์ธ์ž๋“ค

์•ž์—์„œ ์„ค๋ช…ํ•œ TrainingArguments์™€ Trainer์˜ ๋งค๊ฐœ๋ณ€์ˆ˜๋“ค ์ค‘์—์„œ๋Š” ์ˆซ์ž๋‚˜ ๋ฌธ์ž์—ด์ด ์•„๋‹Œ, ํ•จ์ˆ˜๋‚˜ ํŠน์ • ํด๋ž˜์Šค์˜ ๊ฐ์ฒด๋ฅผ ์š”๊ตฌํ•˜๋Š” ๊ฒƒ๋“ค์ด ์žˆ๋‹ค. ์ด ๋•Œ ์ธ์ž๋กœ ์ฃผ์–ด์ง€๋Š” ํ•จ์ˆ˜๋Š” ๋‹น์—ฐํžˆ ํŠน์ •ํ•œ ์ž…๋ ฅ๊ณผ ์ถœ๋ ฅ ํ˜•์‹์„ ๋”ฐ๋ผ์•ผ๋งŒ ํ•  ๊ฒƒ์ด๊ณ , ๊ฐ์ฒด๋Š” ๋‹น์—ฐํžˆ ์ •ํ•ด์ ธ ์žˆ๋Š” ํŠน์ • ํด๋ž˜์Šค์˜ ๊ฐ์ฒด์—ฌ์•ผ๋งŒ ํ•  ๊ฒƒ์ด๋‹ค. ๊ทธ๋Ÿฌ์ง€ ์•Š์œผ๋ฉด ์—๋Ÿฌ๊ฐ€ ๋ฐœ์ƒํ•˜๊ฒŒ ๋œ๋‹ค. ์„ค๋ช…ํ•œ ๊ฒƒ ์ธ์ž๋“ค ์ค‘์—์„œ๋Š” data_collator์™€ compute_metrics๊ฐ€ ์ด๋Ÿฌํ•œ ๊ฒฝ์šฐ์— ํ•ด๋‹นํ•˜๋Š”๋ฐ, ๊ฐ๊ฐ ์–ด๋–ค ํ˜•์‹์„ ๋”ฐ๋ผ์•ผ ํ•˜๋Š”์ง€ ๊ฐ„๋žตํ•˜๊ฒŒ ์•Œ์•„๋ณด์ž.

data_collator

data_collator๋Š” ์•ž์„œ ์„ค๋ช…ํ–ˆ๋“ฏ์ด ๋ฐ์ดํ„ฐ๋ฅผ batch๋กœ ๋ฌถ์–ด model์— ์ „๋‹ฌํ•  ์ˆ˜ ์žˆ๋Š” ํ˜•ํƒœ๋กœ ๋งŒ๋“ค์–ด์ฃผ๋ฉฐ, DataCollator ํด๋ž˜์Šค์˜ ์ธ์Šคํ„ด์Šค๊ฐ€ ๋˜๋„๋ก ์ •ํ•ด์ ธ ์žˆ๋‹ค. DataCollator ํด๋ž˜์Šค๋Š” ์—ฌ๋Ÿฌ ์ž์‹ ํด๋ž˜์Šค๋ฅผ ๊ฐ€์ง€๋Š”๋ฐ, ์–ด๋–ค task๋ฅผ ์ˆ˜ํ–‰ํ•˜๋Š๋ƒ์— ๋”ฐ๋ผ ๋‹ค๋ฅธ ๊ฒƒ์„ ์‚ฌ์šฉํ•˜๋ฉด ๋œ๋‹ค.

compute_metrics

Evaluation ์‹œ์— ์‚ฌ์šฉํ•  metric์„ ๊ณ„์‚ฐํ•ด์ฃผ๋Š” compute_metrics๋Š” EvalPrediction ๊ฐ์ฒด๋ฅผ ์ž…๋ ฅ์œผ๋กœ ๋ฐ›์•„ dictionary๋ฅผ ์ถœ๋ ฅํ•˜๋Š” ํ•จ์ˆ˜์ด๋‹ค. ์ด๋•Œ EvalPrediction์€ ์ผ์ข…์˜ named tuple์œผ๋กœ, predictions์™€ label_ids๋ผ๋Š” ๋‘ ๊ฐœ์˜ ์†์„ฑ์„ ํ•„์ˆ˜์ ์œผ๋กœ ๊ฐ–๋Š”๋‹ค. ์ด๋ฆ„์—์„œ ์•Œ ์ˆ˜ ์žˆ๋“ฏ์ด, predictions๋Š” ๋ชจ๋ธ์˜ ์˜ˆ์ธก๊ฐ’, label_ids๋Š” ๋ฐ์ดํ„ฐ์…‹์ด ์ œ๊ณตํ•˜๋Š” ์ •๋‹ต์„ ์˜๋ฏธํ•œ๋‹ค. ์ด ๋‘˜์„ ์—ฌ๋Ÿฌ ๊ฐœ์˜ metric์„ ์‚ฌ์šฉํ•ด ๋น„๊ตํ•˜๋Š” ๊ฒƒ์ด ๋ฐ”๋กœ compute_metrics์˜ ์—ญํ• ์ด๋ผ๊ณ  ํ•  ์ˆ˜ ์žˆ๋‹ค. ๊ณ„์‚ฐ์„ ์™„๋ฃŒํ•˜๋ฉด metric์˜ ์ด๋ฆ„์„ key๋กœ, ๊ทธ ๊ฐ’์„ value๋กœ ํ•˜๋Š” dictionary๋ฅผ ๋ฐ˜ํ™˜ํ•ด์•ผ ํ•œ๋‹ค.

๋‹ค์Œ์€ compute_metrics๋ฅผ ์ž‘์„ฑํ•œ ์˜ˆ์‹œ์ด๋‹ค.

import numpy as np
from datasets import load_metric
from transformers import TrainingArguments, Trainer

# metric๋“ค์„ ๊ฐ€์ ธ์˜ค๊ธฐ
accuracy_metric = load_metric("accuracy")
f1_metric = load_metric("f1")

def compute_metrics(eval_pred):
    predictions, label_ids = eval_pred 
    # predictions, label_ids = eval_pred.predictions, eval_pred.label_ids 
    # ์™€ ๊ฐ™์ด ์ ‘๊ทผํ•  ์ˆ˜๋„ ์žˆ๋‹ค. 

    preds = predictions.argmax(axis=1)

    accuracy = accuracy_metric.compute(predictions=preds, references=label_ids)
    f1 = f1_metric.compute(predictions=preds, references=label_ids, average="weighted")

    return {
        "accuracy": accuracy["accuracy"],
        "f1": f1["f1"],
    }

์ปค์Šคํ…œ ๋ชจ๋ธ ์‚ฌ์šฉํ•˜๊ธฐ

Trainer API๋Š” transformers ๋ชจ๋ธ์„ ํ›ˆ๋ จ์‹œํ‚ค๋Š” ๊ฒƒ์— ์ตœ์ ํ™”๋˜์–ด ์žˆ์ง€๋งŒ, ์‚ฌ์šฉ์ž๊ฐ€ PyTorch๋กœ ๊ตฌํ˜„ํ•œ ์ปค์Šคํ…€ ๋ชจ๋ธ์„ ํ›ˆ๋ จ์‹œํ‚ฌ ๋•Œ๋„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค. Trainer API documentation์—์„œ๋Š” ์ปค์Šคํ…€ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•  ์‹œ ์ฃผ์˜ํ•ด์•ผ ํ•  ์ ๋“ค์„ ์–ธ๊ธ‰ํ•˜๊ณ  ์žˆ๋‹ค.

GPT ํŒŒ์ธํŠœ๋‹ํ•˜๊ธฐ

์ด์ œ ํ—ˆ๊น…ํŽ˜์ด์Šค Trainer ์‚ฌ์šฉ๋ฒ•์€ ๋ฐฐ์› ์œผ๋‹ˆ, ์‹ค์ œ LLM์„ ํ›ˆ๋ จํ•ด๋ณด์ž. GPT-2๋ฅผ IMDb ์˜ํ™”๋ฆฌ๋ทฐ ๋ฐ์ดํ„ฐ์…‹์—์„œ fine-tuningํ•˜์—ฌ, ์ž‘์„ฑ๋œ ์˜ํ™”๋ฆฌ๋ทฐ๊ฐ€ ์˜ํ™”๋ฅผ ์ข‹๊ฒŒ ํ‰๊ฐ€ํ•˜๋Š”์ง€, ๋‚˜์˜๊ฒŒ ํ‰๊ฐ€ํ•˜๋Š”์ง€ ๋ถ„๋ฅ˜ํ•˜๋Š” ๊ฐ„๋‹จํ•œ ์ž‘์—…์„ ์ˆ˜ํ–‰ํ•˜๋„๋ก ํ•ด๋ณธ๋‹ค.
๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์„ค์น˜

%%capture
!pip install -U datasets transformers accelerate

ํ•„์š”ํ•œ ํŒจํ‚ค์ง€๋“ค์„ ์„ค์น˜ํ•ด์ค€๋‹ค. Colab ๊ธฐ์ค€์œผ๋กœ ์œ„ ํŒจํ‚ค์ง€๋“ค์€ ์ด๋ฏธ ์„ค์น˜๋˜์–ด ์žˆ์ง€๋งŒ, ํ˜„์žฌ(2024๋…„ 5์›” 15์ผ ๊ธฐ์ค€) ๋ฒ„์ „ ๋ฌธ์ œ์ธ์ง€ ์ด ์ž‘์—…์„ ํ•ด์ฃผ์ง€ ์•Š์œผ๋ฉด ํ›ˆ๋ จ์ด ๋˜์ง€ ์•Š๋Š”๋‹ค. ์ด์™ธ์—๋„ ์‚ฌ์šฉํ™˜๊ฒฝ์— ๋”ฐ๋ผ ์ฝ”๋“œ๋ฅผ ์‹คํ–‰ํ•˜๋ฉด์„œ ํŒจํ‚ค์ง€๊ฐ€ ์„ค์น˜๋˜์–ด ์žˆ์ง€ ์•Š๋‹ค๊ณ  ๋‚˜์˜ฌ ๊ฒฝ์šฐ pip์œผ๋กœ ์„ค์น˜ํ•ด์ฃผ๋ฉด ๋œ๋‹ค.

๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

from transformers import AutoTokenizer, AutoModelForSequenceClassification

model_ckpt = "openai-community/gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModelForSequenceClassification.from_pretrained(model_ckpt)

ํ—ˆ๊น…ํŽ˜์ด์Šค ํ—ˆ๋ธŒ์—์„œ GPT-2๋ฅผ ์ฐพ์•„์„œ ์ž„ํฌํŠธํ•ด์ฃผ์—ˆ๋‹ค. ๋งํฌ์—์„œ Use in Transformers ๋ฒ„ํŠผ์„ ๋ˆŒ๋Ÿฌ ๊ฐ„ํŽธํ•˜๊ฒŒ ๋ถˆ๋Ÿฌ์˜ฌ ์ˆ˜ ์žˆ๋‹ค.

๋ฐ์ดํ„ฐ์…‹ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

from datasets import load_dataset
dataset = load_dataset("stanfordnlp/imdb")

ds_train = dataset['train'].shuffle().select(range(10000))
ds_test = dataset['test'].shuffle().select(range(2500))

IMDb ๋ฐ์ดํ„ฐ์…‹์„ ๊ฐ€์ ธ์˜จ๋‹ค. ๋ฐ์ดํ„ฐ์…‹ ๋˜ํ•œ ํ—ˆ๊น…ํŽ˜์ด์Šค ํ—ˆ๋ธŒ์—์„œ ์‰ฝ๊ฒŒ ๋ถˆ๋Ÿฌ์˜ฌ ์ˆ˜ ์žˆ๋„๋ก ์ œ๊ณตํ•˜๊ณ  ์žˆ๋‹ค. (๋งํฌ) ์‹ค์ œ ๋ฐ์ดํ„ฐ์…‹์€ train๊ณผ test set์ด ๊ฐ๊ฐ 25000๊ฐœ์˜ ๋ฐ์ดํ„ฐ๋กœ ์ด๋ฃจ์–ด์ ธ ์žˆ์ง€๋งŒ, ๋น ๋ฅธ ํ•™์Šต์„ ์œ„ํ•ด์„œ ๊ฐ๊ฐ 10000๊ฐœ์™€ 2500๊ฐœ๋งŒ ์‚ฌ์šฉํ•˜๊ฒ ๋‹ค.

Metric ์ •์˜ํ•˜๊ธฐ

import numpy as np
from datasets import load_metric
from transformers import TrainingArguments, Trainer

accuracy_metric = load_metric("accuracy")
f1_metric = load_metric("f1")

def compute_metrics(eval_pred):
    predictions, label_ids = eval_pred.predictions, eval_pred.label_ids
    predictions = predictions.argmax(axis=1)
    accuracy = accuracy_metric.compute(predictions=predictions, references=label_ids)
    f1 = f1_metric.compute(predictions=predictions, references=label_ids, average="weighted")

return {
    "accuracy": accuracy["accuracy"],
    "f1": f1["f1"],
}

Evaluation ์‹œ ์‚ฌ์šฉํ•  metric๋“ค์„ ์ง€์ •ํ•ด์ฃผ๊ธฐ ์œ„ํ•ด์„œ compute_metrics ํ•จ์ˆ˜๋ฅผ ์ •์˜ํ•ด์ค€๋‹ค. ์—ฌ๊ธฐ์—์„œ๋Š” accuracy์™€ F1 score๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค. ํ•จ์ˆ˜๊ฐ€ ํ˜•์‹์— ์ž˜ ๋งž๋Š”์ง€ ์ฃผ์˜ํ•ด์•ผ ํ•œ๋‹ค.

ํŒจ๋”ฉ ํ† ํฐ ์ง€์ •ํ•˜๊ธฐ, Data Collator ์ •์˜ํ•˜๊ธฐ

from transformers import DataCollatorWithPadding

model.config.pad_token_id = tokenizer.eos_token_id
tokenizer.pad_token_id = tokenizer.eos_token_id  

data_collator = DataCollatorWithPadding(tokenizer=tokenizer, max_length=256)

๋ฐฐ์น˜๋ฅผ ๋งŒ๋“ค์–ด์„œ ๊ธธ์ด๊ฐ€ ๋‹ค๋ฅธ ์—ฌ๋Ÿฌ ๊ฐœ์˜ ์‹œํ€€์Šค๋ฅผ ํ•œ๊บผ๋ฒˆ์— ์ฒ˜๋ฆฌํ•˜๋ ค๋ฉด ํŒจ๋”ฉ์ด ์ด๋ฃจ์–ด์ ธ์•ผ ํ•œ๋‹ค. ์ด๋ฅผ ์ˆ˜ํ–‰ํ•˜๊ธฐ ์œ„ํ•œ data_collator๋ฅผ ์ •์˜ํ•œ๋‹ค. ๋˜ํ•œ, model๊ณผ tokenizer์—๊ฒŒ ํŒจ๋”ฉ ํ† ํฐ์ด ๋ฌด์—‡์ธ์ง€๋ฅผ ์•Œ๋ ค์ค˜์•ผ ํ•˜๋Š”๋ฐ, ์ผ๋ฐ˜์ ์œผ๋กœ ์œ„์™€ ๊ฐ™์ด EOS(end of sequence) ํ† ํฐ๊ณผ ๋™์ผํ•˜๊ฒŒ ์ง€์ •ํ•ด์ค€๋‹ค.

Trainer ์ •์˜ํ•˜๊ธฐ
์ด์ œ ์ •์˜ํ•œ ๋ณ€์ˆ˜๋“ค์„ ๋ชจ๋‘ ๋ชจ์•„ Trainer๋ฅผ ์ •์˜ํ•ด์ค„ ์ฐจ๋ก€์ด๋‹ค.

from transformers import Trainer, TrainingArguments

training_arguments = TrainingArguments(
    output_dir='./results',
    evaluation_strategy="epoch",
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    learning_rate=3e-5,
    logging_strategy="epoch",
    load_best_model_at_end=True,
    save_strategy="epoch",
    metric_for_best_model="accuracy",
) 

trainer = Trainer(
    model=model,
    train_dataset=ds_train,
    data_collator=data_collator,
    eval_dataset=ds_test,
    args=training_arguments,
    compute_metrics=compute_metrics,
)

๋ฐฐ์› ๋˜๋Œ€๋กœ TrainingArguments์™€ Trainer๋ฅผ ์ฐจ๋ก€๋Œ€๋กœ ์ •์˜ํ•ด์ฃผ๊ณ , ํ•„์š”ํ•œ ๋งค๊ฐœ๋ณ€์ˆ˜๋“ค์„ ํ•˜๋‚˜์”ฉ ๋„ฃ์–ด์ฃผ์ž. ๊ผญ ์ด ๊ธ€์— ์žˆ๋Š”๋Œ€๋กœ ํ•  ํ•„์š” ์—†์ด, ์ธ์ž๋“ค์„ ํ•˜๋‚˜์”ฉ ๋ฐ”๊ฟ”๋ณด๊ฑฐ๋‚˜ ๋‹ค๋ฅธ ๋งค๊ฐœ๋ณ€์ˆ˜๋“ค์„ ๋„ฃ์–ด๋ณด๋Š” ์‹์œผ๋กœ ์ฝ”๋“œ๋ฅผ ๋ฐ”๊ฟ”๋ณด๋ฉด ์ดํ•ด์— ๋„์›€์ด ๋  ๊ฒƒ์ด๋‹ค.

ํ›ˆ๋ จํ•˜๊ธฐ

trainer.train()

์œ„์™€ ๊ฐ™์ด ํ›ˆ๋ จ์ด ์ž˜ ์ง„ํ–‰๋˜๋Š” ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค.

์ฐธ๊ณ ๋ฌธํ—Œ

โ€Hugging Faceโ€, Wikipedia
Trainer API documentation
# ๋ฐ์ดํ„ฐ ๊ณผํ•™์ž๋“ค์ด ์ˆซ์ž 42๋ฅผ ์ข‹์•„ํ•˜๋Š” ์ด์œ 
Github transformers EvalPrediction ์†Œ์Šค์ฝ”๋“œ
transformers.utils.ModelOutput documentation