Gibbs in the Shader

Probabilistic Inference, Gibbs and the task of denoising
Probabilistic inference is about making educated guesses with incomplete
information. Gibbs sampling is one such thing; it incrementally updates the guess based on available information, making small changes based on context while fixing other variables.
Now let's consider image denoising: when an image contains noise, missing values, or other imperfections, our goal is to infer the original clean image.
We begin with a fundamental assumption: each pixel's true value depends on both its noisy observation and its neighboring pixels. We can treat each pixel as a random variable whose value is a function of its noisy observation and its surrounding context. By fixing the neighboring pixels, we can estimate the true value of the target pixel.
Applying this process iteratively to all pixels in the image constitutes Gibbs sampling.
After multiple iterations, we gradually converge toward a denoised image. The final result represents our estimate—a probable approximation of the true underlying image.
From Sequential to Parallel Updates
Gibbs sampling, in its vanilla form, processes pixels sequentially, which makes it inefficient, especially for images where the complexity is determined by the image size (col×row) which can quickly become prohibitive.
// Sequential update
for (let i = 0; i < totalPixels; i++) {
updatePixel(i); // Depends on previous state
}
For parallel processing, we need a strategy where multiple pixels can be updated simultaneously without breaking the Markov property that underpins Gibbs sampling. The key insight is that non-adjacent pixels not present in each other's Markov blanket can be updated in parallel because each update only depends on its immediate neighbors.
Our implementation divides the image into 3×3 blocks and processes all pixels at the same relative position within these blocks in parallel. For each iteration, we:
- Select one of the nine positions in the 3×3 pattern
- Update all pixels at this position across the image simultaneously
- Repeat for all nine positions to complete one full iteration
// Parallel pattern selection
const gridStep = step % 9; // Cycle through 9 positions (0-8) in 3×3 grid
const gridX = gridStep % 3; // X position within 3×3 grid (0, 1, or 2)
const gridY = Math.floor(gridStep / 3); // Y position within 3×3 grid (0, 1, or 2)
// Process all pixels at position (gridX, gridY) in their respective 3×3 blocks
for (let blockY = 0; blockY < rows; blockY += 3) {
for (let blockX = 0; blockX < cols; blockX += 3) {
const x = blockX + gridX;
const y = blockY + gridY;
// Skip if outside image bounds
if (x >= cols || y >= rows) continue;
// Update this pixel based on its neighbors
updatePixel(x, y); // All these positions can be updated in parallel
}
}
GPU Implementation
WebGL's parallel processing capabilities make it ideal for implementing this algorithm:
- The vertex shader creates a full-screen quad for processing
- The fragment shader executes in parallel for each pixel:
- Each shader instance determines if the current pixel matches the update position
- If matched, it samples neighboring pixels from the current state texture
- It then computes the weighted average of neighbors and original signal
- Finally, it adds a small random component to simulate the stochastic nature of Gibbs sampling
The fragment shader code handles the core computation:
// Fragment shader (simplified)
void main() {
// Determine if this pixel should be updated
vec2 gridPos = mod(gl_FragCoord.xy, 3.0);
int currentPos = int(gridPos.y) * 3 + int(gridPos.x);
if (currentPos != u_step) {
gl_FragColor = texture2D(u_state, uv);
return;
}
// Sample neighbors and compute new value
float meanNeighbors = calculateNeighborAverage();
float original = texture2D(u_original, uv).r;
float estimate = 0.7 * meanNeighbors + 0.3 * original;
float noise = (rand(gl_FragCoord.xy) * 2.0 - 1.0) * 0.05;
gl_FragColor = vec4(vec3(clamp(estimate + noise, 0.0, 1.0)), 1.0);
}
This implementation demonstrates how probabilistic algorithms can be efficiently parallelized on GPUs, making real-time Bayesian inference possible for image processing applications.
Conclusion
We've shown how to convert sequential Gibbs sampling to a parallel implementation. By identifying which pixels can be updated simultaneously without affecting statistical correctness, we transformed a sequential algorithm into one that works efficiently on parallel hardware.
This is probably a more general note not limited to just Gibbs sampling, and indeed we saw these kinds of tricks applied in a lot of things.