k1lib.knn module
Some nice utils to complement torch.nn
. This is exposed automatically
with:
from k1lib.imports import *
knn.Lambda # exposed
- class k1lib.knn.Lambda(f: Callable[[Any], Any])[source]
Bases:
Module
- __init__(f: Callable[[Any], Any])[source]
Creates a simple module with a specified forward function.
- forward(x)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class k1lib.knn.Identity[source]
Bases:
Lambda
Creates a module that returns the input in forward function.
- class k1lib.knn.LinBlock(inC, outC)[source]
Bases:
Module
- forward(x)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class k1lib.knn.MultiheadAttention(qdim, kdim, vdim, embed, head=4, outdim=None)[source]
Bases:
Module
- __init__(qdim, kdim, vdim, embed, head=4, outdim=None)[source]
Kinda like
torch.nn.MultiheadAttention
, just simpler, shorter, and clearer. Probably not as fast as the official version, and doesn’t have masks and whatnot, but easy to read! Example:xb = torch.randn(14, 32, 35) # (S, N, ), or sequence size 14, batch size 32, feature size 35 # returns torch.Size([14, 32, 50]) MultiheadAttention(35, 35, 35, 9, 4, 50)(xb).shape
Although you can use this right away with no mods, I really encourage you to copy and paste the source code of this and modify it to your needs.
- Parameters
qdim – Basic query, key and value dimensions
embed – a little different from
torch.nn.MultiheadAttention
, as this is after splitting into headsoutdim – if not specified, then equals to
embed * head