[혼자 공부하는 머신러닝+딥러닝] 5강. 정교한 결과 도출을 위한 데이터 전처리 알아보기 - 표준화 정리
* 해당 수업을 이해하는데 기반이 되는 표준 정규분포와 표준화에 대한 정리는 다른 포스팅에서 다루기로 함
- 사이킷런으로 데이터 나누기
- train_test_split을 사용하면 한쪽에 편중되어 있는 데이터들이라 하더라도 골고루 섞이게 만든 훈련 세트 테스트 세트들을 쉽게 만들 수 있다.
- train_test_split(fish_data, fish_target, stratify=fish_target, random_state=42)
- stratify는 데이터가 편중되어 있을 때, 골고루 섞이게 만들기 위한 기준을 설정하는 것이다
- random_state는 랜덤 시드와 같다
- train_test_split를 실행하면 train_input, test_input, train_target, test_target 이렇게 4개가 리턴된다
- 표준화를 하기 전 데이터를 실행하여 특정한 값을 입력하고 이에 영향을 준 이웃 데이터를 확인하면 다음과 같이 출력된다.
plt.scatter(train_input[:,0], train_input[:,1])
plt.scatter(25,150,marker='^')
plt.scatter(fish_data[indexes,0], fish_data[indexes,1],marker='D')
plt.xlabel('length')
plt.ylabel('weight')
plt.show()
- 입력 값 : kn.kneighbors([[25,150]])
- 분산표를 보면 이웃 데이터의 산출이 잘못된 것으로 보인다.
- 문제는 1000:40 이라는 스케일 때문에 정확한 판단을 하기 어렵다
- 스케일을 1000:1000으로 맞추어 확인해 보자
plt.scatter(train_input[:,0], train_input[:,1])
plt.scatter(25,150,marker='^')
plt.scatter(fish_data[indexes,0], fish_data[indexes,1],marker='D')
plt.xlim((0,1000))
plt.xlabel('length')
plt.ylabel('weight')
plt.show()
- 입력 값에 큰 영향을 주는 것은 길이가 아니라 무게인 것을 확인할 수 있다.
- 현재 그래프로는 이웃된 데이터의 거리를 확인하기 어렵다. 표준화를 거쳐 보기 쉽게 만들자
- 표준화 구하기 : (특성 - 평균)/표준편차
* 중요한 사항으로 표준점수는 훈련세트의 데이터를 테스트 세트에 적용해야 한다 ( 평균, 표준편차 )
mean = np.mean(train_input, axis=0)
std = np.std(train_input, axis=0)
train_scaled = (train_input - mean) / std
new = ([25,150]-mean)/std
plt.scatter(train_scaled[:,0], train_scaled[:,1])
plt.scatter(new[0],new[1],marker='^')
plt.xlabel('length')
plt.ylabel('weight')
plt.show()
- 표준화를 통해 1.5:1.5 스케일의 분산표가 만들어 졌다
- 표준화된 훈련 데이터를 타겟 데이터를 이용해 fit으로 훈련을 다시 하고 테스트 데이터도 표준화를 한다
- 이후 입력 데이터(new)를 표준화 하고 이웃에 대한 데이터를 출력해 본다
kn.fit(train_scaled, train_target)
test_scaled = (test_input - mean) / std
kn.score(test_scaled, test_target)
distances, indexes = kn.kneighbors([new])
plt.scatter(train_scaled[:,0], train_scaled[:,1])
plt.scatter(new[0],new[1],marker='^')
plt.scatter(train_scaled[indexes,0],train_scaled[indexes,1],marker='D')
plt.xlabel('length')
plt.ylabel('weight')
plt.show()
- reference :
https://speedanddirection.tistory.com/71
댓글