My research interests, and perhaps yours
Thank you for your interest in working with me! Projects are likely going to be in one of these areas:
- Understanding how planning algorithms emerge inside neural networks.
- Making NN interpretability more of an empirically grounded field of ML.
- What is it good for? How can we tell whether a interpretability hypothesis about the internals of the model is correct or useful?
- Automating interpretability.
I’m also happy to hear some pitches for working in other things relevant to alignment.
How to apply
You should work on the part specific to this application for at most 4 hours.
- Do not spend time polishing your CV! I will spend very little time looking at it. Instead, focus on the next two steps.
- Please complete the Questions in this document. You should first read the definition of serial cost, and if you are confused read the introduction. Please write down the times that you work on each question, excluding breaks. Aim to spend 1 hour or less on this part of the assessment.
- Please complete the programming assessment, which is about improving an automatic interpretability baseline. You can spend the remaining time on this part of the assessment, up to 4 hours total.
Afterwards, submit your answers to the questions, how much time you spent on each part, and a Git bundle with your code.
Questions
Search usually gives better results, per amount of computation, if we search for plans one after the other. This is because we can use the evaluation of previously considered plans to guide the next step of search. This made me interested in the number of serial steps that various NN architectures use, at training and deployment time.
For the purposes of the following questions, we consider all element-wise operations (scalar multiplication, addition, inverse square root, GeLU, ReLU) as elementary computation steps.
Some of the necessary assumptions may be missing from the questions. Please indicate clearly how you arrived at the answer, and any extra assumptions that you made. Try to make assumptions only when necessary.
Remember to get as close as you can to the minimum number of serial computational steps!
Please spend up to 1 hour on this task
Definition: serial cost of a parallelizable problem
The serial cost $C(P)$ of a computational program $P$ is the longest chain of steps that depend on information from previous steps, that are needed to compute $P$. In more formal terms: if $\text{solve}(P)$ is the set of algorithms that solves $P$, then the serial cost of $P$ is:
$$C(P) = \min_{A \in \text{solve}(P)} \text{max-dependent-step-chain}(A)$$
where $\text{max-dependent-step-chain}(A)$ is the longest chain of steps that depend on information from previous steps, in the algorithm $A$.
In other words, if you had infinite memory and infinitely many CPUs with no communication costs, but each CPU step still took some time, how many steps would you have to wait for the answer?
I’m going to ask you to come up with tight upper bounds for the $C(P)$ of some common ML operations.
Example question: serial cost of adding two vectors
You have two vectors $a,b$ of length $10$. What is the serial cost of $\sigma(a + b)$, where $\sigma$ is the sigmoid function applied elementwise? What is the total cost?
Example answer
The total cost this operation is the cost of computing each of the additions $a_i + b_i$ for each element $i$, plus the cost of applying $\sigma$ to each result. Thus, the total cost is $2\cdot 10 = 20$ operations (for each $i$, 1 addition and 1 application of $\sigma$.)
The serial cost, however, is just 2: one addition and one application of $\sigma$. Since the value of one element does not depend on the value of other elements, we can just compute them in parallel.
Question 1: GRU RNN forward pass, serial cost
You have a sequence of 100 tokens, embedded into a 1024-dimensional space. You would like to process it using a 12-layer GRU RNN (see e.g. the Pytorch docs), with a 1024-dimensional hidden size, with no dropout. The vocabulary size is 16000 and there is a softmax after the unembedding.
What is the serial cost of computing a forward pass of the RNN for all tokens? Make sure to reduce this number as much as possible, while still correctly computing the RNN forward pass. Note down your assumptions and justify your answer. Please write down the actual number of steps and not an asymptotic expression.
Question 2: Gradient, serial cost
Consider the previous model. You want to train it to predict the next token, using the cross-entropy loss. This means you need the gradient of the loss with respect to the parameters.
What is the serial cost of computing the gradient of the cross-entropy loss with respect to the parameters? Make sure to reduce this number as much as possible, while still correctly computing gradient. Note down your assumptions and justify your answer. It’s OK to be somewhat imprecise, but don’t go all the way to the level of asymptotics; analyze this like you would a practical algorithm optimization problem.
Question 3: total cost, and training vs. inference
You have trained your RNN, and it is time to do inference! What is the serial cost at test time, i.e. when you generate a sequence of tokens with the RNN?
Consider also the total cost (the usual notion of cost; equivalent to the serial cost assuming 1 processor). For some sequence length, what is the total cost of a forward pass at training time? What about inference time?
Programming assessment: edges for subnetwork probing
One of the baselines in Automatic Circuit Discovery is Subnetwork Probing (SP). That baseline has a problem in competing with ACDC: it operates at the level of nodes, instead of edges. That is to say: in a forward pass of the network, Subnetwork Probing either masks a head/MLP or does not. In contrast, ACDC masks a different set of heads/MLPs for each input to a head/MLP.
The task here is to modify SP so it also operates on the level of edges. I think it is unlikely that someone unfamiliar with the code will finish all of these in 3 hours. So, please submit your work in progress at the end of 3h, and don’t feel bad if you did not finish!