Spoken commands example

This example uses an audio classifier model from a Tensorflow tutorial: https://www.tensorflow.org/tutorials/sequences/audio_recognition

N.B. This script downloads a large (2.3GB) speech commands dataset!

import sys
from pathlib import Path
import tarfile
import shutil
import pandas as pd
from scipy.io.wavfile import read, write
from sklearn.metrics import confusion_matrix
from dpemu.nodes.series import Series
from dpemu.nodes.tuple import Tuple
from dpemu.filters.sound import ClipWAV
from dpemu.filters.common import ApplyToTuple
from dpemu.plotting_utils import visualize_confusion_matrix

First we download the dataset unless it is already present. If you have downloaded and extracted the dataset into a different directory, change the data_dir variable accordingly.

data_url = "https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz"
fname = "speech_commands_v0.02.tar.gz"
data_dir = Path.home() / "datasets/speech_data"

if not data_dir.exists():
    !wget {data_url} -P {data_dir}
    tarfile.open(data_dir / fname, "r:gz").extractall(data_dir)
trained_categories = ["yes", "no", "up", "down", "left", "right", "on", "off", "stop", "go"]
labels = ["_silence_", "_unknown_", "yes", "no", "up", "down", "left", "right", "on", "off", "stop", "go"]

test_set_rel_paths = !cat {data_dir / "testing_list.txt"}
test_set_files = [data_dir / p for p in test_set_rel_paths]
test_categories = !cut -d'/' -f1 {data_dir / "testing_list.txt"} | sort -u

len(test_set_files), len(test_categories), len(trained_categories)
(11005, 35, 10)

In order to download the speech commands dataset to the correct place, we need to set the variables dpemu_path and example_path.

dpemu_path = Path.cwd().parents[1]
example_path = dpemu_path / "examples/speech_commands"

Choose a category in which to generate errors. Later on we will generate errors in all of the test set categories.

category = "stop"
data_subset_dir = data_dir / category

fs = list(data_subset_dir.iterdir())
wavs = [read(f) for f in data_subset_dir.iterdir()]

Create an error generating tree and generate errors in the category chosen above.

wav_node = Tuple()
wav_node.addfilter(ApplyToTuple(ClipWAV("dyn_range"), 1))
root_node = Series(wav_node)

err_params = {"dyn_range": .2}
clipped = root_node.generate_error(wavs, err_params)

Now we arbitrarily choose a speech command example from the data subset. To try another audio clip, change the index.

example_index = 123
clipped_filename = data_dir / 'clipped.wav'
write(clipped_filename, 16000, clipped[example_index][1])
!aplay {fs[example_index]}
Playing WAVE '/home/jpssilve/datasets/speech_data/stop/3ec05c3d_nohash_0.wav' : Signed 16 bit Little Endian, Rate 16000 Hz, Mono
!aplay {clipped_filename}
Playing WAVE '/home/jpssilve/datasets/speech_data/clipped.wav' : Signed 16 bit Little Endian, Rate 16000 Hz, Mono

Define a function to filter out irrelevant output (e.g. Python deprecation warnings):

def filter_scores(output):
    return [line for line in output if "score" in line or ".wav" in line]

Run the model on the clean clip selected above:

scores_clean = !python {example_path}/label_wav.py \
--graph={example_path}/trained_model/my_frozen_graph.pb \
--labels={example_path}/trained_model/conv_labels.txt \

['stop (score = 0.54378)',
 'off (score = 0.19993)',
 '_unknown_ (score = 0.07233)']

Run the model on the corresponding errorified clip:

scores_clipped = !python {example_path}/label_wav.py \
--graph={example_path}/trained_model/my_frozen_graph.pb \
--labels={example_path}/trained_model/conv_labels.txt \

['stop (score = 0.22963)',
 'down (score = 0.16858)',
 '_unknown_ (score = 0.11415)']

You can also run the model on an entire directory of .wav files in one go:

scores_clean_dir = !python {example_path}/label_wav_dir.py \
--graph={example_path}/trained_model/my_frozen_graph.pb \
--labels={example_path}/trained_model/conv_labels.txt \

 'stop (score = 0.84888)',
 'up (score = 0.10150)',
 '_unknown_ (score = 0.02897)',
 'stop (score = 0.83839)',
 'up (score = 0.10791)',
 'down (score = 0.01377)',
 'stop (score = 0.99616)',
 'down (score = 0.00215)',
 '_unknown_ (score = 0.00114)',
 'stop (score = 0.94652)',
 'up (score = 0.04828)',
 '_unknown_ (score = 0.00210)',
 'stop (score = 0.98153)',
 'down (score = 0.00989)',
 'up (score = 0.00290)',
 'stop (score = 0.95047)',
 'up (score = 0.02973)',
 '_unknown_ (score = 0.01149)',
 'stop (score = 0.51887)',
 'down (score = 0.33725)',
 '_unknown_ (score = 0.13488)',
 'stop (score = 0.83974)',
 'up (score = 0.15001)',
 '_unknown_ (score = 0.00472)',
 'stop (score = 0.99714)',
 '_unknown_ (score = 0.00153)',
 'up (score = 0.00119)',
 'stop (score = 0.92912)',
 '_unknown_ (score = 0.04616)',
 'go (score = 0.02025)',
 'stop (score = 0.26836)',
 'down (score = 0.19330)',
 '_unknown_ (score = 0.13893)',
 'stop (score = 0.69810)',
 'up (score = 0.15704)',
 '_unknown_ (score = 0.06251)',
 'stop (score = 0.99679)',
 'up (score = 0.00229)',
 'down (score = 0.00053)',
 'stop (score = 0.94688)',
 '_unknown_ (score = 0.02761)',
 'up (score = 0.02019)',
 'stop (score = 0.61116)',
 'up (score = 0.25302)',
 'off (score = 0.04122)',
 'stop (score = 0.71052)',
 'off (score = 0.07688)',
 'up (score = 0.07429)',
 'stop (score = 0.90859)',
 'up (score = 0.03769)',
 '_unknown_ (score = 0.01578)',
 'stop (score = 0.29116)',
 'up (score = 0.19151)',
 'off (score = 0.13839)',
 'stop (score = 0.86171)',
 'up (score = 0.10564)',
 'off (score = 0.01070)',
 'stop (score = 0.96517)',
 '_unknown_ (score = 0.03305)',
 'up (score = 0.00064)',
 'stop (score = 0.91730)',
 'up (score = 0.05711)',
 '_unknown_ (score = 0.02363)',
 'stop (score = 0.99641)',
 'up (score = 0.00274)',
 'off (score = 0.00026)',
 'stop (score = 0.69207)',
 'up (score = 0.10310)',
 'go (score = 0.07329)',
 'stop (score = 0.90009)',
 'off (score = 0.02894)',
 '_unknown_ (score = 0.02361)',
 'stop (score = 0.57909)',
 'up (score = 0.28567)',
 'down (score = 0.03399)',
 'stop (score = 0.92027)',
 'up (score = 0.06529)',
 '_unknown_ (score = 0.01240)',
 'stop (score = 0.56984)',
 'up (score = 0.14682)',
 '_unknown_ (score = 0.06884)',
 'stop (score = 0.98627)',
 '_unknown_ (score = 0.01248)',
 'down (score = 0.00076)',
 'stop (score = 0.78066)',
 'up (score = 0.07329)',
 '_unknown_ (score = 0.07164)',
 'stop (score = 0.99162)',
 'go (score = 0.00331)',
 'up (score = 0.00219)',
 'stop (score = 0.59300)',
 'up (score = 0.39114)',
 '_unknown_ (score = 0.00468)',
 'stop (score = 0.44148)',
 'up (score = 0.21584)',
 'go (score = 0.08655)',
 'stop (score = 0.86626)',
 'up (score = 0.10812)',
 'down (score = 0.00908)',
 'stop (score = 0.91559)',
 'up (score = 0.03920)',
 '_unknown_ (score = 0.01798)',
 'stop (score = 0.95504)',
 '_unknown_ (score = 0.04000)',
 'go (score = 0.00360)',
 'stop (score = 0.74314)',
 'up (score = 0.22082)',
 'off (score = 0.01753)',
 'stop (score = 0.98796)',
 'up (score = 0.01117)',
 '_unknown_ (score = 0.00050)',
 'up (score = 0.42242)',
 'stop (score = 0.14160)',
 'down (score = 0.09028)',
 'stop (score = 0.34194)',
 'up (score = 0.33417)',
 '_unknown_ (score = 0.13248)',
 'stop (score = 0.92179)',
 'up (score = 0.03660)',
 '_unknown_ (score = 0.01734)',
 'stop (score = 0.99787)',
 'up (score = 0.00141)',
 '_unknown_ (score = 0.00037)',
 'stop (score = 0.83214)',
 'up (score = 0.03625)',
 'down (score = 0.02931)',
 'stop (score = 0.97966)',
 'up (score = 0.01852)',
 '_unknown_ (score = 0.00124)',
 'stop (score = 0.89391)',
 'up (score = 0.07879)',
 'go (score = 0.01087)',
 'stop (score = 0.97557)',
 'up (score = 0.01867)',
 'off (score = 0.00205)',
 'stop (score = 0.70121)',
 'up (score = 0.14588)',
 '_unknown_ (score = 0.04814)',
 'stop (score = 0.98115)',
 '_unknown_ (score = 0.01325)',
 'up (score = 0.00284)',
 'stop (score = 0.98360)',
 'down (score = 0.00743)',
 '_unknown_ (score = 0.00367)',
 'stop (score = 0.98975)',
 '_unknown_ (score = 0.00548)',
 'up (score = 0.00266)',
 'stop (score = 0.99790)',
 '_unknown_ (score = 0.00121)',
 'up (score = 0.00055)',
 'stop (score = 0.98081)',
 'up (score = 0.00732)',
 '_unknown_ (score = 0.00548)',
 'stop (score = 0.99353)',
 '_unknown_ (score = 0.00320)',
 'up (score = 0.00177)',
 'stop (score = 0.89639)',
 '_unknown_ (score = 0.04231)',
 'up (score = 0.04003)',
 'stop (score = 0.65718)',
 'down (score = 0.19441)',
 '_unknown_ (score = 0.05682)',
 'stop (score = 0.94357)',
 'down (score = 0.02360)',
 'up (score = 0.02062)',
 'stop (score = 0.89233)',
 'up (score = 0.03672)',
 'down (score = 0.02436)',
 'stop (score = 0.51013)',
 'down (score = 0.17139)',
 'go (score = 0.08875)',
 'stop (score = 0.93051)',
 'up (score = 0.06421)',
 '_unknown_ (score = 0.00301)',
 'stop (score = 0.44148)',
 'up (score = 0.21584)',
 'go (score = 0.08655)',
 'stop (score = 0.99756)',
 'up (score = 0.00147)',
 '_unknown_ (score = 0.00069)',
 'stop (score = 0.99686)',
 '_unknown_ (score = 0.00174)',
 'up (score = 0.00112)',
 'stop (score = 0.86248)',
 'up (score = 0.06937)',
 '_unknown_ (score = 0.01894)',
 'stop (score = 0.84088)',
 'up (score = 0.14219)',
 '_unknown_ (score = 0.00756)',
 'stop (score = 0.72441)',
 'up (score = 0.08968)',
 'off (score = 0.05914)',
 'up (score = 0.35406)',
 'stop (score = 0.30395)',
 'off (score = 0.16752)',
 'stop (score = 0.99878)',
 '_unknown_ (score = 0.00053)',
 'up (score = 0.00030)',
 'stop (score = 0.99874)',
 '_unknown_ (score = 0.00081)',
 'up (score = 0.00021)',
 'down (score = 0.30055)',
 '_unknown_ (score = 0.19656)',
 'stop (score = 0.15799)',
 'stop (score = 0.41698)',
 'off (score = 0.14234)',
 'down (score = 0.12729)',
 'stop (score = 0.99386)',
 'up (score = 0.00474)',
 'down (score = 0.00070)',
 'stop (score = 0.99527)',
 '_unknown_ (score = 0.00271)',
 'up (score = 0.00107)',
 'stop (score = 0.45093)',
 '_unknown_ (score = 0.24628)',
 'go (score = 0.06381)',
 'stop (score = 0.85112)',
 'up (score = 0.06795)',
 'go (score = 0.03845)',
 'stop (score = 0.37752)',
 'up (score = 0.32371)',
 '_unknown_ (score = 0.05701)',
 'stop (score = 0.98600)',
 'up (score = 0.01023)',
 '_unknown_ (score = 0.00227)',
 'stop (score = 0.95925)',
 'up (score = 0.03561)',
 '_unknown_ (score = 0.00178)',
 'stop (score = 0.43057)',
 'up (score = 0.36197)',
 'off (score = 0.13706)',
 'stop (score = 0.99196)',
 'up (score = 0.00500)',
 'down (score = 0.00119)',
 'stop (score = 0.98538)',
 'go (score = 0.00724)',
 '_unknown_ (score = 0.00485)',
 'stop (score = 0.97974)',
 '_unknown_ (score = 0.01620)',
 'go (score = 0.00237)',
 'stop (score = 0.45099)',
 '_unknown_ (score = 0.12778)',
 'off (score = 0.09222)',
 'stop (score = 0.97331)',
 '_unknown_ (score = 0.01048)',
 'down (score = 0.00528)',
 'stop (score = 0.99472)',
 'down (score = 0.00212)',
 '_unknown_ (score = 0.00174)',
 'stop (score = 0.87186)',
 'down (score = 0.06694)',
 'go (score = 0.02199)',
 'stop (score = 0.99799)',
 'up (score = 0.00112)',
 '_unknown_ (score = 0.00045)',
 'stop (score = 0.99478)',
 'up (score = 0.00325)',
 'off (score = 0.00098)',
 'stop (score = 0.99936)',
 'up (score = 0.00028)',
 '_unknown_ (score = 0.00019)',
 'stop (score = 0.85205)',
 'down (score = 0.12575)',
 '_unknown_ (score = 0.00612)',
 'stop (score = 0.78082)',
 'up (score = 0.13418)',
 '_unknown_ (score = 0.03140)',
 'stop (score = 0.99090)',
 'up (score = 0.00493)',
 '_unknown_ (score = 0.00252)',
 'stop (score = 0.95521)',
 'up (score = 0.03692)',
 '_unknown_ (score = 0.00382)',
 'stop (score = 0.97289)',
 'up (score = 0.01633)',
 '_unknown_ (score = 0.00663)',
 'stop (score = 0.81987)',
 'down (score = 0.05813)',
 'up (score = 0.05399)',
 'stop (score = 0.39819)',
 '_unknown_ (score = 0.22448)',
 'up (score = 0.10589)',
 'stop (score = 0.98871)',
 'up (score = 0.00868)',
 '_unknown_ (score = 0.00104)',
 'stop (score = 0.87487)',
 'up (score = 0.05612)',
 'go (score = 0.02334)',
 'stop (score = 0.79478)',
 'off (score = 0.06393)',
 '_unknown_ (score = 0.04445)',
 'stop (score = 0.92717)',
 'off (score = 0.03468)',
 'up (score = 0.02418)',
 'stop (score = 0.23585)',
 'up (score = 0.15615)',
 '_unknown_ (score = 0.11477)',
 'stop (score = 0.98693)',
 'up (score = 0.00912)',
 'down (score = 0.00194)',
 'stop (score = 0.91138)',
 'up (score = 0.08291)',
 '_unknown_ (score = 0.00190)',
 'stop (score = 0.99362)',
 'up (score = 0.00550)',
 '_unknown_ (score = 0.00048)',
 'stop (score = 0.60951)',
 'down (score = 0.11640)',
 'go (score = 0.10724)',
 'stop (score = 0.99213)',
 'up (score = 0.00532)',
 '_unknown_ (score = 0.00173)',
 'stop (score = 0.99898)',
 'up (score = 0.00088)',
 '_unknown_ (score = 0.00011)',
 'stop (score = 0.87669)',
 'off (score = 0.03555)',
 '_unknown_ (score = 0.03319)',
 'stop (score = 0.72243)',
 '_unknown_ (score = 0.16989)',
 'down (score = 0.03805)',
 'stop (score = 0.93210)',
 'up (score = 0.03181)',
 'down (score = 0.01050)',
 'stop (score = 0.99897)',
 '_unknown_ (score = 0.00045)',
 'up (score = 0.00029)',
 'stop (score = 0.50612)',
 'up (score = 0.47685)',
 'off (score = 0.00776)',
 'stop (score = 0.71747)',
 'down (score = 0.13705)',
 '_unknown_ (score = 0.06468)',
 'stop (score = 0.66247)',
 'go (score = 0.17296)',
 '_unknown_ (score = 0.07923)',
 'stop (score = 0.32753)',
 'up (score = 0.17626)',
 'off (score = 0.11351)',
 'stop (score = 0.98290)',
 'up (score = 0.00738)',
 '_unknown_ (score = 0.00650)',
 'stop (score = 0.68633)',
 'up (score = 0.08010)',
 'off (score = 0.07813)',
 'stop (score = 0.81587)',
 'down (score = 0.05165)',
 'up (score = 0.04305)',
 'stop (score = 0.22536)',
 '_unknown_ (score = 0.18580)',
 'go (score = 0.12804)',
 'stop (score = 0.94474)',
 'up (score = 0.02064)',
 '_unknown_ (score = 0.01619)',
 'stop (score = 0.94834)',
 'up (score = 0.01997)',
 'down (score = 0.01052)',
 'stop (score = 0.99678)',
 '_unknown_ (score = 0.00310)',
 'down (score = 0.00006)',
 'stop (score = 0.95208)',
 'up (score = 0.04081)',
 '_unknown_ (score = 0.00239)',
 'stop (score = 0.26048)',
 'down (score = 0.16748)',
 'go (score = 0.13965)',
 'stop (score = 0.99921)',
 '_unknown_ (score = 0.00047)',
 'up (score = 0.00021)',
 'stop (score = 0.54378)',
 'off (score = 0.19993)',
 '_unknown_ (score = 0.07233)',
 'stop (score = 0.92796)',
 'down (score = 0.02070)',
 '_unknown_ (score = 0.01309)',
 'stop (score = 0.78939)',
 'up (score = 0.19312)',
 '_unknown_ (score = 0.00795)',
 'stop (score = 0.89709)',
 'down (score = 0.03361)',
 '_unknown_ (score = 0.01997)',
 'stop (score = 0.94139)',
 '_unknown_ (score = 0.04825)',
 'up (score = 0.00626)',
 'stop (score = 0.99623)',
 '_unknown_ (score = 0.00219)',
 'up (score = 0.00138)',
 'stop (score = 0.89436)',
 '_unknown_ (score = 0.05919)',
 'go (score = 0.02806)',
 'stop (score = 0.45233)',
 '_unknown_ (score = 0.16313)',
 'off (score = 0.11665)',
 'stop (score = 0.93726)',
 'down (score = 0.02824)',
 'up (score = 0.01093)',
 'stop (score = 0.99688)',
 'up (score = 0.00160)',
 '_unknown_ (score = 0.00101)',
 'stop (score = 0.80808)',
 '_unknown_ (score = 0.10039)',
 'up (score = 0.03336)',
 'stop (score = 0.24488)',
 '_unknown_ (score = 0.11815)',
 'left (score = 0.10235)',
 'stop (score = 0.99836)',
 'up (score = 0.00083)',
 '_unknown_ (score = 0.00038)',
 'stop (score = 0.97935)',
 '_unknown_ (score = 0.01288)',
 'up (score = 0.00481)',
 'stop (score = 0.98262)',
 'up (score = 0.01171)',
 '_unknown_ (score = 0.00243)',
 'stop (score = 0.64022)',
 'go (score = 0.16679)',
 'no (score = 0.06918)',
 'stop (score = 0.76328)',
 'off (score = 0.15193)',
 'up (score = 0.02800)',
 'stop (score = 0.92439)',
 'up (score = 0.05283)',
 '_unknown_ (score = 0.00782)',
 'stop (score = 0.79350)',
 'go (score = 0.05066)',
 'up (score = 0.04322)',
 'stop (score = 0.98467)',
 'down (score = 0.00778)',
 '_unknown_ (score = 0.00234)',
 'stop (score = 0.99034)',
 'up (score = 0.00893)',
 '_unknown_ (score = 0.00047)',
 'stop (score = 0.98414)',
 'up (score = 0.01113)',
 'down (score = 0.00169)',
 'stop (score = 0.67216)',
 'up (score = 0.13359)',
 '_unknown_ (score = 0.05075)',
 'stop (score = 0.78502)',
 'up (score = 0.10579)',
 'down (score = 0.05232)',
 'stop (score = 0.96869)',
 'up (score = 0.02890)',
 '_unknown_ (score = 0.00105)',
 'stop (score = 0.95722)',
 'go (score = 0.02449)',
 'down (score = 0.00590)',
 'up (score = 0.88006)',
 '_unknown_ (score = 0.03862)',
 'stop (score = 0.02416)',
 'stop (score = 0.44896)',
 'down (score = 0.21524)',
 'no (score = 0.15847)',
 'stop (score = 0.93532)',
 'up (score = 0.02308)',
 'off (score = 0.01062)',
 'stop (score = 0.65481)',
 'down (score = 0.10290)',
 '_unknown_ (score = 0.09691)',
 'stop (score = 0.83370)',
 'up (score = 0.15832)',
 '_unknown_ (score = 0.00376)',
 'stop (score = 0.68990)',
 'up (score = 0.12806)',
 'down (score = 0.05177)',
 'stop (score = 0.86563)',
 'off (score = 0.03857)',
 'up (score = 0.03350)',
 'stop (score = 0.71492)',
 'up (score = 0.16641)',
 'go (score = 0.04935)',
 'stop (score = 0.63599)',
 'go (score = 0.14249)',
 'up (score = 0.10391)',
 'stop (score = 0.79370)',
 '_unknown_ (score = 0.06695)',
 'down (score = 0.06271)',
 'stop (score = 0.93857)',
 'go (score = 0.02124)',
 'down (score = 0.01481)',
 'stop (score = 0.35819)',
 'up (score = 0.19114)',
 'down (score = 0.08839)',
 'up (score = 0.58615)',
 'stop (score = 0.25748)',
 'off (score = 0.13660)',
 'stop (score = 0.16377)',
 'up (score = 0.14177)',
 '_unknown_ (score = 0.14092)',
 'stop (score = 0.99972)',
 'up (score = 0.00016)',
 '_unknown_ (score = 0.00011)',
 'stop (score = 0.77719)',
 'up (score = 0.11667)',
 '_unknown_ (score = 0.03379)',
 'stop (score = 0.97408)',
 'up (score = 0.02085)',
 '_unknown_ (score = 0.00265)',
 'stop (score = 0.97451)',
 'go (score = 0.01067)',
 '_unknown_ (score = 0.00638)',
 'stop (score = 0.98016)',
 'down (score = 0.00753)',
 'go (score = 0.00371)',
 'stop (score = 0.99810)',
 'up (score = 0.00149)',
 '_unknown_ (score = 0.00032)',
 'stop (score = 0.85410)',
 '_unknown_ (score = 0.06047)',
 'up (score = 0.02200)',
 'stop (score = 0.99523)',
 '_unknown_ (score = 0.00262)',
 'left (score = 0.00094)',
 'stop (score = 0.99719)',
 '_unknown_ (score = 0.00182)',
 'up (score = 0.00048)',
 'stop (score = 0.97167)',
 'up (score = 0.01566)',
 '_unknown_ (score = 0.00619)',
 'stop (score = 0.52191)',
 'up (score = 0.22323)',
 'left (score = 0.05919)',
 'stop (score = 0.64668)',
 'down (score = 0.20922)',
 'no (score = 0.04130)',
 'stop (score = 0.61945)',
 'up (score = 0.12795)',
 'off (score = 0.09970)',
 'stop (score = 0.58292)',
 'up (score = 0.39636)',
 'off (score = 0.00788)',
 'stop (score = 0.96773)',
 '_unknown_ (score = 0.00939)',
 'up (score = 0.00694)',
 'stop (score = 0.99393)',
 '_unknown_ (score = 0.00234)',
 'up (score = 0.00195)',
 'stop (score = 0.89692)',
 'up (score = 0.04895)',
 'off (score = 0.02443)',
 'stop (score = 0.99402)',
 'up (score = 0.00317)',
 'down (score = 0.00159)',
 'stop (score = 0.18026)',
 'left (score = 0.13037)',
 'down (score = 0.12699)',
 'stop (score = 0.99877)',
 '_unknown_ (score = 0.00084)',
 'up (score = 0.00034)',
 'stop (score = 0.48925)',
 '_unknown_ (score = 0.12916)',
 'right (score = 0.11773)',
 'stop (score = 0.97986)',
 '_unknown_ (score = 0.01157)',
 'up (score = 0.00352)',
 'stop (score = 0.98379)',
 'up (score = 0.01037)',
 '_unknown_ (score = 0.00274)',
 'stop (score = 0.99270)',
 'up (score = 0.00675)',
 '_unknown_ (score = 0.00050)',
 'stop (score = 0.27674)',
 'up (score = 0.25178)',
 'no (score = 0.10935)',
 'stop (score = 0.86411)',
 'off (score = 0.06190)',
 '_unknown_ (score = 0.03669)',
 'stop (score = 0.88871)',
 'up (score = 0.09937)',
 '_unknown_ (score = 0.00315)',
 'stop (score = 0.48994)',
 'up (score = 0.12429)',
 'no (score = 0.10686)',
 'stop (score = 0.51262)',
 'up (score = 0.38220)',
 '_unknown_ (score = 0.05622)',
 'stop (score = 0.99406)',
 'up (score = 0.00474)',
 '_unknown_ (score = 0.00085)',
 'stop (score = 0.98721)',
 'up (score = 0.01022)',
 '_unknown_ (score = 0.00182)',
 'stop (score = 0.92152)',
 'down (score = 0.04298)',
 'no (score = 0.00939)',
 'stop (score = 0.98223)',
 'up (score = 0.01197)',
 'down (score = 0.00265)',
 'stop (score = 0.45783)',
 'up (score = 0.13047)',
 '_unknown_ (score = 0.08471)',
 'stop (score = 0.98573)',
 '_unknown_ (score = 0.00698)',
 'go (score = 0.00461)',
 'stop (score = 0.79965)',
 'off (score = 0.06937)',
 'up (score = 0.04585)',
 'go (score = 0.48640)',
 'up (score = 0.16947)',
 '_unknown_ (score = 0.15874)',
 'stop (score = 0.98158)',
 'up (score = 0.01708)',
 '_unknown_ (score = 0.00080)',
 'stop (score = 0.48661)',
 'up (score = 0.15354)',
 'down (score = 0.10605)',
 'stop (score = 0.33172)',
 'up (score = 0.12149)',
 '_unknown_ (score = 0.10527)',
 'stop (score = 0.98440)',
 'up (score = 0.00508)',
 '_unknown_ (score = 0.00465)',
 'stop (score = 0.95439)',
 'up (score = 0.03564)',
 '_unknown_ (score = 0.00394)',
 'stop (score = 0.99749)',
 'up (score = 0.00134)',
 '_unknown_ (score = 0.00095)',
 'stop (score = 0.70006)',
 'up (score = 0.20123)',
 'left (score = 0.03805)',
 'stop (score = 0.63192)',
 '_unknown_ (score = 0.15333)',
 'go (score = 0.12291)',
 'stop (score = 0.99821)',
 'up (score = 0.00156)',
 '_unknown_ (score = 0.00010)',
 'stop (score = 0.92353)',
 '_unknown_ (score = 0.02751)',
 'down (score = 0.02737)',
 'stop (score = 0.99818)',
 'up (score = 0.00158)',
 '_unknown_ (score = 0.00013)',
 'stop (score = 0.95619)',
 'up (score = 0.02845)',
 '_unknown_ (score = 0.00523)',
 'stop (score = 0.93942)',
 'up (score = 0.03724)',
 'go (score = 0.00980)',
 'stop (score = 0.99567)',
 'up (score = 0.00360)',
 '_unknown_ (score = 0.00044)',
 'stop (score = 0.75682)',
 '_unknown_ (score = 0.09401)',
 'down (score = 0.04084)',
 'stop (score = 0.46735)',
 'down (score = 0.11926)',
 '_unknown_ (score = 0.11014)',
 '_unknown_ (score = 0.35277)',
 'stop (score = 0.25628)',
 'up (score = 0.22787)',
 'down (score = 0.64915)',
 'stop (score = 0.13520)',
 'go (score = 0.06183)',
 'stop (score = 0.51185)',
 'up (score = 0.45758)',
 'off (score = 0.01466)',
 'stop (score = 0.98674)',
 '_unknown_ (score = 0.00840)',
 'up (score = 0.00273)',
 'stop (score = 0.99574)',
 'down (score = 0.00201)',
 '_unknown_ (score = 0.00110)',
 'stop (score = 0.78802)',
 'up (score = 0.11744)',
 'down (score = 0.04498)',
 'stop (score = 0.97416)',
 '_unknown_ (score = 0.01517)',
 'up (score = 0.00404)',
 'off (score = 0.16264)',
 '_silence_ (score = 0.11814)',
 'yes (score = 0.09154)',
 'stop (score = 0.99422)',
 'up (score = 0.00445)',
 '_unknown_ (score = 0.00075)',
 'stop (score = 0.98304)',
 '_unknown_ (score = 0.00850)',
 'down (score = 0.00307)',
 'stop (score = 0.91704)',
 'up (score = 0.07075)',
 'off (score = 0.00562)',
 'stop (score = 0.96610)',
 'up (score = 0.01413)',
 'down (score = 0.00977)',
 'stop (score = 0.91341)',
 'up (score = 0.04878)',
 'down (score = 0.01150)',
 'stop (score = 0.98021)',
 'up (score = 0.01911)',
 '_unknown_ (score = 0.00029)',
 'stop (score = 0.87034)',
 'up (score = 0.06751)',
 '_unknown_ (score = 0.04405)',
 'stop (score = 0.13540)',
 'up (score = 0.11850)',
 '_unknown_ (score = 0.10476)',
 'stop (score = 0.67461)',
 'up (score = 0.11249)',
 'down (score = 0.05729)',
 'stop (score = 0.98290)',
 'up (score = 0.00918)',
 '_unknown_ (score = 0.00575)',
 'stop (score = 0.49829)',
 'down (score = 0.13202)',
 '_unknown_ (score = 0.11668)',
 'stop (score = 0.98843)',
 '_unknown_ (score = 0.00946)',
 'go (score = 0.00120)',
 'stop (score = 0.42095)',
 '_unknown_ (score = 0.13196)',
 'up (score = 0.09977)',
 'stop (score = 0.97009)',
 '_unknown_ (score = 0.02019)',
 'go (score = 0.00626)',
 'stop (score = 0.98068)',
 'up (score = 0.00856)',
 '_unknown_ (score = 0.00304)',
 'stop (score = 0.99057)',
 'up (score = 0.00858)',
 'down (score = 0.00050)',
 'stop (score = 0.98696)',
 '_unknown_ (score = 0.00550)',
 'up (score = 0.00223)',
 'stop (score = 0.93676)',
 'up (score = 0.06078)',
 '_unknown_ (score = 0.00106)',
 'stop (score = 0.84753)',
 'up (score = 0.11528)',
 'down (score = 0.01605)',
 'stop (score = 0.86223)',
 '_unknown_ (score = 0.04718)',
 'go (score = 0.03105)',
 'stop (score = 0.99879)',
 'up (score = 0.00114)',
 '_unknown_ (score = 0.00003)',
 'stop (score = 0.91839)',
 'up (score = 0.03990)',
 '_unknown_ (score = 0.01542)',
 'no (score = 0.16717)',
 'go (score = 0.14407)',
 'down (score = 0.13581)',
 'stop (score = 0.98426)',
 'up (score = 0.01385)',
 '_unknown_ (score = 0.00076)',
 'stop (score = 0.43997)',
 '_unknown_ (score = 0.21405)',
 'go (score = 0.05974)',
 'stop (score = 0.99714)',
 'up (score = 0.00277)',
 '_unknown_ (score = 0.00007)',

That was not pretty! We’d better define some helper functions to extract the model’s guesses from that messy output:

def get_guesses(scores):
    scores = filter_scores(scores)
    if len(scores) % 4 != 0:
        raise ValueError(f"Expected scores list to have a length divisible by 4 after filtering but got length {len(scores)}")
    num_files = len(scores) / 4
    fnames = scores[0::4]
    guesses = [guess.split(' ')[0] for guess in scores[1::4]]
    return zip(fnames, guesses)

def score_directory(directory):
    scores = !python {example_path}/label_wav_dir.py \
        --graph={example_path}/trained_model/my_frozen_graph.pb \
        --labels={example_path}/trained_model/conv_labels.txt \
    return filter_scores(scores)

Define a function to generate errors in all wav files in a given directory. If an inclusion list is provided, only files on the list will be processed.

def errorify_directory(data_root_dir, dir_name, tree_root, err_params, inclusion_list=None):
    clean_data_dir = data_root_dir / dir_name
    if not clean_data_dir.exists():
        raise ValueError(f"Directory {clean_data_dir} does not exist.")
    err_data_dir = data_root_dir / (dir_name + "_err")
    if not err_data_dir.exists():
    if not inclusion_list:
        inclusion_list = [f for f in clean_data_dir.iterdir() if ".wav" in str(f)]
    for file in inclusion_list:
        fname = file.name
        wav = read(file)
        clipped = tree_root.generate_error([wav], err_params)[0]
        err_file_path = err_data_dir / fname
        write(err_file_path, clipped[0], clipped[1])
    return err_data_dir

Define a function to generate errors in all wav files on a list. The function is needed when files from multiple categories are present on the list. To facilitate comparisons between clean and errorified data, the clean files the list can be automatically copied to suitably named directories. To do this, provide the parameter copy_clean=True.

def errorify_list(data_files, categories, tree_root, err_params, copy_clean=False):
    data_root_dir = data_files[0].parents[1]
    for cat in categories:
        files_in_cat = [f for f in data_files if (cat + "/") in str(f)]
        print("category:", cat)
        errorify_directory(data_root_dir, cat, tree_root, err_params, inclusion_list=files_in_cat)
        if copy_clean:
            copy_dir = data_root_dir / (cat + "_clean")
            for file in files_in_cat:
                shutil.copy(file, copy_dir)

Define a function to compare the model’s guesses on clean and errorified data. The results are returned in a Pandas dataframe.

def compare(data_root, category, clean_ext="_clean", err_ext="_err"):
    scores_clean = score_directory(data_root / (category + clean_ext))
    guesses_clean = get_guesses(scores_clean)
    scores_err = score_directory(data_root / (category + err_ext))
    guesses_err = get_guesses(scores_err)
    df_clean = pd.DataFrame(guesses_clean, columns=["file", "clean_guess"])
    df_err = pd.DataFrame(guesses_err, columns=["file", "err_guess"])
    res = pd.merge(df_clean, df_err, on="file", how="inner")
    res['true_label'] = category
    return res

Generate errors in all test set audio clips.

errorify_list(test_set_files, trained_categories, root_node, err_params, copy_clean=True)
category: yes
category: no
category: up
category: down
category: left
category: right
category: on
category: off
category: stop
category: go

Run model on clean and errorified data.

results = [compare(data_dir, cat) for cat in trained_categories]
df = pd.concat(results)

Create confusion matrices for clean and errorified data, respectively.

cm_clean = confusion_matrix(df['true_label'], df['clean_guess'], labels=labels)
cm_err = confusion_matrix(df['true_label'], df['err_guess'], labels=labels)

Visualize the confusion matrix for the clean data.

visualize_confusion_matrix(df, cm_clean, 0, labels, "dyn_range", "true_label", "clean_guess")

Visualize the confusion matrix for the errorified data.

visualize_confusion_matrix(df, cm_err, 0, labels, "dyn_range", "true_label", "err_guess")

The notebook for this case study can be found here.