分析で利用するテーブルデータが少ないことがあります。
もう少し増やせないだろうか、と夢見ることも少なくないでしょう。
ここ最近、色々な生成AIが登場してきました。
そこで使われている技術の1つにGAN(敵対的生成ネットワーク)というものがあります。あらかじめ準備されたデータをもとに、擬似的なデータを生成することができます。
テーブルデータを生成するGANもあります。CTGAN(Conditional General Adversarial Networks)です。
前回、Pythonを使いテーブルデータ生成AI CTGANで、簡単な例で使い方を説明しました。
CTGAN(Conditional General Adversarial Networks)で、手軽にテーブルデータを生成することができますが、無邪気に生成すると、ちょっと変なデータが生成されることがあります。
例えば、データセットの主キーです。
主キーは、データベースのテーブルの行を一意に識別するために使用されるキーで、重複する値を持たず、常に一意である必要があります。
無邪気に生成すると、主キーが一意でないデータセットが生成される危険性があります。
今回は、主キーを指定しテーブルデータ生成AI CTGANで、テーブルデータを生成する方法を説明します。
PythonのSDVパッケージを使いますので、インストールされていないかたは、前回の記事を参考にインストールしてください。
Contents
利用するデータセット
今回利用する データセットは、大学の学生の就職状況に関するデータ(Campus Recruitment)です。
以下からもダウンロードできます。
Placement_Data_Full_Class.csv
https://www.salesanalytics.co.jp/su1r
このデータセットは、学生の学業成績、科目、仕事の経験、専門分野など、学生の就職に影響を与える要因や、学生の就職を予測するモデルを構築するために使用されたりします。
以下の、変数sl_noを主キーとするデータセットです。
- sl_no: 学生番号 ※主キー
- gender: 学生の性別
- ssc_p: 義務教育10年目の試験の得点率
- ssc_b: 義務教育10年目の試験を実施する委員会
- hsc_p:義務教育12年目の試験の得点率
- hsc_b: 義務教育12年目の試験を実施する委員会
- hsc_s:義務教育12年目の試験の専攻
- degree_p: 工学学位の試験で得た得点の割合
- degree_t: 工学学位の専攻
- workex: 学生の職務経験の有無
- etest_p: 入学試験における学生の得点率
- specialisation: 学生の専門分野
- mba_p: MBA試験における学生の得点率
- status: 学生が就職したかどうか(Placed:内定した、Not Placed:内定しなかった)
- salary: 学生に提示された給与
このデータセットには、次の2つの制約もしくはルールがあります。
- sl_no(学生番号)は、レコードごとに異なり重複しない
- salary(提示された給与)は、statusがPlaced(内定)のときのみデータがある
今回は、sl_no(学生番号)が重複しないように、テーブルデータ生成することを考えます。もう一方は、次回扱います。
必要なモジュールの読み込み
必要なモジュールを読み込みます。
以下、コードです。
# 基本モジュール import pandas as pd import numpy as np # SVD from sdv.single_table import CTGANSynthesizer from sdv.metadata import SingleTableMetadata import warnings warnings.simplefilter('ignore')
データセット読み込み
データセットを読み込みます。
以下、コードです。
dataset = 'Placement_Data_Full_Class.csv' real_data = pd.read_csv(dataset) print(real_data)
以下、実行結果です。
sl_no gender ssc_p ssc_b hsc_p hsc_b hsc_s degree_p \ 0 1 M 67.00 Others 91.00 Others Commerce 58.00 1 2 M 79.33 Central 78.33 Others Science 77.48 2 3 M 65.00 Central 68.00 Central Arts 64.00 3 4 M 56.00 Central 52.00 Central Science 52.00 4 5 M 85.80 Central 73.60 Central Commerce 73.30 .. ... ... ... ... ... ... ... ... 210 211 M 80.60 Others 82.00 Others Commerce 77.60 211 212 M 58.00 Others 60.00 Others Science 72.00 212 213 M 67.00 Others 67.00 Others Commerce 73.00 213 214 F 74.00 Others 66.00 Others Commerce 58.00 214 215 M 62.00 Central 58.00 Others Science 53.00 degree_t workex etest_p specialisation mba_p status salary 0 Sci&Tech No 55.0 Mkt&HR 58.80 Placed 270000.0 1 Sci&Tech Yes 86.5 Mkt&Fin 66.28 Placed 200000.0 2 Comm&Mgmt No 75.0 Mkt&Fin 57.80 Placed 250000.0 3 Sci&Tech No 66.0 Mkt&HR 59.43 Not Placed NaN 4 Comm&Mgmt No 96.8 Mkt&Fin 55.50 Placed 425000.0 .. ... ... ... ... ... ... ... 210 Comm&Mgmt No 91.0 Mkt&Fin 74.49 Placed 400000.0 211 Sci&Tech No 74.0 Mkt&Fin 53.62 Placed 275000.0 212 Comm&Mgmt Yes 59.0 Mkt&Fin 69.72 Placed 295000.0 213 Comm&Mgmt No 70.0 Mkt&HR 60.23 Placed 204000.0 214 Comm&Mgmt No 89.0 Mkt&HR 60.22 Not Placed NaN [215 rows x 15 columns]
レコード数は215で、変数の数は15です。
データセットの特徴把握
読み込んだデータセットの情報を見てみます。
以下、コードです。
real_data.info()
以下、実行結果です。
<class 'pandas.core.frame.DataFrame'> RangeIndex: 215 entries, 0 to 214 Data columns (total 15 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 sl_no 215 non-null int64 1 gender 215 non-null object 2 ssc_p 215 non-null float64 3 ssc_b 215 non-null object 4 hsc_p 215 non-null float64 5 hsc_b 215 non-null object 6 hsc_s 215 non-null object 7 degree_p 215 non-null float64 8 degree_t 215 non-null object 9 workex 215 non-null object 10 etest_p 215 non-null float64 11 specialisation 215 non-null object 12 mba_p 215 non-null float64 13 status 215 non-null object 14 salary 148 non-null float64 dtypes: float64(6), int64(1), object(8) memory usage: 25.3+ KB
salaryを見ていただくと分かりますが、N-n-Null Count(欠測していないデータの数)が148と、215よりも少なく欠測値がある(67名の方が内定をもらっていない)ことが分かります。
sl_noが重複していいないかどうかを確認します。
以下、コードです。
real_data['sl_no'].duplicated()
以下、実行結果です。重複している行がTrueになっています。
0 False 1 False 2 False 3 False 4 False ... 210 False 211 False 212 False 213 False 214 False Name: sl_no, Length: 215, dtype: bool
Trueの数を数えます。
以下、コードです。
real_data['sl_no'].duplicated().sum()
以下、実行結果です。
0
0です。
当然ですが、学習で利用するこのデータセットは、sl_noが重複していません。
メタデータの取得
テーブルデータを生成するために、先程読み込んだデータセットからメタデータを取得します。
以下、コードです。
# データフレームからメタデータを自動抽出 metadata = SingleTableMetadata() metadata.detect_from_dataframe(real_data)
念のため、メタデータを見てみます。
以下、コードです。
metadata
以下、実行結果です。各変数の型が定義されています。
{ "METADATA_SPEC_VERSION": "SINGLE_TABLE_V1", "columns": { "sl_no": { "sdtype": "numerical" }, "gender": { "sdtype": "categorical" }, "ssc_p": { "sdtype": "numerical" }, "ssc_b": { "sdtype": "categorical" }, "hsc_p": { "sdtype": "numerical" }, "hsc_b": { "sdtype": "categorical" }, "hsc_s": { "sdtype": "categorical" }, "degree_p": { "sdtype": "numerical" }, "degree_t": { "sdtype": "categorical" }, "workex": { "sdtype": "categorical" }, "etest_p": { "sdtype": "numerical" }, "specialisation": { "sdtype": "categorical" }, "mba_p": { "sdtype": "numerical" }, "status": { "sdtype": "categorical" }, "salary": { "sdtype": "numerical" } } }
データ生成その1(主Key設定なし)
主キーは、先程取得したメタデータを修正し指定します。
先ずは、メタデータを修正することなく、テーブルデータを生成していきます。
以下、コードです。
# インスタンス生成 ctgan = CTGANSynthesizer(metadata,epochs=10) # 学習 ctgan.fit(real_data) # データ生成 synthetic_data = ctgan.sample(20000) # 生成したデータを確認 print(synthetic_data)
以下、実行結果です。
sl_no gender ssc_p ssc_b hsc_p hsc_b hsc_s degree_p \ 0 215 M 86.91 Others 67.75 Central Commerce 50.00 1 168 M 62.49 Others 94.87 Central Science 51.76 2 215 F 89.40 Others 61.27 Central Commerce 59.31 3 155 F 86.97 Others 96.04 Others Science 78.99 4 151 M 85.66 Others 75.40 Others Commerce 65.19 ... ... ... ... ... ... ... ... ... 19995 164 F 88.70 Central 63.86 Others Commerce 54.74 19996 215 M 78.41 Central 63.03 Central Commerce 63.29 19997 198 M 76.26 Central 92.06 Central Science 50.00 19998 13 M 69.93 Central 86.26 Others Science 50.00 19999 215 M 58.20 Others 81.18 Others Commerce 55.25 degree_t workex etest_p specialisation mba_p status salary 0 Comm&Mgmt Yes 75.68 Mkt&Fin 56.84 Not Placed NaN 1 Comm&Mgmt No 70.20 Mkt&HR 59.94 Not Placed 200000.0 2 Others No 73.44 Mkt&Fin 52.34 Placed NaN 3 Others No 74.43 Mkt&Fin 55.45 Placed NaN 4 Sci&Tech Yes 64.69 Mkt&Fin 52.47 Not Placed NaN ... ... ... ... ... ... ... ... 19995 Comm&Mgmt Yes 93.29 Mkt&HR 54.85 Placed 306046.0 19996 Comm&Mgmt Yes 55.07 Mkt&Fin 68.71 Not Placed 317092.0 19997 Comm&Mgmt No 98.00 Mkt&HR 51.21 Not Placed 485863.0 19998 Sci&Tech Yes 64.21 Mkt&HR 51.21 Placed 268246.0 19999 Comm&Mgmt No 51.46 Mkt&HR 57.42 Placed 399696.0 [20000 rows x 15 columns]
生成したテーブルデータのsl_noが重複していいないかどうかを確認します。
以下、コードです。
synthetic_data['sl_no'].duplicated().sum()
以下、実行結果です。
19785
19785です。かなりsl_noが重複しています。
メタデータの修正(主キー設定)
メタデータを修正し主キーを設定します。
主キーに設定する変数は、例えば「id」型である必要があるため、先ず型変換を行い、次に主キー設定をします。
以下、コードです。
# 変数の型の変更 metadata.update_column( column_name='sl_no', sdtype='id') # 主Key設定 metadata.set_primary_key(column_name="sl_no")
メターデータを確認してみます。
以下、コードです。
metadata
以下、実行結果です。
{ "primary_key": "sl_no", "METADATA_SPEC_VERSION": "SINGLE_TABLE_V1", "columns": { "sl_no": { "sdtype": "id" }, "gender": { "sdtype": "categorical" }, "ssc_p": { "sdtype": "numerical" }, "ssc_b": { "sdtype": "categorical" }, "hsc_p": { "sdtype": "numerical" }, "hsc_b": { "sdtype": "categorical" }, "hsc_s": { "sdtype": "categorical" }, "degree_p": { "sdtype": "numerical" }, "degree_t": { "sdtype": "categorical" }, "workex": { "sdtype": "categorical" }, "etest_p": { "sdtype": "numerical" }, "specialisation": { "sdtype": "categorical" }, "mba_p": { "sdtype": "numerical" }, "status": { "sdtype": "categorical" }, "salary": { "sdtype": "numerical" } } }
primary_keyが主キーを表します。sl_noになっているのが分かるかと思います。
sl_noの型(sdtype)がidになっているのも分かるかと思います。
この状態で、テーブルデータを生成していきます。
データ生成その2(主キー設定あり)
sl_noを主キーに設定したメタデータを使い、テーブルデータを生成します。
以下、コードです。
# インスタンス生成 ctgan = CTGANSynthesizer(metadata,epochs=10) # 学習 ctgan.fit(real_data) # データ生成 synthetic_data = ctgan.sample(20000) # 生成したデータを確認 print(synthetic_data)
以下、実行結果です。
sl_no gender ssc_p ssc_b hsc_p hsc_b hsc_s degree_p \ 0 0 M 63.32 Central 51.18 Others Arts 78.93 1 1 F 65.35 Others 66.10 Central Science 67.23 2 2 F 40.89 Central 47.76 Central Science 67.20 3 3 F 40.89 Others 60.37 Others Commerce 89.05 4 4 F 56.19 Central 39.60 Central Commerce 63.17 ... ... ... ... ... ... ... ... ... 19995 19995 M 64.25 Central 43.54 Others Arts 77.72 19996 19996 F 76.01 Central 70.35 Others Commerce 71.96 19997 19997 F 72.59 Central 97.70 Others Arts 67.86 19998 19998 M 51.43 Others 56.94 Central Science 59.97 19999 19999 M 56.75 Central 52.63 Central Commerce 50.00 degree_t workex etest_p specialisation mba_p status salary 0 Sci&Tech Yes 65.37 Mkt&Fin 65.79 Placed 282434.0 1 Sci&Tech Yes 84.75 Mkt&HR 57.03 Placed 374237.0 2 Sci&Tech No 58.12 Mkt&Fin 51.21 Placed NaN 3 Comm&Mgmt Yes 82.79 Mkt&Fin 52.39 Placed 200000.0 4 Sci&Tech No 71.98 Mkt&HR 53.52 Not Placed NaN ... ... ... ... ... ... ... ... 19995 Sci&Tech Yes 92.09 Mkt&Fin 53.48 Placed NaN 19996 Comm&Mgmt Yes 82.98 Mkt&HR 55.71 Not Placed NaN 19997 Sci&Tech Yes 59.22 Mkt&HR 62.57 Placed 457765.0 19998 Sci&Tech Yes 75.93 Mkt&HR 51.21 Placed 293694.0 19999 Others No 50.00 Mkt&Fin 51.21 Placed 394105.0 [20000 rows x 15 columns]
生成したテーブルデータのsl_noが重複していいないかどうかを確認します。
以下、コードです。
synthetic_data['sl_no'].duplicated().sum()
以下、実行結果です。
0
0です。sl_noが重複していません。
ここで、もう1つの制約というかルールであった「salary(提示された給与)は、statusがPlaced(内定)のときのみデータがある」が保持されているかどうか見てみます。
先ず、statusとsalaryの変数だけ抜き出して見てみます。
以下、コードです。
print(synthetic_data.loc[:,['status','salary']])
以下、実行結果です。
status salary 0 Placed 282434.0 1 Placed 374237.0 2 Placed NaN 3 Placed 200000.0 4 Not Placed NaN ... ... ... 19995 Placed NaN 19996 Not Placed NaN 19997 Placed 457765.0 19998 Placed 293694.0 19999 Placed 394105.0 [20000 rows x 2 columns]
status変数がPlaced(内定)なのにsalary変数が欠測しているなど、上手く行っていない様子が分かります。
Placed(内定)かNot Placedどうかで、salary変数の基本統計量がどうなっているか見てみます。
以下、コードです。
print(synthetic_data.loc[:,['status','salary']].groupby('status').describe())
以下、実行結果です。
salary \ count mean std min 25% status Not Placed 6134.0 397787.475383 158847.556333 200000.0 296821.75 Placed 7679.0 399903.842427 158153.818051 200000.0 297113.00 50% 75% max status Not Placed 363193.0 452395.5 940000.0 Placed 368274.0 457405.0 940000.0
Not Placed(内定がでていない)なのにsalary変数に値のある(給与が提示された)レコードがあることが分かります。
要するに、「salary(提示された給与)は、statusがPlaced(内定)のときのみデータがある」が保持されていません。
まとめ
今回は、主キーを指定しテーブルデータ生成AI CTGANで、テーブルデータを生成する方法を説明しました。
メタデータの設定を変えることで簡単に対応できます。
テーブルデータには、主キーにも色々な制約やルールが課されていることがあります。
次回は、今回と同じデータセットを使い、変数間の関係性に関する制約・ルールを定義し、テーブルデータを生成する方法について説明します。
要は、「salary(提示された給与)は、statusがPlaced(内定)のときのみデータがある」を制約・ルールを課すということです。