Skip to content

Commit 789c5ce

Browse files
add wip browser example
1 parent 2bbf7b0 commit 789c5ce

File tree

7 files changed

+364
-0
lines changed

7 files changed

+364
-0
lines changed

web/package.json

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
{
2+
"name": "online-lid",
3+
"version": "0.1.0",
4+
"description": "",
5+
"main": "index.js",
6+
"dependencies": {
7+
"@tensorflow/tfjs": "^2.7.0",
8+
"@tensorflow/tfjs-backend-wasm": "^2.7.0",
9+
"clean-webpack-plugin": "^3.0.0",
10+
"copy-webpack-plugin": "^6.3.2",
11+
"css-loader": "^5.0.1",
12+
"html-webpack-plugin": "^4.5.0",
13+
"style-loader": "^2.0.0",
14+
"webpack": "^5.8.0",
15+
"webpack-cli": "^4.2.0",
16+
"webpack-dev-server": "^3.11.0",
17+
"webpack-hot-middleware": "^2.25.0"
18+
},
19+
"devDependencies": {
20+
"ts-loader": "^8.0.11",
21+
"typescript": "^4.1.2"
22+
},
23+
"scripts": {
24+
"serve": "webpack serve",
25+
"build": "webpack",
26+
"tfjs-convert": "python3 ./src/feat.py ./static"
27+
},
28+
"author": "",
29+
"license": "MIT"
30+
}

web/src/feat.py

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import argparse
2+
import os
3+
import tempfile
4+
5+
import tensorflow as tf
6+
import tensorflowjs as tfjs
7+
8+
from lidbox.features import audio, cmvn, feature_scaling
9+
10+
11+
@tf.function(input_signature=[
12+
tf.TensorSpec([None, None, None], tf.float32),
13+
tf.TensorSpec([], tf.int32),
14+
tf.TensorSpec([], tf.int32)])
15+
def convertBrowserFFT(spec, sample_rate, num_mel_bins):
16+
S = audio.db_to_power(spec)
17+
# S = tf.math.abs(tf.signal.stft(signals, 400, 160, 512))
18+
# S = audio.spectrograms(signals, sample_rate)
19+
S = audio.linear_to_mel(S, sample_rate, num_mel_bins=num_mel_bins, fmax=tf.cast(sample_rate/2, tf.float32))
20+
S = tf.math.log(1e-6 + S)
21+
S = cmvn(S, axis=1)
22+
return S
23+
24+
@tf.function(input_signature=[
25+
tf.TensorSpec([None, None], tf.float32),
26+
tf.TensorSpec([], tf.int32),
27+
tf.TensorSpec([], tf.int32)])
28+
def signals2logmel(signals, sample_rate, num_mel_bins):
29+
signals, sample_rate = signals[:,::3], sample_rate // 3
30+
flen = audio.ms_to_frames(sample_rate, 25)
31+
fstep = audio.ms_to_frames(sample_rate, 10)
32+
S = tf.math.square(tf.math.abs(tf.signal.stft(signals, flen, fstep, fft_length=512)))
33+
# S = audio.spectrograms(signals, sample_rate)
34+
S = audio.linear_to_mel(S, sample_rate, num_mel_bins=num_mel_bins, fmax=tf.cast(sample_rate, tf.float32))
35+
S = tf.math.log(1e-6 + S)
36+
S = cmvn(S, axis=1)
37+
return S
38+
39+
40+
if __name__ == "__main__":
41+
parser = argparse.ArgumentParser()
42+
parser.add_argument("out_dir")
43+
out_dir = parser.parse_args().out_dir
44+
45+
export_list = [
46+
("spec2logmel", convertBrowserFFT),
47+
("signals2logmel", signals2logmel),
48+
]
49+
50+
for name, fn in export_list:
51+
with tempfile.TemporaryDirectory() as tfmodel_path:
52+
m = tf.Module()
53+
m.__call__ = fn
54+
tf.saved_model.save(m, tfmodel_path)
55+
tfjs.converters.convert_tf_saved_model(
56+
tfmodel_path, os.path.join(out_dir, "tfjs", name))

web/src/index.ts

+185
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
import './style.css';
2+
import * as tf from '@tensorflow/tfjs';
3+
// require('@tensorflow/tfjs-backend-wasm');
4+
5+
6+
interface AppState {
7+
TFBackend: string,
8+
microphoneConfig: tf.data.MicrophoneConfig;
9+
microphoneCaptureIntervalMs: number;
10+
signalCanvas: HTMLCanvasElement;
11+
// dBSpecCanvas: HTMLCanvasElement;
12+
logMelSpecCanvas: HTMLCanvasElement;
13+
predictionLabel: HTMLElement;
14+
running: boolean;
15+
renderIntervalID: number;
16+
spec2logmel: tf.GraphModel;
17+
signals2logmel: tf.GraphModel;
18+
xvector: tf.LayersModel;
19+
int2label: tf.Tensor1D;
20+
}
21+
22+
23+
export const state: AppState = {
24+
TFBackend: "webgl",
25+
microphoneConfig: {
26+
fftSize: 1024,
27+
numFramesPerSpectrogram: 198,
28+
sampleRateHz: 44100,
29+
includeSpectrogram: true,
30+
includeWaveform: true,
31+
},
32+
microphoneCaptureIntervalMs: 200,
33+
signalCanvas: null,
34+
// dBSpecCanvas: null,
35+
logMelSpecCanvas: null,
36+
predictionLabel: null,
37+
running: false,
38+
renderIntervalID: 0,
39+
spec2logmel: null,
40+
signals2logmel: null,
41+
xvector: null,
42+
int2label: null,
43+
}
44+
45+
46+
function fatalError(error: Error): void {
47+
console.error(error)
48+
console.error("cannot recover from error")
49+
stopApp()
50+
}
51+
52+
export function stopApp(): void {
53+
console.info("app stopping")
54+
state.running = false
55+
if (state.renderIntervalID > 0) {
56+
window.clearInterval(state.renderIntervalID)
57+
state.renderIntervalID = 0
58+
}
59+
}
60+
61+
62+
function createElement(id, tag): HTMLElement {
63+
const e: HTMLElement = document.createElement(tag)
64+
e.id = id
65+
document.body.appendChild(e)
66+
return e
67+
}
68+
69+
70+
function spectrogramToCanvas(spec: tf.Tensor3D, canvas: HTMLCanvasElement): void {
71+
// Scale all values between 0 and 1
72+
const min: tf.Tensor3D = spec.min([0], true)
73+
const max: tf.Tensor3D = spec.max([0], true)
74+
let image: tf.Tensor3D = tf.divNoNan(spec.sub(min), max.sub(min))
75+
image = image.transpose([1, 0, 2]).reverse(0) as tf.Tensor3D
76+
77+
// Render to canvas
78+
tf.browser.toPixels(image, canvas).catch(fatalError)
79+
}
80+
81+
82+
let spec2logmelInput = {
83+
spec: tf.zeros([1, 1, 1]),
84+
sample_rate: tf.scalar(16000, "int32"),
85+
num_mel_bins: tf.scalar(40, "int32"),
86+
}
87+
88+
let signals2logmelInput = {
89+
signals: tf.zeros([1, 1]),
90+
sample_rate: tf.scalar(state.microphoneConfig.sampleRateHz, "int32"),
91+
num_mel_bins: tf.scalar(40, "int32"),
92+
}
93+
94+
function updatePredictionLabel(predictedIndexes: Int32Array): void {
95+
const labels: string[] = Array.from(predictedIndexes, i => state.int2label[i])
96+
state.predictionLabel.innerText = "prediction: " + labels.join(", ")
97+
}
98+
99+
function handleMicrophoneInput(data: any): void {
100+
if (!state.running) {
101+
console.warn("app not running, ignoring microphone input data")
102+
return
103+
}
104+
// tf.tidy(() => spectrogramToCanvas(data.spectrogram.clipByValue(-200, 0), state.dBSpecCanvas))
105+
106+
tf.tidy(() => {
107+
spec2logmelInput.spec = data.spectrogram.transpose([2, 0, 1])
108+
data.spectrogram.dispose()
109+
const logmel = state.spec2logmel.execute(spec2logmelInput)
110+
const imgInput = (logmel as tf.Tensor).clipByValue(-1, 1).transpose([1, 2, 0]) as tf.Tensor3D
111+
spectrogramToCanvas(imgInput, state.logMelSpecCanvas)
112+
113+
const prediction: tf.Tensor1D = state.xvector.predict(logmel) as tf.Tensor1D
114+
prediction.argMax(1).data().then(updatePredictionLabel)
115+
})
116+
117+
// signal2logmelInput.signals = data.waveform.transpose([1, 0])
118+
// data.waveform.dispose()
119+
120+
// state.signals2logmel.executeAsync(signal2logmelInput)
121+
// .then(logmel => {
122+
// tf.tidy(() => {
123+
// const imgInput = (logmel as tf.Tensor).clipByValue(-1, 1).transpose([1, 2, 0]) as tf.Tensor3D
124+
// spectrogramToCanvas(imgInput, state.logMelSpecCanvas)
125+
// signal2logmelInput.signals.dispose();
126+
// (logmel as tf.Tensor).dispose()
127+
// })
128+
// })
129+
// .catch(fatalError)
130+
131+
}
132+
133+
134+
function startListenLoop(mic: any): void {
135+
state.renderIntervalID = window.setInterval(
136+
() => {
137+
mic.capture()
138+
.then(micData => handleMicrophoneInput(micData))
139+
.catch(fatalError)
140+
},
141+
state.microphoneCaptureIntervalMs)
142+
}
143+
144+
145+
async function main() {
146+
// state.signalCanvas = createCanvas("signal-canvas")
147+
// state.dBSpecCanvas = createCanvas("decibel-spectrogram-canvas")
148+
state.logMelSpecCanvas = createElement("logscale-melspectrogram-canvas", "canvas") as HTMLCanvasElement
149+
state.predictionLabel = createElement("prediction-label", "h2")
150+
151+
state.microphoneConfig.columnTruncateLength = Math.round(
152+
(state.microphoneConfig.fftSize / 2 + 1)
153+
/ (state.microphoneConfig.sampleRateHz/16000))
154+
155+
await tf.setBackend(state.TFBackend)
156+
console.log("initialized tensorflow.js backend:", tf.getBackend())
157+
158+
console.log("requesting access to an input device")
159+
const mic = await tf.data.microphone(state.microphoneConfig)
160+
console.log("got permission to use input device", (mic as any).stream.id)
161+
162+
const graph1 = await tf.loadGraphModel("./static/tfjs/spec2logmel/model.json")
163+
console.log("tf graph1 loaded")
164+
state.spec2logmel = graph1
165+
166+
// const graph2 = await tf.loadGraphModel("./static/tfjs/signals2logmel/model.json")
167+
// console.log("tf graph2 loaded")
168+
// state.signals2logmel = graph2
169+
170+
const graph3 = await tf.loadLayersModel("./static/tfjs/xvector_mv/model.json")
171+
console.log("tf graph3 loaded")
172+
state.xvector = graph3
173+
state.xvector.summary()
174+
175+
const int2label = await tf.util.fetch("./static/tfjs/xvector_mv/int2label.json")
176+
state.int2label = await int2label.json()
177+
178+
console.log("starting app")
179+
state.running = true
180+
startListenLoop(mic)
181+
182+
}
183+
184+
185+
document.addEventListener("DOMContentLoaded", () => main().catch(fatalError))

web/src/layers.js

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import * as tf from '@tensorflow/tfjs';
2+
3+
class GlobalMeanStddevPooling1D extends tf.layers.Layer {
4+
static get className() {
5+
return 'GlobalMeanStddevPooling1D';
6+
}
7+
constructor(config) {
8+
super(config || {name: "stats_pooling"});
9+
}
10+
computeOutputShape(inputShape) {
11+
return [inputShape[0], 2 * inputShape[2]];
12+
}
13+
call(inputs) {
14+
const input = inputs[0];
15+
const timeAxis = 1;
16+
const mean = tf.mean(input, timeAxis);
17+
const stddev = tf.sqrt(tf.mean(tf.square(tf.sub(input, mean)), timeAxis));
18+
return tf.concat([mean, stddev], timeAxis);
19+
}
20+
};
21+
tf.serialization.registerClass(GlobalMeanStddevPooling1D);
22+
23+
class logSoftmaxV2 extends tf.layers.Layer {
24+
static get className() {
25+
return 'logSoftmaxV2';
26+
}
27+
constructor(config) {
28+
super(config || {name: "log_softmax"});
29+
}
30+
call(logits) {
31+
return tf.logSoftmax(logits);
32+
}
33+
};
34+
tf.serialization.registerClass(logSoftmaxV2);

web/src/style.css

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
canvas {
2+
width: 100%;
3+
max-width: 800px;
4+
max-height: 200px;
5+
}

web/tsconfig.json

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"compilerOptions": {
3+
"outDir": "./built",
4+
"allowJs": true
5+
},
6+
"include": ["./src/**/*"]
7+
}

web/webpack.config.js

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
const path = require('path');
2+
const HtmlWebpackPlugin = require('html-webpack-plugin');
3+
const { CleanWebpackPlugin } = require('clean-webpack-plugin');
4+
const CopyPlugin = require("copy-webpack-plugin");
5+
6+
module.exports = {
7+
devtool: 'inline-source-map',
8+
devServer: {
9+
contentBase: './dist',
10+
hot: true,
11+
},
12+
plugins: [
13+
new CleanWebpackPlugin(),
14+
new HtmlWebpackPlugin({
15+
title: 'dev title',
16+
}),
17+
new CopyPlugin({
18+
patterns: [
19+
{from: "static", to: "static"},
20+
{from: "node_modules/@tensorflow/tfjs-backend-wasm/wasm-out/tfjs-backend-wasm.wasm", to: "."},
21+
],
22+
}),
23+
],
24+
entry: ['./src/index.ts', '/src/layers.js'],
25+
module: {
26+
rules: [
27+
{
28+
test: /\.tsx?$/,
29+
use: ['ts-loader'],
30+
exclude: /node_modules/,
31+
},
32+
{
33+
test: /\.css$/i,
34+
use: ['style-loader', 'css-loader'],
35+
exclude: /node_modules/,
36+
},
37+
],
38+
},
39+
resolve: {
40+
extensions: [ '.tsx', '.ts', '.js' ],
41+
},
42+
output: {
43+
filename: '[name].bundle.js',
44+
path: path.resolve(__dirname, 'dist'),
45+
},
46+
};
47+

0 commit comments

Comments
 (0)