PyTorch is awesome, and it provides a very effective way to execute ML code fast. What it lacks is surrounding infrastructure to make general debugging and discovery process better. Other more official wrapper frameworks sort of don't make sense to me, so this is an attempt at recreating a robust suite of tools that makes sense.
Table of contents:
Let's see an example:
from k1lib.imports import *
k1lib.imports
is just a file that imports lots of common utilities, so that importing stuff is easier and quicker.
class SkipBlock(nn.Module):
def __init__(self, hiddenDim=10):
super().__init__()
def gen(): return nn.Linear(hiddenDim, hiddenDim), nn.LeakyReLU()
self.seq = nn.Sequential(*gen(), *gen(), *gen())
def forward(self, x):
return self.seq(x) + x
class Network(nn.Module):
def __init__(self, hiddenDim=10, blocks=3, block=SkipBlock):
super().__init__()
layers = [nn.Linear(1, hiddenDim), nn.LeakyReLU()]
layers += [block(hiddenDim) for _ in range(blocks)]
layers += [nn.Linear(hiddenDim, 1)]
self.bulk = nn.Sequential(*layers)
def forward(self, x):
return self.bulk(x)
Here is our network. Just a normal feed-forward network, with skip blocks in the middle.
def dataF(bs=32, epochs=200):
return torch.linspace(-5, 5, 1000) | apply(op().item()) | apply(lambda x: (x, math.exp(x))) | randomize(None) | splitW() |\
(repeatFrom() | batched(bs) | (transpose() | ((unsqueeze(1) | toTensor()) + toTensor())).all()).all() | stagger.tv(epochs) | toList()
def newL(*args, **kwargs):
l = k1lib.Learner()
l.model = Network(*args, **kwargs)
l.data = dataF(64, 200)
l.opt = optim.Adam(l.model.parameters(), lr=1e-2)
l.lossF = lambda x, y: ((x.squeeze() - y)**2).mean()
l.cbs.add(Cbs.ModifyBatch(lambda x, y: (x[:, None], y)))
l.cbs.add(Cbs.DType(torch.float32))
l.cbs.add(Cbs.CancelOnLowLoss(1, epochMode=True))
l.css = """SkipBlock #0: HookParam
SkipBlock: HookModule"""
def evaluate(self):
xbs, ybs, ys = self.Recorder.record(1, 3)
xbs = torch.vstack(xbs).squeeze()
ybs = torch.vstack([yb[:, None] for yb in ybs]).squeeze()
ys = torch.vstack(ys).squeeze()
plt.plot(xbs, ys.detach(), ".")
l.evaluate = partial(evaluate, l)
return l
l = newL()
l.run(10);
Progress: 30%, epoch: 3/10, batch: 0/200, elapsed: 1.49s, loss: 0.018542366102337837 Run cancelled: Low loss 1 ([10.633015524595976, 2.2217107348144056, 0.0817870583734475] actual) achieved!.
Here is where things get a little more interesting. k1lib.Learner
is the main wrapper where training will take place. It has 4 basic parameters that must be set before training: model, data loader, optimizer, and loss function.
Tip: docs are tailored for each object so you can do
print(obj)
or justobj
in a code cell
l.cbs
Callbacks: - CoreNormal - Profiler - ProgressBar - DontTrainValid - HookModule - HookParam - LossF - DType - ModifyBatch - Recorder - Loss - Accuracy - ParamFinder - CancelOnExplosion - CancelOnLowLoss Use... - cbs.add(cb[, name]): to add a callback with a name - cbs("startRun"): to trigger a specific checkpoint, this case "startRun" - cbs.Loss: to get a specific callback by name, this case "Loss" - cbs[i]: to get specific callback by index - cbs.timings: to get callback execution times - cbs.checkpointGraph(): to graph checkpoint calling orders - cbs.context(): context manager that will detach all Callbacks attached inside the context - cbs.suspend("Loss", "Cuda"): context manager to temporarily prevent triggering checkpoints
There're lots of Callbacks. What they are will be discussed later, but here's a tour of a few of them:
l = newL(); l.ParamFinder.plot(samples=1000)[:0.99]
Progress: 0%, epoch: 1/1000, batch: 40/200, elapsed: 0.53s, loss: 2318.115478515625 Run cancelled: Loss increases significantly. Suggested param: 4.100944749601106e-05
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt Reminder: slice range here is actually [0, 1], because it's kinda hard to slice the normal way
As advertised, this callback searches for a perfect parameter for the network.
l = newL(); l.run(10); l.Loss
Progress: 20%, epoch: 2/10, batch: 0/200, elapsed: 1.05s, loss: 0.33111751079559326 Run cancelled: Low loss 1 ([9.34933760613203, 0.24827303597703576] actual) achieved!.
Callback `Loss`, use... - cb.train: for all training losses over all epochs and batches (#epochs * #batches) - cb.valid: for all validation losses over all epochs and batches (#epochs * #batches) - cb.plot(): to plot the 2 above - cb.epoch: for average losses of each epochs - cb.Landscape: for loss-landscape-plotting Callback - cb.something: to get specific attribute "something" from learner if not available - cb.withCheckpoint(checkpoint, f): to quickly insert an event handler - cb.detach(): to remove itself from its parent Callbacks
l.Loss.plot()
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt Reminder: the actual slice you put in is for the training plot. The valid loss's plot will update automatically to be in the same time frame
Data type returned is k1lib.viz.SliceablePlot
, so you can zoom the plot in a specific range, like this:
l.Loss.plot()[120:]
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt Reminder: the actual slice you put in is for the training plot. The valid loss's plot will update automatically to be in the same time frame
Notice how original train range is [0, 250]
, and valid range is [0, 60]
. When sliced with [120:]
, train's range sliced as planned from the middle to end, and valid's range adapting and also sliced from middle to end ([30:]
).
l.Loss.Landscape.plot()
Progress: 100%, 4s 8/8 Finished [-2.818, 2.818] range Run cancelled: Landscape finished.
l.Loss.Landscape.plot()
Progress: 100%, 4s 8/8 Finished [-2.818, 2.818] range Run cancelled: Landscape finished.
Oh and yeah, this callback can give you a quick view into how the landscape is. The center point (0, 0) is always the lowest portion of the landscape, so that tells us the network has learned stuff.
l.HookParam
Callback `HookParam`: 6 params, 134 means and stds each: 0. bulk.2.seq.0.weight 1. bulk.2.seq.0.bias 2. bulk.3.seq.0.weight 3. bulk.3.seq.0.bias 4. bulk.4.seq.0.weight 5. bulk.4.seq.0.bias Use... - p.plot(): to quickly look at everything - p[i]: to view a single param - p[a:b]: to get a new HookParam with selected params - p.css("..."): to select a specific subset of modules only - cb.something: to get specific attribute "something" from learner if not available - cb.withCheckpoint(checkpoint, f): to quickly insert an event handler - cb.detach(): to remove itself from its parent Callbacks
l.HookParam.plot()
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt
This tracks parameters' means, stds, mins and maxs while training. You can also display only certain number of parameters:
l.HookParam[::2].plot()[50:]
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt
l.HookModule.plot()
Sliceable plot. Can... - p[a:b]: to focus on a specific range of the plot - p.yscale("log"): to perform operation as if you're using plt
Pretty much same thing as before. This callback hooks into selected modules, and captures the forward and backward passes. Both HookParam
and HookModule
will only hook into selected modules (by default all is selected):
l.selector
ModuleSelector: root: Network bulk: Sequential 0: Linear 1: LeakyReLU 2: SkipBlock HookModule seq: Sequential 0: Linear HookParam 1: LeakyReLU 2: Linear 3: LeakyReLU 4: Linear 5: LeakyReLU 3: SkipBlock HookModule seq: Sequential 0: Linear HookParam 1: LeakyReLU 2: Linear 3: LeakyReLU 4: Linear 5: LeakyReLU 4: SkipBlock HookModule seq: Sequential 0: Linear HookParam 1: LeakyReLU 2: Linear 3: LeakyReLU 4: Linear 5: LeakyReLU 5: Linear Can... - mS.deepestDepth: get deepest depth possible - mS.nn: get the underlying nn.Module object - mS.apply(f): apply to self and all descendants - "HookModule" in mS: whether this module has a specified prop - mS.highlight(prop): highlights all modules with specified prop - mS.parse([..., ...]): parses extra css - mS.directParams: get Dict[str, nn.Parameter] that are directly under this module
You can select specific modules by setting l.css = ...
, kinda like this:
l = newL()
l.css = """
#bulk > Linear: a
#bulk > #1: b
SkipBlock Sequential: c
SkipBlock LeakyReLU
"""
l.selector
ModuleSelector: root: Network bulk: Sequential 0: Linear a 1: LeakyReLU b 2: SkipBlock seq: Sequential c 0: Linear 1: LeakyReLU * 2: Linear 3: LeakyReLU * 4: Linear 5: LeakyReLU * 3: SkipBlock seq: Sequential c 0: Linear 1: LeakyReLU * 2: Linear 3: LeakyReLU * 4: Linear 5: LeakyReLU * 4: SkipBlock seq: Sequential c 0: Linear 1: LeakyReLU * 2: Linear 3: LeakyReLU * 4: Linear 5: LeakyReLU * 5: Linear a Can... - mS.deepestDepth: get deepest depth possible - mS.nn: get the underlying nn.Module object - mS.apply(f): apply to self and all descendants - "HookModule" in mS: whether this module has a specified prop - mS.highlight(prop): highlights all modules with specified prop - mS.parse([..., ...]): parses extra css - mS.directParams: get Dict[str, nn.Parameter] that are directly under this module
Essentially, you can:
Different callbacks will recognize certain props. HookModule
will hook all modules with props "all" or "HookModule". Likewise, HookParam
will hook all parameters with props "all" or "HookParam".
l.data
[<k1lib.cli.modifier.StaggeredStream at 0x7f8ab54c6ca0>, <k1lib.cli.modifier.StaggeredStream at 0x7f8ab54fbb50>]
for xb, yb in l.data[0]:
print(xb.shape, yb.shape)
break
torch.Size([32]) torch.Size([32])
It's simple, really! l.data
contains a train
and valid
data loader, and each "dispenses" a batch as usual.
Let's look at l
again:
l
l.model: Network( (bulk): Sequential( (0): Linear(in_features=1, out_features=10, bias=True) (1): LeakyReLU(negative_slope=0.01) (2): SkipBlock( (seq): Sequential( (0): Linear(in_features=10, out_features=10, bias=True) (1): LeakyReLU(negative_slope=0.01) (2): Linear(in_features=10, out_features=10, bias=True) (3): LeakyReLU(negative_slope=0.01) ..... l.opt: Adam ( Parameter Group 0 amsgrad: False betas: (0.9, 0.999) eps: 1e-08 lr: 0.01 weight_decay: 0 ) l.cbs: Callbacks: - CoreNormal - Profiler - ProgressBar - DontTrainValid - HookModule - HookParam - LossF - DType - ModifyBatch ..... Use... - l.model = ...: to specify a nn.Module object - l.data = ...: to specify data object - l.opt = ...: to specify an optimizer - l.lossF = ...: to specify a loss function - l.css = ...: to select modules using CSS. "#root" for root model - l.cbs = ...: to use a custom `Callbacks` object - l.selector: to get the modules selected by `l.css` - l.run(epochs): to run the network - l.Loss: to get a specific callback, this case "Loss"
l.model
and l.opt
is simple enough. It's just PyTorch's primitives. The part where most of the magic lies is in l.cbs
, an object of type k1lib.Callbacks
, a container object of k1lib.Callback
. Notice the final "s" in the name.
A callback is pretty simple. While training, you may want to sort of insert functionality here and there. Let's say you want the program to print out a progress bar after each epoch. You can edit the learning loop directly, with some internal variables to keep track of the current epoch and batch, like this:
startTime = time.time()
for epoch in epochs:
for batch in batches:
# do training
data = getData()
train(data)
# calculate progress
elapsedTime = time.time() - startTime
progress = round((batch / batches + epoch) / epochs * 100)
print(f"\rProgress: {progress}%, elapsed: {round(elapsedTime, 2)}s ", end="")
But this means when you don't want that functionality anymore, you have to know what internal variable belongs to the progress bar, and you have to delete it. With callbacks, things work a little bit differently:
class ProgressBar(k1lib.Callback):
def startRun(self):
pass
def startBatch(self):
self.progress = round((self.batch / self.batches + self.epoch) / self.epochs * 100)
a = f"Progress: {self.progress}%"
b = f"epoch: {self.epoch}/{self.epochs}"
c = f"batch: {self.batch}/{self.batches}"
print(f"{a}, {b}, {c}")
class Learner:
def run(self):
self.epochs = 1; self.batches = 10
self.cbs = k1lib.Callbacks()
self.cbs.append(ProgressBar())
self.cbs("startRun")
for self.epoch in self.epochs:
self.cbs("startEpoch")
for self.batch in self.batches:
self.xb, self.yb = getData()
self.cbs("startBatch")
# do training
self.y = self.model(data); self.cbs("endPass")
self.loss = self.lossF(self.y); self.cbs("endLoss")
if self.cbs("startBackward"): self.loss.backward()
self.cbs("endBatch")
self.cbs("endEpoch")
self.cbs("endRun")
This is a stripped down version of k1lib.Learner
, to get the idea across. Point is, whenever you do self.cbs("startRun")
, it will run through all k1lib.Callback
that it has (ProgressBar
in this example), check if it implements startRun
, and if yes, executes it.
Inside ProgressBar
's startBatch
, you can access learner's current epoch by doing self.learner.epoch
. But you can also do self.epoch
alone. If the attribute is not defined, then it will automatically be searched inside self.learner
.
As you can see, if you want to get rid of the progress bar without using k1lib.Callbacks
, you have to delete the startTime
line and the actual calculate progress lines. This requires you to remember which lines belongs to which functionality. If you use the k1lib.Callbacks
mechanism instead, then you can just uncomment self.cbs.append(ProgressBar())
, and that's it. This makes swapping out components extremely easy, repeatable, and beautiful.
Other use cases include intercepting at startBatch
, and push all the training data to the GPU. You can also reshape the data however you want. You can insert different loss mechanisms (endLoss
) in addition to lossF
, or quickly inspect the model output. You can also change learning rates while training (startEpoch
) according to some schedules. The possibility are literally endless.