Back to blog

JavaScript and Machine Learning: How TensorFlow.js Is Democratizing AI

Hello HaWkers, imagine training a Machine Learning model and running it directly in the browser, without Python backend, without heavy servers. Sounds futuristic? It's reality with TensorFlow.js.

You, JavaScript developer, now have access to AI capabilities that were once exclusive to data scientists. Let's explore how this works, real use cases and practical code you can run today.

Why Machine Learning in the Browser?

TensorFlow.js Advantages:

  1. Privacy: Data never leaves user's device
  2. Zero latency: No roundtrip to server
  3. Cost: Processing distributed on clients (not your server)
  4. Accessibility: Any JavaScript dev can start
  5. Multiplatform: Browser, Node.js, React Native, Electron

Real use cases in 2025:

  • Real-time camera filters (Instagram, Snapchat)
  • Offline audio transcription (Zoom, Meet)
  • Fraud detection in payments
  • Personalized recommendations without sending data
  • Accessibility (subtitles, real-time translation)

Basic Setup: Your First Model

Let's create a text sentiment detector — classifies whether a sentence is positive or negative.

Installation:

npm install @tensorflow/tfjs
# Or via CDN:
# <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>

Simple pre-trained model:

import * as tf from '@tensorflow/tfjs';

// 1. Create simple sequential model
const model = tf.sequential({
  layers: [
    // Input: Text converted to numbers (embedding)
    tf.layers.dense({ inputShape: [100], units: 16, activation: 'relu' }),

    // Hidden layer
    tf.layers.dense({ units: 8, activation: 'relu' }),

    // Output: 2 classes (positive/negative)
    tf.layers.dense({ units: 2, activation: 'softmax' })
  ]
});

// 2. Compile model
model.compile({
  optimizer: tf.train.adam(0.001),
  loss: 'categoricalCrossentropy',
  metrics: ['accuracy']
});

// 3. Train with data
async function trainModel(reviews, labels) {
  // reviews: array of texts
  // labels: array of [0, 1] or [1, 0]

  const xs = tf.tensor2d(reviews);
  const ys = tf.tensor2d(labels);

  await model.fit(xs, ys, {
    epochs: 50,
    batchSize: 32,
    validationSplit: 0.2,
    callbacks: {
      onEpochEnd: (epoch, logs) => {
        console.log(`Epoch ${epoch}: loss = ${logs.loss.toFixed(4)}`);
      }
    }
  });

  console.log('✓ Model trained!');
}

// 4. Make predictions
function predict(text) {
  // Convert text to numeric vector (simplified)
  const vector = textToVector(text);
  const input = tf.tensor2d([vector]);

  const prediction = model.predict(input);
  const scores = prediction.dataSync();

  return {
    positive: scores[0],
    negative: scores[1],
    sentiment: scores[0] > scores[1] ? 'Positive' : 'Negative'
  };
}

// Usage example
const result = predict('Loved this product, very good!');
console.log(result);
// { positive: 0.89, negative: 0.11, sentiment: 'Positive' }

AI Technology

Practical Case: Real-Time Image Recognition

Let's use a pre-trained model (MobileNet) to classify images from webcam.

import * as tf from '@tensorflow/tfjs';
import * as mobilenet from '@tensorflow-models/mobilenet';

class ImageClassifier {
  constructor() {
    this.model = null;
    this.video = null;
  }

  async init() {
    // 1. Load model (automatic download)
    console.log('Loading MobileNet...');
    this.model = await mobilenet.load();
    console.log('✓ Model loaded!');

    // 2. Setup webcam
    this.video = document.getElementById('webcam');

    const stream = await navigator.mediaDevices.getUserMedia({
      video: { width: 640, height: 480 }
    });

    this.video.srcObject = stream;
    await new Promise(resolve => {
      this.video.onloadedmetadata = resolve;
    });

    this.video.play();
  }

  async classify() {
    if (!this.model || !this.video) return;

    // 3. Make real-time prediction
    const predictions = await this.model.classify(this.video);

    return predictions.map(p => ({
      class: p.className,
      probability: (p.probability * 100).toFixed(2) + '%'
    }));
  }

  async classifyLoop() {
    const resultsDiv = document.getElementById('results');

    setInterval(async () => {
      const predictions = await this.classify();

      resultsDiv.innerHTML = predictions
        .map(p => `<p>${p.class}: ${p.probability}</p>`)
        .join('');
    }, 1000); // Classify every 1 second
  }
}

// Corresponding HTML
/*
<video id="webcam" autoplay></video>
<div id="results"></div>

<script>
  const classifier = new ImageClassifier();
  classifier.init().then(() => {
    classifier.classifyLoop();
  });
</script>
*/

Result: Application identifies objects in camera without sending data to server!

Transfer Learning: Training Your Own Classifier

What if you want to classify specific things? Use Transfer Learning — take pre-trained model and adjust for your case.

Example: Pose classifier (physical exercises)

import * as tf from '@tensorflow/tfjs';
import * as mobilenet from '@tensorflow-models/mobilenet';

class PoseClassifier {
  constructor() {
    this.baseModel = null;
    this.model = null;
    this.classes = ['squat', 'pushup', 'plank', 'jump'];
  }

  async loadBaseModel() {
    // Load MobileNet without last layer
    const mobilenet = await tf.loadLayersModel(
      'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json'
    );

    // Remove last layer (generic classification)
    const layer = mobilenet.getLayer('conv_pw_13_relu');
    this.baseModel = tf.model({
      inputs: mobilenet.inputs,
      outputs: layer.output
    });
  }

  createCustomModel() {
    // Add new layers for your classes
    const model = tf.sequential({
      layers: [
        tf.layers.flatten({
          inputShape: this.baseModel.outputs[0].shape.slice(1)
        }),
        tf.layers.dense({
          units: 128,
          activation: 'relu',
          kernelInitializer: 'varianceScaling'
        }),
        tf.layers.dropout({ rate: 0.5 }),
        tf.layers.dense({
          units: this.classes.length,
          activation: 'softmax'
        })
      ]
    });

    model.compile({
      optimizer: tf.train.adam(0.0001),
      loss: 'categoricalCrossentropy',
      metrics: ['accuracy']
    });

    this.model = model;
  }

  async train(images, labels) {
    // images: array of <img> elements
    // labels: array of indices [0, 1, 2, 3]

    // Extract features with base model
    const features = tf.tidy(() => {
      const imageTensors = images.map(img => {
        return tf.browser.fromPixels(img)
          .resizeBilinear([224, 224])
          .toFloat()
          .div(127.5)
          .sub(1);
      });

      const batched = tf.stack(imageTensors);
      return this.baseModel.predict(batched);
    });

    // Convert labels to one-hot
    const ys = tf.oneHot(tf.tensor1d(labels, 'int32'), this.classes.length);

    // Train only new layers
    await this.model.fit(features, ys, {
      epochs: 20,
      batchSize: 32,
      validationSplit: 0.2,
      callbacks: {
        onEpochEnd: (epoch, logs) => {
          console.log(
            `Epoch ${epoch + 1}: ` +
            `loss = ${logs.loss.toFixed(4)}, ` +
            `accuracy = ${logs.acc.toFixed(4)}`
          );
        }
      }
    });

    features.dispose();
    ys.dispose();
  }

  async predict(imageElement) {
    const processed = tf.tidy(() => {
      const img = tf.browser.fromPixels(imageElement)
        .resizeBilinear([224, 224])
        .toFloat()
        .div(127.5)
        .sub(1)
        .expandDims(0);

      const features = this.baseModel.predict(img);
      return this.model.predict(features);
    });

    const probabilities = await processed.data();
    const classIndex = processed.argMax(-1).dataSync()[0];

    processed.dispose();

    return {
      class: this.classes[classIndex],
      confidence: (probabilities[classIndex] * 100).toFixed(2) + '%',
      allProbabilities: this.classes.map((name, i) => ({
        name,
        probability: (probabilities[i] * 100).toFixed(2) + '%'
      }))
    };
  }
}

// Usage
const classifier = new PoseClassifier();

await classifier.loadBaseModel();
classifier.createCustomModel();

// Collect data (take photos of each exercise)
const squatImgs = [img1, img2, img3]; // 3 examples
const pushupImgs = [img4, img5, img6];
// ... more examples

const allImages = [...squatImgs, ...pushupImgs, ...];
const labels = [0, 0, 0, 1, 1, 1, ...]; // Class indices

await classifier.train(allImages, labels);

// Make prediction on new image
const result = await classifier.predict(newImageElement);
console.log(result);
// {
//   class: 'squat',
//   confidence: '94.32%',
//   allProbabilities: [...]
// }

Performance: GPU Acceleration

TensorFlow.js uses WebGL for GPU computing — crucial for large models.

Important optimizations:

// 1. Use tf.tidy() to manage memory
function processData(input) {
  return tf.tidy(() => {
    // All tensors created here are automatically released
    const normalized = input.div(255);
    const reshaped = normalized.reshape([1, 224, 224, 3]);
    return model.predict(reshaped);
  });
}

// 2. Batch processing for multiple images
async function classifyMultiple(images) {
  const batch = tf.tidy(() => {
    const tensors = images.map(img =>
      tf.browser.fromPixels(img).resizeBilinear([224, 224])
    );
    return tf.stack(tensors);
  });

  const predictions = await model.predict(batch).array();

  batch.dispose();

  return predictions;
}

// 3. Quantization for smaller models
async function loadQuantizedModel() {
  // Model with 4x less weight
  const model = await tf.loadGraphModel(
    'https://example.com/model_quantized/model.json'
  );

  return model;
}

// 4. Web Workers to not block UI
// worker.js
importScripts('https://cdn.jsdelivr.net/npm/@tensorflow/tfjs');

self.onmessage = async (e) => {
  const { imageData } = e.data;

  const tensor = tf.browser.fromPixels(imageData);
  const prediction = await model.predict(tensor);

  self.postMessage({ prediction: await prediction.array() });

  tensor.dispose();
  prediction.dispose();
};

Advanced Use Cases

1. Pose Detection (PoseNet):

import * as posenet from '@tensorflow-models/posenet';

const net = await posenet.load({
  architecture: 'MobileNetV1',
  outputStride: 16,
  inputResolution: { width: 640, height: 480 },
  multiplier: 0.75
});

const pose = await net.estimateSinglePose(video);

// pose.keypoints contains 17 points (nose, eyes, shoulders, etc)
const nose = pose.keypoints.find(kp => kp.part === 'nose');
console.log(`Nose at: x=${nose.position.x}, y=${nose.position.y}`);

2. Person Segmentation (BodyPix):

import * as bodyPix from '@tensorflow-models/body-pix';

const net = await bodyPix.load();

const segmentation = await net.segmentPerson(video);

// Apply background blur (Zoom style)
const foregroundColor = { r: 0, g: 0, b: 0, a: 0 };
const backgroundColor = { r: 0, g: 0, b: 0, a: 255 };
const backgroundBlurAmount = 15;

const backgroundBlur = await bodyPix.blurBodyPart(
  canvas,
  video,
  segmentation,
  backgroundBlurAmount,
  foregroundColor,
  backgroundColor
);

3. Speech Recognition:

import * as speechCommands from '@tensorflow-models/speech-commands';

const recognizer = speechCommands.create('BROWSER_FFT');

await recognizer.ensureModelLoaded();

// Listen for commands
recognizer.listen(result => {
  const scores = result.scores;
  const maxScore = Math.max(...scores);
  const command = recognizer.wordLabels()[scores.indexOf(maxScore)];

  console.log(`Command detected: ${command} (${(maxScore * 100).toFixed(2)}%)`);
}, {
  includeSpectrogram: true,
  probabilityThreshold: 0.75
});

Limitations and Considerations

What TensorFlow.js is NOT ideal for:

  • Training giant models (GPT, DALL-E)
  • Massive batch processing (millions of images)
  • Cutting-edge ML research

Ideal for:

  • Real-time client inference
  • Small/medium models (<50MB)
  • Quick prototypes
  • Applications that need privacy

Trade-offs:

  • Performance: ~2-5x slower than native Python/C++
  • Size: Models need to be light for web
  • Compatibility: Not all Python APIs are available

The Future of ML in JavaScript

2025 Trends:

  • WebGPU for even better performance
  • Increasingly smaller models (compression techniques)
  • Edge computing (ML on IoT devices with JS)
  • Framework integration (React, Vue components with ML)

If you want to explore more about how JavaScript is expanding to innovative areas, see JavaScript and the IoT World: Integrating Web to Physical Environment.

Let's go! 🦅

📚 Want to Deepen Your JavaScript Knowledge?

This article covered Machine Learning with JavaScript, but there's much more to explore in modern development.

Developers who invest in solid, structured knowledge tend to have more opportunities in the market.

Complete Study Material

If you want to master JavaScript from basics to advanced, I've prepared a complete guide:

Investment options:

  • $4.90 (single payment)

👉 Learn About JavaScript Guide

💡 Material updated with industry best practices

Comments (0)

This article has no comments yet 😢. Be the first! 🚀🦅

Add comments