k1lib.selector module
This module is for selecting a subnetwork using CSS so that you can do special things to them. Checkout the tutorial section for a walkthrough. This is exposed automatically with:
from k1lib.imports import *
selector.select # exposed
- class k1lib.selector.ModuleSelector(parent: ModuleSelector, name: str, nn: Module)[source]
- Bases: - object- nn: torch.nn.Module
- The associated - torch.nn.Moduleof this- ModuleSelector
 - props: List[str]
- Properties of this - ModuleSelector
 - idx: int
- Unique id of this - ModuleSelectorin the entire script. May be useful for module recognition
 - property deepestDepth
- Deepest depth of the tree. If self doesn’t have any child, then depth is 0 
 - __call__(*args, **kwargs)[source]
- Calls the internal - torch.nn.Module
 - __contains__(prop: str = None) bool[source]
- Whether this - ModuleSelectorhas a specific prop. Example:- # returns True "b" in nn.Linear(3, 4).select("*:b") # returns False "h" in nn.Linear(3, 4).select("*:b") # returns True, "*" here means the ModuleSelector has any properties at all "*" in nn.Linear(3, 4).select("*:b") 
 - named_children(prop: str = None) Iterator[Tuple[str, ModuleSelector]][source]
- Get all named direct childs. - Parameters:
- prop – Filter property. See also: - __contains__()
 
 - children(prop: str = None) Iterator[ModuleSelector][source]
- Get all direct childs. - Parameters:
- prop – Filter property. See also: - __contains__()
 
 - named_modules(prop: str = None) Iterator[Tuple[str, ModuleSelector]][source]
- Get all named child recursively. Example: - modules = list(nn.Sequential(nn.Linear(3, 4), nn.ReLU()).select().named_modules()) # return 3 len(modules) # return tuple ('0', <ModuleSelector of Linear>) modules[1] - Parameters:
- prop – Filter property. See also: - __contains__()
 
 - modules(prop: str = None) Iterator[ModuleSelector][source]
- Get all child recursively. - Parameters:
- prop – Filter property. See also: - __contains__()
 
 - parse(selectors: List[str] | str) ModuleSelector[source]
- Parses extra selectors. Clears all old selectors, but retain the props. Returns self. Example: - mS = selector.ModuleSelector.sample().parse("Conv2d:propA") # returns True "propA" in mS[1][0] - Parameters:
- selectors – can be the preprocessed list, or the unprocessed css string 
 
 - apply(f: Callable[[ModuleSelector], None])[source]
- Applies a function to self and all child - ModuleSelector
 - clearProps() ModuleSelector[source]
- Clears all existing props of this and all descendants - ModuleSelector. Example:- # returns False "b" in nn.Linear(3, 4).select("*:b").clearProps() 
 - property displayF
- Function to display each ModuleSelector’s lines. Default is just: - lambda mS: ", ".join(mS.props) 
 - static sample() ModuleSelector[source]
- Create a new example - ModuleSelectorthat has a bit of hierarchy to them, with no css.
 - hookF(f: Callable[[ModuleSelector, Module, Tuple[Tensor], Tensor], None] = None, prop: str = '*')[source]
- Context manager for applying forward hooks. Example: - def f(mS, i, o): print(i, o) m = nn.Linear(3, 4) with m.select().hookF(f): m(torch.randn(2, 3)) - Parameters:
- f – hook callback, should accept - ModuleSelector, inputs and output
- prop – filter property of module to hook onto. If not specified, then it will print out input and output tensor shapes. 
 
 
 - hookFp(f=None, prop: str = '*')[source]
- Context manager for applying forward pre hooks. Example: - def f(mS, i): print(i) m = nn.Linear(3, 4) with m.select().hookFp(f): m(torch.randn(2, 3)) - Parameters:
- f – hook callback, should accept - ModuleSelectorand inputs
- prop – filter property of module to hook onto. If not specified, then it will print out input tensor shapes. 
 
 
 - hookB(f=None, prop: str = '*')[source]
- Context manager for applying backward hooks. Example: - def f(mS, i, o): print(i, o) m = nn.Linear(3, 4) with m.select().hookB(f): m(torch.randn(2, 3)).sum().backward() - Parameters:
- f – hook callback, should accept - ModuleSelector, grad inputs and outputs
- prop – filter property of module to hook onto. If not specified, then it will print out input tensor shapes. 
 
 
 - freeze(prop: str = '*')[source]
- Returns a context manager that freezes (set requires_grad to False) parts of the network. Example: - l = k1lib.Learner.sample() w = l.model.lin1.lin.weight.clone() # weights before with l.model.select("#lin1").freeze(): l.run(1) # returns True (l.model.lin1.lin.weight == w).all() 
 - unfreeze(prop: str = '*')[source]
- Returns a context manager that unfreezes (set requires_grad to True) parts of the network. Example: - l = k1lib.Learner.sample() w = l.model.lin1.lin.weight.clone() # weights before with l.model.select("#lin1").freeze(): with l.model.select("#lin1 > #lin").unfreeze(): l.run(1) # returns False (l.model.lin1.lin.weight == w).all() 
 - cutOff() Module
- Creates a new network that returns the selected layer’s output. Example: - xb = torch.randn(10, 2) m = nn.Sequential(nn.Linear(2, 5), nn.Linear(5, 4), nn.Linear(4, 6)) m0 = m.select("#0").cutOff(); m1 = m.select("#1").cutOff() # returns (10, 6) m(xb).shape # returns (10, 5) m0(xb).shape == torch.Size([10, 5]) # returns (10, 4) m1(xb).shape == torch.Size([10, 4]) 
 - intercept()
- Returns a context manager that intercept forward and backward signals to parts of the network. Example: - l = k1lib.Learner.sample() with l.model.select("#lin1").intercept() as d: l.run(2) # returns (1, 2, 600, 2, 1, 32, 1), or (#selected modules, [forward, backward], #steps, [input, output], actual data) d | shape() 
 
- k1lib.selector.preprocess(selectors: str, defaultProp='*') List[str][source]
- Removes all quirkly features allowed by the css language, and outputs nice lines. Example: - # returns ["a:f", "a:g,h", "b:g,h", "t:*"] selector.preprocess("a:f; a, b: g,h; t") - Parameters:
- selectors – single css selector string. Statements separated by “\n” or “;” 
- defaultProp – default property, if statement doesn’t have one 
 
 
- k1lib.selector.select(model: Module, css: str = '*') ModuleSelector[source]
- Creates a new ModuleSelector, in sync with a model. Example: - mS = selector.select(nn.Linear(3, 4), "#root:propA") - Or, you can do it the more direct way: - mS = nn.Linear(3, 4).select("#root:propA") - Parameters:
- model – the - torch.nn.Moduleobject to select from
- css – the css selectors