Notice
Recent Posts
Recent Comments
Link
«   2024/11   »
1 2
3 4 5 6 7 8 9
10 11 12 13 14 15 16
17 18 19 20 21 22 23
24 25 26 27 28 29 30
Archives
Today
Total
관리 메뉴

잡동사니 블로그

[Python] Torchvision의 ImageFolder 본문

Python

[Python] Torchvision의 ImageFolder

코딩부대찌개 2024. 9. 10. 23:21

ImageFolder는 PyTorch의 torchvision.datasets 모듈에 속한 클래스로, 이미지 데이터를 폴더 구조에 기반해 자동으로 라벨을 할당하고 데이터셋을 생성해주는 클래스

 

폴더 이름순으로 자동으로 Labeling Abnormal = 0 , Normal = 1

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt

#Image PreProcessing
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.Grayscale(num_output_channels=1), 
    transforms.ToTensor(),
])

fail_path = '/'
dataset = datasets.ImageFolder(fail_path, transform=transform)

# Train = 0.7, Valid = 0.15, Test = 0.15
train_size = int(0.7 * len(dataset))
valid_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - valid_size

# Split Data
train_dataset, valid_dataset, test_dataset = random_split(dataset, [train_size, valid_size, test_size])

indices = torch.randperm(len(train_dataset))[:30]
plt.figure(figsize=(15, 10))
for i, idx in enumerate(indices):
  image, label = train_dataset[idx]
  plt.subplot(6, 5, i + 1)
  plt.imshow(image.permute(1, 2, 0))
  plt.title(f"Label: {label}")
  plt.axis('off')
plt.tight_layout()
plt.show()

 

이후에 Dataloader를 사용하여 학습 진행