Goal: Recreate the DiT paper from “scratch” (only pytorch API for now), using only the research paper + other papers online. No online code/github implementations or LLM’s (including copilot) allowed.
My approach for research papers has always been 1) have an understanding of the paper at a high level, 2) read the code implementing the paper to get a deep understanding (ie. if i don’t understand a concept, look at the code + keep going lower into the hierarchy until I understand it)
A problem thats come up recently is that its actually harder to find open source implementations of certain papers as things get more closed source + theres just something fun about doing it yourself! I haven’t had any experience building diffusion models from scratch (unlike language models, graph neural nets, or RL), so I’m taking this as a challenge to be able to break that abstraction layer & get more mathy!
I’m also planning on implementing this using WSGL on the web, so this whole process should help me understand the fundamentals a lot better.
Check it out above! Its a ~130 line implementation of DiT’s where the code is (hopefully) really readable, and a good learning resource for people! It took a couple after work sessions to get everything working, but I think I came out understanding a lot more about diffusion than I did before.
I think I could probably get the line count down a lot but readability is more important for me.
TODOS:
Playing around with noise and generation actually gave me what I think is an interesting way to think of diffusion. Basically thinking of diffusion models as “guides” in latent space, instead of next noise predictors.
This way you can actually manipulate what the model outputs at sampling time, just using it as a vibe measure of what the real image should look like. (ie. you can use negative controls, or even think of it as an equation that slowly traverses the gradients to give you the correct answer!)
Once you think of it that way, you can actually cross apply a lot of the ideas from backpropogation & navigating gradients here! (learning rates how much to update images etc.) The sample above is with a mini technique where I tried “batch estimating” gradients—I’m experimenting a lot here!
Okay awesome, after reading through the paper it seems like there are mainly two components that I should implement to get it working:
I’ve actually had a lot of experience with VAE’s, so I tried to keep this one fast and as barebones as possible. A couple interesting tidbits I ran into here—
eps = 1e-7
def cross_entropy(x,y):
# also found out that there is a decently high probability of y
# being exactly 1 or 0, had to add an epsilon to not get nan
# this also assumes y has a range of [0, 1]
return -(x*T.log(y+eps) + (1-x)*T.log(1-y +eps)).mean()
I already knew the intuition of why KL divergence was needed (to pull everything together in the latent space), but implementing it was slightly weirder. The way I ended up getting the KL divergence intuition down for the equation is:
Where B is your prior (before) and A is your posterior (after). The reason you divide A/B is to get a percentage of before vs after on your probabilities. The log here is really useful mostly for one reason—it scaled proportions linearly. For example ln(1/4) is exactly the negative of ln(4), therefore not skewing our “divergence” too high upward.
Lastly the P(A) at the end actually can be removed (you actually get similar outputs), but it scales the equation so that larger things in the posterior are weighted more.
Then I started out with the definition of a gaussian distribution probability density function (PDF, fancy for integral is 1 and describes probabilities):
This looks super super scary (I really recommend the 3b1b video), but then as soon as you take the integral of the original KLD function, replacing everything with that—it all simplifies to a really nice equation. (I originally forgot about the 1/2 but it turned out it didn’t matter haha)
added beta scheduling to make sure the KL ramps up etc.
I started with a vague idea of diffusion models, so this was actually really interesting in concretely putting code to shaky ideas. One thing I found really helpful was just plotting Y distributions against predicted distributions—
img += noise
)
it looked good BUT my model wasn’t training well :/
This actually stumped me for some time during the debugging process, because the original paper mostly glazed over the basics of diffusion models. I ended up reading a good deal into the more basic papers (and having a couple good friends help me out to realize where I was going wrong—thanks so much to @danielmend_) + deeper into them and came upon this formula!
then it suddenly clicked…
DISTRIBUTIONSSSSSSSS I was just adding to an image without actually rescaling the image or adjusting the probability distribution. Here it actually looks a lot like binary cross-entropy, mostly because its also derived from KL divergence and integrals of gaussians.
The math above this area is really dense and I’m still not 100% sure how they got there (specifically after the from the proof KL divergences ⇒ the equations) but it intuitively makes sense.
alpha bar is your cumulative noise (a nice closed form way to get t steps down the line) between (0,1)
and you scale down x_o a certain amount and scale the noise proportionally. This way the distribution literally loses data instead of adding noise to the data, and your equation then asymptotically becomes noise! If all the equations are scary, I would highly suggest looking at my notebook. This ends up being the final noise schedule!
tried this out and it didn’t work very well... im -= sample(im, i, 2)
mostly because it doesn’t treat the noise with the correct constants + assumes each timestep is equally important
yay!! im -= (sample(im, i, 2) - sample(im, i, 10)*0.9)
I actually stumbled upon classifier free guidance by accident through experimenting. Initially I ended up doing a “poor mans cfg” which entailed predicting all 10 number directions, averaging them, then selecting the difference between the mean & the chosen number! This actually worked quite well, but I showed this to my friend and they pointed me to the CFG paper which I implemented a lot faster
the equations look scary again, but you’re functionally applying the same noise shift from sampling time, but here—taking some amount off the noise and then adding it to the main image, having it be timestep dependent.
im = (1/math.sqrt(alphas[t]))
(im - noise
(1-alphas[t])/(math.sqrt(1-alphas_prod[t]))) + random_noise_like(im)*0.001
trying to take “mini-steps” in the right direction before taking a large step. this seems to work better in practice and produce a lot cleaner images! I liked this because it really aligns with the idea of diffusion models as guides on an “image gradient”.
for i in range(steps):
t = int((i/steps) * 20)
noises = []
for i in range(substeps):
noise = sample(im, t, 3) - sample(im, t, 10)*0.4
noises.append(noise)
im -= noise*0.11
noise = T.vstack(noises).mean(dim=0)
im = (1/math.sqrt(alphas[t])) *(im - noise*(1-alphas[t])/(math.sqrt(1-alphas_prod[t])))
thanks it! if you got here, thanks so much for reading 🎉, if you notice anything weird feel free to email me @ neel.redkar@gmail.com or DM me on twitter!