{ "cells": [ { "cell_type": "markdown", "id": "c46a6957", "metadata": {}, "source": [ "#### Importing important libraries" ] }, { "cell_type": "code", "execution_count": 1, "id": "8e218c1e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting skops\n", " Obtaining dependency information for skops from https://files.pythonhosted.org/packages/fd/fd/8ee9d18fa13118f4230766cc31fe66846928eca1713b1907ffd61fa86ed3/skops-0.9.0-py3-none-any.whl.metadata\n", " Downloading skops-0.9.0-py3-none-any.whl.metadata (5.9 kB)\n", "Requirement already satisfied: scikit-learn>=0.24 in c:\\users\\dell\\anaconda3\\lib\\site-packages (from skops) (1.3.0)\n", "Requirement already satisfied: huggingface-hub>=0.17.0 in c:\\users\\dell\\anaconda3\\lib\\site-packages (from skops) (0.20.2)\n", "Requirement already satisfied: tabulate>=0.8.8 in c:\\users\\dell\\anaconda3\\lib\\site-packages (from skops) (0.8.10)\n", "Requirement already satisfied: packaging>=17.0 in c:\\users\\dell\\anaconda3\\lib\\site-packages (from skops) (23.1)\n", "Requirement already satisfied: filelock in c:\\users\\dell\\anaconda3\\lib\\site-packages (from huggingface-hub>=0.17.0->skops) (3.9.0)\n", "Requirement already satisfied: fsspec>=2023.5.0 in c:\\users\\dell\\anaconda3\\lib\\site-packages (from huggingface-hub>=0.17.0->skops) (2023.12.2)\n", "Requirement already satisfied: requests in c:\\users\\dell\\anaconda3\\lib\\site-packages (from huggingface-hub>=0.17.0->skops) (2.31.0)\n", "Requirement already satisfied: tqdm>=4.42.1 in c:\\users\\dell\\anaconda3\\lib\\site-packages (from huggingface-hub>=0.17.0->skops) (4.65.0)\n", "Requirement already satisfied: pyyaml>=5.1 in c:\\users\\dell\\anaconda3\\lib\\site-packages (from huggingface-hub>=0.17.0->skops) (6.0)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in c:\\users\\dell\\anaconda3\\lib\\site-packages (from huggingface-hub>=0.17.0->skops) (4.9.0)\n", "Requirement already satisfied: numpy>=1.17.3 in c:\\users\\dell\\anaconda3\\lib\\site-packages (from scikit-learn>=0.24->skops) (1.24.3)\n", "Requirement already satisfied: scipy>=1.5.0 in c:\\users\\dell\\anaconda3\\lib\\site-packages (from scikit-learn>=0.24->skops) (1.11.1)\n", "Requirement already satisfied: joblib>=1.1.1 in c:\\users\\dell\\anaconda3\\lib\\site-packages (from scikit-learn>=0.24->skops) (1.2.0)\n", "Requirement already satisfied: threadpoolctl>=2.0.0 in c:\\users\\dell\\anaconda3\\lib\\site-packages (from scikit-learn>=0.24->skops) (2.2.0)\n", "Requirement already satisfied: colorama in c:\\users\\dell\\anaconda3\\lib\\site-packages (from tqdm>=4.42.1->huggingface-hub>=0.17.0->skops) (0.4.6)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in c:\\users\\dell\\anaconda3\\lib\\site-packages (from requests->huggingface-hub>=0.17.0->skops) (2.0.4)\n", "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\dell\\anaconda3\\lib\\site-packages (from requests->huggingface-hub>=0.17.0->skops) (3.4)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in c:\\users\\dell\\anaconda3\\lib\\site-packages (from requests->huggingface-hub>=0.17.0->skops) (1.26.16)\n", "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\dell\\anaconda3\\lib\\site-packages (from requests->huggingface-hub>=0.17.0->skops) (2023.7.22)\n", "Downloading skops-0.9.0-py3-none-any.whl (120 kB)\n", " ---------------------------------------- 0.0/120.7 kB ? eta -:--:--\n", " --- ------------------------------------ 10.2/120.7 kB ? eta -:--:--\n", " ------------- ------------------------- 41.0/120.7 kB 991.0 kB/s eta 0:00:01\n", " ---------------------------------------- 120.7/120.7 kB 1.4 MB/s eta 0:00:00\n", "Installing collected packages: skops\n", "Successfully installed skops-0.9.0\n", "Note: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "pip install skops" ] }, { "cell_type": "code", "execution_count": 2, "id": "57b72b5c", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import os\n", "import librosa\n", "import skops.io as sio\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "id": "b8e1283b", "metadata": {}, "source": [ "#### Importing the dataset" ] }, { "cell_type": "code", "execution_count": 2, "id": "bfde13e5", "metadata": {}, "outputs": [], "source": [ "df_valid_train = pd.read_csv('cv-valid-train.csv')\n", "df_valid_test = pd.read_csv('cv-valid-test.csv')\n", "df_valid_dev = pd.read_csv('cv-valid-dev.csv')\n", "\n", "df_other_train = pd.read_csv('cv-valid-train.csv')\n", "df_other_test = pd.read_csv('cv-valid-test.csv')\n", "df_other_dev = pd.read_csv('cv-valid-dev.csv')\n", "\n", "df_invalid = pd.read_csv('cv-invalid.csv')" ] }, { "cell_type": "code", "execution_count": 3, "id": "830b20a3", "metadata": {}, "outputs": [], "source": [ "### Concatenating the dataframes\n", "df = pd.concat([df_valid_train,df_valid_test,df_valid_dev,df_other_train,df_other_test,df_other_dev,df_invalid])" ] }, { "cell_type": "code", "execution_count": 4, "id": "96d2684b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "### Checking for the missing values\n", "plt.figure(figsize=(12,6))\n", "sns.heatmap(df.isnull())" ] }, { "cell_type": "code", "execution_count": 5, "id": "e77a1c36", "metadata": {}, "outputs": [], "source": [ "### dropping the column because of the missing values\n", "df = df.drop('duration',axis=1)" ] }, { "cell_type": "code", "execution_count": 6, "id": "7903426d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(433097, 7)\n", "(136525, 7)\n" ] } ], "source": [ "print(df.shape)\n", "df = df.dropna()\n", "print(df.shape)" ] }, { "cell_type": "code", "execution_count": 7, "id": "13b907b6", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
filenametextup_votesdown_votesagegenderaccent
5cv-valid-train/sample-000005.mp3a shepherd may like to travel but he should ne...10twentiesfemaleus
8cv-valid-train/sample-000008.mp3put jackie right on the staff30seventiesmaleus
13cv-valid-train/sample-000013.mp3but he had found a guide and didn't want to mi...10thirtiesfemaleus
14cv-valid-train/sample-000014.mp3as they began to decorate the hallway a silhou...10sixtiesmaleengland
19cv-valid-train/sample-000019.mp3then they got ahold of some dough and went goofy10fiftiesmaleaustralia
\n", "
" ], "text/plain": [ " filename \\\n", "5 cv-valid-train/sample-000005.mp3 \n", "8 cv-valid-train/sample-000008.mp3 \n", "13 cv-valid-train/sample-000013.mp3 \n", "14 cv-valid-train/sample-000014.mp3 \n", "19 cv-valid-train/sample-000019.mp3 \n", "\n", " text up_votes down_votes \\\n", "5 a shepherd may like to travel but he should ne... 1 0 \n", "8 put jackie right on the staff 3 0 \n", "13 but he had found a guide and didn't want to mi... 1 0 \n", "14 as they began to decorate the hallway a silhou... 1 0 \n", "19 then they got ahold of some dough and went goofy 1 0 \n", "\n", " age gender accent \n", "5 twenties female us \n", "8 seventies male us \n", "13 thirties female us \n", "14 sixties male england \n", "19 fifties male australia " ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.head()" ] }, { "cell_type": "code", "execution_count": 8, "id": "da574ff3", "metadata": {}, "outputs": [], "source": [ "### dropping the columns\n", "df = df.drop(['text','up_votes','down_votes'],axis=1)" ] }, { "cell_type": "code", "execution_count": 9, "id": "087bbb10", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
filenameagegenderaccent
5cv-valid-train/sample-000005.mp3twentiesfemaleus
8cv-valid-train/sample-000008.mp3seventiesmaleus
13cv-valid-train/sample-000013.mp3thirtiesfemaleus
14cv-valid-train/sample-000014.mp3sixtiesmaleengland
19cv-valid-train/sample-000019.mp3fiftiesmaleaustralia
\n", "
" ], "text/plain": [ " filename age gender accent\n", "5 cv-valid-train/sample-000005.mp3 twenties female us\n", "8 cv-valid-train/sample-000008.mp3 seventies male us\n", "13 cv-valid-train/sample-000013.mp3 thirties female us\n", "14 cv-valid-train/sample-000014.mp3 sixties male england\n", "19 cv-valid-train/sample-000019.mp3 fifties male australia" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.head()" ] }, { "cell_type": "code", "execution_count": 10, "id": "835148a8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "male 99194\n", "female 35746\n", "other 1585\n", "Name: gender, dtype: int64" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df['gender'].value_counts()" ] }, { "cell_type": "code", "execution_count": 13, "id": "34db866b", "metadata": {}, "outputs": [], "source": [ "### Removing the observations with other gender due to the underbalance \n", "index = list(df[df['gender']=='other'].index)\n", "\n", "df = df.drop(index)\n", "### resetting the index\n", "df = df.reset_index()\n", "df = df.drop('index', axis = 1)" ] }, { "cell_type": "code", "execution_count": 19, "id": "6d31096c", "metadata": { "scrolled": false }, "outputs": [], "source": [ "### Function to change the filename according to the path\n", "def func(fileName):\n", " \n", " fileName = fileName.replace('/','\\\\')\n", " fileName = 'C:\\\\Users\\\\Dell\\\\Desktop\\\\Audio Recognition Project\\\\' + fileName\n", " return fileName" ] }, { "cell_type": "code", "execution_count": 20, "id": "e4a07ec5", "metadata": {}, "outputs": [], "source": [ "df['filename'] = df['filename'].apply(lambda x:func(x))" ] }, { "cell_type": "code", "execution_count": 21, "id": "1d770fe7", "metadata": {}, "outputs": [], "source": [ "df = df.drop_duplicates(subset=['filename'])" ] }, { "cell_type": "code", "execution_count": 22, "id": "40891fa5", "metadata": {}, "outputs": [], "source": [ "### moving all the files relevant to us to the same folder\n", "def moving(fileName):\n", " \n", " source = fileName\n", " index_m = source.index('m')\n", " index_v = source.index('v')\n", " fileName = fileName[index_v-1:index_m-3] + '_' + fileName[index_m-2:]\n", " destination = 'C:\\\\Users\\\\Dell\\\\Desktop\\\\Audio Recognition Project\\\\Final Destination\\\\' + fileName\n", " os.rename(source,destination)" ] }, { "cell_type": "code", "execution_count": 27, "id": "942b4842", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "for i in range(len(df_invalid)):\n", " moving(df['filename'][i])" ] }, { "cell_type": "code", "execution_count": 60, "id": "34d18ca5", "metadata": {}, "outputs": [], "source": [ "### Saving the dataset\n", "df.to_csv('metadata.csv')" ] }, { "cell_type": "code", "execution_count": 2, "id": "0f07952d", "metadata": {}, "outputs": [], "source": [ "df = pd.read_csv('metadata.csv')" ] }, { "cell_type": "markdown", "id": "984e68d9", "metadata": {}, "source": [ "##### converting the audio files to array's" ] }, { "cell_type": "code", "execution_count": 28, "id": "ab9bfd1d", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
filenameagegenderaccent
0C:\\Users\\Dell\\Desktop\\Audio Recognition Projec...twentiesfemaleus
1C:\\Users\\Dell\\Desktop\\Audio Recognition Projec...seventiesmaleus
2C:\\Users\\Dell\\Desktop\\Audio Recognition Projec...thirtiesfemaleus
3C:\\Users\\Dell\\Desktop\\Audio Recognition Projec...sixtiesmaleengland
4C:\\Users\\Dell\\Desktop\\Audio Recognition Projec...fiftiesmaleaustralia
\n", "
" ], "text/plain": [ " filename age gender \\\n", "0 C:\\Users\\Dell\\Desktop\\Audio Recognition Projec... twenties female \n", "1 C:\\Users\\Dell\\Desktop\\Audio Recognition Projec... seventies male \n", "2 C:\\Users\\Dell\\Desktop\\Audio Recognition Projec... thirties female \n", "3 C:\\Users\\Dell\\Desktop\\Audio Recognition Projec... sixties male \n", "4 C:\\Users\\Dell\\Desktop\\Audio Recognition Projec... fifties male \n", "\n", " accent \n", "0 us \n", "1 us \n", "2 us \n", "3 england \n", "4 australia " ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.head()" ] }, { "cell_type": "code", "execution_count": 29, "id": "e970b484", "metadata": {}, "outputs": [], "source": [ "### Changing the filenames to filepaths with all the files in the same folder\n", "def filename_change(fileName):\n", " \n", " index = fileName.index('v')\n", " name = fileName[index-1:].replace('\\\\','_')[:-3] + 'wav'\n", " filename = 'C:\\\\Users\\\\Dell\\\\Desktop\\\\Audio Recognition Project\\\\Common Voice WAV\\\\' + name\n", " return filename" ] }, { "cell_type": "code", "execution_count": 30, "id": "13380b8e", "metadata": {}, "outputs": [], "source": [ "df['filename'] = df['filename'].apply(lambda x:filename_change(x))" ] }, { "cell_type": "code", "execution_count": 35, "id": "0b5ea468", "metadata": {}, "outputs": [], "source": [ "df = df.reset_index()\n", "df = df.drop('index',axis=1)" ] }, { "cell_type": "code", "execution_count": 36, "id": "4f2019e9", "metadata": {}, "outputs": [], "source": [ "df.to_csv('metadata_final.csv',index=False)" ] }, { "cell_type": "markdown", "id": "d3bd2e8f", "metadata": {}, "source": [ "#### Feature Extraction" ] }, { "cell_type": "code", "execution_count": 4, "id": "cf8fc8c3", "metadata": {}, "outputs": [], "source": [ "df = pd.read_csv('metadata_final.csv')" ] }, { "cell_type": "code", "execution_count": 18, "id": "c865bef6", "metadata": {}, "outputs": [], "source": [ "### extract the features from the audio files using mfcc\n", "def feature_extracter(fileName):\n", " audio,sample_rate = librosa.load(fileName,res_type='kaiser_fast')\n", " mfcc_features = librosa.feature.mfcc(y=audio,sr=sample_rate,n_mfcc=30)\n", " mfccs_scaled_features = np.mean(mfcc_features.T, axis=0)\n", " \n", " return list(mfccs_scaled_features)" ] }, { "cell_type": "code", "execution_count": 19, "id": "33225037", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "128" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(feature_extracter('C:\\\\Users\\\\Dell\\\\Desktop\\\\Audio Recognition Project\\\\Common Voice WAV\\\\cv-valid-train_sample-000005.wav'))" ] }, { "cell_type": "code", "execution_count": 7, "id": "a9d13e6c", "metadata": {}, "outputs": [], "source": [ "feature_extraction_dataset = []" ] }, { "cell_type": "code", "execution_count": null, "id": "06543e76", "metadata": {}, "outputs": [], "source": [ "for i in range(len(df)):\n", " data = feature_extracter(df['filename'][i])\n", " data.append(df['age'][i])\n", " data.append(df['gender'][i])\n", " data.append(df['accent'][i])\n", " feature_extraction_dataset.append(data)\n", " #print(\"{} files completed\".format(i))" ] }, { "cell_type": "code", "execution_count": 9, "id": "0faa3c6a", "metadata": {}, "outputs": [], "source": [ "### converting the list to dataframe\n", "feature_extraction_dataset = pd.DataFrame(feature_extraction_dataset)" ] }, { "cell_type": "code", "execution_count": 11, "id": "a7f242fa", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
0123456789...121122123124125126127128129130
0-583.109741100.3704530.94325448.706886-10.5407710.412921-18.057110-8.326262-1.521947-12.522870...-0.192631-0.3155640.343298-0.4024660.258311-0.2967830.058532twentiesfemaleus
1-394.04379399.290733-38.35684636.20208713.018510-0.332277-30.521049-24.611736-20.024382-8.018195...-0.181873-0.617222-0.2033000.0660650.179560-0.0358150.082535seventiesmaleus
2-329.92706389.804878-93.25531853.524906-14.073633-17.782843-30.973644-11.1152971.704327-15.689195...-0.064750-0.1546700.1440960.3408310.3978980.183917-0.035574thirtiesfemaleus
3-544.49835280.18797311.45608926.1494988.8674084.068777-5.959950-0.991863-5.118737-1.111087...-0.293615-0.154340-0.152312-0.136096-0.248668-0.185580-0.021339sixtiesmaleengland
4-265.772156105.1039359.27538827.4940171.30829328.853981-11.55351419.866306-9.07240516.467325...-0.064412-0.548103-0.691253-0.090602-0.078694-0.036180-0.491893fiftiesmaleaustralia
..................................................................
69900-376.485321119.770882-0.66136810.84285411.8847364.8079470.22187714.571591-7.1060678.881387...-0.514412-0.4191680.3862080.013728-0.990892-0.474937-0.042218teensmaleus
69901-1131.3709720.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000...0.0000000.0000000.0000000.0000000.0000000.0000000.000000sixtiesmaleus
69902-265.763916114.194046-22.39389023.325445-6.32488214.880446-5.5503218.608303-9.4894394.931038...-1.028494-0.561055-0.638143-0.282054-0.419412-0.1933800.354584twentiesmaleindian
69903-494.44854756.549568-10.6774886.895159-0.0684892.941504-14.5314702.231544-4.246867-1.574609...0.2368070.8816890.132002-0.2148250.1432240.034405-0.521862sixtiesfemaleus
69904-461.47467045.163513-45.25056137.988201-2.061493-11.344541-11.8620062.977398-10.133149-1.859840...0.3415810.2535530.186000-0.0788740.1106010.0187720.275380fourtiesmaleaustralia
\n", "

69905 rows × 131 columns

\n", "
" ], "text/plain": [ " 0 1 2 3 4 5 \\\n", "0 -583.109741 100.370453 0.943254 48.706886 -10.540771 0.412921 \n", "1 -394.043793 99.290733 -38.356846 36.202087 13.018510 -0.332277 \n", "2 -329.927063 89.804878 -93.255318 53.524906 -14.073633 -17.782843 \n", "3 -544.498352 80.187973 11.456089 26.149498 8.867408 4.068777 \n", "4 -265.772156 105.103935 9.275388 27.494017 1.308293 28.853981 \n", "... ... ... ... ... ... ... \n", "69900 -376.485321 119.770882 -0.661368 10.842854 11.884736 4.807947 \n", "69901 -1131.370972 0.000000 0.000000 0.000000 0.000000 0.000000 \n", "69902 -265.763916 114.194046 -22.393890 23.325445 -6.324882 14.880446 \n", "69903 -494.448547 56.549568 -10.677488 6.895159 -0.068489 2.941504 \n", "69904 -461.474670 45.163513 -45.250561 37.988201 -2.061493 -11.344541 \n", "\n", " 6 7 8 9 ... 121 122 \\\n", "0 -18.057110 -8.326262 -1.521947 -12.522870 ... -0.192631 -0.315564 \n", "1 -30.521049 -24.611736 -20.024382 -8.018195 ... -0.181873 -0.617222 \n", "2 -30.973644 -11.115297 1.704327 -15.689195 ... -0.064750 -0.154670 \n", "3 -5.959950 -0.991863 -5.118737 -1.111087 ... -0.293615 -0.154340 \n", "4 -11.553514 19.866306 -9.072405 16.467325 ... -0.064412 -0.548103 \n", "... ... ... ... ... ... ... ... \n", "69900 0.221877 14.571591 -7.106067 8.881387 ... -0.514412 -0.419168 \n", "69901 0.000000 0.000000 0.000000 0.000000 ... 0.000000 0.000000 \n", "69902 -5.550321 8.608303 -9.489439 4.931038 ... -1.028494 -0.561055 \n", "69903 -14.531470 2.231544 -4.246867 -1.574609 ... 0.236807 0.881689 \n", "69904 -11.862006 2.977398 -10.133149 -1.859840 ... 0.341581 0.253553 \n", "\n", " 123 124 125 126 127 128 129 \\\n", "0 0.343298 -0.402466 0.258311 -0.296783 0.058532 twenties female \n", "1 -0.203300 0.066065 0.179560 -0.035815 0.082535 seventies male \n", "2 0.144096 0.340831 0.397898 0.183917 -0.035574 thirties female \n", "3 -0.152312 -0.136096 -0.248668 -0.185580 -0.021339 sixties male \n", "4 -0.691253 -0.090602 -0.078694 -0.036180 -0.491893 fifties male \n", "... ... ... ... ... ... ... ... \n", "69900 0.386208 0.013728 -0.990892 -0.474937 -0.042218 teens male \n", "69901 0.000000 0.000000 0.000000 0.000000 0.000000 sixties male \n", "69902 -0.638143 -0.282054 -0.419412 -0.193380 0.354584 twenties male \n", "69903 0.132002 -0.214825 0.143224 0.034405 -0.521862 sixties female \n", "69904 0.186000 -0.078874 0.110601 0.018772 0.275380 fourties male \n", "\n", " 130 \n", "0 us \n", "1 us \n", "2 us \n", "3 england \n", "4 australia \n", "... ... \n", "69900 us \n", "69901 us \n", "69902 indian \n", "69903 us \n", "69904 australia \n", "\n", "[69905 rows x 131 columns]" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "feature_extraction_dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "84b1b787", "metadata": {}, "outputs": [], "source": [ "### naming the features for convenience\n", "col_name = []\n", "for i in range(1,31):\n", " col_name.append('Feature_'+str(i))\n", "col_name = col_name + ['age','gender','accent']\n", "feature_extraction_dataset.columns = col_name" ] }, { "cell_type": "code", "execution_count": 9, "id": "a64049e4", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Feature_1Feature_2Feature_3Feature_4Feature_5Feature_6Feature_7Feature_8Feature_9Feature_10...Feature_24Feature_25Feature_26Feature_27Feature_28Feature_29Feature_30agegenderaccent
0-583.109741100.3704530.94325548.706882-10.5407710.412921-18.057110-8.326262-1.521946-12.522870...-2.081630-4.345325-5.5432710.867665-2.325722-4.993744-2.170289twentiesfemaleus
1-394.04379399.290733-38.35684636.20208713.018513-0.332277-30.521049-24.611736-20.024382-8.018195...-5.490204-5.979488-5.525753-3.199488-12.2292312.510893-1.527913seventiesmaleus
2-329.92706389.804886-93.25531853.524906-14.073632-17.782843-30.973644-11.1152981.704327-15.689195...-2.942057-3.8995095.903738-2.4959304.2841842.987215-1.986013thirtiesfemaleus
3-544.49829180.18797311.45608926.1494988.8674084.068777-5.959950-0.991863-5.118737-1.111087...-3.014753-1.035969-4.900939-1.521831-3.017688-1.5105142.014324sixtiesmaleengland
4-265.772156105.1039359.27538727.4940171.30829328.853981-11.55351219.866306-9.07240516.467325...1.660018-2.5508891.886928-5.2480731.555064-0.886034-0.406267fiftiesmaleaustralia
\n", "

5 rows × 33 columns

\n", "
" ], "text/plain": [ " Feature_1 Feature_2 Feature_3 Feature_4 Feature_5 Feature_6 \\\n", "0 -583.109741 100.370453 0.943255 48.706882 -10.540771 0.412921 \n", "1 -394.043793 99.290733 -38.356846 36.202087 13.018513 -0.332277 \n", "2 -329.927063 89.804886 -93.255318 53.524906 -14.073632 -17.782843 \n", "3 -544.498291 80.187973 11.456089 26.149498 8.867408 4.068777 \n", "4 -265.772156 105.103935 9.275387 27.494017 1.308293 28.853981 \n", "\n", " Feature_7 Feature_8 Feature_9 Feature_10 ... Feature_24 Feature_25 \\\n", "0 -18.057110 -8.326262 -1.521946 -12.522870 ... -2.081630 -4.345325 \n", "1 -30.521049 -24.611736 -20.024382 -8.018195 ... -5.490204 -5.979488 \n", "2 -30.973644 -11.115298 1.704327 -15.689195 ... -2.942057 -3.899509 \n", "3 -5.959950 -0.991863 -5.118737 -1.111087 ... -3.014753 -1.035969 \n", "4 -11.553512 19.866306 -9.072405 16.467325 ... 1.660018 -2.550889 \n", "\n", " Feature_26 Feature_27 Feature_28 Feature_29 Feature_30 age \\\n", "0 -5.543271 0.867665 -2.325722 -4.993744 -2.170289 twenties \n", "1 -5.525753 -3.199488 -12.229231 2.510893 -1.527913 seventies \n", "2 5.903738 -2.495930 4.284184 2.987215 -1.986013 thirties \n", "3 -4.900939 -1.521831 -3.017688 -1.510514 2.014324 sixties \n", "4 1.886928 -5.248073 1.555064 -0.886034 -0.406267 fifties \n", "\n", " gender accent \n", "0 female us \n", "1 male us \n", "2 female us \n", "3 male england \n", "4 male australia \n", "\n", "[5 rows x 33 columns]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "feature_extraction_dataset.head()" ] }, { "cell_type": "code", "execution_count": 10, "id": "825f9b04", "metadata": {}, "outputs": [], "source": [ "feature_extraction_dataset.to_csv('features_data.csv',index=False)" ] }, { "cell_type": "markdown", "id": "17c41ee7", "metadata": {}, "source": [ "#### Exploratory Data Analysis" ] }, { "cell_type": "code", "execution_count": 24, "id": "d31cc42b", "metadata": {}, "outputs": [], "source": [ "df_features = pd.read_csv('features_data.csv')" ] }, { "cell_type": "code", "execution_count": 25, "id": "21c0fc54", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1 51284\n", "0 18621\n", "Name: male, dtype: int64" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_features['male'].value_counts()" ] }, { "cell_type": "code", "execution_count": 30, "id": "df429429", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Feature_1Feature_2Feature_3Feature_4Feature_5Feature_6Feature_7Feature_8Feature_9Feature_10...Feature_30maleage_labelfiftiesfourtiesseventiessixtiesteensthirtiestwenties
0-583.10974100.3704500.94325548.706882-10.5407710.412921-18.057110-8.326262-1.521946-12.522870...-2.170289070000001
1-394.0438099.290730-38.35684636.20208713.018513-0.332277-30.521050-24.611736-20.024382-8.018195...-1.527913130010000
2-329.9270689.804886-93.25532053.524906-14.073632-17.782843-30.973644-11.1152981.704327-15.689195...-1.986013060000010
3-544.4983080.18797011.45608926.1494988.8674084.068777-5.959950-0.991863-5.118737-1.111087...2.014324140001000
4-265.77216105.1039359.27538727.4940171.30829328.853981-11.55351219.866306-9.07240516.467325...-0.406267111000000
\n", "

5 rows × 39 columns

\n", "
" ], "text/plain": [ " Feature_1 Feature_2 Feature_3 Feature_4 Feature_5 Feature_6 \\\n", "0 -583.10974 100.370450 0.943255 48.706882 -10.540771 0.412921 \n", "1 -394.04380 99.290730 -38.356846 36.202087 13.018513 -0.332277 \n", "2 -329.92706 89.804886 -93.255320 53.524906 -14.073632 -17.782843 \n", "3 -544.49830 80.187970 11.456089 26.149498 8.867408 4.068777 \n", "4 -265.77216 105.103935 9.275387 27.494017 1.308293 28.853981 \n", "\n", " Feature_7 Feature_8 Feature_9 Feature_10 ... Feature_30 male \\\n", "0 -18.057110 -8.326262 -1.521946 -12.522870 ... -2.170289 0 \n", "1 -30.521050 -24.611736 -20.024382 -8.018195 ... -1.527913 1 \n", "2 -30.973644 -11.115298 1.704327 -15.689195 ... -1.986013 0 \n", "3 -5.959950 -0.991863 -5.118737 -1.111087 ... 2.014324 1 \n", "4 -11.553512 19.866306 -9.072405 16.467325 ... -0.406267 1 \n", "\n", " age_label fifties fourties seventies sixties teens thirties twenties \n", "0 7 0 0 0 0 0 0 1 \n", "1 3 0 0 1 0 0 0 0 \n", "2 6 0 0 0 0 0 1 0 \n", "3 4 0 0 0 1 0 0 0 \n", "4 1 1 0 0 0 0 0 0 \n", "\n", "[5 rows x 39 columns]" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_features.head()" ] }, { "cell_type": "code", "execution_count": 29, "id": "d7c32d49", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(8,8))\n", "sns.histplot(x='age_label',data=df_features,hue='male')" ] }, { "cell_type": "code", "execution_count": 14, "id": "a4442733", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],\n", " [Text(0, 0, ''),\n", " Text(0, 0, ''),\n", " Text(0, 0, ''),\n", " Text(0, 0, ''),\n", " Text(0, 0, ''),\n", " Text(0, 0, ''),\n", " Text(0, 0, ''),\n", " Text(0, 0, ''),\n", " Text(0, 0, ''),\n", " Text(0, 0, ''),\n", " Text(0, 0, ''),\n", " Text(0, 0, ''),\n", " Text(0, 0, ''),\n", " Text(0, 0, ''),\n", " Text(0, 0, ''),\n", " Text(0, 0, '')])" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(12,8))\n", "ax = sns.histplot(x='accent',data=df_features,hue='gender')\n", "plt.xticks(rotation=45)" ] }, { "cell_type": "code", "execution_count": 15, "id": "b4dc03c3", "metadata": {}, "outputs": [], "source": [ "### dropping the accent column because of the imbalanced dataset\n", "df_features = df_features.drop('accent',axis=1)" ] }, { "cell_type": "markdown", "id": "dcf7e980", "metadata": {}, "source": [ "Label 0: female || 1: male" ] }, { "cell_type": "code", "execution_count": 16, "id": "dd00f6d3", "metadata": {}, "outputs": [], "source": [ "### hot encoding the gender attribute\n", "gender = pd.get_dummies(df_features['gender'],drop_first=True)\n", "df_features = df_features.drop('gender',axis=1)\n", "df_features = pd.concat([df_features,gender],axis=1)" ] }, { "cell_type": "markdown", "id": "ae8c6038", "metadata": {}, "source": [ "Label 7: twenties ||\n", "6: thirties ||\n", "2: fourties ||\n", "1: fifties ||\n", "4: sixties ||\n", "5: teens ||\n", "3: seventies ||\n", "0: eighties " ] }, { "cell_type": "code", "execution_count": 17, "id": "83473b61", "metadata": {}, "outputs": [], "source": [ "from sklearn.preprocessing import LabelEncoder" ] }, { "cell_type": "code", "execution_count": 22, "id": "30fb7c14", "metadata": {}, "outputs": [], "source": [ "### label encoding the age attribute\n", "encoding = LabelEncoder()\n", "encoding.fit(df_features['age'])\n", "age = encoding.transform(df_features['age'])\n", "age = pd.DataFrame(age,columns=['age_label'])\n", "df_features = pd.concat([df_features,age],axis=1)" ] }, { "cell_type": "code", "execution_count": 24, "id": "410f6638", "metadata": {}, "outputs": [], "source": [ "### hot encoding the age attribute\n", "age = pd.get_dummies(df_features['age'],drop_first=True)\n", "df_features = df_features.drop('age',axis=1)\n", "df_features = pd.concat([df_features,age],axis=1)" ] }, { "cell_type": "code", "execution_count": 25, "id": "d933fbfa", "metadata": {}, "outputs": [], "source": [ "df_features.to_csv('features_data.csv',index=False)" ] }, { "cell_type": "code", "execution_count": 3, "id": "d68b3a01", "metadata": {}, "outputs": [], "source": [ "df_features = pd.read_csv('features_data.csv')" ] }, { "cell_type": "code", "execution_count": 4, "id": "18ef8a1a", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Feature_1Feature_2Feature_3Feature_4Feature_5Feature_6Feature_7Feature_8Feature_9Feature_10...Feature_30maleage_labelfiftiesfourtiesseventiessixtiesteensthirtiestwenties
0-583.10974100.3704500.94325548.706882-10.5407710.412921-18.057110-8.326262-1.521946-12.522870...-2.170289070000001
1-394.0438099.290730-38.35684636.20208713.018513-0.332277-30.521050-24.611736-20.024382-8.018195...-1.527913130010000
2-329.9270689.804886-93.25532053.524906-14.073632-17.782843-30.973644-11.1152981.704327-15.689195...-1.986013060000010
3-544.4983080.18797011.45608926.1494988.8674084.068777-5.959950-0.991863-5.118737-1.111087...2.014324140001000
4-265.77216105.1039359.27538727.4940171.30829328.853981-11.55351219.866306-9.07240516.467325...-0.406267111000000
\n", "

5 rows × 39 columns

\n", "
" ], "text/plain": [ " Feature_1 Feature_2 Feature_3 Feature_4 Feature_5 Feature_6 \\\n", "0 -583.10974 100.370450 0.943255 48.706882 -10.540771 0.412921 \n", "1 -394.04380 99.290730 -38.356846 36.202087 13.018513 -0.332277 \n", "2 -329.92706 89.804886 -93.255320 53.524906 -14.073632 -17.782843 \n", "3 -544.49830 80.187970 11.456089 26.149498 8.867408 4.068777 \n", "4 -265.77216 105.103935 9.275387 27.494017 1.308293 28.853981 \n", "\n", " Feature_7 Feature_8 Feature_9 Feature_10 ... Feature_30 male \\\n", "0 -18.057110 -8.326262 -1.521946 -12.522870 ... -2.170289 0 \n", "1 -30.521050 -24.611736 -20.024382 -8.018195 ... -1.527913 1 \n", "2 -30.973644 -11.115298 1.704327 -15.689195 ... -1.986013 0 \n", "3 -5.959950 -0.991863 -5.118737 -1.111087 ... 2.014324 1 \n", "4 -11.553512 19.866306 -9.072405 16.467325 ... -0.406267 1 \n", "\n", " age_label fifties fourties seventies sixties teens thirties twenties \n", "0 7 0 0 0 0 0 0 1 \n", "1 3 0 0 1 0 0 0 0 \n", "2 6 0 0 0 0 0 1 0 \n", "3 4 0 0 0 1 0 0 0 \n", "4 1 1 0 0 0 0 0 0 \n", "\n", "[5 rows x 39 columns]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_features.head()" ] }, { "cell_type": "markdown", "id": "985af0f9", "metadata": {}, "source": [ "#### Train Test Split" ] }, { "cell_type": "code", "execution_count": 5, "id": "cd68791b", "metadata": {}, "outputs": [], "source": [ "df_features = df_features.drop(['fifties','fourties','seventies','sixties','teens','thirties','twenties'], axis = 1)" ] }, { "cell_type": "code", "execution_count": 6, "id": "d098b6e2", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Feature_1Feature_2Feature_3Feature_4Feature_5Feature_6Feature_7Feature_8Feature_9Feature_10...Feature_23Feature_24Feature_25Feature_26Feature_27Feature_28Feature_29Feature_30maleage_label
0-583.10974100.3704500.94325548.706882-10.5407710.412921-18.057110-8.326262-1.521946-12.522870...-10.564928-2.081630-4.345325-5.5432710.867665-2.325722-4.993744-2.17028907
1-394.0438099.290730-38.35684636.20208713.018513-0.332277-30.521050-24.611736-20.024382-8.018195...-1.825251-5.490204-5.979488-5.525754-3.199488-12.2292312.510893-1.52791313
2-329.9270689.804886-93.25532053.524906-14.073632-17.782843-30.973644-11.1152981.704327-15.689195...-2.550232-2.942057-3.8995095.903738-2.4959304.2841842.987215-1.98601306
3-544.4983080.18797011.45608926.1494988.8674084.068777-5.959950-0.991863-5.118737-1.111087...-4.689697-3.014753-1.035969-4.900939-1.521831-3.017688-1.5105142.01432414
4-265.77216105.1039359.27538727.4940171.30829328.853981-11.55351219.866306-9.07240516.467325...-0.5866171.660018-2.5508891.886928-5.2480731.555064-0.886034-0.40626711
\n", "

5 rows × 32 columns

\n", "
" ], "text/plain": [ " Feature_1 Feature_2 Feature_3 Feature_4 Feature_5 Feature_6 \\\n", "0 -583.10974 100.370450 0.943255 48.706882 -10.540771 0.412921 \n", "1 -394.04380 99.290730 -38.356846 36.202087 13.018513 -0.332277 \n", "2 -329.92706 89.804886 -93.255320 53.524906 -14.073632 -17.782843 \n", "3 -544.49830 80.187970 11.456089 26.149498 8.867408 4.068777 \n", "4 -265.77216 105.103935 9.275387 27.494017 1.308293 28.853981 \n", "\n", " Feature_7 Feature_8 Feature_9 Feature_10 ... Feature_23 Feature_24 \\\n", "0 -18.057110 -8.326262 -1.521946 -12.522870 ... -10.564928 -2.081630 \n", "1 -30.521050 -24.611736 -20.024382 -8.018195 ... -1.825251 -5.490204 \n", "2 -30.973644 -11.115298 1.704327 -15.689195 ... -2.550232 -2.942057 \n", "3 -5.959950 -0.991863 -5.118737 -1.111087 ... -4.689697 -3.014753 \n", "4 -11.553512 19.866306 -9.072405 16.467325 ... -0.586617 1.660018 \n", "\n", " Feature_25 Feature_26 Feature_27 Feature_28 Feature_29 Feature_30 \\\n", "0 -4.345325 -5.543271 0.867665 -2.325722 -4.993744 -2.170289 \n", "1 -5.979488 -5.525754 -3.199488 -12.229231 2.510893 -1.527913 \n", "2 -3.899509 5.903738 -2.495930 4.284184 2.987215 -1.986013 \n", "3 -1.035969 -4.900939 -1.521831 -3.017688 -1.510514 2.014324 \n", "4 -2.550889 1.886928 -5.248073 1.555064 -0.886034 -0.406267 \n", "\n", " male age_label \n", "0 0 7 \n", "1 1 3 \n", "2 0 6 \n", "3 1 4 \n", "4 1 1 \n", "\n", "[5 rows x 32 columns]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_features.head()" ] }, { "cell_type": "code", "execution_count": 7, "id": "74ee0331", "metadata": {}, "outputs": [], "source": [ "X = df_features.drop(['male','age_label'],axis=1)\n", "y = df_features[['male','age_label']]" ] }, { "cell_type": "code", "execution_count": 8, "id": "740eda60", "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split" ] }, { "cell_type": "code", "execution_count": 9, "id": "2c4da3cc", "metadata": {}, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.1,random_state=101)" ] }, { "cell_type": "code", "execution_count": 10, "id": "57608c06", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(62914, 30)\n", "(6991, 30)\n", "(62914, 2)\n", "(6991, 2)\n" ] } ], "source": [ "print(X_train.shape)\n", "print(X_test.shape)\n", "print(y_train.shape)\n", "print(y_test.shape)" ] }, { "cell_type": "code", "execution_count": 11, "id": "276fbaeb", "metadata": {}, "outputs": [], "source": [ "col_name = list(X_train.columns)" ] }, { "cell_type": "markdown", "id": "1e45d947", "metadata": {}, "source": [ "#### Feature Normalization" ] }, { "cell_type": "code", "execution_count": 12, "id": "91ca8369", "metadata": {}, "outputs": [], "source": [ "from sklearn.preprocessing import StandardScaler" ] }, { "cell_type": "code", "execution_count": 13, "id": "731977d2", "metadata": {}, "outputs": [], "source": [ "scaler = StandardScaler()\n", "X_train = scaler.fit_transform(X_train)\n", "X_test = scaler.transform(X_test)" ] }, { "cell_type": "code", "execution_count": 17, "id": "d7db8407", "metadata": {}, "outputs": [], "source": [ "sio.dump(scaler, 'scaler')" ] }, { "cell_type": "code", "execution_count": 18, "id": "3e0a9560", "metadata": {}, "outputs": [], "source": [ "X_train = pd.DataFrame(X_train,columns=col_name)\n", "X_test = pd.DataFrame(X_test,columns=col_name)" ] }, { "cell_type": "code", "execution_count": 19, "id": "248ae287", "metadata": {}, "outputs": [], "source": [ "y_train = y_train.reset_index()\n", "y_train = y_train.drop('index',axis=1)\n", "\n", "y_test = y_test.reset_index()\n", "y_test = y_test.drop('index',axis=1)" ] }, { "cell_type": "code", "execution_count": 41, "id": "6c90475b", "metadata": {}, "outputs": [], "source": [ "X_train.to_csv('X_train.csv',index=False)\n", "X_test.to_csv('X_test.csv',index=False)\n", "y_train.to_csv('y_train.csv',index=False)\n", "y_test.to_csv('y_test.csv',index=False)" ] }, { "cell_type": "markdown", "id": "73bcfbbb", "metadata": {}, "source": [ "### Gender Classification" ] }, { "cell_type": "code", "execution_count": 2, "id": "76c8537e", "metadata": {}, "outputs": [], "source": [ "X_train = pd.read_csv('X_train.csv')\n", "X_test = pd.read_csv('X_test.csv')\n", "y_train = pd.read_csv('y_train.csv')\n", "y_test = pd.read_csv('y_test.csv')" ] }, { "cell_type": "code", "execution_count": 20, "id": "3b2587ac", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Feature_1Feature_2Feature_3Feature_4Feature_5Feature_6Feature_7Feature_8Feature_9Feature_10...Feature_21Feature_22Feature_23Feature_24Feature_25Feature_26Feature_27Feature_28Feature_29Feature_30
00.5643950.353829-0.2789280.244259-0.181786-0.359068-0.4349220.421213-2.619306-0.233284...-0.752095-0.047534-1.596541-2.185107-0.475254-2.322643-0.664499-1.974565-0.8528450.497673
1-1.267875-2.791355-0.0870761.215719-0.609884-0.713474-0.228118-1.0441300.946386-0.418563...1.0120541.1523400.9626491.0602071.5687581.9717191.2155562.2483570.7644810.759086
21.255162-0.1926090.540949-1.7154350.2744402.556369-1.5192880.040003-0.1161920.200872...0.452080-0.5119770.175765-0.709609-1.761413-0.222903-0.364405-1.2242400.260271-0.052288
3-1.142260-0.0584410.0731800.767581-0.7071620.221958-0.0941280.026847-0.7135160.806025...-1.7074730.4210670.2106890.256840-0.117659-0.285450-1.195636-0.097772-0.528391-0.231970
4-0.242462-0.681530-0.062332-0.463858-1.4583970.806520-0.6852240.200540-0.0581700.478191...0.8206140.000589-0.1489460.2076310.275857-0.189223-0.384763-0.2252400.261362-0.031957
\n", "

5 rows × 30 columns

\n", "
" ], "text/plain": [ " Feature_1 Feature_2 Feature_3 Feature_4 Feature_5 Feature_6 \\\n", "0 0.564395 0.353829 -0.278928 0.244259 -0.181786 -0.359068 \n", "1 -1.267875 -2.791355 -0.087076 1.215719 -0.609884 -0.713474 \n", "2 1.255162 -0.192609 0.540949 -1.715435 0.274440 2.556369 \n", "3 -1.142260 -0.058441 0.073180 0.767581 -0.707162 0.221958 \n", "4 -0.242462 -0.681530 -0.062332 -0.463858 -1.458397 0.806520 \n", "\n", " Feature_7 Feature_8 Feature_9 Feature_10 ... Feature_21 Feature_22 \\\n", "0 -0.434922 0.421213 -2.619306 -0.233284 ... -0.752095 -0.047534 \n", "1 -0.228118 -1.044130 0.946386 -0.418563 ... 1.012054 1.152340 \n", "2 -1.519288 0.040003 -0.116192 0.200872 ... 0.452080 -0.511977 \n", "3 -0.094128 0.026847 -0.713516 0.806025 ... -1.707473 0.421067 \n", "4 -0.685224 0.200540 -0.058170 0.478191 ... 0.820614 0.000589 \n", "\n", " Feature_23 Feature_24 Feature_25 Feature_26 Feature_27 Feature_28 \\\n", "0 -1.596541 -2.185107 -0.475254 -2.322643 -0.664499 -1.974565 \n", "1 0.962649 1.060207 1.568758 1.971719 1.215556 2.248357 \n", "2 0.175765 -0.709609 -1.761413 -0.222903 -0.364405 -1.224240 \n", "3 0.210689 0.256840 -0.117659 -0.285450 -1.195636 -0.097772 \n", "4 -0.148946 0.207631 0.275857 -0.189223 -0.384763 -0.225240 \n", "\n", " Feature_29 Feature_30 \n", "0 -0.852845 0.497673 \n", "1 0.764481 0.759086 \n", "2 0.260271 -0.052288 \n", "3 -0.528391 -0.231970 \n", "4 0.261362 -0.031957 \n", "\n", "[5 rows x 30 columns]" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train.head()" ] }, { "cell_type": "code", "execution_count": 21, "id": "20b7a0f2", "metadata": {}, "outputs": [], "source": [ "y_gender_train = y_train['male']\n", "y_gender_test = y_test['male']" ] }, { "cell_type": "markdown", "id": "16c5bd19", "metadata": {}, "source": [ "##### Logistic Regression" ] }, { "cell_type": "code", "execution_count": 45, "id": "4fec5404", "metadata": {}, "outputs": [], "source": [ "from sklearn.linear_model import LogisticRegression" ] }, { "cell_type": "code", "execution_count": 46, "id": "1c2f33c0", "metadata": {}, "outputs": [], "source": [ "model = LogisticRegression()\n", "model.fit(X_train,y_gender_train)\n", "pred = model.predict(X_test)" ] }, { "cell_type": "code", "execution_count": 47, "id": "b003003e", "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import f1_score, accuracy_score, classification_report, confusion_matrix" ] }, { "cell_type": "code", "execution_count": 48, "id": "37f18f04", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.8915752278222058\n", "\n", "0.8332141324560148\n", "\n", "[[1031 317]\n", " [ 849 4794]]\n", "\n", " precision recall f1-score support\n", "\n", " 0 0.55 0.76 0.64 1348\n", " 1 0.94 0.85 0.89 5643\n", "\n", " accuracy 0.83 6991\n", " macro avg 0.74 0.81 0.77 6991\n", "weighted avg 0.86 0.83 0.84 6991\n", "\n" ] } ], "source": [ "print(f1_score(pred,y_gender_test))\n", "print()\n", "print(accuracy_score(pred,y_gender_test))\n", "print()\n", "print(confusion_matrix(pred,y_gender_test))\n", "print()\n", "print(classification_report(pred,y_gender_test))" ] }, { "cell_type": "markdown", "id": "fecb4655", "metadata": {}, "source": [ "##### KNN" ] }, { "cell_type": "code", "execution_count": 22, "id": "1d87fa51", "metadata": {}, "outputs": [], "source": [ "from sklearn.neighbors import KNeighborsClassifier" ] }, { "cell_type": "code", "execution_count": 24, "id": "0e338994", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\Dell\\anaconda3\\Lib\\site-packages\\sklearn\\base.py:464: UserWarning: X does not have valid feature names, but KNeighborsClassifier was fitted with feature names\n", " warnings.warn(\n" ] } ], "source": [ "model = KNeighborsClassifier()\n", "model.fit(X_train,y_gender_train)\n", "pred = model.predict(X_test.values)" ] }, { "cell_type": "code", "execution_count": 25, "id": "8a0c5634", "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import f1_score, accuracy_score, confusion_matrix, classification_report" ] }, { "cell_type": "code", "execution_count": 26, "id": "57beade9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.9548720896809427\n", "\n", "0.9326276641396081\n", "\n", "[[1537 128]\n", " [ 343 4983]]\n", "\n", " precision recall f1-score support\n", "\n", " 0 0.82 0.92 0.87 1665\n", " 1 0.97 0.94 0.95 5326\n", "\n", " accuracy 0.93 6991\n", " macro avg 0.90 0.93 0.91 6991\n", "weighted avg 0.94 0.93 0.93 6991\n", "\n" ] } ], "source": [ "print(f1_score(pred,y_gender_test))\n", "print()\n", "print(accuracy_score(pred,y_gender_test))\n", "print()\n", "print(confusion_matrix(pred,y_gender_test))\n", "print()\n", "print(classification_report(pred,y_gender_test))" ] }, { "cell_type": "code", "execution_count": 27, "id": "45b755ad", "metadata": {}, "outputs": [], "source": [ "sio.dump(model,'KNN_gender_detection')" ] }, { "cell_type": "markdown", "id": "0a0e6e65", "metadata": {}, "source": [ "##### SVM" ] }, { "cell_type": "code", "execution_count": 52, "id": "73d7c8f0", "metadata": {}, "outputs": [], "source": [ "from sklearn.svm import SVC" ] }, { "cell_type": "code", "execution_count": null, "id": "b4154e5b", "metadata": {}, "outputs": [], "source": [ "model = SVC(C=100)\n", "model.fit(X_train,y_gender_train)\n", "pred = model.predict(X_test)" ] }, { "cell_type": "code", "execution_count": 8, "id": "28e55925", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.9528063701689649\n", "\n", "0.9304820483478758\n", "\n", "[[1599 205]\n", " [ 281 4906]]\n", "\n", " precision recall f1-score support\n", "\n", " 0 0.85 0.89 0.87 1804\n", " 1 0.96 0.95 0.95 5187\n", "\n", " accuracy 0.93 6991\n", " macro avg 0.91 0.92 0.91 6991\n", "weighted avg 0.93 0.93 0.93 6991\n", "\n" ] } ], "source": [ "print(f1_score(pred,y_gender_test))\n", "print()\n", "print(accuracy_score(pred,y_gender_test))\n", "print()\n", "print(confusion_matrix(pred,y_gender_test))\n", "print()\n", "print(classification_report(pred,y_gender_test))" ] }, { "cell_type": "markdown", "id": "cbca922a", "metadata": {}, "source": [ "##### Decision Tree" ] }, { "cell_type": "code", "execution_count": 9, "id": "3d22031e", "metadata": {}, "outputs": [], "source": [ "from sklearn.tree import DecisionTreeClassifier" ] }, { "cell_type": "code", "execution_count": 10, "id": "b86e6146", "metadata": {}, "outputs": [], "source": [ "model = DecisionTreeClassifier()\n", "model.fit(X_train,y_gender_train)\n", "pred = model.predict(X_test)" ] }, { "cell_type": "code", "execution_count": 11, "id": "c658bb49", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.8603625673689368\n", "\n", "0.7961664997854384\n", "\n", "[[1176 721]\n", " [ 704 4390]]\n", "\n", " precision recall f1-score support\n", "\n", " 0 0.63 0.62 0.62 1897\n", " 1 0.86 0.86 0.86 5094\n", "\n", " accuracy 0.80 6991\n", " macro avg 0.74 0.74 0.74 6991\n", "weighted avg 0.80 0.80 0.80 6991\n", "\n" ] } ], "source": [ "print(f1_score(pred,y_gender_test))\n", "print()\n", "print(accuracy_score(pred,y_gender_test))\n", "print()\n", "print(confusion_matrix(pred,y_gender_test))\n", "print()\n", "print(classification_report(pred,y_gender_test))" ] }, { "cell_type": "markdown", "id": "2a2c39ae", "metadata": {}, "source": [ "##### Bagging Decision Tree" ] }, { "cell_type": "code", "execution_count": 12, "id": "899a56eb", "metadata": {}, "outputs": [], "source": [ "from sklearn.ensemble import BaggingClassifier\n", "from sklearn.tree import DecisionTreeClassifier" ] }, { "cell_type": "code", "execution_count": 13, "id": "db548555", "metadata": {}, "outputs": [], "source": [ "model = BaggingClassifier(DecisionTreeClassifier(),max_samples=0.5,max_features=1.0,n_estimators=10)\n", "model.fit(X_train, y_gender_train)\n", "pred = model.predict(X_test)" ] }, { "cell_type": "code", "execution_count": 197, "id": "b36a613c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.9014679075122326\n", "\n", "0.8530968387927336\n", "\n", "[[1266 413]\n", " [ 614 4698]]\n", "\n", " precision recall f1-score support\n", "\n", " 0 0.67 0.75 0.71 1679\n", " 1 0.92 0.88 0.90 5312\n", "\n", " accuracy 0.85 6991\n", " macro avg 0.80 0.82 0.81 6991\n", "weighted avg 0.86 0.85 0.86 6991\n", "\n" ] } ], "source": [ "print(f1_score(pred,y_gender_test))\n", "print()\n", "print(accuracy_score(pred,y_gender_test))\n", "print()\n", "print(confusion_matrix(pred,y_gender_test))\n", "print()\n", "print(classification_report(pred,y_gender_test))" ] }, { "cell_type": "markdown", "id": "7b1086e6", "metadata": {}, "source": [ "##### Boosting Decision Tree" ] }, { "cell_type": "code", "execution_count": 15, "id": "d4382dcb", "metadata": {}, "outputs": [], "source": [ "from sklearn.ensemble import AdaBoostClassifier\n", "from sklearn.tree import DecisionTreeClassifier" ] }, { "cell_type": "code", "execution_count": 16, "id": "6bac0b1c", "metadata": {}, "outputs": [], "source": [ "model = AdaBoostClassifier(DecisionTreeClassifier(min_samples_split=10,max_depth=4),n_estimators=10,learning_rate=0.6)\n", "model.fit(X_train, y_gender_train)\n", "pred = model.predict(X_test)" ] }, { "cell_type": "code", "execution_count": 17, "id": "b4c24ab0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.89351376574895\n", "\n", "0.8367901587755686\n", "\n", "[[1063 324]\n", " [ 817 4787]]\n", "\n", " precision recall f1-score support\n", "\n", " 0 0.57 0.77 0.65 1387\n", " 1 0.94 0.85 0.89 5604\n", "\n", " accuracy 0.84 6991\n", " macro avg 0.75 0.81 0.77 6991\n", "weighted avg 0.86 0.84 0.85 6991\n", "\n" ] } ], "source": [ "print(f1_score(pred,y_gender_test))\n", "print()\n", "print(accuracy_score(pred,y_gender_test))\n", "print()\n", "print(confusion_matrix(pred,y_gender_test))\n", "print()\n", "print(classification_report(pred,y_gender_test))" ] }, { "cell_type": "markdown", "id": "69d13096", "metadata": {}, "source": [ "##### Random Forest Classifier" ] }, { "cell_type": "code", "execution_count": 18, "id": "76dbae17", "metadata": {}, "outputs": [], "source": [ "from sklearn.ensemble import RandomForestClassifier" ] }, { "cell_type": "code", "execution_count": 19, "id": "ca837583", "metadata": {}, "outputs": [], "source": [ "model = RandomForestClassifier()\n", "model.fit(X_train,y_gender_train)\n", "pred = model.predict(X_test)" ] }, { "cell_type": "code", "execution_count": 203, "id": "7f914a52", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.9218778994247543\n", "\n", "0.8795594335574309\n", "\n", "[[1181 143]\n", " [ 699 4968]]\n", "\n", " precision recall f1-score support\n", "\n", " 0 0.63 0.89 0.74 1324\n", " 1 0.97 0.88 0.92 5667\n", "\n", " accuracy 0.88 6991\n", " macro avg 0.80 0.88 0.83 6991\n", "weighted avg 0.91 0.88 0.89 6991\n", "\n" ] } ], "source": [ "print(f1_score(pred,y_gender_test))\n", "print()\n", "print(accuracy_score(pred,y_gender_test))\n", "print()\n", "print(confusion_matrix(pred,y_gender_test))\n", "print()\n", "print(classification_report(pred,y_gender_test))" ] }, { "cell_type": "markdown", "id": "fb222ca8", "metadata": {}, "source": [ "##### Neural Network " ] }, { "cell_type": "code", "execution_count": 5, "id": "351ddbd2", "metadata": {}, "outputs": [], "source": [ "from keras import layers\n", "from keras import models\n", "from keras import optimizers\n", "from keras import losses\n", "from keras import metrics\n", "from tensorflow.keras import regularizers\n", "import tensorflow as tf\n", "from tensorflow.keras.models import Sequential\n", "from tensorflow.keras.layers import Dense,Dropout\n", "from tensorflow.keras.callbacks import EarlyStopping" ] }, { "cell_type": "code", "execution_count": 118, "id": "32f861bd", "metadata": {}, "outputs": [], "source": [ "# CODE HERE\n", "model = Sequential()\n", "\n", "model.add(Dense(30,activation='relu'))\n", "#model.add(Dropout(0.2))\n", "#model.add(Dense(16,activation='relu',kernel_regularizer = regularizers.l2(0.01)))\n", "#model.add(Dropout(0.2))\n", "#model.add(Dense(7,activation='relu'))\n", "#model.add(Dropout(0.2))\n", "#model.add(Dense(18,activation='relu'))\n", "#model.add(Dropout(0.2))\n", "#model.add(Dense(9,activation='relu'))\n", "#model.add(Dropout(0.2))\n", "#model.add(Dense(4,activation='relu'))\n", "#model.add(Dropout(0.2))\n", "#model.add(Dense(2,activation='relu'))\n", "#model.add(Dropout(0.2))\n", "\n", "model.add(Dense(1,activation='sigmoid'))" ] }, { "cell_type": "code", "execution_count": 119, "id": "380ed644", "metadata": {}, "outputs": [], "source": [ "model.compile(optimizer='rmsprop',loss='binary_crossentropy',metrics=['accuracy'])" ] }, { "cell_type": "code", "execution_count": 120, "id": "625a9efa", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/90\n", "246/246 [==============================] - 3s 6ms/step - loss: 0.4370 - accuracy: 0.8056 - val_loss: 0.3802 - val_accuracy: 0.8431\n", "Epoch 2/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.3585 - accuracy: 0.8514 - val_loss: 0.3550 - val_accuracy: 0.8517\n", "Epoch 3/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.3408 - accuracy: 0.8582 - val_loss: 0.3432 - val_accuracy: 0.8582\n", "Epoch 4/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.3298 - accuracy: 0.8633 - val_loss: 0.3345 - val_accuracy: 0.8600\n", "Epoch 5/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.3218 - accuracy: 0.8672 - val_loss: 0.3273 - val_accuracy: 0.8640\n", "Epoch 6/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.3159 - accuracy: 0.8699 - val_loss: 0.3221 - val_accuracy: 0.8667\n", "Epoch 7/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.3111 - accuracy: 0.8722 - val_loss: 0.3189 - val_accuracy: 0.8654\n", "Epoch 8/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.3075 - accuracy: 0.8732 - val_loss: 0.3158 - val_accuracy: 0.8705\n", "Epoch 9/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.3045 - accuracy: 0.8740 - val_loss: 0.3122 - val_accuracy: 0.8701\n", "Epoch 10/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.3015 - accuracy: 0.8756 - val_loss: 0.3102 - val_accuracy: 0.8697\n", "Epoch 11/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2991 - accuracy: 0.8765 - val_loss: 0.3094 - val_accuracy: 0.8697\n", "Epoch 12/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2968 - accuracy: 0.8778 - val_loss: 0.3080 - val_accuracy: 0.8730\n", "Epoch 13/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2948 - accuracy: 0.8782 - val_loss: 0.3050 - val_accuracy: 0.8727\n", "Epoch 14/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2929 - accuracy: 0.8791 - val_loss: 0.3051 - val_accuracy: 0.8718\n", "Epoch 15/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2909 - accuracy: 0.8797 - val_loss: 0.3021 - val_accuracy: 0.8757\n", "Epoch 16/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2894 - accuracy: 0.8803 - val_loss: 0.3025 - val_accuracy: 0.8728\n", "Epoch 17/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2880 - accuracy: 0.8813 - val_loss: 0.3001 - val_accuracy: 0.8751\n", "Epoch 18/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2865 - accuracy: 0.8818 - val_loss: 0.2988 - val_accuracy: 0.8750\n", "Epoch 19/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2852 - accuracy: 0.8826 - val_loss: 0.2998 - val_accuracy: 0.8753\n", "Epoch 20/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2841 - accuracy: 0.8829 - val_loss: 0.2970 - val_accuracy: 0.8763\n", "Epoch 21/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2828 - accuracy: 0.8840 - val_loss: 0.2972 - val_accuracy: 0.8761\n", "Epoch 22/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2819 - accuracy: 0.8843 - val_loss: 0.2956 - val_accuracy: 0.8773\n", "Epoch 23/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2807 - accuracy: 0.8845 - val_loss: 0.2952 - val_accuracy: 0.8757\n", "Epoch 24/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2797 - accuracy: 0.8849 - val_loss: 0.2944 - val_accuracy: 0.8786\n", "Epoch 25/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2790 - accuracy: 0.8851 - val_loss: 0.2937 - val_accuracy: 0.8767\n", "Epoch 26/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2780 - accuracy: 0.8858 - val_loss: 0.2941 - val_accuracy: 0.8774\n", "Epoch 27/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2773 - accuracy: 0.8853 - val_loss: 0.2921 - val_accuracy: 0.8780\n", "Epoch 28/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2763 - accuracy: 0.8859 - val_loss: 0.2971 - val_accuracy: 0.8756\n", "Epoch 29/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2758 - accuracy: 0.8859 - val_loss: 0.2935 - val_accuracy: 0.8787\n", "Epoch 30/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2749 - accuracy: 0.8866 - val_loss: 0.2925 - val_accuracy: 0.8788\n", "Epoch 31/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2742 - accuracy: 0.8877 - val_loss: 0.2923 - val_accuracy: 0.8784\n", "Epoch 32/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2735 - accuracy: 0.8874 - val_loss: 0.2907 - val_accuracy: 0.8776\n", "Epoch 33/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2727 - accuracy: 0.8874 - val_loss: 0.2911 - val_accuracy: 0.8784\n", "Epoch 34/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2721 - accuracy: 0.8878 - val_loss: 0.2900 - val_accuracy: 0.8804\n", "Epoch 35/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2715 - accuracy: 0.8878 - val_loss: 0.2906 - val_accuracy: 0.8797\n", "Epoch 36/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2709 - accuracy: 0.8881 - val_loss: 0.2891 - val_accuracy: 0.8793\n", "Epoch 37/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2703 - accuracy: 0.8885 - val_loss: 0.2895 - val_accuracy: 0.8788\n", "Epoch 38/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2698 - accuracy: 0.8885 - val_loss: 0.2883 - val_accuracy: 0.8808\n", "Epoch 39/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2690 - accuracy: 0.8894 - val_loss: 0.2879 - val_accuracy: 0.8814\n", "Epoch 40/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2684 - accuracy: 0.8890 - val_loss: 0.2871 - val_accuracy: 0.8810\n", "Epoch 41/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2682 - accuracy: 0.8893 - val_loss: 0.2861 - val_accuracy: 0.8808\n", "Epoch 42/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2674 - accuracy: 0.8907 - val_loss: 0.2879 - val_accuracy: 0.8807\n", "Epoch 43/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2673 - accuracy: 0.8894 - val_loss: 0.2862 - val_accuracy: 0.8811\n", "Epoch 44/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2665 - accuracy: 0.8901 - val_loss: 0.2860 - val_accuracy: 0.8830\n", "Epoch 45/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2663 - accuracy: 0.8900 - val_loss: 0.2865 - val_accuracy: 0.8813\n", "Epoch 46/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2658 - accuracy: 0.8905 - val_loss: 0.2865 - val_accuracy: 0.8823\n", "Epoch 47/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2654 - accuracy: 0.8906 - val_loss: 0.2859 - val_accuracy: 0.8811\n", "Epoch 48/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2650 - accuracy: 0.8906 - val_loss: 0.2847 - val_accuracy: 0.8839\n", "Epoch 49/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2646 - accuracy: 0.8905 - val_loss: 0.2847 - val_accuracy: 0.8823\n", "Epoch 50/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2643 - accuracy: 0.8915 - val_loss: 0.2849 - val_accuracy: 0.8831\n", "Epoch 51/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2639 - accuracy: 0.8907 - val_loss: 0.2856 - val_accuracy: 0.8816\n", "Epoch 52/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2635 - accuracy: 0.8910 - val_loss: 0.2844 - val_accuracy: 0.8841\n", "Epoch 53/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2630 - accuracy: 0.8919 - val_loss: 0.2837 - val_accuracy: 0.8834\n", "Epoch 54/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2629 - accuracy: 0.8908 - val_loss: 0.2846 - val_accuracy: 0.8841\n", "Epoch 55/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2624 - accuracy: 0.8915 - val_loss: 0.2838 - val_accuracy: 0.8830\n", "Epoch 56/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2621 - accuracy: 0.8919 - val_loss: 0.2829 - val_accuracy: 0.8839\n", "Epoch 57/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2618 - accuracy: 0.8919 - val_loss: 0.2825 - val_accuracy: 0.8833\n", "Epoch 58/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2615 - accuracy: 0.8920 - val_loss: 0.2854 - val_accuracy: 0.8818\n", "Epoch 59/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2614 - accuracy: 0.8929 - val_loss: 0.2833 - val_accuracy: 0.8839\n", "Epoch 60/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2611 - accuracy: 0.8924 - val_loss: 0.2822 - val_accuracy: 0.8817\n", "Epoch 61/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2608 - accuracy: 0.8919 - val_loss: 0.2816 - val_accuracy: 0.8823\n", "Epoch 62/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2604 - accuracy: 0.8925 - val_loss: 0.2827 - val_accuracy: 0.8841\n", "Epoch 63/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2600 - accuracy: 0.8921 - val_loss: 0.2830 - val_accuracy: 0.8860\n", "Epoch 64/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2598 - accuracy: 0.8930 - val_loss: 0.2824 - val_accuracy: 0.8828\n", "Epoch 65/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2595 - accuracy: 0.8929 - val_loss: 0.2816 - val_accuracy: 0.8846\n", "Epoch 66/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2594 - accuracy: 0.8937 - val_loss: 0.2824 - val_accuracy: 0.8810\n", "Epoch 67/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2590 - accuracy: 0.8935 - val_loss: 0.2824 - val_accuracy: 0.8839\n", "Epoch 68/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2588 - accuracy: 0.8940 - val_loss: 0.2836 - val_accuracy: 0.8830\n", "Epoch 69/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2586 - accuracy: 0.8936 - val_loss: 0.2815 - val_accuracy: 0.8836\n", "Epoch 70/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2585 - accuracy: 0.8940 - val_loss: 0.2805 - val_accuracy: 0.8861\n", "Epoch 71/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2581 - accuracy: 0.8937 - val_loss: 0.2834 - val_accuracy: 0.8814\n", "Epoch 72/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2579 - accuracy: 0.8946 - val_loss: 0.2843 - val_accuracy: 0.8826\n", "Epoch 73/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2578 - accuracy: 0.8937 - val_loss: 0.2817 - val_accuracy: 0.8840\n", "Epoch 74/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2575 - accuracy: 0.8941 - val_loss: 0.2813 - val_accuracy: 0.8840\n", "Epoch 75/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2573 - accuracy: 0.8941 - val_loss: 0.2814 - val_accuracy: 0.8843\n", "Epoch 76/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2572 - accuracy: 0.8937 - val_loss: 0.2828 - val_accuracy: 0.8836\n", "Epoch 77/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2571 - accuracy: 0.8938 - val_loss: 0.2822 - val_accuracy: 0.8849\n", "Epoch 78/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2568 - accuracy: 0.8943 - val_loss: 0.2822 - val_accuracy: 0.8820\n", "Epoch 79/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2566 - accuracy: 0.8948 - val_loss: 0.2826 - val_accuracy: 0.8831\n", "Epoch 80/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2564 - accuracy: 0.8949 - val_loss: 0.2821 - val_accuracy: 0.8830\n", "Epoch 81/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2565 - accuracy: 0.8946 - val_loss: 0.2810 - val_accuracy: 0.8856\n", "Epoch 82/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2561 - accuracy: 0.8949 - val_loss: 0.2823 - val_accuracy: 0.8833\n", "Epoch 83/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2562 - accuracy: 0.8944 - val_loss: 0.2811 - val_accuracy: 0.8836\n", "Epoch 84/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2558 - accuracy: 0.8947 - val_loss: 0.2800 - val_accuracy: 0.8866\n", "Epoch 85/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2559 - accuracy: 0.8944 - val_loss: 0.2806 - val_accuracy: 0.8834\n", "Epoch 86/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2556 - accuracy: 0.8950 - val_loss: 0.2822 - val_accuracy: 0.8854\n", "Epoch 87/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2552 - accuracy: 0.8955 - val_loss: 0.2820 - val_accuracy: 0.8837\n", "Epoch 88/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2553 - accuracy: 0.8949 - val_loss: 0.2795 - val_accuracy: 0.8857\n", "Epoch 89/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2552 - accuracy: 0.8955 - val_loss: 0.2814 - val_accuracy: 0.8837\n", "Epoch 90/90\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2548 - accuracy: 0.8952 - val_loss: 0.2796 - val_accuracy: 0.8849\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 120, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.fit(X_train,y_gender_train,epochs=90,batch_size=256,validation_data=(X_test,y_gender_test))" ] }, { "cell_type": "code", "execution_count": 121, "id": "b9f53792", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential_15\"\n", "_________________________________________________________________\n", " Layer (type) Output Shape Param # \n", "=================================================================\n", " dense_35 (Dense) (None, 30) 930 \n", " \n", " dense_36 (Dense) (None, 1) 31 \n", " \n", "=================================================================\n", "Total params: 961\n", "Trainable params: 961\n", "Non-trainable params: 0\n", "_________________________________________________________________\n", "None\n" ] } ], "source": [ "print(model.summary())" ] }, { "cell_type": "code", "execution_count": 122, "id": "26f06127", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 122, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# CODE HERE\n", "losses = pd.DataFrame(model.history.history)\n", "losses[['val_loss','loss']].plot()" ] }, { "cell_type": "code", "execution_count": 123, "id": "50dfa1e8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "219/219 [==============================] - 1s 3ms/step\n" ] } ], "source": [ "pred = (model.predict(X_test) > 0.5).astype(\"int32\")" ] }, { "cell_type": "code", "execution_count": 124, "id": "29d417e3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.9228705566733736\n", "\n", "0.8848519525103705\n", "\n", "[[1370 295]\n", " [ 510 4816]]\n", "\n", " precision recall f1-score support\n", "\n", " 0 0.73 0.82 0.77 1665\n", " 1 0.94 0.90 0.92 5326\n", "\n", " accuracy 0.88 6991\n", " macro avg 0.84 0.86 0.85 6991\n", "weighted avg 0.89 0.88 0.89 6991\n", "\n" ] } ], "source": [ "print(f1_score(pred,y_gender_test))\n", "print()\n", "print(accuracy_score(pred,y_gender_test))\n", "print()\n", "print(confusion_matrix(pred,y_gender_test))\n", "print()\n", "print(classification_report(pred,y_gender_test))" ] }, { "cell_type": "markdown", "id": "96f88b41", "metadata": {}, "source": [ "### Age Category Classification" ] }, { "cell_type": "code", "execution_count": 28, "id": "2232d328", "metadata": {}, "outputs": [], "source": [ "y_age_train = y_train['age_label']\n", "y_age_test = y_test['age_label']" ] }, { "cell_type": "markdown", "id": "5e60cf3c", "metadata": {}, "source": [ "##### Logistic Regression" ] }, { "cell_type": "code", "execution_count": 103, "id": "eeec6244", "metadata": {}, "outputs": [], "source": [ "from sklearn.linear_model import LogisticRegression" ] }, { "cell_type": "code", "execution_count": 104, "id": "80fecbf6", "metadata": {}, "outputs": [], "source": [ "model = LogisticRegression(max_iter=10000000000000)\n", "model.fit(X_train,y_age_train)\n", "pred = model.predict(X_test)" ] }, { "cell_type": "code", "execution_count": 105, "id": "033034f2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.37658227933088945\n", "\n", "0.32770705192390215\n", "\n", "[[ 0 1 1 0 0 0 0 1]\n", " [ 1 87 60 5 48 21 43 70]\n", " [ 1 73 134 14 22 24 134 96]\n", " [ 0 1 0 3 0 2 0 2]\n", " [ 0 14 15 2 25 4 12 12]\n", " [ 0 5 7 0 0 28 9 4]\n", " [ 8 386 319 76 194 121 805 566]\n", " [ 17 429 517 83 198 295 787 1209]]\n", "\n", " precision recall f1-score support\n", "\n", " 0 0.00 0.00 0.00 3\n", " 1 0.09 0.26 0.13 335\n", " 2 0.13 0.27 0.17 498\n", " 3 0.02 0.38 0.03 8\n", " 4 0.05 0.30 0.09 84\n", " 5 0.06 0.53 0.10 53\n", " 6 0.45 0.33 0.38 2475\n", " 7 0.62 0.34 0.44 3535\n", "\n", " accuracy 0.33 6991\n", " macro avg 0.18 0.30 0.17 6991\n", "weighted avg 0.49 0.33 0.38 6991\n", "\n" ] } ], "source": [ "print(f1_score(pred,y_age_test,average='weighted'))\n", "print()\n", "print(accuracy_score(pred,y_age_test))\n", "print()\n", "print(confusion_matrix(pred,y_age_test))\n", "print()\n", "print(classification_report(pred,y_age_test))" ] }, { "cell_type": "markdown", "id": "b5f28d37", "metadata": {}, "source": [ "##### KNN" ] }, { "cell_type": "code", "execution_count": 29, "id": "9bd27414", "metadata": {}, "outputs": [], "source": [ "from sklearn.neighbors import KNeighborsClassifier" ] }, { "cell_type": "code", "execution_count": 31, "id": "6f01ba8e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\Dell\\anaconda3\\Lib\\site-packages\\sklearn\\base.py:464: UserWarning: X does not have valid feature names, but KNeighborsClassifier was fitted with feature names\n", " warnings.warn(\n" ] } ], "source": [ "model = KNeighborsClassifier()\n", "model.fit(X_train,y_age_train)\n", "pred = model.predict(X_test.values)" ] }, { "cell_type": "code", "execution_count": 32, "id": "cd6f8cef", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.7962944851183236\n", "\n", "0.795737376627092\n", "\n", "[[ 18 2 0 0 2 1 3 5]\n", " [ 0 852 41 7 18 33 79 94]\n", " [ 2 19 861 3 10 24 77 96]\n", " [ 0 1 3 154 4 4 14 10]\n", " [ 0 15 21 2 413 12 29 43]\n", " [ 0 5 2 1 2 325 21 21]\n", " [ 6 60 64 8 28 48 1478 229]\n", " [ 1 42 61 8 10 48 89 1462]]\n", "\n", " precision recall f1-score support\n", "\n", " 0 0.67 0.58 0.62 31\n", " 1 0.86 0.76 0.80 1124\n", " 2 0.82 0.79 0.80 1092\n", " 3 0.84 0.81 0.83 190\n", " 4 0.85 0.77 0.81 535\n", " 5 0.66 0.86 0.75 377\n", " 6 0.83 0.77 0.80 1921\n", " 7 0.75 0.85 0.79 1721\n", "\n", " accuracy 0.80 6991\n", " macro avg 0.78 0.77 0.77 6991\n", "weighted avg 0.80 0.80 0.80 6991\n", "\n" ] } ], "source": [ "print(f1_score(pred,y_age_test,average='weighted'))\n", "print()\n", "print(accuracy_score(pred,y_age_test))\n", "print()\n", "print(confusion_matrix(pred,y_age_test))\n", "print()\n", "print(classification_report(pred,y_age_test))" ] }, { "cell_type": "code", "execution_count": 110, "id": "3ca50f81", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['KNN_age_model.pkl']" ] }, "execution_count": 110, "metadata": {}, "output_type": "execute_result" } ], "source": [ "joblib.dump(model, 'KNN_age_model')" ] }, { "cell_type": "markdown", "id": "fe64a202", "metadata": {}, "source": [ "##### SVM" ] }, { "cell_type": "code", "execution_count": 111, "id": "834b2032", "metadata": {}, "outputs": [], "source": [ "from sklearn.svm import SVC" ] }, { "cell_type": "code", "execution_count": null, "id": "0ff44c82", "metadata": {}, "outputs": [], "source": [ "model = SVC(C=1)\n", "model.fit(X_train,y_age_train)\n", "pred = model.predict(X_test)" ] }, { "cell_type": "code", "execution_count": 152, "id": "4d7abfac", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.7614100194796487\n", "\n", "0.7608353597482478\n", "\n", "[[ 20 0 1 0 1 0 1 2]\n", " [ 1 798 25 6 16 18 68 92]\n", " [ 1 22 790 5 12 32 55 95]\n", " [ 0 1 2 135 1 1 6 9]\n", " [ 0 3 11 1 373 5 20 26]\n", " [ 0 8 10 1 8 289 22 32]\n", " [ 4 79 98 11 37 62 1413 203]\n", " [ 1 85 116 24 39 88 205 1501]]\n", "\n", " precision recall f1-score support\n", "\n", " 0 0.74 0.80 0.77 25\n", " 1 0.80 0.78 0.79 1024\n", " 2 0.75 0.78 0.77 1012\n", " 3 0.74 0.87 0.80 155\n", " 4 0.77 0.85 0.81 439\n", " 5 0.58 0.78 0.67 370\n", " 6 0.79 0.74 0.76 1907\n", " 7 0.77 0.73 0.75 2059\n", "\n", " accuracy 0.76 6991\n", " macro avg 0.74 0.79 0.76 6991\n", "weighted avg 0.76 0.76 0.76 6991\n", "\n" ] } ], "source": [ "print(f1_score(pred,y_age_test,average='weighted'))\n", "print()\n", "print(accuracy_score(pred,y_age_test))\n", "print()\n", "print(confusion_matrix(pred,y_age_test))\n", "print()\n", "print(classification_report(pred,y_age_test))" ] }, { "cell_type": "markdown", "id": "99ed2d9f", "metadata": {}, "source": [ "##### Decision Tree" ] }, { "cell_type": "code", "execution_count": 4, "id": "fa87ec77", "metadata": {}, "outputs": [], "source": [ "from sklearn.tree import DecisionTreeClassifier" ] }, { "cell_type": "code", "execution_count": 5, "id": "f80d5f38", "metadata": {}, "outputs": [], "source": [ "model = DecisionTreeClassifier()\n", "model.fit(X_train,y_age_train)\n", "pred = model.predict(X_test)" ] }, { "cell_type": "code", "execution_count": 8, "id": "59fd3ee0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.40225615103109547\n", "\n", "0.40208839937061935\n", "\n", "[[ 6 0 1 1 0 1 4 7]\n", " [ 3 409 95 28 48 40 204 195]\n", " [ 4 107 399 23 52 61 175 216]\n", " [ 0 20 15 37 4 6 50 42]\n", " [ 2 46 59 11 201 19 85 104]\n", " [ 4 50 49 11 22 126 106 126]\n", " [ 4 175 204 37 83 96 758 395]\n", " [ 4 189 231 35 77 146 408 875]]\n", "\n", " precision recall f1-score support\n", "\n", " 0 0.22 0.30 0.26 20\n", " 1 0.41 0.40 0.41 1022\n", " 2 0.38 0.38 0.38 1037\n", " 3 0.20 0.21 0.21 174\n", " 4 0.41 0.38 0.40 527\n", " 5 0.25 0.26 0.25 494\n", " 6 0.42 0.43 0.43 1752\n", " 7 0.45 0.45 0.45 1965\n", "\n", " accuracy 0.40 6991\n", " macro avg 0.34 0.35 0.35 6991\n", "weighted avg 0.40 0.40 0.40 6991\n", "\n" ] } ], "source": [ "print(f1_score(pred,y_age_test,average='weighted'))\n", "print()\n", "print(accuracy_score(pred,y_age_test))\n", "print()\n", "print(confusion_matrix(pred,y_age_test))\n", "print()\n", "print(classification_report(pred,y_age_test))" ] }, { "cell_type": "markdown", "id": "ca09c701", "metadata": {}, "source": [ "##### Bagging Decision Tree" ] }, { "cell_type": "code", "execution_count": 9, "id": "d16e092c", "metadata": {}, "outputs": [], "source": [ "from sklearn.ensemble import BaggingClassifier\n", "from sklearn.tree import DecisionTreeClassifier" ] }, { "cell_type": "code", "execution_count": 10, "id": "0450e792", "metadata": {}, "outputs": [], "source": [ "model = BaggingClassifier(DecisionTreeClassifier(),max_samples=0.5,max_features=1.0,n_estimators=10)\n", "model.fit(X_train, y_age_train)\n", "pred = model.predict(X_test)" ] }, { "cell_type": "code", "execution_count": 11, "id": "ff37709a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.5177393450250413\n", "\n", "0.512516092118438\n", "\n", "[[ 3 0 0 0 0 0 0 0]\n", " [ 0 553 77 19 48 52 157 169]\n", " [ 6 87 525 23 43 54 180 179]\n", " [ 0 0 6 57 1 1 4 3]\n", " [ 1 15 16 3 197 9 22 30]\n", " [ 2 13 12 2 6 122 29 39]\n", " [ 7 164 196 41 97 114 1023 437]\n", " [ 8 164 221 38 95 143 375 1103]]\n", "\n", " precision recall f1-score support\n", "\n", " 0 0.11 1.00 0.20 3\n", " 1 0.56 0.51 0.53 1075\n", " 2 0.50 0.48 0.49 1097\n", " 3 0.31 0.79 0.45 72\n", " 4 0.40 0.67 0.51 293\n", " 5 0.25 0.54 0.34 225\n", " 6 0.57 0.49 0.53 2079\n", " 7 0.56 0.51 0.54 2147\n", "\n", " accuracy 0.51 6991\n", " macro avg 0.41 0.63 0.45 6991\n", "weighted avg 0.53 0.51 0.52 6991\n", "\n" ] } ], "source": [ "print(f1_score(pred,y_age_test,average='weighted'))\n", "print()\n", "print(accuracy_score(pred,y_age_test))\n", "print()\n", "print(confusion_matrix(pred,y_age_test))\n", "print()\n", "print(classification_report(pred,y_age_test))" ] }, { "cell_type": "markdown", "id": "d137fdd4", "metadata": {}, "source": [ "##### Boosting Decision Tree" ] }, { "cell_type": "code", "execution_count": 12, "id": "0d0e06c8", "metadata": {}, "outputs": [], "source": [ "from sklearn.ensemble import AdaBoostClassifier\n", "from sklearn.tree import DecisionTreeClassifier" ] }, { "cell_type": "code", "execution_count": 13, "id": "3e7fae64", "metadata": {}, "outputs": [], "source": [ "model = AdaBoostClassifier(DecisionTreeClassifier(min_samples_split=10,max_depth=4),n_estimators=10,learning_rate=0.6)\n", "model.fit(X_train, y_age_train)\n", "pred = model.predict(X_test)" ] }, { "cell_type": "code", "execution_count": 14, "id": "ae285b01", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.395014876392853\n", "\n", "0.3428694035188099\n", "\n", "[[ 2 0 0 0 0 0 0 0]\n", " [ 1 131 41 2 26 8 43 43]\n", " [ 1 40 125 7 13 20 70 69]\n", " [ 0 0 2 2 0 1 0 1]\n", " [ 0 5 2 3 28 1 5 8]\n", " [ 0 6 7 0 1 19 6 9]\n", " [ 8 324 310 94 161 130 763 503]\n", " [ 15 490 566 75 258 316 903 1327]]\n", "\n", " precision recall f1-score support\n", "\n", " 0 0.07 1.00 0.14 2\n", " 1 0.13 0.44 0.20 295\n", " 2 0.12 0.36 0.18 345\n", " 3 0.01 0.33 0.02 6\n", " 4 0.06 0.54 0.10 52\n", " 5 0.04 0.40 0.07 48\n", " 6 0.43 0.33 0.37 2293\n", " 7 0.68 0.34 0.45 3950\n", "\n", " accuracy 0.34 6991\n", " macro avg 0.19 0.47 0.19 6991\n", "weighted avg 0.53 0.34 0.40 6991\n", "\n" ] } ], "source": [ "print(f1_score(pred,y_age_test,average='weighted'))\n", "print()\n", "print(accuracy_score(pred,y_age_test))\n", "print()\n", "print(confusion_matrix(pred,y_age_test))\n", "print()\n", "print(classification_report(pred,y_age_test))" ] }, { "cell_type": "markdown", "id": "e9c2133a", "metadata": {}, "source": [ "##### Random Forest Classifier" ] }, { "cell_type": "code", "execution_count": 15, "id": "9d694f01", "metadata": {}, "outputs": [], "source": [ "from sklearn.ensemble import RandomForestClassifier" ] }, { "cell_type": "code", "execution_count": 16, "id": "aa11f8fa", "metadata": {}, "outputs": [], "source": [ "model = RandomForestClassifier()\n", "model.fit(X_train,y_age_train)\n", "pred = model.predict(X_test)" ] }, { "cell_type": "code", "execution_count": 17, "id": "b389f76b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.6567027435146164\n", "\n", "0.6501215848948648\n", "\n", "[[ 6 0 0 0 0 0 0 1]\n", " [ 0 602 16 2 15 16 22 30]\n", " [ 1 20 587 8 10 11 39 50]\n", " [ 0 0 0 71 0 0 0 0]\n", " [ 0 1 4 1 251 0 0 3]\n", " [ 0 1 1 0 0 123 0 2]\n", " [ 12 155 181 48 106 110 1343 312]\n", " [ 8 217 264 53 105 235 386 1562]]\n", "\n", " precision recall f1-score support\n", "\n", " 0 0.22 0.86 0.35 7\n", " 1 0.60 0.86 0.71 703\n", " 2 0.56 0.81 0.66 726\n", " 3 0.39 1.00 0.56 71\n", " 4 0.52 0.97 0.67 260\n", " 5 0.25 0.97 0.40 127\n", " 6 0.75 0.59 0.66 2267\n", " 7 0.80 0.55 0.65 2830\n", "\n", " accuracy 0.65 6991\n", " macro avg 0.51 0.83 0.58 6991\n", "weighted avg 0.71 0.65 0.66 6991\n", "\n" ] } ], "source": [ "print(f1_score(pred,y_age_test,average='weighted'))\n", "print()\n", "print(accuracy_score(pred,y_age_test))\n", "print()\n", "print(confusion_matrix(pred,y_age_test))\n", "print()\n", "print(classification_report(pred,y_age_test))" ] }, { "cell_type": "markdown", "id": "0b7aee24", "metadata": {}, "source": [ "##### XG Boost" ] }, { "cell_type": "code", "execution_count": 4, "id": "0db9fd04", "metadata": {}, "outputs": [], "source": [ "import xgboost as xgb" ] }, { "cell_type": "code", "execution_count": 5, "id": "9639ab56", "metadata": {}, "outputs": [], "source": [ "model = xgb.XGBClassifier()\n", "model.fit(X_train, y_age_train)\n", "pred = model.predict(X_test)" ] }, { "cell_type": "code", "execution_count": 7, "id": "0276d893", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.5942879832278003\n", "\n", "0.5903304248319268\n", "\n", "[[ 10 0 0 0 0 0 0 1]\n", " [ 1 590 46 10 23 15 69 75]\n", " [ 4 44 511 6 21 24 88 109]\n", " [ 0 1 3 89 2 0 4 7]\n", " [ 0 12 17 2 242 4 15 14]\n", " [ 0 7 12 1 2 148 16 25]\n", " [ 5 149 198 44 115 107 1149 341]\n", " [ 7 193 266 31 82 197 449 1388]]\n", "\n", " precision recall f1-score support\n", "\n", " 0 0.37 0.91 0.53 11\n", " 1 0.59 0.71 0.65 829\n", " 2 0.49 0.63 0.55 807\n", " 3 0.49 0.84 0.62 106\n", " 4 0.50 0.79 0.61 306\n", " 5 0.30 0.70 0.42 211\n", " 6 0.64 0.55 0.59 2108\n", " 7 0.71 0.53 0.61 2613\n", "\n", " accuracy 0.59 6991\n", " macro avg 0.51 0.71 0.57 6991\n", "weighted avg 0.62 0.59 0.59 6991\n", "\n" ] } ], "source": [ "print(f1_score(pred,y_age_test,average='weighted'))\n", "print()\n", "print(accuracy_score(pred,y_age_test))\n", "print()\n", "print(confusion_matrix(pred,y_age_test))\n", "print()\n", "print(classification_report(pred,y_age_test))" ] }, { "cell_type": "markdown", "id": "f4d90192", "metadata": {}, "source": [ "#### Neural Network" ] }, { "cell_type": "markdown", "id": "ebeda6af", "metadata": {}, "source": [ "##### Deep Learning" ] }, { "cell_type": "code", "execution_count": 5, "id": "2a026990", "metadata": {}, "outputs": [], "source": [ "y_age_train = y_train.drop(['male','age_label'],axis=1)\n", "y_age_test = y_test.drop(['male','age_label'],axis=1)" ] }, { "cell_type": "code", "execution_count": 4, "id": "7eab5467", "metadata": {}, "outputs": [], "source": [ "from keras import layers, models, optimizers, losses, metrics\n", "from tensorflow.keras import regularizers\n", "import tensorflow as tf\n", "from tensorflow.keras.models import Sequential\n", "from tensorflow.keras.layers import Dense,Dropout, LSTM\n", "from tensorflow.keras.callbacks import EarlyStopping" ] }, { "cell_type": "code", "execution_count": 138, "id": "e1d2fe35", "metadata": {}, "outputs": [], "source": [ "# CODE HERE\n", "model = Sequential()\n", "\n", "model.add(InputLayer(input_shape=30))\n", "model.add(Dense(200,activation='tanh'))\n", "#model.add(Dropout(0.2))\n", "model.add(Dense(150,activation='tanh'))\n", "#model.add(Dropout(0.2))\n", "model.add(Dense(112,activation='tanh'))\n", "#model.add(Dropout(0.2))\n", "model.add(Dense(84,activation='tanh'))\n", "#model.add(Dropout(0.2))\n", "model.add(Dense(63,activation='tanh'))\n", "#model.add(Dropout(0.2))\n", "model.add(Dense(48,activation='tanh'))\n", "#model.add(Dropout(0.2))\n", "model.add(Dense(36,activation='tanh'))\n", "#model.add(Dropout(0.2))\n", "model.add(Dense(27,activation='relu'))\n", "#model.add(Dropout(0.2))\n", "model.add(Dense(21,activation='relu'))\n", "#model.add(Dropout(0.2))\n", "model.add(Dense(15,activation='relu'))\n", "#model.add(Dropout(0.2))\n", "#model.add(Dense(2,activation='relu'))\n", "#model.add(Dropout(0.2))\n", "\n", "model.add(Dense(7,activation='softmax'))" ] }, { "cell_type": "code", "execution_count": 139, "id": "665635aa", "metadata": {}, "outputs": [], "source": [ "earlystop = EarlyStopping(monitor='val_loss',patience=3)" ] }, { "cell_type": "code", "execution_count": 140, "id": "6036ca66", "metadata": {}, "outputs": [], "source": [ "losses = {'output_1':'binary_crossentropy'}" ] }, { "cell_type": "code", "execution_count": 141, "id": "a037084b", "metadata": {}, "outputs": [], "source": [ "model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])" ] }, { "cell_type": "code", "execution_count": 142, "id": "4d654802", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/190\n", "246/246 [==============================] - 2s 5ms/step - loss: 0.4140 - accuracy: 0.2937 - val_loss: 0.3690 - val_accuracy: 0.3084\n", "Epoch 2/190\n", "246/246 [==============================] - 1s 6ms/step - loss: 0.3592 - accuracy: 0.3581 - val_loss: 0.3444 - val_accuracy: 0.4048\n", "Epoch 3/190\n", "246/246 [==============================] - 1s 4ms/step - loss: 0.3289 - accuracy: 0.4509 - val_loss: 0.3136 - val_accuracy: 0.4765\n", "Epoch 4/190\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2974 - accuracy: 0.5149 - val_loss: 0.2925 - val_accuracy: 0.5207\n", "Epoch 5/190\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2715 - accuracy: 0.5697 - val_loss: 0.2727 - val_accuracy: 0.5726\n", "Epoch 6/190\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2508 - accuracy: 0.6116 - val_loss: 0.2630 - val_accuracy: 0.5870\n", "Epoch 7/190\n", "246/246 [==============================] - 1s 4ms/step - loss: 0.2356 - accuracy: 0.6410 - val_loss: 0.2547 - val_accuracy: 0.6068\n", "Epoch 8/190\n", "246/246 [==============================] - 1s 4ms/step - loss: 0.2222 - accuracy: 0.6623 - val_loss: 0.2529 - val_accuracy: 0.6085\n", "Epoch 9/190\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.2118 - accuracy: 0.6803 - val_loss: 0.2430 - val_accuracy: 0.6261\n", "Epoch 10/190\n", "246/246 [==============================] - 1s 6ms/step - loss: 0.2024 - accuracy: 0.6962 - val_loss: 0.2380 - val_accuracy: 0.6370\n", "Epoch 11/190\n", "246/246 [==============================] - 1s 4ms/step - loss: 0.1943 - accuracy: 0.7084 - val_loss: 0.2382 - val_accuracy: 0.6385\n", "Epoch 12/190\n", "246/246 [==============================] - 1s 5ms/step - loss: 0.1876 - accuracy: 0.7177 - val_loss: 0.2321 - val_accuracy: 0.6498\n", "Epoch 13/190\n", "246/246 [==============================] - 1s 4ms/step - loss: 0.1811 - accuracy: 0.7301 - val_loss: 0.2368 - val_accuracy: 0.6488\n", "Epoch 14/190\n", "246/246 [==============================] - 1s 4ms/step - loss: 0.1754 - accuracy: 0.7380 - val_loss: 0.2340 - val_accuracy: 0.6495\n", "Epoch 15/190\n", "246/246 [==============================] - 1s 4ms/step - loss: 0.1701 - accuracy: 0.7464 - val_loss: 0.2360 - val_accuracy: 0.6514\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 142, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.fit(X_train,y_age_train,epochs=190,batch_size=256,validation_data=(X_test,y_age_test),callbacks=earlystop)" ] }, { "cell_type": "code", "execution_count": 143, "id": "4dcd8aaf", "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential_17\"\n", "_________________________________________________________________\n", " Layer (type) Output Shape Param # \n", "=================================================================\n", " dense_251 (Dense) (None, 200) 6200 \n", " \n", " dense_252 (Dense) (None, 150) 30150 \n", " \n", " dense_253 (Dense) (None, 112) 16912 \n", " \n", " dense_254 (Dense) (None, 84) 9492 \n", " \n", " dense_255 (Dense) (None, 63) 5355 \n", " \n", " dense_256 (Dense) (None, 48) 3072 \n", " \n", " dense_257 (Dense) (None, 36) 1764 \n", " \n", " dense_258 (Dense) (None, 27) 999 \n", " \n", " dense_259 (Dense) (None, 21) 588 \n", " \n", " dense_260 (Dense) (None, 15) 330 \n", " \n", " dense_261 (Dense) (None, 7) 112 \n", " \n", "=================================================================\n", "Total params: 74,974\n", "Trainable params: 74,974\n", "Non-trainable params: 0\n", "_________________________________________________________________\n", "None\n" ] } ], "source": [ "print(model.summary())" ] }, { "cell_type": "code", "execution_count": 144, "id": "ba65e49f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 144, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# CODE HERE\n", "losses = pd.DataFrame(model.history.history)\n", "losses[['val_loss','loss']].plot()" ] }, { "cell_type": "code", "execution_count": 145, "id": "645f48ce", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "219/219 [==============================] - 1s 3ms/step\n" ] } ], "source": [ "pred = (model.predict(X_test) > 0.5).astype(\"int32\")" ] }, { "cell_type": "code", "execution_count": 146, "id": "09812cc9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.6595140112999031\n", "\n", "0.6076383922185667\n", "\n", "\n", " precision recall f1-score support\n", "\n", " 0 0.65 0.79 0.71 821\n", " 1 0.58 0.75 0.66 810\n", " 2 0.47 0.80 0.59 107\n", " 3 0.62 0.75 0.68 403\n", " 4 0.35 0.73 0.48 237\n", " 5 0.60 0.71 0.65 1514\n", " 6 0.69 0.65 0.67 2075\n", "\n", " micro avg 0.61 0.71 0.66 5967\n", " macro avg 0.57 0.74 0.63 5967\n", "weighted avg 0.62 0.71 0.66 5967\n", " samples avg 0.61 0.61 0.61 5967\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\Dell\\anaconda3\\lib\\site-packages\\sklearn\\metrics\\_classification.py:1248: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in samples with no predicted labels. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, msg_start, len(result))\n", "C:\\Users\\Dell\\anaconda3\\lib\\site-packages\\sklearn\\metrics\\_classification.py:1248: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in samples with no true labels. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, msg_start, len(result))\n" ] } ], "source": [ "print(f1_score(pred,y_age_test,average='weighted'))\n", "print()\n", "print(accuracy_score(pred,y_age_test))\n", "print()\n", "#print(confusion_matrix(pred,y_age_test))\n", "print()\n", "print(classification_report(pred,y_age_test))" ] }, { "cell_type": "markdown", "id": "877cde03", "metadata": {}, "source": [ "##### Real Time Model Testing" ] }, { "cell_type": "code", "execution_count": 11, "id": "aa6da0fb", "metadata": {}, "outputs": [], "source": [ "import joblib" ] }, { "cell_type": "code", "execution_count": 12, "id": "64f732d3", "metadata": {}, "outputs": [], "source": [ "## Voice Data Feature Extraction\n", "\n", "### extract the features from the audio files using mfcc\n", "def feature_extracter(fileName):\n", " audio,sample_rate = librosa.load(fileName,res_type='kaiser_fast')\n", " mfcc_features = librosa.feature.mfcc(y=audio,sr=sample_rate,n_mfcc=30)\n", " mfccs_scaled_features = np.mean(mfcc_features.T, axis=0)\n", " \n", " return list(mfccs_scaled_features)" ] }, { "cell_type": "code", "execution_count": 37, "id": "81ff0caa", "metadata": {}, "outputs": [], "source": [ "fileName = 'C:\\\\Users\\\\Dell\\\\Downloads\\\\narendra-modi-walking-out-of-karan-thapar-interview-tuberippercom_R89R3fPq.wav'\n", "col_name = ['Feature_1', 'Feature_2', 'Feature_3', 'Feature_4', 'Feature_5','Feature_6', 'Feature_7', 'Feature_8', 'Feature_9', 'Feature_10','Feature_11', 'Feature_12', 'Feature_13', 'Feature_14', 'Feature_15','Feature_16', 'Feature_17', 'Feature_18', 'Feature_19', 'Feature_20','Feature_21', 'Feature_22', 'Feature_23', 'Feature_24', 'Feature_25','Feature_26', 'Feature_27', 'Feature_28', 'Feature_29', 'Feature_30']" ] }, { "cell_type": "code", "execution_count": 38, "id": "0c037e5f", "metadata": {}, "outputs": [], "source": [ "observation = [feature_extracter(fileName)]\n", "observation = pd.DataFrame(observation, columns = col_name)" ] }, { "cell_type": "code", "execution_count": 39, "id": "b4b0ba94", "metadata": {}, "outputs": [], "source": [ "## scaling the observation\n", "scaler = joblib.load('scaler.pkl')\n", "scaled_observation = scaler.transform(observation)\n", "scaled_observation = pd.DataFrame(scaled_observation, columns = col_name)" ] }, { "cell_type": "code", "execution_count": 40, "id": "38b86647", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Male\n" ] } ], "source": [ "### Gender classification model\n", "gender_model = joblib.load('KNN_gender_detection.pkl')\n", "gender_predict = gender_model.predict(scaled_observation)\n", "## considering the labels 1 = male 0 = female\n", "if gender_predict[0] == 1:\n", " print('Male')\n", "else:\n", " print('Female')" ] }, { "cell_type": "code", "execution_count": 41, "id": "5fc843c1", "metadata": {}, "outputs": [], "source": [ "def age_reverse_labelling(label):\n", " if label == 0:\n", " return 'Eighties'\n", " if label == 1:\n", " return 'Fifties'\n", " if label == 2:\n", " return 'Fourties'\n", " if label == 3:\n", " return 'Seventies'\n", " if label == 4:\n", " return 'Sixties'\n", " if label == 5:\n", " return 'Teens'\n", " if label == 6:\n", " return 'Thirties'\n", " if label == 7:\n", " return 'Twenties'\n", " else:\n", " return 'Cannot be predicted'" ] }, { "cell_type": "code", "execution_count": 42, "id": "4cb31b52", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Twenties\n" ] } ], "source": [ "### Age classification model\n", "age_model = joblib.load('KNN_age_model.pkl')\n", "age_predict = age_model.predict(scaled_observation)\n", "## considering the labels 1 = male 0 = female\n", "print(age_reverse_labelling(age_predict[0]))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 5 }