Prompting by Activation Maximization
Prompt synthesis with activation maximization achieves 95.9% on Yelp Review Polarity (sentiment classificaton task) with Llama-3.2-1B-Instruct and a 4-token prompt. Code on on GitHub.
Motivation
I'm learning PyTorch, and I'm really into activation maximization. I see a lot of potential beyond tricking classifiers and mechanistic intepretation. I've got a project in mind, and I'm building up to it. This experiment is a step on that journey.
Activation Maximization?
The model is a function with adjustable weights (coefficients). We arrange a dataset of inputs paired with ideal outputs, and adjust the model's weights to minimize the difference between the function's actual and expected output. With even a handful of coefficients, the search space is tremendous, but backpropagation and gradient descent help find useful minima representing useful and productive functions.
Activation maximization inverts this. Having a trained model and target output, you can adjust the input instead of the weights. The result is an input which provokes a desired output from the model.
In PyTorch, this is easy. You initialize some subject, name it as the optimizer's target, freeze the model weights and reorganize the training loop. Loss converges, like any other training session.
Basic Experiment
For my first experiment, I used the MNIST dataset, which has 60,000 examples of handwritten digits. (You can find MNIST on Kaggle.) On this, I trained a basic stack with three convolutional networks feeding into a two-layer perceptron. This scored 99% accuracy on the 10,000 sample benchmark.

Then, having the model, I initialize an image with random noise, and name it as the optimizer's target.
import torch
from torch import nn
device = 'cuda'
from model import Classifier
filename = 'classifier.202508042316.pt'
model = Classifier()
model.load_state_dict(
torch.load(
filename,
weights_only=True))
model.to(device)
model.eval()
for p in model.parameters():
p.requires_grad_(False)
import torch.nn.functional as F
subject = torch.rand(1, 28, 28)
subject = subject.to(device)
subject.requires_grad_(True)
optimizer = torch.optim.Adam(
[subject],
lr=1e-3)
target = torch.zeros(10)
target[7] = 1.0 # do a seven
target = target.unsqueeze(0).to(device)
step_count = 512
for step in range(step_count):
optimizer.zero_grad()
image = subject.sigmoid()
logits = model(image)
loss = F.cross_entropy(
logits,
target)
loss.backward()
optimizer.step()
# save image here
We do the forward pass, compute loss, back propagation and run the optimizer. But, instead of amending the weights, we iteratively change the subject to better provoke the desired classification.


I'll confess, I thought this might look more like a three or a seven. I suspect if I classify pixels, insead of the image, I could get something like the melting dogs of Deep Dream.
Prompt Engineering
Great, so, we can provoke behavior in a model we possess. When do we ever do that?
Quite often, actually! It's called "prompt engineering".
In prompting, we take an off-the-shelf, pretrained language model and provide instructions in-band. A portion of the context window is occupied by these hand-written instructions.
Having a dataset, one can also tune the model, or train a LoRA. But these come with tradeoffs, e.g. catestrophic forgetting, and loss of generalization. General models, once resident in memory, have the unique ability to handle many jobs, because instructions are given in-band. This is why small projects save money paying by the token. A hosted LLM occupies VRAM at cost, even when it's idle. Shared hosts like OpenAI keep LLMs supplied constantly with different users' workloads. You cannot do this with tuned models.
Prompting typically means writing long-form texts with examples and explanations. If you want quality at scale, you can't avoid building a dataset; you need automated testing. But you're still searching the prompt space (which is tremendous) by speculating what verbal manipulations might influence the model. These prompts can be quite large.
Classify the given text as either negative (`1`) or positive (`2`). Write nothing else.
One can test prompts a faster with generative AI, but you're still shooting in the dark. That's because our best and most powerful search tool doesn't work on tokenized prompts.
Activation-Maxxing vs. The Prompt
Gradient descent is the search that works, but you can't activation-max a text prompt. LLMs are wrapped in a non-contiuous interface; the tokenizer. For example, "Lorem ipsum" in ChatGPT-4o encodes to [61495, 38714]
. We've written these as numbers, but they're not numbers; they're surrogate keys. You can't gradient descent your way to a surrogate key.
Thankfully, the keys are a lie.
The actual differentiable function at work recieves the context window as a tensor, produced as a stack of embeddings.
Embeddings?
To train a differentiable function, we need a continuous space on both ends. But human language is surrogates all the way down. So, we reframe the text as continuous. On the input side of the LLM, we use embeddings; each token is assigned a vector, such that tokens' vectors are spatially grouped according to their observed statistical relationships to other tokens.
Before the LLM can be trained, each token in the palette must be embedded together in a vector space. This means two things which are highly important.
- Their identity now has semantic valuable.
- They now reside in a continuous space.
But there's a catch for us to exploit.
Consider Llama 3.2, which I'm using today for this experiment. This model features 128,256 distinct tokens. A token's key fits comfortably in just 17 bits. But the vectors are 2048 dimensional. This means 65,536 bits at full precision, or 32,768 at half.
The overwhelming bulk of the vector space is unreferenced by the LUT. But in a perverse twist of Hyrum's Law, the entire vector space is semantically meaningful.
The model is function. Nothing more, nothing less. And a function maps each element of its domain to exactly one element of its codomain.
The Plan
The context window is a tensor of the 2nd degree, in which each row is populated by an embedding from the look-up table. When you write a prompt by hand, your text is tokenized, and the context window is populated accordingly.
We'll build it differently.
The model's input and output from the corpus will be tokenized as usual.
But the prompt will take the form of a 2D tensor, randomly initialized, which we'll concatenate with each input for the forward pass. We'll run the optimizer against the prompt, and in this way derive a prompt which provokes the desired behavior in the model.
The Task
For a simple task, I've chosen Yelp Review Polarity, a simple sentiment classificaton task. This offers a training set of 560,000, and a test set of 38,000.
"2","I drove by yesterday to get a sneak peak. It re-opens on July 14th and I can't wait to take my kids. The new range looks amazing. The entire range appears to be turf, which may or many not help your game, but it looks really nice. The tee boxes look state of the art and the club house looks like something you'll see on a newer course. Can't wait to experience it!"
You can read more about this dataset on Yelp or in its first use Character-level Convolutional Networks for Text Classification (Zhang, Zhau, LeCun) or download it on Kaggle.
The Model
For this exercise, I chose Llama-3.2-1B-Instruct, because it's just small enough for my europoor computer's Nvidia 3090.

Execution
First, I ran the test set with hand-written prompt for reference.
Prompt | Accuracy | Test |
---|---|---|
40 tokens | 56% | 21550 / 38000 |
98 tokens | 57% | 21948 / 38000 |
If these figures surprise you, you're probably accustomed to large models. GPT3 would crush this task, but GPT3 is 183B parameters. Small, generatively pretrained models have deep knowlege about the world but aren't so intelligent they cant follow even simple instructions. Frankly, I'm surprised it even complied enough to write the right categories.
Let's see if we can activation-max a working prompt.
Train | Prompt | Epoch | Acc. | Test |
---|---|---|---|---|
5600 | 16 | 1 | 57% | 21916 / 38000 |
5600 | 16 | 3 | 94% | 2269 / 2400 |
5600 | 8 | 3 | 91% | 2197 / 2400 |
5600 | 4 | 3 | 86% | 2078 / 2400 |
5600 | 4 | 5 | 87.5% | 2100 / 2400 |
58000 | 4 | 5 | 95.9% | 36442 / 38000 |
It wipes the floor with a hand written prompt.
I haven't touched or tuned the model proper. There's not even a LoRA. The 4-token primer is 16,384 kilobyte tensor (bfloat16
), given in-band.
Useful?
I have no idea :)
In theory, the advantage here is about task switching. If the GPU is going to load one model, and perform one job, we might as well tune the model, or train a LoRA. However, if we'd like the model to rapidly accept different jobs, we can feed task-specific primers in with the context window, exactly as if we're prompting a strong, generalized model.
Given the unending RAM shortage, there may well be a place for this technique.
Done Before?
Probably! For didactic reasons, I won't check until after I've completed this experiment and published this article.
Code
It's not an ergonomic use of a model from Hugging Face. But, I've placed the code at github.com/JoeCooper/PromptByMax. Let me know if you find some value in it. Thank you for reading.