람쥐썬더

[PYTHON] PYTORCH 이미지 DataLoader 구축 플로우 정리 본문

파이썬

[PYTHON] PYTORCH 이미지 DataLoader 구축 플로우 정리

람쥐썬더123 2023. 11. 20. 19:51

프로덕트에 신경쓴다고 반년가까이 놓아버려 레퍼런스 없이 하지 못해 복습 겸 정리

 

기준은 이미지 다중 분류모델이다.

 

데이터는 이미지와 이미지 정보를 담은 DataFrame 이며 데이터는 StratifiedKFold 를 이용해 Fold 되어있는 상태

 

Pandas DataFrame 예시)

image_path label fold
path1/path2/image1.jpg 1 1

 

 

Custom DataLoader

from torch.utils.data import Dataset, DataLoader
import cv2
import pandas as pd
import numpy as np

class CustomDataGenerator(Dataset):
    def __init__(self, data, mode, batch_size, image_size=224):
    	# Set Mode ref pandas data fold
        
        if mode == 'train': 
            self.data = data[(data['fold'] != 1) & (data['fold'] != 2)]
        elif mode == 'valid':
            self.data = data[data['fold'] == 1]
        elif mode == 'test':
            self.data = data[data['fold'] == 2]
        '''
        fold 의 1 은 validation 용으로
        fold 의 2 는 test 용으로 찢어두고
        필요한 경우train일 경우와 validation , test의 augmentation을 다르게 적용 한다
        '''
        
        self.batch_size = batch_size # Not Setting Default 
        self.image_size = image_size # Default 224 ( 이미지넷 기준 )
        self.on_epochs_end()
        

    def __getitem__(self, idx):
        # data = self.data[idx * self.batch_size : (idx + 1 ) * self.batch_size]

        start_idx = idx * self.batch_size
        end_idx = (idx + 1) * self.batch_size
        data = self.data[start_idx : end_idx]
        # 배치 사이즈 별 시작 , 끝 인덱스 정의

        X, y = self.get_data(data) # => split된 dataFrame 에서 배치 데이터 가져옴
        return torch.Tensor(X), torch.Tensor(y) # => 밖으로 나가는부분
        
    def get_data(self, data):
    	# input되는 data는 이미 배치사이즈에 맞게 split 되어있음
    	X = np.ndarray((len(data), 3, self.image_size, self.image_size))
    	# 배치사이즈 16기준 (16,3,224,224) 의 빈 array 정의

    	labels = np.array(data['label'].values)
        # dataframe label 부분 array로 잡아주기
        
        for idx , image_path in enumerate(data.image_path):
            temp_image = cv2.imread(image_path) # 이미지 array로
            temp_image = cv2.cvtColor(temp_image, cv2.COLOR_BGR2RGB) # 상황에맞게 
            temp_image = cv2.resize(temp_image, (self.image_size, self.image_size))
            # augmentation이 없을 경우를 가정한 resize

            # if temp_image.shape[-1] != 3:    
            #    temp_image = cv2.resize(temp_image, (self.image_size, self.image_size))[:,:,:3]
            # rgb 이미지가 아닐경우 depth 3 이후 버림
            X[idx][:][:][:] = temp_image # 정리 끝난 이미지 array 에 쌓아둠
   
 		return X, labels.astype('int64')
        # array 다 채운 후 label과 같이 반환, 이 경우 분류문제라 int64 형으로 변경

    def __len__(self):
        return math.ceil(len(self.data) / self.batch_size)
        # batch에 맞춰 split 될 수 있도록 __len__ 함수 정의
        
    def on_epochs_end(self):
        self.data = self.data.sample(frac=1).reset_index(drop=True)
        # 한 epoch가 끝나면 data 리셋

 

어디까지나 이미지 기준으로 cv2 가 아닌 matplotlib으로 imread가 가능하지만

cv2 가 좀 더 빠르고 안정적이었던걸로 기억

augmentation이 적용된 부분은 차후 추가하기로 하고

 

헤당 data를 각각 필요한 단에 사용하기 위해서 로더에 집어넣어준다

batch_size = 16
train_generator = CustomDataGenerator(data= data, mode='train', batch_size=batch_size)
valid_generator = CustomDataGenerator(data= data, mode='valid', batch_size=batch_size)
test_generator = CustomDataGenerator(data= data, mode='test', batch_size=batch_size)

# train, validation, test로 각각 모드 지정해 data leakage 방지


train_loader = DataLoader(train_gen, batch_size=None, shuffle=True)
valid_loader = DataLoader(valid_gen, batch_size=None, shuffle=True)
test_loader = DataLoader(test_gen, batch_size=None, shuffle=False)

# batch_size는 지정해주었으니 False로 잡고 학습, 검증은 shuffle 진행