인공지능

생선 분류하기

kchabin 2024. 10. 16. 11:40

Fish-market dataset

생선 중 도미 분류하기

  1. 30cm 이상 크다면 도미라고 판단하는 프로그램 작성, 30cm 이하 도미가 있다면 잘못 판단
  2. 머신러닝을 이용하면 가지고 있는 도미의 정보를 이용해서 도미를 판단. 다른 종류의 생선들과 섞여 있어도 판단 가능함.

도미 데이터 준비하기

  • 도미(bream)와 빙어(smelt) 구분하기 - 길이, 무게 feature 선정
  • class : 머신러닝 데이터셋의 종류
  • classification : 머신러닝에서 분류
  • binary classification : 이진 분류
## 도미 30, 빙어 14
bream_length = [25.4, 26.3, 26.5, 29.0, 29.0, 29.7, 29.7, 30.0, 30.0, 30.7, 31.0, 31.0,
                31.5, 32.0, 32.0, 32.0, 33.0, 33.0, 33.5, 33.5, 34.0, 34.0, 34.5, 35.0,
                35.0, 35.0, 35.0, 36.0, 36.0, 37.0, 38.5, 38.5, 39.5, 41.0, 41.0]
bream_weight = [242.0, 290.0, 340.0, 363.0, 430.0, 450.0, 500.0, 390.0, 450.0, 500.0, 475.0, 500.0,
                500.0, 340.0, 600.0, 600.0, 700.0, 700.0, 610.0, 650.0, 575.0, 685.0, 620.0, 680.0,
                700.0, 725.0, 720.0, 714.0, 850.0, 1000.0, 920.0, 955.0, 925.0, 975.0, 950.0]
smelt_length = [9.8, 10.5, 10.6, 11.0, 11.2, 11.3, 11.8, 11.8, 12.0, 12.2,
                12.4, 13.0, 14.3, 15.0]
smelt_weight = [6.7, 7.5, 7.0, 9.7, 9.8, 8.7, 10.0, 9.9, 9.8,
                12.2, 13.4, 12.2, 19.7, 19.9]

 

산점도 그래프(scatter plot)

  • x축, y축으로 점(관계표현)으로 표현하는 그래프Matplotlib
  • 과학계산용 그래프를 그려주는 파이썬 라이브러리
  • scatter() 함수 제공
import matplotlib.pyplot as plt

#도미 산점도 그래프 그리기
plt.scatter(bream_length, bream_weight)


plt.xlabel('length')
plt.ylabel('weight')
plt.show()

도미(좌), 빙어(우) 산점도 그래프

도미와 빙어 산점도 그래프 같이 그리기

import matplotlib.pyplot as plt

plt.scatter(bream_length, bream_weight)
plt.scatter(smelt_length, smelt_weight)


plt.xlabel('length')
plt.ylabel('weight')
plt.show()

오렌지 빙어, 블루 도미

k-최근접 이웃(k-nearest neighbors)

  • 대표적인 분류 알고리즘
  • KNeighborsClassifier
  • 비슷한 특징을 갖는 데이터는 비슷한 범주에 속함
    • 모델 생성: 훈련 데이터 셋 저장
    • 분류기준 : 근접한 요소와 거리 측정(유클리드 거리)을 통해 최단거리가 k개에 해당하는 요소로 분류
  • k : 항상 분류가 가능하도록 홀수로 설정(동률방지), 총 데이터 수의 제곱근 값을 사용
  • 장점 : 구현하기 쉬운 단순함, 훈련 없이 빠른 사용
  • 단점 : 특징과 클래스 간 관계를 이해하는데 제한적, 변수와 클래스 간의 관계에 의존적
  • 사용분야 예시 : 이미지 처리, 글자/얼굴 인식, 추천 알고리즘, 의료 분야 등
# 데이터셋 통합(도미 + 빙어)
length = bream_length + smelt_length
weight = bream_weight + smelt_weight

print(length)
print(weight)

 

도미와 빙어의 길이

[25.4, 26.3, 26.5, 29.0, 29.0, 29.7, 29.7, 30.0, 30.0, 30.7, 31.0, 31.0, 31.5, 32.0, 32.0, 32.0, 33.0, 33.0, 33.5, 33.5, 34.0, 34.0, 34.5, 35.0, 35.0, 35.0, 35.0, 36.0, 36.0, 37.0, 38.5, 38.5, 39.5, 41.0, 41.0, 9.8, 10.5, 10.6, 11.0, 11.2, 11.3, 11.8, 11.8, 12.0, 12.2, 12.4, 13.0, 14.3, 15.0]

 

도미와 빙어의 무게

[242.0, 290.0, 340.0, 363.0, 430.0, 450.0, 500.0, 390.0, 450.0, 500.0, 475.0, 500.0, 500.0, 340.0, 600.0, 600.0, 700.0, 700.0, 610.0, 650.0, 575.0, 685.0, 620.0, 680.0, 700.0, 725.0, 720.0, 714.0, 850.0, 1000.0, 920.0, 955.0, 925.0, 975.0, 950.0, 6.7, 7.5, 7.0, 9.7, 9.8, 8.7, 10.0, 9.9, 9.8, 12.2, 13.4, 12.2, 19.7, 19.9]

 

#좌표 (x, y) 추출
fish_data = [[l, w] for l, w in zip(length, weight)]
print(fish_data)

x는 길이, y는 무게로 좌표를 추출한다.

#각 요소에 대해 도미(1)와 빙어(0)로 라벨링
fish_target = [1]*35 + [0]*14

print(fish_target)

 

도미는 1, 빙어는 0으로 라벨링한 fish_target 리스트를 만든다.

[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
 

k-최근접 이웃 알고리즘 선언 및 기본사용

Scikit-learn

  • 데이터 분석 예측
  • 다양한 컨텍스트에서 재사용 가능
  • NumPy, SciPy, matplotlib 기반 구축
## k-최근접이웃 알고리즘 사용
from sklearn.neighbors import KNeighborsClassifier

kn = KNeighborsClassifier()
#훈련 메소드 : kn 모델에 요소(fish_data), 레이블(fish_target)을 지정하여 훈련
kn.fit(fish_data, fish_target)
# 평가 메소드 : kn 모델에 요소(fish_data), 레이블(fish_target)을 지정하여 평가함
kn.score(fish_data, fish_target)

 

출력되는 1.0은 accuracy를 의미한다.

fit() : 훈련 메소드

score() : 평가 메소드

 

평가 TEST

  • predict() : 새로운 정답 예측
## 산점도 그래프 출력 - 도미와 빙어 구별
plt.scatter(bream_length, bream_weight)
plt.scatter(smelt_length, smelt_weight)

## 특정요소 (30, 600) 출력
plt.scatter(30, 600, marker='^')
plt.xlabel('length')
plt.ylabel('weight')
plt.show()

## 값이 (30, 600)일 때 kn 모델의 예측
kn.predict([[30, 600]])

 

길이가 30이고 무게가 600일 때 kn 모델의 예측을 predict() 메소드를 사용해서 구할 수 있다.

array([1]) <- 도미(1)로 예측함.

그래프 상에서도 30, 600을 표시한 세모 마커가 도미 쪽 데이터에 가까이 위치하고 있는 것을 확인할 수 있다.

KNeighborsClassifier 속성

  • _fit_X : 모델의 첫번째 parameter, 학습된 데이터 값(입력된 물고기의 길이와 무게 데이터)
  • _y : 모델의 parameter 2, 학습된 레이블 값(도미(1), 빙어(0)으로 구분하는 이진 분류 데이터)
  • n_neighbors : k로 k개 만큼 가까운 데이터 참고, 기본값=5
## 모델에 학습된 데이터와 레이블 값
print(kn._fit_X)
print(kn._y)

## 모든 데이터셋 개수인 49개만큼 가까운 데이터를 참고하는 모델 생성
kn49 = KNeighborsClassifier(n_neighbors=49)

## 모델학습 및 평가(정확도 출력)
kn49.fit(fish_data, fish_target)
kn49.score(fish_data, fish_target)

## 도미 35개에 대해서만 정확히 맞추기 때문에 35/49와 같은 정확도를 보임
print(35/49)

 

새로운 데이터에 대해 예측할 때 전체 49개의 데이터를 모두 참고하여 그 중 다수의 레이블을 기준으로 예측 결과를 결정하게 된다. 

도미가 35개, 빙어가 15개인 상태 -> 49개의 이웃 중 도미가 더 많기 떄문에 도미로 예측될 가능성이 크다. => 데이터의 편향

도미 35개에 대해서만 정확하게 예측하여 35/49의 정확도를 보이게 된다.

[[  25.4  242. ]
 [  26.3  290. ]
 [  26.5  340. ]
 [  29.   363. ]
 [  29.   430. ]
 [  29.7  450. ]
 [  29.7  500. ]
 [  30.   390. ]
 [  30.   450. ]
 [  30.7  500. ]
 [  31.   475. ]
 [  31.   500. ]
 [  31.5  500. ]
 [  32.   340. ]
 [  32.   600. ]
 [  32.   600. ]
 [  33.   700. ]
 [  33.   700. ]
 [  33.5  610. ]
 [  33.5  650. ]
 [  34.   575. ]
 [  34.   685. ]
 [  34.5  620. ]
 [  35.   680. ]
 [  35.   700. ]
 [  35.   725. ]
 [  35.   720. ]
 [  36.   714. ]
 [  36.   850. ]
 [  37.  1000. ]
 [  38.5  920. ]
 [  38.5  955. ]
 [  39.5  925. ]
 [  41.   975. ]
 [  41.   950. ]
 [   9.8    6.7]
 [  10.5    7.5]
 [  10.6    7. ]
 [  11.     9.7]
 [  11.2    9.8]
 [  11.3    8.7]
 [  11.8   10. ]
 [  11.8    9.9]
 [  12.     9.8]
 [  12.2   12.2]
 [  12.4   13.4]
 [  13.    12.2]
 [  14.3   19.7]
 [  15.    19.9]]
[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0
 0 0 0 0 0 0 0 0 0 0 0 0]
0.7142857142857143

 

(15, 80) 예측

길이가 15, 무게가 80인 데이터를 예측해보자

plt.scatter(bream_length, bream_weight)
plt.scatter(smelt_length, smelt_weight)
plt.scatter(15, 80, marker='^')
plt.xlabel('length')
plt.ylabel('weight')
plt.show()

kn.predict([[15, 80]])

빙어(0)로 예측