k1lib.selector module

This module is for selecting a subnetwork using CSS so that you can do special things to them. Checkout the tutorial section for a walkthrough. This is exposed automatically with:

from k1lib.imports import *
selector.select # exposed
class k1lib.selector.ModuleSelector(parent: ModuleSelector, name: str, nn: Module)[source]

Bases: object

nn: torch.nn.Module

The associated torch.nn.Module of this ModuleSelector

props: List[str]

Properties of this ModuleSelector

idx: int

Unique id of this ModuleSelector in the entire script. May be useful for module recognition

property deepestDepth

Deepest depth of the tree. If self doesn’t have any child, then depth is 0

highlight(prop: str)[source]

Highlights the specified prop when displaying the object.

__call__(*args, **kwargs)[source]

Calls the internal torch.nn.Module

__contains__(prop: Optional[str] = None) bool[source]

Whether this ModuleSelector has a specific prop. Example:

# returns True
"b" in nn.Linear(3, 4).select("*:b")
# returns False
"h" in nn.Linear(3, 4).select("*:b")
# returns True, "*" here means the ModuleSelector has any properties at all
"*" in nn.Linear(3, 4).select("*:b")
named_children(prop: Optional[str] = None) Iterator[Tuple[str, ModuleSelector]][source]

Get all named direct childs.

Parameters

prop – Filter property. See also: __contains__()

children(prop: Optional[str] = None) Iterator[ModuleSelector][source]

Get all direct childs.

Parameters

prop – Filter property. See also: __contains__()

named_modules(prop: Optional[str] = None) Iterator[Tuple[str, ModuleSelector]][source]

Get all named child recursively. Example:

modules = list(nn.Sequential(nn.Linear(3, 4), nn.ReLU()).select().named_modules())
# return 3
len(modules)
# return tuple ('0', <ModuleSelector of Linear>)
modules[1]
Parameters

prop – Filter property. See also: __contains__()

modules(prop: Optional[str] = None) Iterator[ModuleSelector][source]

Get all child recursively.

Parameters

prop – Filter property. See also: __contains__()

property directParams: Dict[str, Parameter]

Dict params directly under this module

parse(selectors: Union[List[str], str]) ModuleSelector[source]

Parses extra selectors. Clears all old selectors, but retain the props. Returns self. Example:

mS = selector.ModuleSelector.sample().parse("Conv2d:propA")
# returns True
"propA" in mS[1][0]
Parameters

selectors – can be the preprocessed list, or the unprocessed css string

apply(f: Callable[[ModuleSelector], None])[source]

Applies a function to self and all child ModuleSelector

clearProps() ModuleSelector[source]

Clears all existing props of this and all descendants ModuleSelector. Example:

# returns False
"b" in nn.Linear(3, 4).select("*:b").clearProps()
property displayF

Function to display each ModuleSelector’s lines. Default is just:

lambda mS: ", ".join(mS.props) 
static sample() ModuleSelector[source]

Create a new example ModuleSelector that has a bit of hierarchy to them, with no css.

hookF(f: Callable[[ModuleSelector, Module, Tuple[Tensor], Tensor], None] = None, prop: str = '*')[source]

Context manager for applying forward hooks. Example:

def f(mS, i, o):
    print(i, o)

m = nn.Linear(3, 4)
with m.select().hookF(f):
    m(torch.randn(2, 3))
Parameters
  • f – hook callback, should accept ModuleSelector, inputs and output

  • prop – filter property of module to hook onto. If not specified, then it will print out input and output tensor shapes.

hookFp(f=None, prop: str = '*')[source]

Context manager for applying forward pre hooks. Example:

def f(mS, i):
    print(i)

m = nn.Linear(3, 4)
with m.select().hookFp(f):
    m(torch.randn(2, 3))
Parameters
  • f – hook callback, should accept ModuleSelector and inputs

  • prop – filter property of module to hook onto. If not specified, then it will print out input tensor shapes.

hookB(f=None, prop: str = '*')[source]

Context manager for applying backward hooks. Example:

def f(mS, i, o):
    print(i, o)

m = nn.Linear(3, 4)
with m.select().hookB(f):
    m(torch.randn(2, 3)).sum().backward()
Parameters
  • f – hook callback, should accept ModuleSelector, grad inputs and outputs

  • prop – filter property of module to hook onto. If not specified, then it will print out input tensor shapes.

freeze(prop: str = '*')[source]

Returns a context manager that freezes (set requires_grad to False) parts of the network. Example:

l = k1lib.Learner.sample()
w = l.model.lin1.lin.weight.clone() # weights before
with l.model.select("#lin1").freeze():
    l.run(1)
# returns True
(l.model.lin1.lin.weight ==  w).all()
unfreeze(prop: str = '*')[source]

Returns a context manager that unfreezes (set requires_grad to True) parts of the network. Example:

l = k1lib.Learner.sample()
w = l.model.lin1.lin.weight.clone() # weights before
with l.model.select("#lin1").freeze():
    with l.model.select("#lin1 > #lin").unfreeze():
        l.run(1)
# returns False
(l.model.lin1.lin.weight ==  w).all()
cutOff() Module

Creates a new network that returns the selected layer’s output. Example:

xb = torch.randn(10, 2)
m = nn.Sequential(nn.Linear(2, 5), nn.Linear(5, 4), nn.Linear(4, 6))
m0 = m.select("#0").cutOff(); m1 = m.select("#1").cutOff()
# returns (10, 6)
m(xb).shape
# returns (10, 5)
m0(xb).shape == torch.Size([10, 5])
# returns (10, 4)
m1(xb).shape == torch.Size([10, 4])
intercept()

Returns a context manager that intercept forward and backward signals to parts of the network. Example:

l = k1lib.Learner.sample()
with l.model.select("#lin1").intercept() as d:
    l.run(2)
# returns (1, 2, 600, 2, 1, 32, 1), or (#selected modules, [forward, backward], #steps, [input, output], actual data)
d | shape()

See also: hookF(), hookFp(), hookB()

k1lib.selector.preprocess(selectors: str, defaultProp='*') List[str][source]

Removes all quirkly features allowed by the css language, and outputs nice lines. Example:

# returns ["a:f", "a:g,h", "b:g,h", "t:*"]
selector.preprocess("a:f; a, b: g,h; t")
Parameters
  • selectors – single css selector string. Statements separated by “\n” or “;”

  • defaultProp – default property, if statement doesn’t have one

k1lib.selector.select(model: Module, css: str = '*') ModuleSelector[source]

Creates a new ModuleSelector, in sync with a model. Example:

mS = selector.select(nn.Linear(3, 4), "#root:propA")

Or, you can do it the more direct way:

mS = nn.Linear(3, 4).select("#root:propA")
Parameters
  • model – the torch.nn.Module object to select from

  • css – the css selectors