Pythonとscikit-learnで機械学習に必要なデータを作成する

eyecatch
  • URLをコピーしました!

※本ページはアフィリエイト広告を利用しています

今回は機械学習ライブラリscikit-learnを使用して機械学習に必要なデータを作成します。

また、本記事は「Kerasでディープラーニング!Pythonで始める機械学習入」シリーズの手順を解説するページです。シリーズの一覧は以下をご覧ください。

\ 機械学習を学びたい人には自宅で学べるUdemyがおすすめ! /

講座単位で購入できます!

目次

データ作成の準備

これまで画像の収集から読み込み、正規化までの手順を解説してきました。
本記事の内容に入る前に以下の記事をあわせてご覧ください。

著:Sebastian Raschka, 著:Vahid Mirjalili, 著:株式会社クイープ, 著:福島 真太朗
¥3,960 (2025/01/07 07:22時点 | Amazon調べ)

学習データとテストデータの作成 train_test_split

スクレイピングで集めた画像データと作成したラベルをもとに、学習データとテストデータを作成します。データの作成にはscikit-learnのtrain_test_splitを使用します。

以下のように第1引数に画像データ、第2引数にラベル、第3引数にtest_sizeを指定します。

x_train, x_test, y_train, y_test = train_test_split(image_list, label_list, test_size=0.2)

test_size

test_sizeにはint型とfloat型の2通りの設定が可能です。
int型の場合はテストデータ数を指定します。float型の場合はテストデータと学習データの比率を指定します。

今回はfloat型で0.3と指定しましたので、全体の100枚の画像データのうち30枚がテストデータ、70枚が学習データとして分類されます。

random_state

random_state=1

「1」を設定した場合は毎回同じデータが選ばれます。

今回のように指定しない場合はランダムなデータが選ばれます。

ラベルをダミーデータに変換する

今回使用するラベルの情報は文字列で「stag_beetle」(クワガタ)と「mantis」(カマキリ)という値が保存されています。
ラベルの情報を機械学習で使用する場合、そのまま文字列データでは使用できないため、これらを数値で表すダミーデータに変換します。

今回ラベルで扱うデータは次の2種類。

          type
0  stag_beetle
1       mantis

それを以下のように変換します。

   type_mantis  type_stag_beetle
0            0                 1
1            1                 0

クワガタの場合は1行目の「0 1」、カマキリの場合は2行目の「1 0」というデータに置き換えていきます。

変換処理の実装

変換にはデータ解析ライブラリpandasを使用します。

columns = ['type'] #列名を指定
df_train = pd.DataFrame(y_train,  columns=columns)
df_test = pd.DataFrame(y_test, columns=columns)

y_train = pd.get_dummies(df_train)
y_test = pd.get_dummies(df_test)

変換前のラベルデータは以下となります。

           type
0        mantis
1        mantis
2   stag_beetle
3   stag_beetle
4        mantis
..          ...
75       mantis
76       mantis
77  stag_beetle
78       mantis
79  stag_beetle

[80 rows x 1 columns]

get_dummiesメソッドを実行してダミーデータに変更すると、先ほど説明したパターンの値に正しく変更されていることが確認できます。

    type_mantis  type_stag_beetle
0             1                 0
1             1                 0
2             0                 1
3             0                 1
4             1                 0
..          ...               ...
75            1                 0
76            1                 0
77            0                 1
78            1                 0
79            0                 1

[80 rows x 2 columns]

画像データを配列に変換する

読み込んだ画像データはlist形式となっています。そのままでは機械学習に使用できないため配列に変換します。
配列への変換はNumPyモジュールのarrayメソッドを使用します。

#リスト型を配列型に変換
x_train = np.array(x_train)
x_test = np.array(x_test)

x_train = x_train.reshape(80, 10800) #60×60×3(RGB)
x_test = x_test.reshape(20, 10800)

画像一枚のデータは60(pixel)×60(pixel)×3(RGB)のサイズとなります。
reshapeメソッドで形状を指定して変換します。

著:松尾 豊
¥891 (2023/09/11 22:12時点 | Amazon調べ)

作成したソースコード

今回作成した全体の処理は以下となります。

from icrawler.builtin import BingImageCrawler
import glob
import cv2
from sklearn.model_selection import train_test_split
from tensorflow import keras
from tensorflow.keras.utils import to_categorical
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

#クワガタの画像を収集
crawler = BingImageCrawler(storage={"root_dir":"images"}) #ダウンロード先のディレクトリを指定
crawler.crawl(keyword="クワガタ", max_num=50) #クロール実行

#カマキリの画像を収集
crawler = BingImageCrawler(storage={"root_dir":"images2"}) #ダウンロード先のディレクトリを指定
crawler.crawl(keyword="カマキリ", max_num=50) #クロール実行

#全ての画像ファイルのパスを取得する
files1 = glob.glob("images/*.jpg")
files2 = glob.glob("images2/*.jpg")

files1[len(files1):len(files2)] = files2

#画像データを格納するりすと
image_list = []

#ファイルパスから画像を読み込み
for imgpath in files1:
  image = cv2.imread(imgpath) #画像を読み込み
  #image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) #RGBモードに変換
  image = cv2.resize(image, (60, 60)) #画像のサイズを変更

  #正規化
  image = image / 255 #[0~1]にスケーリング
  #image = (image - 127.5) / 127.5 #[-1~1]にスケーリング

  #画像をリストに追加
  image_list.append(image)

#ラベルを作成する
label_list = []
label_list2 = []

#ラベル情報を50ずつ作成
for i in range(50):
    label_list.append("stag_beetle")

for i in range(50):
    label_list2.append("mantis")

#ラベルを統合
label_list[len(label_list):len(label_list2)] = label_list2

#学習データとテストデータを作成
x_train, x_test, y_train, y_test = train_test_split(image_list, label_list, test_size=0.2) #20件をテストデータ、それ以外を学習データに分ける


#テーブルデータを作成
columns = ['type'] #列名を指定
df_train = pd.DataFrame(y_train,  columns=columns)
df_test = pd.DataFrame(y_test, columns=columns)
print("--TableData--")
print("train df = \n", df_train)
print("test df = \n", df_test)

# 文字列(カテゴリ変数)をダミー変数に変換
y_train = pd.get_dummies(df_train)
y_test = pd.get_dummies(df_test)
print("--dummies--")
print("train dummies = \n", y_train)
print("test dummies = \n", y_test)

#リスト型を配列型に変換
x_train = np.array(x_train)
x_test = np.array(x_test)

x_train = x_train.reshape(80, 10800) #60×60×3(RGB)
x_test = x_test.reshape(20, 10800)

print("xtrain array =\n", x_train)
print("xtest array =\n", x_test)

実行結果

実行結果は以下となります。

--TableData--
train df =
            type
0   stag_beetle
1        mantis
2   stag_beetle
3        mantis
4   stag_beetle
..          ...
75       mantis
76  stag_beetle
77  stag_beetle
78  stag_beetle
79       mantis

[80 rows x 1 columns]
test df =
            type
0        mantis
1   stag_beetle
2   stag_beetle
3        mantis
4        mantis
5        mantis
6   stag_beetle
7        mantis
8   stag_beetle
9   stag_beetle
10       mantis
11       mantis
12  stag_beetle
13       mantis
14  stag_beetle
15       mantis
16       mantis
17       mantis
18  stag_beetle
19       mantis
--dummies--
train dummies =
     type_mantis  type_stag_beetle
0             0                 1
1             1                 0
2             0                 1
3             1                 0
4             0                 1
..          ...               ...
75            1                 0
76            0                 1
77            0                 1
78            0                 1
79            1                 0

[80 rows x 2 columns]
test dummies =
     type_mantis  type_stag_beetle
0             1                 0
1             0                 1
2             0                 1
3             1                 0
4             1                 0
5             1                 0
6             0                 1
7             1                 0
8             0                 1
9             0                 1
10            1                 0
11            1                 0
12            0                 1
13            1                 0
14            0                 1
15            1                 0
16            1                 0
17            1                 0
18            0                 1
19            1                 0
xtrain array =
 [[0.59215686 0.6627451  0.75294118 ... 0.55686275 0.62352941 0.70588235]
 [0.         0.         0.         ... 0.         0.         0.        ]
 [0.6        0.5372549  0.42745098 ... 0.18823529 0.37647059 0.60784314]
 ...
 [0.43921569 0.55686275 0.57647059 ... 0.57647059 0.7254902  0.73333333]
 [0.51372549 0.59607843 0.65098039 ... 0.2745098  0.36078431 0.45490196]
 [0.40784314 0.49411765 0.51372549 ... 0.30980392 0.38039216 0.40784314]]
xtest array =
 [[0.10980392 0.34901961 0.29411765 ... 0.11372549 0.25882353 0.22745098]
 [0.16078431 0.25882353 0.3372549  ... 0.41176471 0.30980392 1.        ]
 [0.53333333 0.59607843 0.7372549  ... 0.74117647 0.7254902  0.76862745]
 ...
 [0.02352941 0.20784314 0.22352941 ... 0.44313725 0.69019608 0.65490196]
 [0.78823529 0.83529412 0.90588235 ... 0.75686275 0.83529412 0.90196078]
 [0.5372549  0.40392157 0.03137255 ... 0.56078431 0.44313725 0.11764706]]

機械学習に必要なデータとラベルが作成できました。

まとめ

今回は機械学習に必要なデータの前処理について解説しました。これで必要なデータが揃いましたので次回からは実際に学習、推論を実行するためのプログラムを実装していきたいと思います。

機械学習を効率よく学びたい方には、自分のペースで動画で学べるUdemyの以下の講座がおすすめです。数学的な理論からPythonでの実装までを習得できます。(私自身もこの講座を受講しています)

icon icon 【徹底的に解説!】人工知能・機械学習エンジニア養成講座(初級編~統計学から数字認識まで~) icon

また、以下の記事で効率的にPythonのプログラミングスキルを学べるプログラミングスクールの選び方について解説しています。最近ではほとんどのスクールがオンラインで授業を受けられるようになり、仕事をしながらでも自宅で自分のペースで学習できるようになりました。

スキルアップや副業にぜひ活用してみてください。

スクールではなく、自分でPythonを習得したい方には、いつでもどこでも学べる動画学習プラットフォームのUdemyがおすすめです。

講座単位で購入できるため、スクールに比べ非常に安価(セール時1200円程度~)に学ぶことができます。私も受講しているおすすめの講座を以下の記事でまとめていますので、ぜひ参考にしてみてください。

それでは、また次の記事でお会いしましょう。

著:須藤秋良, 監修:株式会社フレアリンク
¥3,300 (2023/09/18 22:18時点 | Amazon調べ)

参考

よかったらシェアしてね!
  • URLをコピーしました!

コメント

コメントする

CAPTCHA


目次