TIL PyTorch Dataset and DataLoader

PyTorch에서 데이터셋을 효율적으로 관리하고 모델 학습에 활용하기 위한 핵심 구성 요소인 Dataset과 DataLoader의 개념 및 사용법.

요약 표

구성요소역할주요 메서드/옵션
Dataset데이터셋 정의 및 개별 데이터 반환__len__, __getitem__
DataLoader데이터를 배치 단위로 로드batch_size, shuffle, num_workers, drop_last, pin_memory, collate_fn
torchvision.datasets미리 정의된 데이터셋 제공MNIST, CIFAR10, ImageFolder

1. Dataset 개념 및 활용

1.1 Dataset 정의

  • torch.utils.data.Dataset 추상 클래스 상속

  • len: 데이터셋 전체 크기 반환

  • getitem: 인덱스에 해당하는 데이터 샘플 반환

1.2 torchvision.datasets 활용

  • 이미지, 텍스트 등 미리 정의된 데이터셋 제공

  • MNIST, CIFAR10, ImageFolder 등

  • PyTorch Dataset 클래스 상속 및 직접 사용 가능

2. DataLoader 기능 및 옵션

2.1 DataLoader 역할

  • Dataset 객체로부터 데이터 효율적 로드

  • 학습 시 필요한 다양한 옵션 제공

  • 필수 인자: dataset 객체

2.2 주요 옵션

  • batch_size: 미니 배치 크기 설정

  • shuffle: 데이터 순서 무작위 섞기 (학습 시 True)

  • num_workers: 병렬 데이터 로딩을 위한 서브 프로세스 수

  • drop_last: 마지막 미니 배치 버릴지 여부

  • pin_memory: CUDA 고정 메모리 사용 여부 (GPU 학습 시 속도 향상)

  • collate_fn: 사용자 정의 미니 배치 구성 함수

3. DataLoader 활용 및 고려사항

3.1 데이터 로딩 예제

  • torchvision.datasets로 데이터셋 로드

  • DataLoader 객체 생성 및 옵션 설정

  • for 루프를 통한 배치 단위 데이터 접근

3.2 핵심 사용 사례

  • shuffle: 모델의 데이터 순서 의존성 방지

  • num_workers: CPU 코어 활용 병렬 처리로 로딩 속도 향상

  • drop_last: 배치 정규화 등 특정 레이어 사용 시 권장

  • pin_memory: GPU 학습 시 데이터 전송 속도 최적화

4. DataLoader 사용 예제

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 1. 데이터셋 정의
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)

# 2. DataLoader 정의
dataloader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=True,
    num_workers=4,
    drop_last=True,
    pin_memory=True
)

# 3. DataLoader를 이용한 데이터 반복
for batch_idx, (data, target) in enumerate(dataloader):
    print(f"Batch {batch_idx} | Data Shape: {data.shape} | Target Shape: {target.shape}")
    break