k1lib.selector module
This module is for selecting a subnetwork using CSS so that you can do special things to them. Checkout the tutorial section for a walkthrough. This is exposed automatically with:
from k1lib.imports import *
selector.select # exposed
- class k1lib.selector.ModuleSelector(parent: ModuleSelector, name: str, nn: Module)[source]
Bases:
object
- nn: torch.nn.Module
The associated
torch.nn.Module
of thisModuleSelector
- props: List[str]
Properties of this
ModuleSelector
- idx: int
Unique id of this
ModuleSelector
in the entire script. May be useful for module recognition
- property deepestDepth
Deepest depth of the tree. If self doesn’t have any child, then depth is 0
- __call__(*args, **kwargs)[source]
Calls the internal
torch.nn.Module
- __contains__(prop: Optional[str] = None) bool [source]
Whether this
ModuleSelector
has a specific prop. Example:# returns True "b" in nn.Linear(3, 4).select("*:b") # returns False "h" in nn.Linear(3, 4).select("*:b") # returns True, "*" here means the ModuleSelector has any properties at all "*" in nn.Linear(3, 4).select("*:b")
- named_children(prop: Optional[str] = None) Iterator[Tuple[str, ModuleSelector]] [source]
Get all named direct childs.
- Parameters
prop – Filter property. See also:
__contains__()
- children(prop: Optional[str] = None) Iterator[ModuleSelector] [source]
Get all direct childs.
- Parameters
prop – Filter property. See also:
__contains__()
- named_modules(prop: Optional[str] = None) Iterator[Tuple[str, ModuleSelector]] [source]
Get all named child recursively. Example:
modules = list(nn.Sequential(nn.Linear(3, 4), nn.ReLU()).select().named_modules()) # return 3 len(modules) # return tuple ('0', <ModuleSelector of Linear>) modules[1]
- Parameters
prop – Filter property. See also:
__contains__()
- modules(prop: Optional[str] = None) Iterator[ModuleSelector] [source]
Get all child recursively.
- Parameters
prop – Filter property. See also:
__contains__()
- parse(selectors: Union[List[str], str]) ModuleSelector [source]
Parses extra selectors. Clears all old selectors, but retain the props. Returns self. Example:
mS = selector.ModuleSelector.sample().parse("Conv2d:propA") # returns True "propA" in mS[1][0]
- Parameters
selectors – can be the preprocessed list, or the unprocessed css string
- apply(f: Callable[[ModuleSelector], None])[source]
Applies a function to self and all child
ModuleSelector
- clearProps() ModuleSelector [source]
Clears all existing props of this and all descendants
ModuleSelector
. Example:# returns False "b" in nn.Linear(3, 4).select("*:b").clearProps()
- property displayF
Function to display each ModuleSelector’s lines. Default is just:
lambda mS: ", ".join(mS.props)
- static sample() ModuleSelector [source]
Create a new example
ModuleSelector
that has a bit of hierarchy to them, with no css.
- hookF(f: Callable[[ModuleSelector, Module, Tuple[Tensor], Tensor], None] = None, prop: str = '*')[source]
Context manager for applying forward hooks. Example:
def f(mS, i, o): print(i, o) m = nn.Linear(3, 4) with m.select().hookF(f): m(torch.randn(2, 3))
- Parameters
f – hook callback, should accept
ModuleSelector
, inputs and outputprop – filter property of module to hook onto. If not specified, then it will print out input and output tensor shapes.
- hookFp(f=None, prop: str = '*')[source]
Context manager for applying forward pre hooks. Example:
def f(mS, i): print(i) m = nn.Linear(3, 4) with m.select().hookFp(f): m(torch.randn(2, 3))
- Parameters
f – hook callback, should accept
ModuleSelector
and inputsprop – filter property of module to hook onto. If not specified, then it will print out input tensor shapes.
- hookB(f=None, prop: str = '*')[source]
Context manager for applying backward hooks. Example:
def f(mS, i, o): print(i, o) m = nn.Linear(3, 4) with m.select().hookB(f): m(torch.randn(2, 3)).sum().backward()
- Parameters
f – hook callback, should accept
ModuleSelector
, grad inputs and outputsprop – filter property of module to hook onto. If not specified, then it will print out input tensor shapes.
- freeze(prop: str = '*')[source]
Returns a context manager that freezes (set requires_grad to False) parts of the network. Example:
l = k1lib.Learner.sample() w = l.model.lin1.lin.weight.clone() # weights before with l.model.select("#lin1").freeze(): l.run(1) # returns True (l.model.lin1.lin.weight == w).all()
- unfreeze(prop: str = '*')[source]
Returns a context manager that unfreezes (set requires_grad to True) parts of the network. Example:
l = k1lib.Learner.sample() w = l.model.lin1.lin.weight.clone() # weights before with l.model.select("#lin1").freeze(): with l.model.select("#lin1 > #lin").unfreeze(): l.run(1) # returns False (l.model.lin1.lin.weight == w).all()
- cutOff() Module
Creates a new network that returns the selected layer’s output. Example:
xb = torch.randn(10, 2) m = nn.Sequential(nn.Linear(2, 5), nn.Linear(5, 4), nn.Linear(4, 6)) m0 = m.select("#0").cutOff(); m1 = m.select("#1").cutOff() # returns (10, 6) m(xb).shape # returns (10, 5) m0(xb).shape == torch.Size([10, 5]) # returns (10, 4) m1(xb).shape == torch.Size([10, 4])
- intercept()
Returns a context manager that intercept forward and backward signals to parts of the network. Example:
l = k1lib.Learner.sample() with l.model.select("#lin1").intercept() as d: l.run(2) # returns (1, 2, 600, 2, 1, 32, 1), or (#selected modules, [forward, backward], #steps, [input, output], actual data) d | shape()
- k1lib.selector.preprocess(selectors: str, defaultProp='*') List[str] [source]
Removes all quirkly features allowed by the css language, and outputs nice lines. Example:
# returns ["a:f", "a:g,h", "b:g,h", "t:*"] selector.preprocess("a:f; a, b: g,h; t")
- Parameters
selectors – single css selector string. Statements separated by “\n” or “;”
defaultProp – default property, if statement doesn’t have one
- k1lib.selector.select(model: Module, css: str = '*') ModuleSelector [source]
Creates a new ModuleSelector, in sync with a model. Example:
mS = selector.select(nn.Linear(3, 4), "#root:propA")
Or, you can do it the more direct way:
mS = nn.Linear(3, 4).select("#root:propA")
- Parameters
model – the
torch.nn.Module
object to select fromcss – the css selectors