k1lib.knn module

Some nice utils to complement torch.nn. This is exposed automatically with:

from k1lib.imports import *
knn.Lambda # exposed
class k1lib.knn.Lambda(f: Callable[[Any], Any])[source]

Bases: Module

__init__(f: Callable[[Any], Any])[source]

Creates a simple module with a specified forward function.

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class k1lib.knn.Identity[source]

Bases: Lambda

Creates a module that returns the input in forward function.

training: bool
class k1lib.knn.LinBlock(inC, outC)[source]

Bases: Module

__init__(inC, outC)[source]

Linear layer with relu behind it

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class k1lib.knn.MultiheadAttention(qdim, kdim, vdim, embed, head=4, outdim=None)[source]

Bases: Module

__init__(qdim, kdim, vdim, embed, head=4, outdim=None)[source]

Kinda like torch.nn.MultiheadAttention, just simpler, shorter, and clearer. Probably not as fast as the official version, and doesn’t have masks and whatnot, but easy to read! Example:

xb = torch.randn(14, 32, 35) # (S, N, ), or sequence size 14, batch size 32, feature size 35
# returns torch.Size([14, 32, 50])
MultiheadAttention(35, 35, 35, 9, 4, 50)(xb).shape

Although you can use this right away with no mods, I really encourage you to copy and paste the source code of this and modify it to your needs.

Parameters
  • qdim – Basic query, key and value dimensions

  • embed – a little different from torch.nn.MultiheadAttention, as this is after splitting into heads

  • outdim – if not specified, then equals to embed * head

forward(query, key=None, value=None)[source]

If key or value is not specified, just default to query.

training: bool