My Notebook
#project
0

scratche-dit pt.1 (DiT implementation in ~130 lines)

2024-08-15 · 12 min reada writeup of a DiT I wrote from scratch!

Goal: Recreate the DiT paper from “scratch” (only pytorch API for now), using only the research paper + other papers online. No online code/github implementations or LLM’s (including copilot) allowed. arxiv.org

…why would you want to do this?

My approach for research papers has always been 1) have an understanding of the paper at a high level, 2) read the code implementing the paper to get a deep understanding (ie. if i don’t understand a concept, look at the code + keep going lower into the hierarchy until I understand it)

A problem thats come up recently is that its actually harder to find open source implementations of certain papers as things get more closed source + theres just something fun about doing it yourself! I haven’t had any experience building diffusion models from scratch (unlike language models, graph neural nets, or RL), so I’m taking this as a challenge to be able to break that abstraction layer & get more mathy!

I’m also planning on implementing this using WSGL on the web, so this whole process should help me understand the fundamentals a lot better.

the result!

Notion image

Check it out above! Its a ~130 line implementation of DiT’s where the code is (hopefully) really readable, and a good learning resource for people! It took a couple after work sessions to get everything working, but I think I came out understanding a lot more about diffusion than I did before.

I think I could probably get the line count down a lot but readability is more important for me.

TODOS:

  • have sin time embeddings instead of encoding the timestep
  • have realtime noising instead of iteratively noising & saving everything
  • implement this all in WSGL and push it onto my website (30% of the way there though it might take half a week to push it through)

quick aside on diffusion:

Playing around with noise and generation actually gave me what I think is an interesting way to think of diffusion. Basically thinking of diffusion models as “guides” in latent space, instead of next noise predictors.

This way you can actually manipulate what the model outputs at sampling time, just using it as a vibe measure of what the real image should look like. (ie. you can use negative controls, or even think of it as an equation that slowly traverses the gradients to give you the correct answer!)

Once you think of it that way, you can actually cross apply a lot of the ideas from backpropogation & navigating gradients here! (learning rates how much to update images etc.) The sample above is with a mini technique where I tried “batch estimating” gradients—I’m experimenting a lot here!

Notion image

🔥
pt. 2 coming soon…

30% of the way done with porting all these layers to webgpu WGSL shaders/kernels, planning on having hippo favicon diffusion on my website soon!!)

my reasoning for pushing pt1/pt2 separately is that I wanted to give the research-y side its own post + me learning WSGL/shaders outside of WebGL would crowd out the rest

💡
Hey there! I’m trying to get a lot more serious about documenting my work/research and being a lot more open with my experimentation process (thanks to all my work at notion ai, I’ve realized documentation is so awesome and everyone I admire does it)

heres a running log of what i ran into/thoughts!

Okay awesome, after reading through the paper it seems like there are mainly two components that I should implement to get it working:

  1. a VAE (Variational Autoencoder)
  2. the actual DiT (Diffusion transformer)

the VAE

I’ve actually had a lot of experience with VAE’s, so I tried to keep this one fast and as barebones as possible. A couple interesting tidbits I ran into here—

  • even if its with images, since VAE’s output “probability maps” using binary cross entropy as the loss function gives better results (*queue the dora music* this is a tool that will help us later)
            eps = 1e-7
def cross_entropy(x,y):
		# also found out that there is a decently high probability of y 
    # being exactly 1 or 0, had to add an epsilon to not get nan

		# this also assumes y has a range of [0, 1]
		
    return -(x*T.log(y+eps) + (1-x)*T.log(1-y +eps)).mean()
          
KL divergence isn’t super intuitive (slightly long rant below)

I already knew the intuition of why KL divergence was needed (to pull everything together in the latent space), but implementing it was slightly weirder. The way I ended up getting the KL divergence intuition down for the equation is:

Where B is your prior (before) and A is your posterior (after). The reason you divide A/B is to get a percentage of before vs after on your probabilities. The log here is really useful mostly for one reason—it scaled proportions linearly. For example ln(1/4) is exactly the negative of ln(4), therefore not skewing our “divergence” too high upward.

Lastly the P(A) at the end actually can be removed (you actually get similar outputs), but it scales the equation so that larger things in the posterior are weighted more.

Then I started out with the definition of a gaussian distribution probability density function (PDF, fancy for integral is 1 and describes probabilities):

This looks super super scary (I really recommend the 3b1b video), but then as soon as you take the integral of the original KLD function, replacing everything with that—it all simplifies to a really nice equation. (I originally forgot about the 1/2 but it turned out it didn’t matter haha)

  • PREDICTING LOG STANDARD DEVIATION
    • I realized this was so important after my loss started skyrocketing with my variance going through the roof—ended up having it predict a smaller value which I scaled up/clamped, then reread the paper and realized it the log

added beta scheduling to make sure the KL ramps up etc.

Notion image
Notion image
Notion image

the diffusion model!

I started with a vague idea of diffusion models, so this was actually really interesting in concretely putting code to shaky ideas. One thing I found really helpful was just plotting Y distributions against predicted distributions—

Notion image
  • noise was really interesting here—i first tried linear noise (literally img += noise )
Notion image
  • then added cosine scheduling!
Notion image

it looked good BUT my model wasn’t training well :/

This actually stumped me for some time during the debugging process, because the original paper mostly glazed over the basics of diffusion models. I ended up reading a good deal into the more basic papers (and having a couple good friends help me out to realize where I was going wrong—thanks so much to @danielmend_) + deeper into them and came upon this formula!

Notion image

then it suddenly clicked…

DISTRIBUTIONSSSSSSSS I was just adding to an image without actually rescaling the image or adjusting the probability distribution. Here it actually looks a lot like binary cross-entropy, mostly because its also derived from KL divergence and integrals of gaussians.

The math above this area is really dense and I’m still not 100% sure how they got there (specifically after the from the proof KL divergences ⇒ the equations) but it intuitively makes sense.

alpha bar is your cumulative noise (a nice closed form way to get t steps down the line) between (0,1) and you scale down x_o a certain amount and scale the noise proportionally. This way the distribution literally loses data instead of adding noise to the data, and your equation then asymptotically becomes noise! If all the equations are scary, I would highly suggest looking at my notebook. This ends up being the final noise schedule!

Notion image

Sampling!!!

1. basic linear sampling

tried this out and it didn’t work very well... im -= sample(im, i, 2) mostly because it doesn’t treat the noise with the correct constants + assumes each timestep is equally important

Notion image
Notion image

2. classifier free guidance!!

yay!! im -= (sample(im, i, 2) - sample(im, i, 10)*0.9)

I actually stumbled upon classifier free guidance by accident through experimenting. Initially I ended up doing a “poor mans cfg” which entailed predicting all 10 number directions, averaging them, then selecting the difference between the mean & the chosen number! This actually worked quite well, but I showed this to my friend and they pointed me to the CFG paper which I implemented a lot faster

Notion image
Notion image

3. trying out the ddpm equations

the equations look scary again, but you’re functionally applying the same noise shift from sampling time, but here—taking some amount off the noise and then adding it to the main image, having it be timestep dependent.

im = (1/math.sqrt(alphas[t])) (im - noise(1-alphas[t])/(math.sqrt(1-alphas_prod[t]))) + random_noise_like(im)*0.001

Notion image
Notion image

4. trying out averaging the samples

trying to take “mini-steps” in the right direction before taking a large step. this seems to work better in practice and produce a lot cleaner images! I liked this because it really aligns with the idea of diffusion models as guides on an “image gradient”.

Notion image
            for i in range(steps):
    t = int((i/steps) * 20)
    noises = []

    for i in range(substeps):
        noise  = sample(im, t, 3) - sample(im, t, 10)*0.4
        noises.append(noise)
        im -= noise*0.11

    noise = T.vstack(noises).mean(dim=0)
    im  = (1/math.sqrt(alphas[t])) *(im - noise*(1-alphas[t])/(math.sqrt(1-alphas_prod[t])))
          

thanks it! if you got here, thanks so much for reading 🎉, if you notice anything weird feel free to email me @ neel.redkar@gmail.com or DM me on twitter!

Thanks for reading! Liked the story? Click the heart
Created with ☕ by @neelr