Machine Learning in JavaScript with TensorFlow.js

在JavaScript中训练和运行机器学习模型!学习机器学习基础知识,并使用TensorFlow.js,仅用几行代码就能借助预训练模型添加图像识别功能。利用网络摄像头的输入来训练自定义模型。学习如何优化和提高模型准确性,并在物体识别、手势识别和音频识别方面发挥创意。

0-introduction

课程介绍

  • 讲师:Charlie Gerard,Socket 公司的资深工程师
  • 自称创意技术专家,专注于探索 JavaScript 的边界应用
  • 从事前端开发 10 年,业余时间探索 JavaScript 在机器学习领域的应用

TensorFlow.js 应用案例展示

实用项目案例

暗模式拍手扩展

  • 使用机器学习模型训练拍手声音识别
  • 通过两次拍手切换网站的暗模式和亮模式
  • 在 Netlify 公司工作期间开发的定制扩展

垃圾分类识别应用

  • 手机端优化的网站应用
  • 拍摄物品照片,识别应投放的垃圾桶类型
  • 结合图像检测和城市公共资源数据
  • 根据所在地区的法规提供分类建议

创意项目案例

体感街霸游戏

  • 使用硬件设备构建自定义机器学习模型
  • 通过身体动作实时训练模型
  • 挥拳动作控制游戏角色
  • 虽为原型但展示了实时数据采集和模型训练的可能性

体感水果忍者游戏

  • 结合 Three.js 构建 3D 环境
  • 纯客户端 JavaScript 实现
  • 通过摄像头追踪身体动作
  • 用手势在 3D 世界中切水果

Figma 插件 UI 控制

  • 通过 TensorFlow.js 追踪身体动作
  • 将动作映射到 Figma 图层操作
  • 探索键盘鼠标之外的交互方式
  • 在浏览器外环境使用 JavaScript

课程技术要求

  • Git 版本控制
  • Node.js 环境
  • 代码编辑器(推荐 VS Code)
  • 浏览器(推荐 Chrome)
  • 通过 npm install 安装依赖即可开始

课程结构

三个核心主题

预训练模型使用

  • 直接使用已训练好的模型
  • 快速实现机器学习功能

迁移学习

  • 使用预训练模型添加自定义输入
  • 将通用模型适配特定需求
  • 类似于 RAG(检索增强生成)概念

自定义模型构建

  • 从零开始创建图像检测模型
  • 数据收集和处理
  • 训练集和测试集划分
  • 模型层结构设计
  • 最终实现浏览器端绘图识别应用

1-machine-learning-overview

机器学习基础概念

核心定义

机器学习是让计算机在不被明确编程的情况下执行任务的能力。与传统编程的区别在于:

  • 传统编程:编写明确指令,计算机严格按照指令执行
  • 机器学习:通过寻找模式和规律来预测结果,无需明确告知具体操作

技术分类层次

人工智能 (AI)
├── 机器学习 (Machine Learning)
│   └── 深度学习 (Deep Learning) - 使用神经网络
│       ├── 计算机视觉 (Computer Vision)
│       └── 自然语言处理 (Natural Language Processing)

机器学习核心术语

模型 (Model)

  • 本质:接收输入、执行计算、产生概率输出的函数
  • 功能:将数据转换为预测结果

机器学习算法类型

卷积神经网络 (CNN)

  • 主要用途:图像检测和识别
  • 特点:专门处理图像数据的算法结构

朴素贝叶斯 (Naive Bayes)

  • 适用场景:处理数值型数据表格

长短期记忆网络 (LSTM)

  • 适用场景:序列数据处理

权重 (Weights)

  • 定义:衡量数据集中每个特征对模型准确性的重要程度
  • 作用:确定不同特征在预测中的影响力
  • 示例:花朵识别中,颜色、花瓣长度、花瓣质感等特征的重要性各不相同

过拟合 (Overfitting)

  • 现象:模型过度适应训练数据集
  • 表现:训练准确率达到 1.0(完美),但对新数据预测能力差
  • 原因:模型记住了训练数据的具体内容,缺乏泛化能力

类别/标签 (Classes/Labels)

  • 定义:对单个数据的描述性标识
  • 格式:通常为单个词汇
  • 作用:建立数据与标签的关联关系,供模型学习

机器学习类型

监督学习 (Supervised Learning)

  • 特点:提供带标签的数据集
  • 过程:算法学习数据与标签的对应关系
  • 应用:本课程主要采用此方法
  • 示例:图像-标签配对的数据集

无监督学习 (Unsupervised Learning)

  • 特点:仅提供数据,不提供标签
  • 目的:发现数据中的潜在模式和聚类
  • 应用场景:情感分析、客户反馈分类
  • 结果:将数据分为不同类别或情感倾向

半监督学习 (Semi-supervised Learning)

  • 特点:部分数据有标签,部分无标签
  • 优势:可能实现更高的准确性和效率
  • 状态:在 TensorFlow.js 中的实现方式不明确

强化学习 (Reinforcement Learning)

  • 机制:基于奖励和惩罚的训练方法
  • 过程:通过试错学习期望行为
  • 应用:主要用于游戏 AI 和智能体开发
  • 特点:模型通过反馈不断优化决策策略

JavaScript 机器学习工具

TensorFlow.js

  • 地位:本课程主要使用的框架
  • 特点:Google 开发的 JavaScript 机器学习库

其他工具选择

ML5.js

  • 特点:基于 TensorFlow.js 的更简化抽象层
  • 适用:快速原型开发

专用库

  • Brain.js:通用神经网络库
  • ConvNet.js:专门用于卷积神经网络
  • Keras.js:Keras 的 JavaScript 版本实现

2-pre-trained-models-overview

预训练模型概念

基本定义

预训练模型是已经完成训练和测试、可直接使用的机器学习模型。模型的创建过程包括:

  • 数据收集和处理
  • 数据标注
  • 模型训练
  • 结构优化
  • 准确性测试
  • 开源发布或内部共享

使用方式

作为使用者,只需:

  1. 提供新的输入数据
  2. 获得模型预测结果

预训练模型类型

图像识别模型

  • 功能:识别和分类图像内容
  • 应用:物体检测、场景理解

文本分类模型

  • 功能:对文本内容进行分类
  • 代表:大语言模型(LLM)
  • 机制:基于大量文本训练,通过关联分析回答问题

情感分析模型

  • 训练方式:监督学习,使用带有情感标签的数据
  • 功能:分析文本的情感倾向(快乐、悲伤等)
  • 输出:情感类别标签

模型获取资源

Kaggle 平台

  • 特点:可按 TensorFlow.js 筛选模型
  • 内容:包含多种框架的模型,部分专门适配 Python
  • 使用:通过筛选功能找到 JavaScript 兼容模型

GitHub 官方仓库

  • 位置:TensorFlow.js 专用模型仓库
  • 内容:姿态检测、人脸检测等多种模型
  • 特点:官方维护,兼容性良好

Hugging Face 平台

  • 定位:流行的模型分享平台
  • 兼容性:部分模型支持 TensorFlow.js,部分需要其他工具
  • 扩展:掌握机器学习基础后可探索不同库的使用

模型选择标准

核心考虑因素

应用场景匹配

  • 原则:选择与应用需求匹配的模型类型
  • 示例:图像应用选择图像训练模型,避免文本训练模型

训练数据质量

  • 重要性:决定模型在特定场景的表现
  • 匹配度:模型训练内容应与应用目标相符
  • 示例:动物识别应用避免使用汽车图像训练的模型

数据量考虑

  • 原则:数据量通常越大越好
  • 平衡:需同时考虑数量和质量
  • 风险:大量低质量或错误标注数据无助于提升效果

类别标签获取

  • 价值:了解模型训练的具体类别列表
  • 来源:Google 开源模型通常提供类别清单
  • 验证:检查目标预测类别是否在训练范围内
  • 替代:无法获取时通过实际测试验证效果

实践建议

  • 优先选择有详细文档的模型
  • 测试模型在具体应用场景的表现
  • 关注模型的维护状态和社区支持
  • 考虑模型大小对应用性能的影响

3-object-detection-with-tensorflow-js

项目环境搭建

仓库获取

  • GitHub 地址:github.com/charliegerald/fem-ml-workshop
  • 项目结构:Exercises/Project 1
  • 启动命令:npm watch

代码文件结构

utils.js 工具文件

  • 功能:包含与 TensorFlow.js 无关的 JavaScript 辅助函数
  • showResults():将模型输出显示在屏幕上而非仅在控制台
  • handleFilePicker():处理文件选择功能

UI 交互设计

  • 界面:简单的文件选择按钮
  • 流程:用户选择本地图像 → 模型预测 → 显示结果
  • 建议:准备猫、狗、汽车等测试图像

文件选择器实现

// handleFilePicker函数核心逻辑
- 获取DOM元素
- onChange事件处理
- 图像格式验证
- FileReader API使用
- 图像元素创建
- 回调函数调用predict()

TensorFlow.js 实现

基础导入

import * as tf from "tensorflow";
import * as cocoSsd from "coco-ssd";

模型加载和初始化

let model;

async function init() {
  // 加载预训练模型
  model = await cocoSsd.load();

  // 启用文件选择器
  handleFilePicker(predict);
}

预测函数实现

async function predict(imageElement) {
  // 执行检测
  const predictions = await model.detect(imageElement);

  // 显示结果
  console.log(predictions);
  showResult(predictions); // 可选的UI显示
}

// 启动应用
init();

CocoSSD 模型特性

模型基本信息

  • 类型:物体检测预训练模型
  • 大小:约 7KB(gzip 压缩后 2.8KB)
  • 优化:专为浏览器环境优化

输出格式

  • 预测结果:包含类别和概率的数组
  • 概率范围:0-1 之间的数值
  • 示例:{class: "cat", probability: 0.98} 表示 98%的置信度识别为猫

模型局限性

训练数据限制

  • 问题:无法获取完整的训练类别列表
  • 影响:可能无法识别训练集外的物体
  • 示例:意大利烩饭被识别为卡车

识别准确性因素

  • 物体在图像中的比例
  • 图像质量和清晰度
  • 物体与训练数据的相似程度

实际应用场景

垃圾分类应用

  • 功能:识别物品并建议投放垃圾桶类型
  • 实现:图像检测 + 本地法规数据库
  • 注意:需考虑不同地区的分类标准差异

自动 Alt 标签生成

  • 应用:为网页图像自动生成替代文本
  • 流程:模型初步识别 + 人工审核优化
  • 价值:提升网页可访问性

内容审核工具

  • 示例:nsfw.js(不适宜工作场所内容检测)
  • 功能:自动标记不当图像内容
  • 应用场景:用户上传内容的初步筛选
  • 输出:内容类型标签(中性、不当等)

开发技巧

简化实现方案

// 最简实现:无需文件选择器
const imageElement = document.querySelector("img");
const predictions = await model.detect(imageElement);

性能考虑

  • 模型大小:适合网页应用的轻量级设计
  • 加载时间:相对较快的初始化过程
  • 运行环境:纯客户端 JavaScript 实现

4-using-webcam-as-input

项目升级概述

改进目标

  • 保持使用 CocoSSD 模型
  • 将输入源从本地文件改为摄像头实时画面
  • 提供更动态的用户交互体验

代码文件结构

  • 新文件:part2.js
  • HTML 修改:将 part1.js 改为 part2.js
  • 保留:大部分核心逻辑代码

摄像头集成实现

核心导入保持不变

import * as tf from "tensorflow";
import * as cocoSsd from "coco-ssd";

摄像头工具函数

utils.js 中的辅助函数

  • startWebcam():启动摄像头功能
  • takePicture():从视频流捕获静态图像

getUserMedia API 实现

navigator.mediaDevices.getUserMedia({
  audio: false,
  video: { width: xxx, height: xxx },
});

UI 元素获取

const webcamButton = document.getElementById("webcam");
const captureButton = document.getElementById("pause"); // 实际为捕获功能
const video = document.querySelector("video");

模型初始化

let model;

async function init() {
  model = await cocoSsd.load();

  // 绑定摄像头启动事件
  webcamButton.onclick = () => {
    startWebcam(video);
  };

  // 绑定图像捕获事件
  captureButton.onclick = () => {
    takePicture(video, predict);
  };
}

图像捕获机制

技术原理

  • 视频流本质:连续的图像序列
  • 捕获方法:从视频流中提取单帧图像
  • 处理流程:视频元素 → Canvas 画布 → 静态图像

takePicture 函数工作流程

  1. 获取当前视频画面
  2. 绘制到 Canvas 元素
  3. 从 Canvas 获取图像数据
  4. 调用预测回调函数

预测函数实现

async function predict(imageElement) {
  const predictions = await model.detect(imageElement);
  console.log(predictions);
  showResult(predictions);
}

用户交互流程

操作步骤

  1. 点击"Webcam"按钮启动摄像头
  2. 浏览器请求摄像头权限(用户需允许)
  3. 实时视频流显示
  4. 点击"Capture"按钮捕获当前画面
  5. 点击"Predict"按钮进行物体识别

实际测试结果

  • 人物识别:预测"person",置信度 0.91(91%)
  • 手机识别:预测"remote",置信度 0.92(92%)
  • 识别逻辑:手机外观与遥控器相似,模型给出合理推测

实际应用案例

垃圾分类识别应用

  • 功能实现:开启摄像头 → 拍摄物品 → 识别类别 → 查询分类规则
  • 材质识别:结合物品识别和材质判断
  • 地域适配:根据当地法规提供分类建议
  • 用户体验:移动端优化的实时识别

技术架构优势

  • 客户端处理:无需服务器上传图像
  • 实时反馈:即时获得识别结果
  • 隐私保护:图像数据不离开设备
  • 离线可用:加载模型后可离线使用

开发注意事项

权限处理

  • 摄像头权限:用户必须主动允许
  • 错误处理:权限拒绝时的用户提示
  • 隐私考虑:明确告知用户数据使用方式

性能优化

  • 图像尺寸:模型可能对输入图像进行 resize
  • 显示效果:UI 中的比例可能不协调,但不影响识别精度
  • 响应时间:不同设备的处理速度差异

5-using-a-face-detection-model

项目演进概述

变更重点

  • 保持摄像头输入方式
  • 更换为人脸检测专用模型
  • 演示同一输入源适配不同模型的灵活性

技术栈变更

  • HTML 修改:part2.js → part3.js
  • 保留:摄像头交互逻辑
  • 新增:MediaPipe 人脸检测模型

MediaPipe 模型集成

导入配置

import * as tf from "@tensorflow/tensorflow.js";
import * as faceDetection from "@tensorflow-models/face-detection";
// 可能需要额外的后端支持
// import '@tensorflow/tfjs-backend-webgl';

MediaPipe 工具套件

  • 开发商:Google
  • 特点:专注人体部位检测
  • 模型类型:人脸、姿态、手部动作检测
  • 应用案例:Figma 插件中使用手部检测

UI 元素复用

const webcamButton = document.getElementById("webcam");
const captureButton = document.getElementById("pause");
const video = document.querySelector("video");

模型初始化差异

MediaPipe 语法特点

let detector;

async function init() {
  detector = await faceDetection.createDetector(
    faceDetection.SupportedModels.MediaPipeFaceDetector,
    {
      runtime: "tfjs",
    }
  );
}

语法对比

  • CocoSSD:cocoSsd.load()
  • MediaPipe:createDetector(model, config)
  • 原因:MediaPipe 模型需要更多配置参数

人脸检测实现

预测函数

async function predict(photo) {
  const faces = await detector.estimateFaces(photo, {
    flipHorizontal: false,
  });
  console.log(faces);
}

方法命名差异

  • 通用模型:predict()detect()
  • 人脸专用:estimateFaces()
  • flipHorizontal 参数:控制图像水平翻转

完整初始化流程

async function init() {
  detector = await faceDetection.createDetector(
    faceDetection.SupportedModels.MediaPipeFaceDetector,
    { runtime: "tfjs" }
  );

  webcamButton.onclick = () => startWebcam(video);
  captureButton.onclick = () => takePicture(video, predict);
}

init();

检测结果分析

输出数据结构

边界框信息(box)

  • height:人脸区域高度
  • width:人脸区域宽度
  • xMax, yMax:边界坐标
  • 用途:在图像上绘制人脸框

关键点数据(keypoints)

  • rightEye:右眼坐标
  • leftEye:左眼坐标
  • 其他面部特征点的精确位置
  • 应用:精确定位面部特定区域

多人脸支持

  • 返回值:人脸数组
  • 单人结果:数组包含一个人脸对象
  • 多人潜力:理论支持多人同时检测(未测试)

技术细节处理

图像尺寸适配

  • 模型预处理:自动调整输入图像尺寸
  • UI 显示:可能出现比例失调,但不影响检测精度
  • 优化考虑:模型训练时可能使用方形图像

性能特点

  • 处理时间:相比物体检测可能略慢
  • 精度优势:专门针对人脸特征优化
  • 实时性:支持连续检测应用

扩展应用前景

下一步开发方向

  • 人脸框绘制:使用边界框坐标在图像上绘制检测框
  • 特征点应用:基于关键点实现面部滤镜或装饰
  • 多人场景:测试和优化多人脸同时检测
  • 实时追踪:结合视频流实现连续人脸跟踪

6-detecting-parts-of-the-face.txt

人脸检测框绘制实现

本节课程介绍如何在检测到人脸后,在图像上绘制检测框来标识人脸位置。

Canvas 绘制基础

使用 Canvas API 在前端绘制检测框:

  • 创建 Canvas 元素,获取图像尺寸
  • 使用从人脸检测模型返回的坐标数据
  • 通过 Canvas 的绘图方法在人脸位置绘制矩形框

检测框绘制流程

绘制检测框的关键步骤:

  1. 调用beginPath()开始绘制路径
  2. 设置strokeStyle定义线条颜色(示例中使用红色)
  3. 使用检测结果中的坐标信息绘制矩形
  4. 从 faces 对象中获取第一个检测结果(faces[0])
  5. 提取 xMin、yMin 作为左上角坐标
  6. 使用 width 和 height 定义矩形尺寸

坐标系统说明

人脸检测返回的坐标信息包含:

  • xMin, yMin: 检测框左上角坐标
  • xMax, yMax: 检测框右下角坐标
  • width, height: 检测框的宽度和高度

代码实现

通过导入drawFaceBox函数,在预测完成后调用:

  • 第一个参数传入图像数据(photo)
  • 第二个参数传入检测结果(faces)
  • 函数会自动在图像上绘制红色检测框

应用场景扩展

这种检测框技术可以应用于:

  • 车辆计数系统
  • 人流统计应用
  • 艺术装置中的交互检测
  • 各种需要物体定位的计算机视觉项目

通过这种方式,可以直观地看到模型检测到的对象位置,为后续的应用开发奠定基础。

7-face-detection-demos-and-q-a.txt

人脸关键点检测演示

本节展示了更高级的人脸检测技术和实际应用案例。

MediaPipe 人脸关键点模型

使用 MediaPipe 模型进行更精确的人脸分析:

  • 提供人脸网格数据,包含大量关键点
  • 每个关键点都有对应的 x、y 坐标信息
  • 可以检测眼部运动和其他精细面部特征

眼部追踪技术

通过分析眼部关键点实现眼动追踪:

  • 识别眼睛的白色部分和瞳孔位置
  • 计算当前位置与下一个位置的变化量
  • 判断眼睛移动方向(左右移动)

免手编程系统演示

基于眼动追踪开发的编程辅助工具:

  • 通过眼睛移动选择代码片段
  • 将代码选项分为左右两列布局
  • 支持变量声明、函数创建、控制台输出等常用代码模板
  • 通过眼部方向选择对应的代码类型

界面设计策略

采用二分法布局提高选择效率:

  • 将选项分为左右两个区域
  • 根据视线方向过滤显示内容
  • 类似二进制搜索的交互方式
  • 大幅减少选择时间和操作步骤

凝视控制键盘

复制 Google 无障碍项目的浏览器版本:

  • 为行动不便用户设计的文本输入方案
  • 通过眼部运动选择字母
  • 采用分层选择机制提高输入效率
  • 无需下载应用,直接在浏览器中使用

游戏应用演示

将眼动追踪应用到 Chrome 恐龙游戏:

  • 通过向上看触发跳跃动作
  • 反应速度快,延迟很低
  • 展示了在浏览器中进行实时眼动检测的可行性

技术性能评估

JavaScript 和摄像头结合的实时检测表现出色:

  • 响应速度超出预期
  • 浏览器环境下的模型运行效率很高
  • 为更多创新应用奠定了技术基础

这些演示展现了人脸检测技术在交互设计和无障碍功能开发中的巨大潜力。

8-transfer-learning-overview.txt

迁移学习概述

本节介绍迁移学习的基本概念和在 TensorFlow.js 中的应用。

迁移学习基本概念

迁移学习是在预训练模型基础上添加自定义样本来执行新任务的方法:

  • 利用现有模型的训练成果
  • 添加特定领域的训练数据
  • 快速获得针对特定任务的识别能力

实际应用场景

以意大利烩饭识别为例:

  • 通用图像识别模型无法识别烩饭
  • 通过添加烩饭图片样本和标签
  • 训练模型学会识别烩饭
  • 可以扩展到任何自定义物体识别

浏览器端实现

在浏览器中直接进行迁移学习的优势:

  • 用户可以添加自己的图像样本
  • 实时训练和测试模型效果
  • 无需服务器端处理
  • 快速构建定制化识别应用

关键术语解释

训练周期(Epochs)

  • 训练过程中的迭代次数
  • 模型需要多次处理数据来优化预测
  • 通过多次迭代逐步提升准确性

激活函数(Activation Function)

  • 添加到神经网络中的数学函数
  • 帮助网络学习复杂的数据模式
  • 不同类型的激活函数适用于不同场景

批量大小(Batch Size)

  • 每次训练使用的样本数量
  • 可以将数据集分割成小批次处理
  • 影响训练效率和内存使用

超参数调优

学习率(Learning Rate)

  • 控制模型权重更新幅度的超参数
  • 需要通过实验找到最优值
  • 影响模型收敛速度和最终准确性

权重更新机制

  • 决定各个特征在预测中的重要程度
  • 学习率控制权重调整的频率和幅度

数据结构概念

张量(Tensor)

  • 描述物理属性的数学对象
  • 机器学习中的专用数据类型
  • TensorFlow 框架的核心数据结构

优化器(Optimizer)

  • 调整神经网络权重和学习率的函数
  • 自动优化模型训练过程
  • 提高训练效率和模型性能

这些概念为后续的实际编程实践奠定了理论基础,有助于理解迁移学习的工作原理和参数设置。

9-training-a-model-with-teachable-machine.txt

使用 Teachable Machine 训练模型

本节演示如何使用 Google 的 Teachable Machine 工具快速创建自定义图像识别模型。

Teachable Machine 简介

访问 teachablemachine.withgoogle.com 开始模型训练:

  • 支持图像、音频、姿势三种项目类型
  • 选择标准图像模型进行浏览器端训练
  • 嵌入式图像模型适用于 Arduino 等微控制器

项目设置

创建图像分类项目的基本步骤:

  • 选择图像项目类型
  • 使用标准图像模型(适合浏览器环境)
  • 设置两个不同的分类标签进行训练

数据采集配置

摄像头设置参数说明:

  • 帧率:控制图像捕获频率
  • 录制模式:可选择按住录制或连续录制
  • 延迟时间:默认每 2 秒拍摄一张照片
  • 录制时长:最长录制持续时间限制

样本收集策略

高质量训练数据的收集要点:

  • 每个标签收集约 30 个样本图像
  • 样本数量不需要完全相等,但应该相近
  • 避免一个标签样本过多而另一个过少
  • 确保动作差异明显(如头部左倾 vs 右倾)

训练数据平衡

样本数量对模型准确性的影响:

  • 数据不平衡会影响预测准确性
  • 类似人类学习,重复次数影响记忆效果
  • 如果两个动作差异细微,需要更多样本
  • 动作差异越大,需要的训练样本越少

模型训练过程

训练操作和注意事项:

  • 点击训练按钮开始模型训练
  • 训练期间不要关闭浏览器标签
  • 训练在浏览器中实时进行
  • 训练完成后立即可以测试效果

模型测试

验证训练效果的方法:

  • 实时测试不同姿势的识别准确性
  • 观察置信度百分比变化
  • 测试边界情况(如中间位置)
  • 确认模型能够区分不同类别

模型导出

将训练好的模型保存到本地:

  • 选择下载选项而非云端上传
  • 获得包含模型文件的 ZIP 压缩包
  • 将模型文件夹重命名为易于识别的名称
  • 为后续代码集成做好文件路径准备

通过这种方式,无需深入的机器学习知识就能快速创建自定义识别模型,为后续的应用开发提供了便利的起点。

10-using-a-custom-trained-model.txt

使用自定义训练模型

本节演示如何在 Web 应用中集成和使用从 Teachable Machine 导出的自定义模型。

项目结构设置

项目文件组织方式:

  • 使用 HTML 模板文件作为起点
  • 无 utils 工具文件,所有功能从零开始编写
  • 包含第二部分项目,将实现完整的训练界面
  • 基本 HTML 结构包含启动按钮、摄像头容器和标签容器

库文件导入

通过 CDN 引入必要的 JavaScript 库:

  • TensorFlow.js 核心库(压缩版)
  • Teachable Machine 图像识别库
  • 使用 script 标签方式而非 ES6 模块导入
  • 展示两种不同的库引入方法

模型文件结构分析

导出的模型包含三个关键文件:

model.json

  • 描述神经网络的层次结构
  • 包含序列模型(Sequential Model)定义
  • 每层的参数配置,如 Conv2D 卷积层
  • 指定神经网络的完整架构

metadata.json

  • 记录 TensorFlow.js 版本信息
  • 包含训练时使用的标签(如"right"和"left")
  • 指定训练图像尺寸(224 像素)
  • 存储模型预测时需要的元数据

weights 文件

  • 二进制格式存储的模型权重
  • 包含训练过程中优化的参数
  • 模型运行时的必需文件
  • 不可删除,否则模型无法正常工作

模型加载实现

设置模型路径和加载逻辑:

const path = "./my_model/";
const modelPath = path + "model.json";
const metadataPath = path + "metadata.json";

const model = await tmImage.load(modelPath, metadataPath);

模型信息获取

通过 API 方法获取模型基本信息:

  • getTotalClasses()方法获取分类总数
  • 了解模型可以识别的类别数量
  • 为后续预测结果处理做准备

摄像头集成

使用 Teachable Machine 库的摄像头功能:

  • 创建 webcam 对象并设置尺寸参数
  • 第三个参数控制摄像头画面是否翻转
  • 调用 setup()方法请求摄像头访问权限
  • 通过 play()方法开始摄像头数据流

实时预测循环

建立连续预测机制:

  • 使用 requestAnimationFrame 创建渲染循环
  • 将摄像头画布添加到页面容器中
  • 准备实现持续的图像识别功能

技术架构说明

权重文件的重要性:

  • 存储训练过程中调整的模型参数
  • model.json 引用 weights 文件进行预测
  • 即使是二进制格式也不能删除
  • 是模型智能的核心数据载体

这种架构为构建实时的自定义物体识别应用提供了完整的技术框架。

11-running-predictions-on-a-loop

创建循环预测函数

创建一个异步循环函数来持续运行预测:

const loop = async function () {
  // 更新摄像头帧
  webcam.update();
  // 运行预测
  await predict();
  // 继续循环
  window.requestAnimationFrame(loop);
};

实现预测函数

预测函数的基本结构:

const predict = async function () {
  const prediction = await model.predict(webcam.canvas);
  console.log(prediction);
};

解决常见问题

启动服务器问题

  • 如果使用 npm run start 遇到 "Failed to parse model JSON response" 错误
  • 改用 npm run watch 命令启动 Python 服务器在 localhost:1234

循环执行问题

在 predict 函数末尾需要再次调用 window.requestAnimationFrame(loop) 确保持续运行

处理预测结果

获取最高概率预测

从预测数组中提取概率最高的结果:

// 找到最高概率值
const topPrediction = Math.max(...predictions.map((p) => p.probability));

// 找到对应的索引
const topPredictionIndex = predictions.findIndex(
  (p) => p.probability === topPrediction
);

// 获取标签
console.log(predictions[topPredictionIndex].className);

摄像头翻转设置

如果预测结果与实际动作相反,可以调整摄像头翻转设置:

const webcam = await tmImage.webcam(200, 200, true); // 最后一个参数设为true

改进模型准确性

增加训练样本

  • 原始 30 个样本可能不够准确
  • 建议增加到 100 个样本提高准确性
  • 重新训练后导出新模型替换原有文件

训练注意事项

  • 确保在相同环境下录制样本
  • 保持一致的光照条件
  • 注意背景对预测的影响
  • 位置变化会影响预测准确性

完整工作流程

  1. 启动摄像头并获取权限
  2. 加载训练好的模型
  3. 循环更新摄像头帧
  4. 对每一帧进行预测
  5. 提取最高概率的预测结果
  6. 显示或使用预测标签

应用扩展思路

  • 头部控制滚动:通过上下头部动作控制页面滚动
  • 免触控交互:在不方便使用手的情况下进行操作
  • 自定义手势识别:根据具体应用需求训练不同手势

12-more-ways-to-train-models

训练样本数量的考虑因素

影响因素

  • 背景复杂度:背景越复杂,需要更多样本来区分前景和背景
  • 光照条件:光线不佳的环境需要更多样本来适应不同光照
  • 手势差异度:相似手势之间的区别越小,需要更多样本来准确区分
  • 环境一致性:如果使用环境与训练环境不同,需要在多种环境下录制样本

训练策略

  • 在目标使用环境中录制样本
  • 如果用户会在户外使用,应在户外环境下训练
  • 录制多种环境下的相同手势来提高泛化能力
  • 通过多样本让模型识别出相同手势的共同特征

Tiny Motion Trainer 工具

基本概念

  • 使用 Arduino 微控制器进行动作识别
  • 专门设计用于 TensorFlow.js 的轻量化模型
  • 使用 TensorFlow Lite 在微控制器上运行

支持的硬件

  • Arduino BLE Sense
  • Arduino IoT Sense
  • 内置加速度计和陀螺仪传感器

工作流程

  1. 通过 Web USB API 连接 Arduino 到浏览器
  2. 设置参数:运动阈值、样本数量、捕获延迟
  3. 录制不同手势的动作数据
  4. 训练模型并通过蓝牙上传到 Arduino
  5. Arduino 实时识别动作并返回标签

应用场景

  • 空中手势控制
  • 体感游戏交互
  • 无接触式设备控制
  • 运动检测和分析

Arduino 编程环境

传统方式

  • 使用 C 语言或 Arduino 专用语言
  • 直接在板上运行,性能最优
  • 适合时间敏感的应用

JavaScript 方式

  • 使用 Firmata 库作为桥接
  • JavaScript 代码转换为 Arduino 指令
  • 存在延迟,不适合实时性要求高的应用
  • 使用 Johnny Five 库进行开发

开发工具和库

  • Johnny Five:Arduino 的 JavaScript 开发库
  • Firmata:JavaScript 到 Arduino 的通信协议
  • Node.js:服务端 JavaScript 运行环境
  • WebSockets:实时数据传输

Web API 技术扩展

硬件交互 API

  • Web USB API:浏览器直接访问 USB 设备
  • Web Bluetooth API:无线设备连接和通信
  • 为创意实验项目提供更多可能性

技术组合应用

  • 机器学习 + 硬件传感器
  • Web 技术 + 物理计算
  • 浏览器 + IoT 设备集成

13-training-an-audio-model

音频模型训练基础

背景噪音采集

  • 需要至少 20 秒的背景噪音样本
  • 音频采样以 1 秒为单位进行
  • 安静环境下录制纯背景噪音作为基准

音频到图像的转换

  • 原始音频数据包含数千或数百万个数据点
  • 直接使用原始数据训练速度很慢
  • 将音频转换为频谱图(spectrogram)图像
  • 使用卷积神经网络处理这些小图像

频谱图特征

  • 背景噪音:频谱图基本为空,没有明显模式
  • 拍手声:短暂高频信号,在频谱图上呈现尖峰模式
  • 咳嗽声:包含多个频率成分,显示人声的复杂频率模式

音频样本录制

录制要求

  • 最少 8 个样本才能开始训练
  • 建议录制更多样本以提高准确性
  • 每个音频类别需要足够的变化样本

样本质量控制

  • 确保不同类别的音频在频谱图上有明显差异
  • 避免样本过于相似导致分类困难
  • 背景环境保持一致性

模型训练和预测

训练过程

  • 使用转换后的频谱图图像训练模型
  • 模型实际处理的是图像数据,非原始音频
  • 训练速度比直接使用音频数据快很多

预测参数调整

  • 重叠因子:控制预测频率和响应速度
  • 重叠因子越大,预测响应越快
  • 可以根据应用需求调整参数

技术原理

数据转换流程

  1. 录制原始音频数据
  2. 生成频谱图图像
  3. 将图像输入卷积神经网络
  4. 输出音频分类结果

可视化优势

  • 频谱图让我们能直观判断模型是否能区分不同音频
  • 如果不同类别的频谱图看起来太相似,预测准确性会下降
  • 为开发者和用户提供直观的训练反馈

14-sound-model-demos

洗手计时器项目

项目背景

  • 受 2020 年 Apple Watch 洗手功能启发
  • 在 WWDC 2020 发布后数小时内完成开发
  • 使用 JavaScript 和音频机器学习实现

实现方式

  • 训练识别水流声音的模型
  • 仅使用个人水龙头录制约 20 个样本
  • 检测到水声后启动 20 秒倒计时
  • 纯网页实现,可在手机上运行

开发优势

  • 无需购买 Apple Watch 即可获得类似功能
  • 个人开发者无需团队和资金投入
  • 展示了机器学习和 JavaScript 结合的可能性

声学活动识别

研究背景

  • 基于声学活动识别的学术论文
  • 通过环境声音改进 IoT 家居系统
  • 原论文非 JavaScript 实现,作者进行了技术转换

应用场景

  • 厨房场景:烹饪时免手操作 iPad 菜谱
  • 声音识别
    • 切菜声(刀具撞击砧板)
    • 搅拌机运行声
    • 冰箱门开关声
    • 微波炉完成提示音

技术优势

  • 替代购买多个智能设备的方案
  • 避免数据隐私泄露问题
  • 统一的智能家居控制系统
  • 基于环境音频的自动化触发

系统架构

  • 单一音频监听系统
  • 多种家庭活动声音识别
  • 智能设备状态推断和控制
  • 隐私保护的本地处理

项目意义

技术民主化

  • 复杂功能的简单实现方式
  • 个人开发者也能实现商业级功能
  • 开源技术降低创新门槛

实用性考量

  • 解决日常生活中的实际问题
  • 成本效益远超传统解决方案
  • 跨平台兼容性强

15-loading-the-model-webcam

HTML 结构调整

注释第一部分代码

注释掉 Part 1 相关的内容:

  • Teachable Machine 库的 script 标签
  • TensorFlow.js 的 script 标签
  • Part 1 的相关 HTML 元素

启用第二部分元素

取消注释 Part 2 的 HTML 结构:

  • 录制按钮(左右倾斜样本)
  • 训练按钮
  • 预测按钮
  • 样本数量显示区域
  • 将 script 引用改为 part2.js

JavaScript 库导入

基础库导入

import * as tf from "@tensorflow/tfjs";
import * as tfd from "@tensorflow/tfjs-data";

关键区别

  • 使用原生 TensorFlow.js 而非 Teachable Machine 库
  • 添加 TensorFlow.js 数据处理库
  • 需要手动实现迁移学习逻辑

DOM 元素获取

按钮和容器元素

const recordButtons = document.getElementsByClassName("record-button");
const buttonsContainer = document.getElementById("buttons-container");
const trainButton = document.getElementById("train");
const predictButton = document.getElementById("predict");
const statusElement = document.getElementById("status");

核心变量声明

模型相关变量

let webcam;
let initialModel; // 基础模型(MobileNet)
let newModel; // 训练后的新模型

训练参数设置

const learningRate = 0.0001; // 学习率
const batchSize = 0.4; // 批处理大小
const epochs = 30; // 训练轮数
const denseUnits = 100; // 密集层单元数

应用状态变量

let mouseDown = false; // 鼠标按下状态
let isTraining = false; // 训练状态
let isPredicting = false; // 预测状态
const totals = [0, 0]; // 样本计数
const labels = ["left", "right"]; // 标签定义

模型加载函数

MobileNet 模型获取

async function loadModel() {
  const mobilenet = await tf.loadLayersModel("MobileNet模型URL");
  const layer = mobilenet.getLayer("conv_pw_13_relu");
  const initialModel = tf.model({
    inputs: mobilenet.inputs,
    outputs: layer.output,
  });
  return initialModel;
}

关键技术点

  • 使用loadLayersModel而非loadModel
  • 提取 MobileNet 的特定层作为特征提取器
  • 创建新的模型作为迁移学习的基础

初始化函数

摄像头和模型初始化

async function init() {
  // 初始化摄像头
  webcam = await tfd.webcam(document.getElementById("webcam"));

  // 加载基础模型
  initialModel = await loadModel();

  // 更新UI状态
  statusElement.style.display = "none";
  document.getElementById("controller").style.display = "block";
}

UI 状态管理

  • 隐藏"加载模型"提示
  • 显示控制面板
  • 启用用户交互功能

技术架构说明

迁移学习实现方式

  • 基于预训练的 MobileNet 模型
  • 提取中间层特征
  • 在特征基础上训练自定义分类器

与 Teachable Machine 的区别

  • 更细粒度的参数控制
  • 完整的训练流程可见性
  • 更多自定义选项和灵活性

16-recording-examples

样本录制功能实现

鼠标事件处理

在 buttonsContainer 上添加 mousedown 事件监听器,避免用户需要重复点击录制按钮。通过事件对象判断是哪个按钮被按下:

  • 第一个按钮(左样本):调用 handleAddExample(0)
  • 第二个按钮(右样本):调用 handleAddExample(1)

标签系统通常使用数字索引(0, 1, 2...)而不是字符串。

handleAddExample 函数

const handleAddExample = async (labelIndex) => {
  mouseDown = true; // 设置鼠标按下状态

  // 获取UI元素用于更新总数显示
  const total = document.getElementById(labels[labelIndex] + "total");

  // 持续添加样本直到鼠标释放
  while (mouseDown) {
    await addExample(labelIndex);
    total.innerText = ++totals[labelIndex]; // 更新计数显示
    await tf.nextFrame(); // 等待下一帧,避免阻塞浏览器
  }
};

重要:必须添加 mouseup 事件监听器将 mouseDown 设置为 false,否则会陷入无限循环导致浏览器崩溃。

getImage 函数

const getImage = async () => {
  const image = await webcam.capture(); // 捕获摄像头图像

  // 使用tf.tidy进行内存管理
  const processedImage = tf.tidy(() => {
    return image
      .expandDims(0) // 扩展维度以匹配模型期望
      .toFloat() // 转换为浮点数
      .div(127.5) // 图像标准化
      .sub(1); // 进一步标准化
  });

  image.dispose(); // 释放原始图像内存
  return processedImage;
};

addExample 函数

const addExample = async (index) => {
  const image = await getImage();

  // 使用初始模型进行预测获得特征
  const example = initialModel.predict(image);

  // 创建one-hot编码标签
  const y = tf.tidy(() => {
    return tf.oneHot(tf.tensor1d([index]).toInt(), labels.length);
  });

  // 数据存储逻辑
  if (xs == null) {
    // 第一次添加样本
    xs = tf.keep(example);
    xy = tf.keep(y);
  } else {
    // 后续样本与现有数据连接
    const previousX = xs;
    xs = tf.keep(tf.concat([previousX, example], 0));

    const previousY = xy;
    xy = tf.keep(tf.concat([previousY, y], 0));

    // 释放中间变量内存
    previousX.dispose();
    previousY.dispose();
  }

  // 清理内存
  y.dispose();
  image.dispose();
};

关键概念

  • tf.tidy(): 自动内存管理,防止内存泄漏
  • tf.keep(): 保持张量在内存中用于后续操作
  • tf.concat(): 连接张量数据
  • tf.nextFrame(): 等待 requestAnimationFrame,保持 UI 响应性
  • dispose(): 手动释放张量内存

17-tensorflow-model-layers

模型训练与层结构

训练按钮事件处理

trainButton.onclick = async () => {
  await train();
  statusElement.innerHTML = "Training..."; // 更新UI状态
};

train 函数结构

const train = async () => {
  isTraining = true;

  // 错误检查
  if (xs == null) {
    throw new Error("请在训练前添加样本");
  }

  // 创建模型...
};

重要提醒:必须在 buttonsContainer 上添加 mouseup 事件监听器:

buttonsContainer.onmouseup = () => {
  mouseDown = false; // 防止无限循环
};

Sequential 模型创建

newModel = tf.sequential({
  layers: [
    // 第一层:展平层,处理MobileNet输出
    tf.layers.flatten({
      inputShape: initialModel.outputs[0].shape.slice(1),
    }),

    // 第二层:密集层
    tf.layers.dense({
      units: denseUnits, // 100个单元
      activation: "relu", // ReLU激活函数
      kernelInitializer: "varianceScaling",
      useBias: true,
    }),

    // 第三层:输出层
    tf.layers.dense({
      units: labels.length, // 输出单元数等于标签数
      activation: "softmax", // Softmax激活函数用于分类
      kernelInitializer: "varianceScaling",
      useBias: true,
    }),
  ],
});

层参数说明

  • units: 神经元数量,最后一层应等于分类数量
  • activation: 激活函数
    • relu: 修正线性单元,常用于隐藏层
    • softmax: 用于多分类输出层,输出概率分布
  • kernelInitializer: 权重初始化方法
    • varianceScaling: 方差缩放初始化
  • useBias: 是否使用偏置向量

模型架构理解

  1. 输入层: 处理 MobileNet 特征提取结果
  2. 隐藏层: 100 个神经元的全连接层,使用 ReLU 激活
  3. 输出层: 2 个神经元(对应左右两个类别),使用 Softmax 输出概率

参数调优建议

可以尝试修改不同参数观察效果:

  • 改变激活函数
  • 调整神经元数量
  • 修改初始化方法
  • 添加更多层

这是参数调优的基础,通过实验找到最佳配置。

18-optimizing-batching-training

模型编译与训练优化

优化器创建

const optimizer = tf.train.adam(learningRate);

Adam 优化器是一种自适应学习率优化算法。TensorFlow.js 还提供其他优化器:

  • sgd: 随机梯度下降
  • momentum: 动量梯度下降
  • rmsprop: RMSprop 优化器

学习率控制模型权重更新的频率和幅度。

模型编译

newModel.compile({
  optimizer: optimizer,
  loss: "categoricalCrossentropy",
});

损失函数说明:

  • categoricalCrossentropy: 分类交叉熵,用于多分类问题
  • 损失值越低表示模型准确性越高
  • 训练目标是最小化损失函数

批处理设置

const batchSize = Math.floor(xs.shape[0] * batchSizeFraction);

批处理将训练数据分成小批次:

  • 减少内存使用
  • 提高训练稳定性
  • 加快训练速度
  • batchSizeFraction 控制批次大小比例

模型训练

await newModel.fit(xs, xy, {
  batchSize: batchSize,
  epochs: epochs,
  callbacks: {
    onBatchEnd: async (batch, logs) => {
      statusElement.innerHTML = `Loss: ${logs.loss.toFixed(5)}`;
    },
  },
});

训练参数说明:

  • xs: 训练特征数据
  • xy: 训练标签数据
  • batchSize: 每批处理的样本数
  • epochs: 训练轮数,完整遍历数据集的次数
  • callbacks: 回调函数,监控训练进度

训练监控

onBatchEnd 回调函数在每个批次结束后执行:

  • 实时显示损失值
  • 监控训练进度
  • 损失值应该逐渐下降

训练完成处理

isTraining = false; // 标记训练结束

模型创建流程总结

  1. 定义架构: 选择模型类型和层结构
  2. 添加层: 配置各层参数
  3. 编译模型: 设置优化器和损失函数
  4. 数据准备: 分批处理训练数据
  5. 训练模型: 使用 fit 方法进行训练
  6. 监控进度: 通过回调函数跟踪训练状态

这个过程比预训练模型复杂,但提供了更多控制和定制能力。

19-predictions-from-live-webcam

实时预测功能实现

预测按钮事件处理

predictButton.onclick = async () => {
  isPredicting = true;

  while (isPredicting) {
    const image = await getImage(); // 获取摄像头图像

    // 两阶段预测过程
    const initialModelPrediction = initialModel.predict(image);
    const predictions = newModel.predict(initialModelPrediction);

    // 提取预测结果
    const predictedClass = predictions.as1d().argMax();
    const classId = await predictedClass.data();

    // 输出预测标签
    console.log(labels[classId[0]]);

    // 内存清理
    image.dispose();

    await tf.nextFrame(); // 等待下一帧
  }
};

两阶段预测流程

  1. 第一阶段: 使用 MobileNet 层提取特征
    • initialModel.predict(image) 返回特征嵌入(embeddings)
    • 不是最终分类标签,而是特征表示
  2. 第二阶段: 使用自定义模型分类
    • newModel.predict() 基于特征进行分类
    • 输出每个类别的概率分布

预测结果处理

const predictedClass = predictions.as1d().argMax();
const classId = await predictedClass.data();
console.log(labels[classId[0]]);

方法说明:

  • as1d(): 将预测结果转换为一维张量
  • argMax(): 找到概率最高的类别索引
  • data(): 异步获取张量数据
  • classId[0]: 获取预测类别的索引

实际演示效果

  • 左倾头部: 显示 "left"
  • 右倾头部: 显示 "right"
  • 实时响应动作变化
  • 损失值从 0.8 降至 0.00010,表明训练效果良好

内存管理

持续预测过程中必须进行内存清理:

image.dispose(); // 释放图像内存
await tf.nextFrame(); // 防止阻塞UI

模型保存功能

// 可以保存训练好的模型
await newModel.save("localstorage://my-model");

模型保存的优势:

  • 避免每次刷新重新训练
  • 保持用户创建的样本
  • 提升用户体验

与 Teachable Machine 的比较

这个实现本质上与 Teachable Machine 相同:

  • 都使用迁移学习
  • 都基于 MobileNet 特征提取
  • 区别在于 UI 包装和用户体验

注意事项

  • 刷新页面会丢失模型,需要重新训练
  • 可以通过模型保存功能解决持久化问题
  • 实时预测消耗较多计算资源

这个功能展示了完整的机器学习工作流:数据收集、模型训练、实时推理。

20-image-detector-project-setup

自定义图像分类器项目

项目目标与意义

创建完全自定义的机器学习模型,包括:

  1. 自建数据集: 体验数据集创建的复杂性和时间成本
  2. 从零训练: 不依赖预训练模型,完全自主训练
  3. 许可证考虑: 避免商业使用中的开源许可证问题

项目设置

导航到 exercises/project-three 文件夹:

npm run watch # 启动Python服务器,端口1234

访问 localhost:1234 选择 public 文件夹。

界面功能设计

  • Canvas 画布: 用于绘制形状
  • Clear 按钮: 清空画布内容
  • Predict 按钮: 使用训练好的模型进行预测
  • Download 按钮: 下载绘制的图像作为训练数据

工作流程

  1. 数据收集: 绘制不同形状并下载为图像
  2. 数据标注: 根据文件名进行标签分类
  3. 模型训练: 使用收集的数据训练分类器
  4. 模型测试: 在 UI 中测试预测效果

HTML 结构

<canvas id="canvas"></canvas>
<div id="prediction"></div>
<!-- 用于显示预测结果 -->
<button id="clear">Clear</button>
<button id="predict">Predict</button>
<button id="download">Download</button>

导入工具文件

在 index.html 中添加:

<script src="utils.js" type="module"></script>
<script src="index.js" type="module"></script>

Canvas 绘图功能(utils.js)

包含基础的 Canvas 绘图事件处理:

  • mousedown: 开始绘图
  • mousemove: 持续绘制
  • mouseup: 结束绘图
  • 坐标获取和重绘功能

数据集创建策略

  • 计划创建 40-50 个样本每类
  • 使用手绘图形而非照片
  • 避免版权和许可证问题
  • 体验真实的数据集构建过程

与前两个项目的区别

  1. 预训练模型: 直接使用现成模型
  2. 迁移学习: 在预训练模型基础上添加自定义层
  3. 完全自定义: 从零开始创建模型和数据

下一步计划

  • 实现 Canvas 清空功能
  • 添加图像下载功能
  • 收集训练数据
  • 构建和训练自定义模型
  • 集成预测功能

这个项目将完整展示机器学习的端到端流程,从数据收集到模型部署。

21-clearing-the-drawing

Canvas 清空功能实现

获取清空按钮元素

const clearButton = document.getElementById("clear-button");

添加点击事件监听器

clearButton.onclick = () => {
  // 清空功能逻辑
};

从 utils 文件导入清空函数

需要在文件顶部导入:

import { resetCanvas, clearRect } from "./utils.js";

注意文件扩展名必须是.js,否则导入会失败。

实现清空功能

clearButton.onclick = () => {
  resetCanvas(); // 重置画布状态

  // 清空预测结果显示
  const predictionParagraph = document.getElementsByClassName("prediction")[0];
  predictionParagraph.textContent = ""; // 清空预测文本

  clearRect(); // 清空画布绘制内容
};

预测结果清空

虽然目前还没有实现预测功能,但提前准备清空预测结果的逻辑:

  • 获取 class 为'prediction'的 p 标签元素
  • 使用getElementsByClassName()方法,取第一个元素[0]
  • textContent设置为空字符串

功能测试

测试步骤:

  1. 在画布上绘制图形(如正方形)
  2. 点击 Clear 按钮
  3. 画布应该完全清空
  4. 预测结果区域也应该被清空

代码组织建议

如果决定重写 utils 文件,可以考虑:

  • 将相关功能合并到一个函数中
  • 使用不同的函数命名方式
  • 将所有功能整合到一个文件中

这个功能为后续的图像下载和预测功能做好了准备,确保用户可以清空画布重新绘制。

22-collecting-shape-training-data

训练数据收集流程

数据收集计划

  • 选择两种不同的形状:三角形和圆形
  • 每种形状绘制 40 个样本
  • 总计 80 个训练样本
  • 预计耗时 10-15 分钟

文件夹结构设置

在项目根目录创建 data 文件夹:

project-root/
├── public/
└── data/          # 新建文件夹

样本命名规范

文件命名格式:序号-形状名.png

0-triangle.png
1-triangle.png
2-triangle.png
...
0-circle.png
1-circle.png
2-circle.png
...

数据收集操作流程

  1. 在画布上绘制形状
  2. 点击 Download 按钮下载图像
  3. 重命名文件(例如:0-triangle.png
  4. 将文件移动到 data 文件夹
  5. 清空画布,重复以上步骤

数据集分割

创建训练集和测试集子文件夹:

data/
├── train/         # 训练集文件夹
├── test/          # 测试集文件夹
├── 0-triangle.png
├── 1-triangle.png
└── ...

数据分割比例

推荐的数据分割策略:

  • 训练集: 80%的数据(每类 30 个样本)
  • 测试集: 20%的数据(每类 10 个样本)

具体操作:

  1. 选择 0-29 号文件复制到 train 文件夹
  2. 选择 30-39 号文件复制到 test 文件夹

数据质量注意事项

  • 绘制过程中保持形状的一致性
  • 避免疲劳导致的形状变形
  • 圆形不要画成正方形

数据增强建议

如果需要更多训练数据,可以使用 Python 脚本进行数据增强:

# 数据变换技术
- 旋转:每隔19度旋转图像
- 翻转:水平和垂直翻转
- 缩放:轻微的尺寸变化

通过数据增强,80 个原始样本可以扩展到几百个样本。

训练集和测试集的作用

  • 训练集: 用于模型学习和权重更新
  • 测试集: 用于评估模型在未见过数据上的性能
  • 测试集必须与训练过程完全独立

数据加载准备

文件组织完成后,数据结构应该是:

data/
├── train/
│   ├── 0-triangle.png
│   ├── 1-triangle.png
│   └── ...
└── test/
    ├── 30-triangle.png
    ├── 31-triangle.png
    └── ...

这种结构便于后续的数据加载和标签提取,文件名中的形状名称将用作自动标签生成。

23-building-training-testing-datasets

Node.js 数据加载实现

必需模块导入

const tf = require("@tensorflow/tfjs-node-gpu");
const fs = require("fs");
const path = require("path");

数据路径配置

const trainImagesDir = "./data/train";
const testImagesDir = "./data/test";

数据容器初始化

let trainData;
let testData;

图像加载核心函数

const loadImages = (dataDirectory) => {
  const images = [];
  const labels = [];

  // 读取目录中的所有文件
  let files = fs.readdirSync(dataDirectory);

  // 遍历每个文件
  for (let i = 0; i < files.length; i++) {
    const filePath = path.join(dataDirectory, files[i]);
    const buffer = fs.readFileSync(filePath);

    // 处理图像数据...
  }

  return [images, labels];
};

图像张量转换

const imageTensor = tf.node
  .decodeImage(buffer)
  .resizeNearestNeighbor([28, 28]) // 调整到28x28尺寸
  .expandDims(0); // 添加批次维度

images.push(imageTensor);

重要参数说明

  • 28x28: 目标图像尺寸,必须与模型期望尺寸一致
  • resizeNearestNeighbor: 最近邻插值调整大小方法
  • expandDims: 添加维度以匹配模型输入要求

标签提取和处理

// 检测文件类型
const circle = files[i].toLowerCase().endsWith("circle.png");
const triangle = files[i].toLowerCase().endsWith("triangle.png");

// 分配数值标签
if (circle) {
  labels.push(0); // 圆形标记为0
} else if (triangle) {
  labels.push(1); // 三角形标记为1
}

数据加载主函数

const loadData = () => {
  console.log("Loading images...");
  trainData = loadImages(trainImagesDir);
  testData = loadImages(testImagesDir);
  console.log("Images loaded successfully");
};

训练数据处理器

const getTrainData = () => {
  return {
    images: tf.concat(trainData[0]), // 连接所有图像张量
    labels: tf.oneHot(
      tf.tensor1d(trainData[1], "float32"), // 创建一维标签张量
      2 // 类别数量
    ),
  };
};

测试数据处理器

const getTestData = () => {
  return {
    images: tf.concat(testData[0]),
    labels: tf.oneHot(tf.tensor1d(testData[1], "float32"), 2),
  };
};

One-Hot 编码说明

// 原始标签: [0, 1, 0, 1]
// One-Hot编码后:
// 0 -> [1, 0]
// 1 -> [0, 1]

One-Hot 编码将类别标签转换为模型可以处理的数值格式。

模块导出

module.exports = {
  loadData,
  getTrainData,
  getTestData,
};

数据流水线总结

  1. 文件读取: 使用 fs.readFileSync 读取图像文件
  2. 图像解码: tf.node.decodeImage 解码图像缓冲区
  3. 尺寸调整: 统一调整到 28x28 像素
  4. 标签提取: 根据文件名自动生成标签
  5. 张量转换: 转换为 TensorFlow.js 张量格式
  6. 数据组织: 分离训练集和测试集

调试验证

node getData.js  # 查看缓冲区数据格式

这个数据加载流水线将手绘图像转换为机器学习模型可以使用的张量格式,为后续的模型训练做好准备。

24-image-model-layers

自定义 CNN 模型构建

模块导入和参数设置

const tf = require("@tensorflow/tfjs");

// 模型参数配置
const kernelSize = [3, 3]; // 卷积核大小,3x3窗口
const filters = 32; // 过滤器数量
const numClasses = 2; // 分类数量(圆形和三角形)

卷积核(Kernel)概念

卷积核是一个小的矩阵(如 3x3),在输入数据上滑动:

  • 作用: 特征提取和模式识别
  • 过程: 将大尺寸数据转换为更小的特征表示
  • 输出: 每个卷积核产生一个特征图

过滤器(Filters)参数

const filters = 32; // 32个不同的特征检测器
  • 每个过滤器学习识别不同的特征模式
  • 更多过滤器 = 更丰富的特征表示
  • 可以尝试不同数值(10, 20, 64 等)观察效果

Sequential 模型创建

const model = tf.sequential();

Sequential 模型特点:

  • 层按顺序连接
  • 数据从第一层流向最后一层
  • 适合大多数标准的深度学习架构

第一层:2D 卷积层

model.add(
  tf.layers.conv2d({
    inputShape: [28, 28, 4], // 输入形状:宽x高x通道
    filters: filters, // 过滤器数量
    kernelSize: kernelSize, // 卷积核大小
    activation: "relu", // ReLU激活函数
  })
);

输入形状说明

  • 28x28: 图像尺寸(与数据预处理一致)
  • 4: RGBA 通道(红、绿、蓝、透明度)

第二层:最大池化层

model.add(
  tf.layers.maxPooling2d({
    poolSize: [2, 2], // 2x2池化窗口
  })
);

最大池化作用:

  • 减少数据维度
  • 保留重要特征
  • 防止过拟合
  • 减少计算量

第三层:展平层

model.add(tf.layers.flatten());

展平层必要性:

  • 将 2D 特征图转换为 1D 向量
  • 为全连接层准备输入
  • 连接 CNN 和全连接网络部分

第四层:全连接层

model.add(
  tf.layers.dense({
    units: 10, // 神经元数量
    activation: "relu", // 非最后层使用ReLU
  })
);

第五层:输出层

model.add(
  tf.layers.dense({
    units: numClasses, // 输出单元数 = 分类数
    activation: "softmax", // 最后层使用Softmax
  })
);

激活函数选择原则

  • 隐藏层: ReLU(修正线性单元)
  • 输出层: Softmax(多分类概率分布)
  • 替代选择: Sigmoid(二分类)

模型编译配置

const optimizer = tf.train.adam(0.0001); // Adam优化器

model.compile({
  optimizer: optimizer,
  loss: "categoricalCrossentropy", // 分类交叉熵损失
  metrics: ["accuracy"], // 监控准确率指标
});

学习率选择

const learningRate = 0.0001; // 学习率参数
  • 较小值:学习稳定但缓慢
  • 较大值:学习快速但可能不稳定
  • 需要根据训练效果调整

模型导出

module.exports = model;

层设计策略讨论

关于层设计的方法:

  • 经验法则: 最后层单元数 = 分类数量
  • 试验方法: 尝试不同层数和参数组合
  • 性能导向: 根据准确率调整架构
  • 资源考虑: 平衡模型复杂度和计算需求

参数调优建议

可以尝试调整的参数:

  1. 过滤器数量: 16, 32, 64
  2. 卷积核大小: [3,3], [5,5]
  3. 全连接层单元: 5, 10, 20
  4. 学习率: 0.001, 0.0001, 0.00001
  5. 激活函数: relu, sigmoid, tanh

模型架构总结

输入(28x28x4)
    ↓
Conv2D(32个3x3卷积核) + ReLU
    ↓
MaxPooling2D(2x2)
    ↓
Flatten
    ↓
Dense(10单元) + ReLU
    ↓
Dense(2单元) + Softmax
    ↓
输出(2类概率)

这个 5 层 CNN 架构适合小规模图像分类任务,在有限数据集上应该能获得合理的分类性能。

25-training-the-model-with-image-data

模型训练实现

必需模块和函数导入

const tf = require("@tensorflow/tfjs-node-gpu");
const { loadData, getTrainData, getTestData } = require("./getData");
const model = require("./createModel");

训练主函数结构

const train = async () => {
  // 加载数据
  loadData();

  // 获取训练数据和标签
  const trainImages = getTrainData().images;
  const trainLabels = getTrainData().labels;

  // 模型训练...
};

模型训练配置

await model.fit(trainImages, trainLabels, {
  epochs: 10, // 训练轮数
  batchSize: 5, // 批次大小
  validationSplit: 0.2, // 验证集比例
});

训练参数说明

  • epochs: 完整遍历数据集的次数,更多轮次可能提高准确率但耗时更长
  • batchSize: 每次训练使用的样本数量,影响训练速度和内存使用
  • validationSplit: 从训练集中分出 20%作为验证集

验证集 vs 测试集区别

  • 验证集: 训练过程中使用,帮助调整参数和监控过拟合
  • 测试集: 训练完成后使用,评估模型在未见数据上的真实性能

模型评估实现

// 获取测试数据
const testImages = getTestData().images;
const testLabels = getTestData().labels;

// 评估模型性能
const evalOutput = model.evaluate(testImages, testLabels);

// 提取损失值和准确率
const loss = evalOutput[0].dataSync()[0].toFixed(3);
const accuracy = evalOutput[1].dataSync()[0].toFixed(3);

console.log(`Loss: ${loss}`);
console.log(`Accuracy: ${accuracy}`);

常见错误修复

数据类型错误修复:

// 在getData.js中修改one-hot编码
tf.tensor1d(trainData[1], "int32"); // 改为int32而非float32

训练结果分析

示例输出解读:

Epoch 1/10 - Accuracy: 0.490 - Loss: 0.693
Epoch 2/10 - Accuracy: 0.490 - Loss: 0.693
...

性能分析

  • 准确率 0.49 表示低于随机猜测水平
  • 损失值应该随训练下降
  • 如果数值不变可能存在学习问题

训练效果改进策略

  1. 增加数据量: 每类样本从 40 增加到更多
  2. 调整网络架构: 修改层数或神经元数量
  3. 优化超参数: 调整学习率、批次大小
  4. 数据增强: 旋转、翻转、缩放图像
  5. 多次训练: 由于随机初始化,结果可能不同

模型不确定性

机器学习模型的特点:

  • 每次训练结果可能不同(权重随机初始化)
  • 有时准确率突然提升到 0.9+
  • 需要多次实验找到最佳配置
  • 不是确定性函数,相同输入可能产生不同输出

这个训练过程展示了深度学习的实验性质,需要通过多次迭代和参数调优来获得满意的结果。

26-saving-the-model-data

模型保存与文件分析

模型保存实现

await model.save("file://./public/model");

保存路径说明

  • file://前缀表示保存到本地文件系统
  • 路径./public/model会自动创建 model 文件夹
  • 保存在 public 文件夹便于浏览器访问

生成的模型文件结构

public/
└── model/
    ├── model.json     # 模型架构和元数据
    └── weights.bin    # 模型权重(二进制文件)

model.json 文件内容分析

{
  "modelTopology": {
    "class_name": "Sequential",
    "config": {
      "layers": [
        {
          "class_name": "Conv2D",
          "config": {
            "filters": 32,
            "kernel_size": [3, 3],
            "activation": "relu",
            "input_shape": [28, 28, 4]
          }
        },
        {
          "class_name": "MaxPooling2D",
          "config": {
            "pool_size": [2, 2]
          }
        }
        // 其他层配置...
      ]
    }
  }
}

模型配置元素分析

  • 手动设置的参数: filters、kernel_size、activation 等
  • 默认参数: variance_scaling 初始化、strides 等
  • 层级结构: 完整的 Sequential 模型架构

权重文件重要性

weights.bin 文件包含:

  • 训练后的网络权重
  • 每层的学习参数
  • 模型的"记忆"和知识

权重 vs 架构

  • 架构定义了模型结构(如道路图)
  • 权重是训练得到的参数(如实际驾驶经验)
  • 仅有架构无法进行预测,需要训练后的权重

训练结果的随机性

观察到的现象:

第一次训练: Accuracy: 0.5
第二次训练: Accuracy: 0.9
第三次训练: Accuracy: 0.95

机器学习的"黑盒"特性

  • 难以解释模型决策过程
  • 随机权重初始化导致结果不确定
  • 有时需要多次训练获得满意结果

实际应用考虑

在生产环境中:

  • 需要能解释模型决策的方法
  • 要求一致和可预测的性能
  • 通常进行多次训练选择最佳模型

准确率提升观察

训练过程中准确率变化:

Epoch 1: 0.653 (起始值较高)
Epoch 2: 0.621 (轻微下降)
...
最终: 0.95+ (显著提升)

模型选择策略

  • 保存高准确率的模型版本
  • 创建多个模型文件夹对比性能
  • 在 UI 测试前选择最佳表现的模型

下一步计划

模型保存完成后:

  1. 在浏览器中加载模型
  2. 实现实时预测功能
  3. 根据实际表现调整参数
  4. 优化模型架构和训练过程

模型保存标志着训练阶段的完成,为部署到 Web 应用奠定了基础。

27-model-summary

模型架构可视化分析

恭喜创建第一个自定义模型

成就总结:

  • 完成了从零开始的机器学习模型构建
  • 使用自制数据集训练模型
  • 即使准确率为 0.5,这仍是重要的学习里程碑

数据量与准确率关系

当前限制:

  • 每类仅有 40 个样本
  • 工业级模型通常需要数千至数百万样本
  • 样本量不足是影响准确率的主要因素

模型总结功能

model.summary();

在 train-drawings.js 中添加此行,运行时会输出:

层级架构可视化表格

Layer (type)          Output Shape      Param #
=================================================
conv2d (Conv2D)      (null, 26, 26, 32)   1056
max_pooling2d        (null, 13, 13, 32)   0
flatten (Flatten)    (null, 5408)         0
dense (Dense)        (null, 10)           54090
dense_1 (Dense)      (null, 2)            22
=================================================

形状变换分析

第一层变换

  • 输入: (28, 28, 4)
  • 输出: (26, 26, 32)
  • 32 对应 filters 参数
  • 26×26 是 3×3 卷积核作用后的尺寸减小

池化层变换

  • 输入: (26, 26, 32)
  • 输出: (13, 13, 32)
  • 2×2 池化窗口将尺寸减半

展平层变换

  • 输入: (13, 13, 32)
  • 输出: (5408,)
  • 计算: 13 × 13 × 32 = 5408

全连接层变换

  • 第一个 Dense: 5408 → 10 (units 参数)
  • 最后 Dense: 10 → 2 (分类数量)

参数数量理解

  • Conv2D 层: 1056 个可训练参数
  • 池化层: 0 个参数(仅计算操作)
  • Dense 层: 54090 个参数(权重矩阵)
  • 输出层: 22 个参数

模型优化指导意义

summary 表格用途:

  1. 验证架构正确性: 确认输入输出形状匹配
  2. 识别瓶颈: 找出参数过多或过少的层
  3. 指导调优: 基于形状变化调整层配置
  4. 内存估算: 了解模型复杂度

数据流可视化

28×28×4 图像
    ↓ Conv2D + ReLU
26×26×32 特征图
    ↓ MaxPooling2D
13×13×32 池化特征
    ↓ Flatten
5408维向量
    ↓ Dense(10) + ReLU
10维特征
    ↓ Dense(2) + Softmax
2类概率分布

调试和优化价值

model.summary()帮助:

  • 理解数据在网络中的流动
  • 识别维度不匹配问题
  • 评估模型复杂度
  • 为架构改进提供依据

这个可视化工具让抽象的神经网络变得具体可见,是理解和优化模型的重要手段。

28-using-the-model-in-the-browser

浏览器端模型集成

模型文件准备

确保模型保存结构:

public/
└── model/
    ├── model.json
    └── weights.bin

前端模型加载设置

let model;
const modelPath = "./model/model.json";

const loadModel = async (path) => {
  model = await tf.loadLayersModel(path);
};

重要提醒: 必须将加载的模型赋值给 model 变量,否则 predict 时会出错。

预测按钮事件处理

const predictButton = document.getElementById("check-button");

predictButton.onclick = () => {
  const canvas = getCanvas(); // 从utils导入
  const drawing = canvas.toDataURL(); // 转换为数据URL

  // 创建图像元素用于预测
  const newImg = document.getElementsByClassName("imageToCheck")[0];
  newImg.src = drawing;

  // 图像加载完成后进行预测
  newImg.onload = () => {
    predict(newImg);
    resetCanvas(); // 清空画布准备下次绘制
  };
};

Canvas 到 Image 转换

数据流程:

  1. Canvas 绘制 → toDataURL() → 数据 URL 字符串
  2. 数据 URL → 设置为 img.src → 图像元素
  3. 图像加载完成 → onload 触发 → 开始预测

图像预处理流水线

const predict = async (img) => {
  // 1. 设置图像尺寸
  img.width = 200;
  img.height = 200;

  // 2. 转换为像素张量
  const processedImg = await tf.browser.fromPixelsAsync(img, 4); // RGBA通道

  // 3. 调整尺寸匹配训练时的28x28
  const resizedImg = tf.image.resizeNearestNeighbor(processedImg, [28, 28]);

  // 4. 类型转换为float32
  const updatedImg = tf.cast(resizedImg, "float32");

  // 5. 重塑形状添加批次维度
  const predictions = await model
    .predict(tf.reshape(updatedImg, [1, 28, 28, 4]))
    .data();

  // 6. 提取预测结果
  const label = predictions.indexOf(Math.max(...predictions));
  displayPrediction(label);
};

关键预处理步骤说明

  • fromPixelsAsync(img, 4): 提取 RGBA 四通道像素数据
  • resizeNearestNeighbor: 调整到 28×28 匹配训练尺寸
  • cast('float32'): 确保数据类型一致性
  • reshape([1, 28, 28, 4]): 添加批次维度,匹配模型输入要求

预测结果处理

// 找到概率最高的类别索引
const label = predictions.indexOf(Math.max(...predictions));

// 显示预测结果
displayPrediction(label); // utils函数,更新UI显示

常见问题修复

模型未定义错误

// 错误:忘记赋值
const loadModel = async (path) => {
  await tf.loadLayersModel(path); // model仍为undefined
};

// 正确:记得赋值
const loadModel = async (path) => {
  model = await tf.loadLayersModel(path);
};

实际测试效果

  • 95%准确率模型表现良好
  • 三角形和圆形识别基本正确
  • 偶尔出现误判(5%错误率范围内)
  • 未知形状会被归类到最相似的已知类别

模型限制认知

  • 只能识别训练过的形状(圆形、三角形)
  • 绘制其他形状(蜡烛等)会强制分类到已知类别
  • 扩展识别范围需要重新收集数据和训练

性能优化建议

如果准确率不满意:

  1. 增加训练样本数量
  2. 调整网络架构(更多层或神经元)
  3. 修改训练参数(epochs、batch size)
  4. 尝试数据增强技术

完整工作流程验证

成功实现了端到端的机器学习管道:

  1. 数据收集(手绘图形)
  2. 数据预处理(调整尺寸、格式转换)
  3. 模型训练(自定义 CNN 架构)
  4. 模型保存(本地文件系统)
  5. 浏览器加载(TensorFlow.js)
  6. 实时预测(Canvas 绘制识别)

这展示了完整的自定义机器学习项目开发流程。

29-wrapping-up

课程总结与扩展资源

项目完成总结

通过这个项目,你已经完整体验了自定义机器学习模型的开发流程:

  • 数据收集和预处理
  • 模型架构设计
  • 训练和优化过程
  • 浏览器端部署应用

核心概念回顾

即使使用外部数据集,仍需要:

  • 图像尺寸调整和标准化
  • 张量形状适配
  • 大量试验和错误调试过程

扩展应用案例分享

Air Street Fighter 项目

  • 使用相同的机器学习流程
  • 收集手势动作的实时数据而非图像
  • 实验不同张量维度(1D vs 2D)
  • 通过参数调优获得满意准确率

音频识别应用潜力

音频转图像识别方法:

  1. 使用 getUserMedia 捕获音频
  2. Web Audio API 处理音频数据
  3. 转换为频谱图(spectrogram)图像
  4. 应用本课程的图像分类技术
  5. 实现语音指令识别("hello"、"thanks"等)

脑机接口实验项目

使用脑电传感器的创新应用:

  • 收集大脑神经活动实时数据
  • 识别眨眼动作的脑电信号特征
  • 可视化大脑活动数据验证信号差异
  • 训练模型区分不同面部表情
  • 探索左眼/右眼眨眼的区别识别

生产环境应用实例

LinkedIn 的 TensorFlow.js 应用

  • 大型社交平台采用 TensorFlow.js
  • 主要用于性能优化相关功能
  • 证明了技术的企业级可行性

Adobe Photoshop Web 增强

  • Adobe 在浏览器版 Photoshop 中集成 TensorFlow.js
  • 提升 Web 端图像处理能力
  • 行业标杆工具的技术选择

预测性预加载技术

  • 利用 Google Analytics 导航数据
  • 训练用户行为预测模型
  • 提前预加载可能访问的页面资源
  • 显著改善网站性能体验

学习资源推荐

WebML 社区 Newsletter

  • Jason Mayes(Google TensorFlow.js 开发倡导团队)编写
  • 月度更新社区项目和技术进展
  • 保持与前沿发展的同步

YouTube 技术系列

虽然可能已停更,但仍有价值的内容:

  • 放射学图像分割技术
  • 强化学习应用
  • 社区开源项目展示
  • 高级应用案例分析

继续学习建议

基于 workshop 获得的知识基础:

  1. 实验不同数据类型: 音频、传感器数据、时间序列
  2. 探索高级架构: 更多层数、不同激活函数
  3. 参考官方文档: 尝试未使用过的层类型
  4. 社区项目学习: 研究开源实现案例
  5. 实际项目应用: 将技术应用到工作或个人项目中

技术发展方向

你现在已掌握 TensorFlow.js 的三大核心应用:

  1. 预训练模型使用: 快速集成现有能力
  2. 迁移学习: 在预训练基础上定制化
  3. 完全自定义: 从零构建专用解决方案

最终鼓励

希望你能:

  • 将所学知识应用到实际项目中
  • 持续实验和创新
  • 在实际应用中遇见你们的创新项目
  • 推动 Web 端机器学习技术的边界

这个 workshop 为你打开了 Web 端 AI 应用的大门,未来的创新完全取决于你的想象力和实践。