Gibbs in the Shader

I've got to confess that the name is a knock-off reference Ghost in the Shell (攻殻機動隊 [Kōkaku Kidōtai], 1995)

Ghost in the Shell (1995) movie poster
. During a particularly strange moment, the term crossed my mind and I thought, hey, why not. And so here we are.

Probabilistic Inference, Gibbs and the task of denoising

Probabilistic inference is about making educated guesses with incomplete information. Gibbs sampling is one such thing; it incrementally updates the guess based on available information, making small changes based on context while fixing other variables.

Now let's consider image denoising: when an image contains noise, missing values, or other imperfections, our goal is to infer the original clean image.
We begin with a fundamental assumption: each pixel's true value depends on both its noisy observation and its neighboring pixels. We can treat each pixel as a random variable whose value is a function of its noisy observation and its surrounding context. By fixing the neighboring pixels, we can estimate the true value of the target pixel.

Applying this process iteratively to all pixels in the image constitutes Gibbs sampling.
After multiple iterations, we gradually converge toward a denoised image. The final result represents our estimate—a probable approximation of the true underlying image.



From Sequential to Parallel Updates

Gibbs sampling, in its vanilla form, processes pixels sequentially, which makes it inefficient, especially for images where the complexity is determined by the image size (col×row) which can quickly become prohibitive.

// Sequential update
for (let i = 0; i < totalPixels; i++) {
  updatePixel(i); // Depends on previous state
}

For parallel processing, we need a strategy where multiple pixels can be updated simultaneously without breaking the Markov property that underpins Gibbs sampling. The key insight is that non-adjacent pixels not present in each other's Markov blanket can be updated in parallel because each update only depends on its immediate neighbors.

Our implementation divides the image into 3×3 blocks and processes all pixels at the same relative position within these blocks in parallel. For each iteration, we:

  1. Select one of the nine positions in the 3×3 pattern
  2. Update all pixels at this position across the image simultaneously
  3. Repeat for all nine positions to complete one full iteration
// Parallel pattern selection
const gridStep = step % 9;  // Cycle through 9 positions (0-8) in 3×3 grid
const gridX = gridStep % 3; // X position within 3×3 grid (0, 1, or 2)
const gridY = Math.floor(gridStep / 3); // Y position within 3×3 grid (0, 1, or 2)

// Process all pixels at position (gridX, gridY) in their respective 3×3 blocks
for (let blockY = 0; blockY < rows; blockY += 3) {
    for (let blockX = 0; blockX < cols; blockX += 3) {
        const x = blockX + gridX;
        const y = blockY + gridY;
        
        // Skip if outside image bounds
        if (x >= cols || y >= rows) continue;
        
        // Update this pixel based on its neighbors
        updatePixel(x, y); // All these positions can be updated in parallel
    }
}
N₁ N₂ N₃ N₄ P N₅ N₆ N₇ N₈ P: Current pixel being sampled N₁-N₈: Neighboring pixels influencing P
The Gibbs sampling neighborhood pattern. Each pixel (P) is updated based on the values of its eight neighboring pixels (N₁-N₈), maintaining spatial correlation in the image.
Parallel update pattern: In each step, all pixels with the same position in their respective 3×3 blocks are updated simultaneously. Different shades represent the nine update positions processed in sequence.

GPU Implementation

WebGL's parallel processing capabilities make it ideal for implementing this algorithm:

The fragment shader code handles the core computation:

// Fragment shader (simplified)
void main() {
    // Determine if this pixel should be updated
    vec2 gridPos = mod(gl_FragCoord.xy, 3.0);
    int currentPos = int(gridPos.y) * 3 + int(gridPos.x);
    if (currentPos != u_step) {
        gl_FragColor = texture2D(u_state, uv);
        return;
    }
    
    // Sample neighbors and compute new value
    float meanNeighbors = calculateNeighborAverage();
    float original = texture2D(u_original, uv).r;
    
    float estimate = 0.7 * meanNeighbors + 0.3 * original;
    float noise = (rand(gl_FragCoord.xy) * 2.0 - 1.0) * 0.05;
    
    gl_FragColor = vec4(vec3(clamp(estimate + noise, 0.0, 1.0)), 1.0);
}

This implementation demonstrates how probabilistic algorithms can be efficiently parallelized on GPUs, making real-time Bayesian inference possible for image processing applications.

Conclusion

We've shown how to convert sequential Gibbs sampling to a parallel implementation. By identifying which pixels can be updated simultaneously without affecting statistical correctness, we transformed a sequential algorithm into one that works efficiently on parallel hardware.

This is probably a more general note not limited to just Gibbs sampling, and indeed we saw these kinds of tricks applied in a lot of things.