Modern neural networks are incredibly good at fitting data but pretty bad at continuously learning. Train them on Task A, then on Task B, and they often overwrite what they learned earlier. This problem is known as catastrophic forgetting.
In this tutorial, we will build a small but powerful neural memory agent that can learn a sequence of tasks while still remembering previous ones. The idea is inspired by the brain’s episodic memory system and uses three key ingredients:
- Differentiable episodic memory – a learnable memory matrix the network can read and write to.
- Experience replay – a buffer that replays past examples to avoid forgetting.
- Meta-learning hooks – a way to adapt the model quickly on new tasks.
Why Continual Learning is Hard (and Why Memory Helps)
Standard deep learning assumes a fixed dataset: shuffle, train, converge, deploy. But real-world agents — robots, trading bots, adaptive scientific tools — encounter streams of data where the distribution keeps changing over time. In these settings, a naive training loop quickly overwrites old knowledge.
Biological brains handle this by combining several types of memory:
- Short-term / working memory in prefrontal cortex for current context.
- Episodic memory in the hippocampus to store rich “episodes” of experience.
- Slow structural learning in cortical networks over longer timescales.
Our neural agent will mimic a tiny part of this: an episodic memory module that can store compressed representations of what it has seen, and retrieve them later using similarity-based addressing.
Step 1: Setting Up the Memory Configuration
We start by defining a configuration object that controls the structure of our episodic memory. This is helpful if you later want to experiment with different memory sizes or head counts.
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import deque
import matplotlib.pyplot as plt
@dataclass
class EpisodicMemoryConfig:
memory_slots: int = 128 # how many rows in the memory matrix
slot_dim: int = 64 # dimensionality of each memory slot
read_heads: int = 3 # how many parallel read operations
write_heads: int = 1 # (kept for extensibility)
Think of this as a little key–value store inside your network: each row is a slot that can contain a vector embedding of some experience (e.g. a task, an input pattern, or a hidden state snapshot).
Step 2: Implementing Differentiable Episodic Memory
Our episodic memory is simply a learnable matrix that the model can address by content: given a query vector (a “key”), it computes similarities to all memory slots and uses a softmax to turn these scores into attention weights.
class EpisodicMemory(nn.Module):
def __init__(self, config: EpisodicMemoryConfig):
super().__init__()
self.config = config
# memory: [memory_slots, slot_dim]
self.register_buffer("memory",
torch.zeros(config.memory_slots, config.slot_dim))
self.register_buffer("usage",
torch.zeros(config.memory_slots))
def address(self, key, strength):
# key: [slot_dim]
key_norm = F.normalize(key, dim=-1)
mem_norm = F.normalize(self.memory, dim=-1)
similarity = torch.matmul(key_norm, mem_norm.t()) # [memory_slots]
return F.softmax(strength * similarity, dim=-1)
def read(self, keys, strengths):
reads = []
for i in range(self.config.read_heads):
weights = self.address(keys[i], strengths[i]) # [memory_slots]
read_vec = torch.matmul(weights, self.memory) # [slot_dim]
reads.append(read_vec)
# concatenate all read heads
return torch.cat(reads, dim=-1) # [read_heads * slot_dim]
def write(self, key, vector, erase, strength):
weights = self.address(key, strength) # [memory_slots]
erase_mat = torch.outer(weights.squeeze(), erase) # [slots, dim]
add_mat = torch.outer(weights.squeeze(), vector) # [slots, dim]
self.memory = self.memory * (1 - erase_mat) + add_mat
self.usage = 0.99 * self.usage + weights.squeeze()
The usage vector is a simple heuristic that tracks which memory slots are being touched
more often. In a more advanced version, you could use it to implement “least-used” or “age-based”
replacement strategies.
Step 3: Designing the Memory Controller
The controller plays the role of a tiny “prefrontal cortex”. It receives the current input, processes it through an LSTM, and then decides:
- Which keys to use for reading from memory.
- What to write back into memory (and how much to erase).
- What output to emit for the current time step.
class MemoryController(nn.Module):
def __init__(self, input_dim, hidden_dim, config: EpisodicMemoryConfig):
super().__init__()
self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
total_read_dim = config.read_heads * config.slot_dim
# Read parameters
self.read_keys = nn.Linear(hidden_dim, total_read_dim)
self.read_strength = nn.Linear(hidden_dim, config.read_heads)
# Write parameters
self.write_key = nn.Linear(hidden_dim, config.slot_dim)
self.write_vector = nn.Linear(hidden_dim, config.slot_dim)
self.erase_vector = nn.Linear(hidden_dim, config.slot_dim)
self.write_strength = nn.Linear(hidden_dim, 1)
# Final output on top of controller state + read vectors
self.output = nn.Linear(hidden_dim + total_read_dim, input_dim)
def forward(self, x, memory: EpisodicMemory, hidden=None):
# x: [feat_dim] → add batch/time dim for LSTM
lstm_out, hidden = self.lstm(x.unsqueeze(0), hidden)
state = lstm_out.squeeze(0) # [hidden_dim]
# Compute read parameters
read_keys = self.read_keys(state).view(memory.config.read_heads, -1)
read_s = F.softplus(self.read_strength(state))
# Compute write parameters
write_k = self.write_key(state)
write_v = torch.tanh(self.write_vector(state))
erase_v = torch.sigmoid(self.erase_vector(state))
write_s = F.softplus(self.write_strength(state))
# Interact with memory
read_vecs = memory.read(read_keys, read_s)
memory.write(write_k, write_v, erase_v, write_s)
# Produce output
combined = torch.cat([state, read_vecs], dim=-1)
out = self.output(combined)
return out, hidden
In practice, this controller could sit on top of sensory encoders (images, audio, text). For our demo, we’ll work with synthetic vector inputs so we can focus on the memory mechanics.
Step 4: Experience Replay for Stability
If we simply train on tasks one after another, the controller will still forget. To mitigate this, we maintain a small replay buffer of past examples and periodically mix them into the current training batch.
class ReplayBuffer:
def __init__(self, capacity=5000):
self.buffer = deque(maxlen=capacity)
self.priority = deque(maxlen=capacity)
def push(self, experience, p=1.0):
self.buffer.append(experience)
self.priority.append(p)
def __len__(self):
return len(self.buffer)
def sample(self, batch_size):
probs = np.array(self.priority, dtype=np.float32)
probs = probs / probs.sum()
idx = np.random.choice(len(self.buffer), batch_size, p=probs)
return [self.buffer[i] for i in idx]
Here we store a priority per sample (e.g. proportional to the loss). Harder or more surprising examples are replayed slightly more often, similar to prioritized experience replay in reinforcement learning.
Optional: A Simple Meta-Learning Hook
To keep the blog focused, we only sketch a minimal meta-learning function. The idea is to create a fast adaptation step that updates a copy of the parameters on a small support set.
class MetaLearner(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def adapt(self, x, y, steps=3, lr=0.01):
fast_params = {n: p.clone() for n, p in self.model.named_parameters()}
for _ in range(steps):
y_pred, _ = self.model(x, self.model.memory)
loss = F.mse_loss(y_pred, y)
grads = torch.autograd.grad(loss, self.model.parameters(),
create_graph=True)
fast_params = {
n: p - lr * g for (n, p), g in zip(fast_params.items(), grads)
}
return fast_params
You can later extend this into a full MAML-style loop where the outer optimizer updates the base parameters based on post-adaptation performance.
Step 5: Wrapping Everything into a Continual Learning Agent
Now we build a small ContinualAgent class that holds:
- One episodic memory bank.
- One memory controller.
- A replay buffer.
- An optimizer (Adam).
class ContinualAgent:
def __init__(self, input_dim=64, hidden_dim=128):
self.config = EpisodicMemoryConfig()
self.memory = EpisodicMemory(self.config)
self.controller = MemoryController(input_dim, hidden_dim, self.config)
self.replay = ReplayBuffer()
self.optimizer = torch.optim.Adam(self.controller.parameters(),
lr=1e-3)
def train_step(self, x, y, replay_ratio=0.3):
self.optimizer.zero_grad()
# 1) Current example
pred, _ = self.controller(x, self.memory)
loss = F.mse_loss(pred, y)
self.replay.push((x.detach(), y.detach()), p=loss.item() + 1e-6)
# 2) Mix in a few replay samples
if len(self.replay) >= 16:
replay_batch = self.replay.sample(batch_size=8)
for rx, ry in replay_batch:
rpred, _ = self.controller(rx, self.memory)
rloss = F.mse_loss(rpred, ry)
loss = loss + replay_ratio * rloss
loss.backward()
torch.nn.utils.clip_grad_norm_(self.controller.parameters(), 1.0)
self.optimizer.step()
return loss.item()
def evaluate(self, data):
self.controller.eval()
total = 0.0
with torch.no_grad():
for x, y in data:
pred, _ = self.controller(x, self.memory)
total += F.mse_loss(pred, y).item()
self.controller.train()
return total / len(data)
The key idea: during training on Task t, we constantly replay a few samples from tasks 1 ... t−1. The memory bank simultaneously accumulates representations that are useful across tasks, rather than only for the most recent one.
Step 6: Creating Synthetic Tasks for Continual Learning
To keep things lightweight, we will generate a few synthetic regression tasks. Each task maps a random input vector to a transformed version using a different nonlinearity.
def make_task(task_id: int, n_samples: int = 100):
torch.manual_seed(task_id)
x = torch.randn(n_samples, 64)
if task_id == 0:
# Task 1: sine of the mean
y = torch.sin(x.mean(dim=1, keepdim=True).expand(-1, 64))
elif task_id == 1:
# Task 2: scaled cosine
y = 0.5 * torch.cos(x.mean(dim=1, keepdim=True).expand(-1, 64))
else:
# Task 3+: tanh with shifting
y = torch.tanh(x + task_id * 0.3)
return [(x[i], y[i]) for i in range(n_samples)]
This setup gives us tasks that are related but not identical, which is exactly the regime where continual learning becomes interesting.
Step 7: Running the Continual Learning Demo
Let us put it all together and watch how the test error evolves as the agent encounters new tasks.
def run_demo(num_tasks: int = 4):
print(" Neural Memory Agent — Continual Learning Demo")
print("=" * 60)
agent = ContinualAgent()
history = {"task": [], "error": []}
for task_id in range(num_tasks):
print(f"\n Training on Task {task_id + 1}/{num_tasks}")
train_data = make_task(task_id, n_samples=80)
test_data = make_task(task_id, n_samples=30)
for epoch in range(20):
total_loss = 0.0
for x, y in train_data:
total_loss += agent.train_step(x, y)
if epoch % 5 == 0:
print(f" Epoch {epoch:02d} | Loss = {total_loss/len(train_data):.4f}")
# Evaluate on all tasks seen so far
print("\n Evaluation on all seen tasks:")
for eval_id in range(task_id + 1):
eval_data = make_task(eval_id, n_samples=30)
err = agent.evaluate(eval_data)
print(f" Task {eval_id + 1}: MSE = {err:.4f}")
if eval_id == task_id:
history["task"].append(eval_id + 1)
history["error"].append(err)
return agent, history
Visualizing Memory and Performance
After training, it is useful to look both at the final memory matrix and the evolution of test error across tasks.
def plot_results(agent, history):
fig, axes = plt.subplots(1, 2, figsize=(13, 5))
# Left: memory heatmap
ax = axes[0]
mem = agent.memory.memory.detach().numpy()
im = ax.imshow(mem, aspect="auto", cmap="viridis")
ax.set_title("Episodic Memory Bank")
ax.set_xlabel("Memory Dimension")
ax.set_ylabel("Memory Slot")
fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
# Right: error over tasks
ax = axes[1]
ax.plot(history["task"], history["error"], marker="o", linewidth=2)
ax.set_title("Test Error on Newest Task")
ax.set_xlabel("Task number")
ax.set_ylabel("MSE")
ax.grid(alpha=0.3)
plt.tight_layout()
plt.savefig("neural_memory_agent_results.png", dpi=150)
plt.show()
print(" Saved plot to 'neural_memory_agent_results.png'")
if __name__ == "__main__":
agent, hist = run_demo(num_tasks=4)
plot_results(agent, hist)
Interpreting the Results
When you run the script, you should see:
- Reasonable training loss curves for each task, even as tasks become more complex.
- Test errors on earlier tasks that stay bounded rather than exploding.
- A non-trivial pattern in the memory heatmap, indicating that slots specialize to different regions of task space.
This is not a production-ready continual learning system, but it already demonstrates an important point: as soon as you give your network the ability to store and retrieve compressed experiences, plus a mechanism to revisit past data, its tendency to catastrophically forget is greatly reduced.
Where to Go Next
Here are a few directions you can explore on top of this tutorial:
- Richer tasks: swap the synthetic regression for RL environments or real datasets.
- Task descriptors: add an explicit “task embedding” input and store it in memory.
- Better meta-learning: implement full MAML or Reptile on top of the controller.
- Memory management: use
usageto implement explicit allocation and freeing. - Visualization tools: project memory slots into 2D (PCA/UMAP) and color by task.
If you work in science or engineering, you can imagine plugging this kind of agent into tools that need to adapt over days or weeks, for example, a lab analysis system that gradually learns from each new experiment while still remembering the patterns it saw months ago.
I’ll be experimenting with these ideas inside DataLens.Tools for building adaptive, AI-assisted analysis workflows for neuroscience and biology data.
Final Thoughts
Continual learning sits at the intersection of deep learning, neuroscience, and real-world deployment. By combining differentiable memory, episodic replay, and meta-learning, we can move a little closer to agents that behave less like static models and more like learning systems.
Feel free to adapt this code, plug in your own tasks, and extend the architecture. If you build something interesting with it, I would love to hear about it.