잡동사니 블로그
SMOTE를 활용한 Over sampling(오버샘플링) 본문
내가 처음으로 기업의 데이터로 프로젝트를 하며 정상과 불량의 데이터 수 차이가 나는 불균형 데이터(imbalanced data)였기에 이 때 쓰는 방법인 오버샘플링(Over sampling)을 쓰며...
주로 분류(classification)에서 다수의 데이터와 소수의 데이터 차이가 많이나면 모델의 정확도가 떨어지기 때문에 이를 해결하기 위해 언더샘플링(Under sampling)과 오버샘플링(Over samplinig)이 있는데, 프로젝트를 진행하며 여러 논문을 본 결과 언더샘플링의 경우 데이터의 손실이 있어 정확성이 많이 떨어져 주로 오버샘플링을 이용한다고 한다.
즉 , 신용카드 사기와 같은 불균형 데이터에서 모델의 정확도를 올리기 위하여 오버샘플링을 한다.
데이터셋은 프로젝트에서 활용하였던 데이터셋으로 한번 해보았다.
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score
from sklearn.metrics import confusion_matrix, f1_score, roc_auc_score
from sklearn.model_selection import train_test_split
import plotly.express as px
from catboost import CatBoostClassifier
from sklearn.preprocessing import RobustScaler
from imblearn.under_sampling import TomekLinks
from imblearn.over_sampling import *
from imblearn.combine import *
from imblearn.over_sampling import SVMSMOTE
from imblearn.over_sampling._smote.filter import BorderlineSMOTE
# 평가 함수 정의
def get_clf_eval(y_test, y_pred):
confusion = confusion_matrix(y_test, y_pred)
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
F1 = f1_score(y_test, y_pred)
AUC = roc_auc_score(y_test, y_pred)
print('/오차행렬:\n', confusion)
print('\n정확도: {:.4f}'.format(accuracy))
print('정밀도: {:.4f}'.format(precision))
print('재현율: {:.4f}'.format(recall))
print('F1: {:.4f}'.format(F1))
print('AUC: {:.4f}'.format(AUC))
X_train, X_test, y_train, y_test = train_test_split(x_data, y_data, test_size=0.2, random_state=128, stratify = y_data)
print(X_train.shape, X_test.shape)
print(y_train.shape, y_test.shape)
#(7678, 357) (1920, 357)
#(7678, 1) (1920, 1)
print(y_data.value_counts())
#0 9088
#1 510
#dtype: int64
불량이 9088개 양품이 510개 였는데 분포를 살펴보면
df_smote = pd.concat([X_train,y_train],axis=1)
fig = px.scatter(df_smote , y='x10', x='x2', color='불량명')
fig.update_layout(
title=dict(
text='<b>After Oversampling</b>',
x=0.5,
font=dict(
family='NanumBarunGothic',
size=25,
color="#000000"
)
),
xaxis_title=dict(
text="X"
),
yaxis_title="Y",
font=dict(
family='NanumBarunGothic',
size=12,
color="#000000"
),
)
임의의 feature를 뽑아 산점도를 그려보니 정상 데이터에 비해 비정상이 월등히 많은걸 볼 수 있다.
파편화된 데이터이므로 이상치의 영향을 적게 받는 RobustScaler 사용
robust=RobustScaler().fit(X_train)
X_train = robust.transform(X_train)
X_test = robust.transform(X_test)
오버샘플링을 하지 않은채 Catboost모델을 돌렸더니
cat = CatBoostClassifier(
loss_function="Logloss", #손실함수
eval_metric="F1", #f1으로 성능측정
learning_rate = 0.02, #학습률
depth=8, #깊이
iterations = 150, #반복횟수+
task_type='GPU' # GPU 사용
)
evals = [(X_test, y_test)]
cat.fit(X_train, y_train, early_stopping_rounds=150, eval_set=evals, verbose=True)
pred = cat.predict(X_test)
pred_proba = cat.predict_proba(X_test)
pred_th = [ 1 if x > 0.5 else 0 for x in pred_proba[:,1]]
get_clf_eval(y_test,pred_th)
print('train_score :',cat.score(X_train,y_train),'test_score :',cat.score(X_test,y_test))
#오차행렬:
# [[1817 1]
# [ 80 22]]
#정확도: 0.9578
#정밀도: 0.9565
#재현율: 0.2157
#F1: 0.3520
#AUC: 0.6076
#train_score : 0.9622297473300339 test_score : 0.9578125
역시나 불량을 못맞춘다.
프로젝트를 하며 알게되었던 오버샘플링 방법들로 위와 같이 하이퍼파라미터를 그대로 냅둔채 돌려보았다.
산점도가 비슷해보이지만 막상 보면 X축 범위도 다르고 모양도 다름.
SMOTE
소수데이터를 중심으로 최근접 이웃(k-Nearest Neighbor)을 이용해 새로운 데이터를 생성하는 방법
smote = SMOTE(random_state=1, n_jobs=-1)
X_train, y_train = smote.fit_resample(X_train,y_train)
#오차행렬:
# [[1702 116]
# [ 21 81]]
#정확도: 0.9286
#정밀도: 0.4112
#재현율: 0.7941
#F1: 0.5418
#AUC: 0.8652
#train_score : 0.9629298486932599 test_score : 0.9286458333333333
SOMTEENN
오버 샘플링 방법인 SMOTE를 우선 적용하여, 소수 데이터 수를 다수 데이터 수로 데이터 균형을 이루고, 언더 샘플링인 ENN을 적용하는 방법
enn = SMOTEENN(random_state=1)
X_train, y_train = enn.fit_resample(X_train,y_train)
#/오차행렬:
# [[1641 177]
# [ 11 91]]
#정확도: 0.9021
#정밀도: 0.3396
#재현율: 0.8922
#F1: 0.4919
#AUC: 0.8974
#train_score : 0.9732961631575087 test_score : 0.9020833333333333
SOMTETomek
오버 샘플링 방법인 SMOTE를 우선 적용하여 소수 범주의 수를 다수 범주의 수로 데이터 균형을 이루고, 언더 샘플링인 Tomek link를 적용 하는 방법
smoteto = SMOTETomek(tomek=TomekLinks(sampling_strategy='auto'),random_state=1)
X_train, y_train = smoteto.fit_resample(X_train,y_train)
#/오차행렬:
# [[1702 116]
# [ 20 82]]
#정확도: 0.9292
#정밀도: 0.4141
#재현율: 0.8039
#F1: 0.5467
#AUC: 0.8701
#train_score : 0.9656026417171161 test_score : 0.9291666666666667
BorderlineSMOTE
SMOTE를 보완한 방법으로 소수 범주와 다수 범주를 구분하는 경계에 서 소수 범주의 데이터를 대상으로 k개의 가장 가까운 이웃 사이에 데이터를 생성하는 방법
bord = BorderlineSMOTE(random_state=1, n_jobs=-1)
X_train, y_train = bord.fit_resample(X_train,y_train)
#/오차행렬:
# [[1702 116]
# [ 17 85]]
#정확도: 0.9307
#정밀도: 0.4229
#재현율: 0.8333
#F1: 0.5611
#AUC: 0.8848
#train_score : 0.969050894085282 test_score : 0.9307291666666667
ADASYN
SMOTE를 발전시킨 방법으로 소수 범주의 데이터 경계에 있는 표본을 생성하고 소수 범주의 경계에 있는 데이터를 대상으로 가장 가까운 밀도 분포를 이용해 소수 범주의 데이터 수를 조금 더 생성해야 하는지를 조절하는 방법
ad = ADASYN(random_state=1)
X_train, y_train = ad.fit_resample(X_train,y_train)
#/오차행렬:
# [[1686 132]
# [ 20 82]]
#정확도: 0.9208
#정밀도: 0.3832
#재현율: 0.8039
#F1: 0.5190
#AUC: 0.8657
#train_score : 0.9636501065072494 test_score : 0.9208333333333333
SVMSMOTE
회귀나 분류에서 쓰이는 SVM(Support Vector Machine)을 이용하여 SMOTE에서 변형된 방법
svm = SVMSMOTE(random_state=1, n_jobs=-1)
X_train, y_train = svm.fit_resample(X_train,y_train)
#/오차행렬:
# [[1707 111]
# [ 19 83]]
#정확도: 0.9323
#정밀도: 0.4278
#재현율: 0.8137
#F1: 0.5608
#AUC: 0.8763
#train_score : 0.9716643741403026 test_score : 0.9322916666666666
결론 :
이번 데이터셋으로는 BorderlineSMOTE가 결과가 F1 score가 젤 잘나왔었지만 프로젝트 발표 하면서 데이터 행(rows)과 열(columns)이 바뀌게 되면서 발표를 할 땐 다른 오버샘플링의 파라미터를 건드려봐도 SMOTETomek의 결과가 젤 잘나와서 이걸 활용하여 발표하게 되었다.
내 데이터셋으론 SMOTEENN쪽은 거의 과적합이었다.
Scaler하고 나서 오버샘플링과 오버샘플링을 하고나서 Scaler의 결과가 다르니 유의할것.
틀린점이나 수정할점 있으면 댓글로 남겨주세요!
참조 :
최형규. "불균형 데이터 분류를 위한 오버 샘플링 및 언더 샘플링 조합 방법." 국내석사학위논문 숭실대학교 정보과학대학원, 2020. 서울
Catboost 하이퍼파라미터 참조 : https://catboost.ai/en/docs/concepts/loss-functions-classification
'Python' 카테고리의 다른 글
[Python] pytorch와 sklearn의 train_test_split 활용하여 데이터 셋 나누기와 간단한 CNN (0) | 2023.09.07 |
---|---|
[Python] Selenium과 bs4를 이용한 크롤링 + Pyautogui (0) | 2023.08.28 |
[Python] 시각화에 주로 쓰이는 라이브러리 3가지 (0) | 2023.08.26 |
T-SNE 차원 축소 시각화 (0) | 2022.09.08 |
[Kaggle] 신용카드 사기 분류(Credit Card Fraud Detection) (1) | 2022.09.02 |