Files
brandly/src/utils/ChromaKeyShader.ts
T

250 lines
8.2 KiB
TypeScript

/**
* ChromaKeyShader — WebGL2-based chroma key processor
*
* Uses a GPU fragment shader for real-time background removal.
* Processes in YCbCr color space for more perceptually accurate
* color matching, and includes spill suppression + edge refinement.
*
* Falls back gracefully: caller should check `isSupported()` before use.
*/
const VERTEX_SHADER = `#version 300 es
in vec2 a_position;
in vec2 a_texCoord;
out vec2 v_texCoord;
void main() {
gl_Position = vec4(a_position, 0.0, 1.0);
v_texCoord = a_texCoord;
}
`;
const FRAGMENT_SHADER = `#version 300 es
precision mediump float;
in vec2 v_texCoord;
out vec4 outColor;
uniform sampler2D u_texture;
uniform vec3 u_keyColor; // Key color in RGB (0-1 range)
uniform float u_tolerance; // Inner tolerance radius (color distance)
uniform float u_softness; // Softness zone width
uniform float u_spillSuppress; // Spill suppression strength (0-1)
// Convert RGB to YCbCr for more perceptual color matching
vec3 rgbToYCbCr(vec3 rgb) {
float y = 0.299 * rgb.r + 0.587 * rgb.g + 0.114 * rgb.b;
float cb = -0.169 * rgb.r - 0.331 * rgb.g + 0.500 * rgb.b + 0.5;
float cr = 0.500 * rgb.r - 0.419 * rgb.g - 0.081 * rgb.b + 0.5;
return vec3(y, cb, cr);
}
void main() {
vec4 texColor = texture(u_texture, v_texCoord);
// Convert to YCbCr space — compare only chroma (Cb, Cr), ignore luminance (Y)
vec3 texYCbCr = rgbToYCbCr(texColor.rgb);
vec3 keyYCbCr = rgbToYCbCr(u_keyColor);
// Distance in chroma space only (ignore luminance for better results)
float chromaDist = distance(texYCbCr.yz, keyYCbCr.yz);
// Also factor in RGB distance for edge cases (pure white/black)
float rgbDist = distance(texColor.rgb, u_keyColor);
// Use the minimum of both distances for better detection of white/black
float dist = min(chromaDist * 2.0, rgbDist);
// Smoothstep for clean alpha transition
float alpha = smoothstep(u_tolerance, u_tolerance + u_softness, dist);
// Spill suppression — reduce key color influence on semi-transparent pixels
vec3 finalColor = texColor.rgb;
if (alpha > 0.0 && alpha < 1.0 && u_spillSuppress > 0.0) {
float spillMask = 1.0 - alpha;
finalColor = mix(texColor.rgb,
texColor.rgb + (texColor.rgb - u_keyColor) * spillMask * 0.5,
u_spillSuppress);
finalColor = clamp(finalColor, 0.0, 1.0);
}
outColor = vec4(finalColor, texColor.a * alpha);
}
`;
export interface ChromaKeyShaderParams {
keyColor: [number, number, number]; // RGB 0-255
tolerance: number; // 0-1 range (normalized distance)
softness: number; // 0-1 range
spillSuppress?: number; // 0-1 range, default 0.5
}
export class ChromaKeyShader {
private gl: WebGL2RenderingContext;
private program: WebGLProgram;
private texture: WebGLTexture;
private vao: WebGLVertexArrayObject;
// Uniform locations
private uTexture: WebGLUniformLocation;
private uKeyColor: WebGLUniformLocation;
private uTolerance: WebGLUniformLocation;
private uSoftness: WebGLUniformLocation;
private uSpillSuppress: WebGLUniformLocation;
constructor(canvas: HTMLCanvasElement) {
const gl = canvas.getContext('webgl2', {
alpha: true,
premultipliedAlpha: false,
preserveDrawingBuffer: true,
});
if (!gl) throw new Error('WebGL2 not supported');
this.gl = gl;
// Compile shaders
const vs = this.compileShader(gl.VERTEX_SHADER, VERTEX_SHADER);
const fs = this.compileShader(gl.FRAGMENT_SHADER, FRAGMENT_SHADER);
// Link program
this.program = gl.createProgram()!;
gl.attachShader(this.program, vs);
gl.attachShader(this.program, fs);
gl.linkProgram(this.program);
if (!gl.getProgramParameter(this.program, gl.LINK_STATUS)) {
throw new Error('Shader link failed: ' + gl.getProgramInfoLog(this.program));
}
// Get uniform locations
this.uTexture = gl.getUniformLocation(this.program, 'u_texture')!;
this.uKeyColor = gl.getUniformLocation(this.program, 'u_keyColor')!;
this.uTolerance = gl.getUniformLocation(this.program, 'u_tolerance')!;
this.uSoftness = gl.getUniformLocation(this.program, 'u_softness')!;
this.uSpillSuppress = gl.getUniformLocation(this.program, 'u_spillSuppress')!;
// Create texture
this.texture = gl.createTexture()!;
gl.bindTexture(gl.TEXTURE_2D, this.texture);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.LINEAR);
// Create fullscreen quad VAO
this.vao = gl.createVertexArray()!;
gl.bindVertexArray(this.vao);
// Position buffer (clip space quad)
const posBuf = gl.createBuffer()!;
gl.bindBuffer(gl.ARRAY_BUFFER, posBuf);
gl.bufferData(gl.ARRAY_BUFFER, new Float32Array([
-1, -1, 1, -1, -1, 1,
1, -1, 1, 1, -1, 1,
]), gl.STATIC_DRAW);
const aPos = gl.getAttribLocation(this.program, 'a_position');
gl.enableVertexAttribArray(aPos);
gl.vertexAttribPointer(aPos, 2, gl.FLOAT, false, 0, 0);
// Texcoord buffer (flip Y for video frames)
const tcBuf = gl.createBuffer()!;
gl.bindBuffer(gl.ARRAY_BUFFER, tcBuf);
gl.bufferData(gl.ARRAY_BUFFER, new Float32Array([
0, 1, 1, 1, 0, 0,
1, 1, 1, 0, 0, 0,
]), gl.STATIC_DRAW);
const aTex = gl.getAttribLocation(this.program, 'a_texCoord');
gl.enableVertexAttribArray(aTex);
gl.vertexAttribPointer(aTex, 2, gl.FLOAT, false, 0, 0);
gl.bindVertexArray(null);
// Enable alpha blending
gl.enable(gl.BLEND);
gl.blendFunc(gl.SRC_ALPHA, gl.ONE_MINUS_SRC_ALPHA);
}
/**
* Render a source (video or image) with chroma key applied.
* Updates the canvas size to match the source dimensions.
*/
render(
source: HTMLVideoElement | HTMLImageElement | HTMLCanvasElement,
params: ChromaKeyShaderParams
): void {
const gl = this.gl;
const canvas = gl.canvas as HTMLCanvasElement;
// Get source dimensions
const srcW = source instanceof HTMLVideoElement ? source.videoWidth : source.width;
const srcH = source instanceof HTMLVideoElement ? source.videoHeight : source.height;
if (srcW === 0 || srcH === 0) return;
// Resize canvas if needed
if (canvas.width !== srcW || canvas.height !== srcH) {
canvas.width = srcW;
canvas.height = srcH;
gl.viewport(0, 0, srcW, srcH);
}
// Upload source to texture
gl.bindTexture(gl.TEXTURE_2D, this.texture);
gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, source);
// Set uniforms
gl.useProgram(this.program);
gl.uniform1i(this.uTexture, 0);
gl.uniform3f(
this.uKeyColor,
params.keyColor[0] / 255,
params.keyColor[1] / 255,
params.keyColor[2] / 255
);
gl.uniform1f(this.uTolerance, params.tolerance);
gl.uniform1f(this.uSoftness, params.softness);
gl.uniform1f(this.uSpillSuppress, params.spillSuppress ?? 0.5);
// Clear and draw
gl.clearColor(0, 0, 0, 0);
gl.clear(gl.COLOR_BUFFER_BIT);
gl.bindVertexArray(this.vao);
gl.drawArrays(gl.TRIANGLES, 0, 6);
gl.bindVertexArray(null);
}
/**
* Check if WebGL2 is available in the current browser.
*/
static isSupported(): boolean {
try {
const canvas = document.createElement('canvas');
const gl = canvas.getContext('webgl2');
canvas.remove();
return !!gl;
} catch {
return false;
}
}
/**
* Cleanup all WebGL resources.
*/
dispose(): void {
const gl = this.gl;
gl.deleteTexture(this.texture);
gl.deleteProgram(this.program);
gl.deleteVertexArray(this.vao);
}
private compileShader(type: number, source: string): WebGLShader {
const gl = this.gl;
const shader = gl.createShader(type)!;
gl.shaderSource(shader, source);
gl.compileShader(shader);
if (!gl.getShaderParameter(shader, gl.COMPILE_STATUS)) {
const info = gl.getShaderInfoLog(shader);
gl.deleteShader(shader);
throw new Error('Shader compile failed: ' + info);
}
return shader;
}
}