By Brian Muhia
https://github.com/poppingtonic/semisupervised-rice-africa
The historical amount of pesticide use in agriculture has been a constant producer of illness as a secondary effect, that can be reduced by early plant disease detection. This document by Our World In Data shows which countries are most affected by massive amounts of pesticide use. Artificial intelligence should be helpful in more targeted, early interventions that help people to avoid using large amounts of pesticides. This project is to ensure the intelligence is interpretable, trustworthy, reproducible and generative. First we train, then we explain a family of self-supervised vision-only models in order to study the value of rigorous interpretability methods, and their potential for educating practitioners on their uses - leading eventually to a world where more practitioners routinely explore model capabilities and limits with more rigor than is the current norm. The goal here is to improve food safety, and also AI safety by exemplifying a norm of rigorous interpretability research.
<aside> đź’ˇ The loss curves in the cover image can be reproduced during the training of an unsupervised encoder (from imagenet weights, and from scratch) that is further finetuned for classification in two ways: mlp-only, and discriminatively. The notebooks are in the linked repo. In the plot I included about 60 different iterations, for different image scales and architectures.
</aside>
This article starts with a walk-through of my thoughts about the different training runs from this challenge to help Egypt, while my memory is still fresh. First thought that occurs is that there seem to be two phases to training this model if you’re fine-tuning from ImageNet weights.
This is recommended as you can do this in a few hours on a laptop as a proof of concept, as opposed to a few days of waiting for the from-scratch model to randomly wander its way down to low-loss land.
DALLE-2 (latent guided diffusion model), prompt: “A residual network (50 layers), on its long random walk down the rolling hills of loss land, in search of the mythical global minimum. Unreal engine, artstation, afropunk, fps”
First, we use the ImageNet weights to carve a significant chunk out of the error before plateauing. The dark red trajectory (scroll up to the cover image) shows a Res2Next50 making it out of what would have been several hours while training from scratch. Thankfully fp16 enables a high data rate and high batch size. In practice, we need fast iteration time, and image size is one way to get it. This lets us make decisions about killing training runs that aren’t going as fast as we need. Through careful monitoring of the speed of early training, I reload the best of these plateaued training checkpoints, first running a hyper-parameter search using Learner.lr_find. The resulting learning rate should enable a gradual change in the gradient.
In the second phase, we need to initialize the optimizer such that it is in a state that it can find a loss trajectory that is gradual and not too fast, and maintain this slow random walk for a long time. With some manufactured luck i.e. many tries, we learn which hyperparameters are right. After some training time it uncovers “something significant” about the data/task, solving the problem much faster. We can/should do interpretability research to explain this “bump” in the loss curve. Something similar to individual contributions in https://transformer-circuits.pub, or Neel Nanda’s Grokking work.
We have two “converged” models, and even more checkpoints interrupted at various stages in training. They are trained in a unique data setup that allows us to test interpretability tools and methods:
https://docs.google.com/document/d/1baKtI8zcZ33WswpWGzi-pCSx7fCOqxml-Lpv18SXpYY/edit?usp=sharing