Skip to main content

Command Palette

Search for a command to run...

Controlled Branching Processes: A Case Study in Structured Sampling

Updated
9 min read

Background & Problem Statement

Earlier this week, we were developing a random AST sampler for a toy language, which is a small subset of a folklore language, IMP. It has a really simple syntax:

syntax AExp ::= Int | Id
              | AExp "/" AExp
              > AExp "+" AExp
              | "(" AExp ")"
syntax Program ::= AExp

Simply, the syntax of a program contains only an arithmetic expression denoted as AExp. And an AExp node in the tiny AST could be an integer Int, an identifier Id, a division operator, an addition operator or a bracketed expression. The recursive definition in the operators enables the ASTs to expand in width and depth.

In our implementation, we create a concrete node type and a corresponding mask node type for each production rule in the syntax. This representation allows us to hide (partial) semantic information from a real-world or synthesized program and thus plays an important role in the forward/reverse process of a graph edit neural network.

To synthesize random samples for the language, we employ a top-down generation strategy, where we always start with a program with an [AExpMask] as its body. Next, we perform two types of actions to the nodes, to grow the AST into a semantically complete program without any masks:

  1. Mask-Down. As we have described above, the node types are defined as a hierarchical structure, where the abstract type AExp is inherited by other expression types. Hence, an [AExpMask] node can be replaced with an [IntMask], an [IdMask], a [DivExpMask], an [AddExpMask], or a [BrAExpMask].

  2. Unmask. A mask for a concrete expression type can be replaced with an instance of the concrete expression node. For terminal types in the language definition, i.e., Int and Id, we replace a mask with a node with a random sampled numeric value or identifier name. For non-terminal types, we create an empty node of the concrete type, with each operand, i.e., the nested expression, set as an [AExpMask].

We apply the actions to the masks in the AST until there remains none. In fact, the two actions can also be considered as one, which iteratively converts an [AExpMask] to a concrete expression node and applies the production rules to grow the tree. Between the two types of actions, it can be easily seen that Mask-Down controls the depth and width of the tree. To keep the sampled programs in a reasonable scale, the sampler performs the action based on a pre-defined set of weights.

To simplify the problem, we only Mask-Down an [AExpMask] to an [IntMask], an [IdMask], a [DivExpMask] or an [AddExpMask], where the former two are terminal nodes in an AST and do not have offsprings, and the latter two are non-terminal nodes with exactly 2 offsprings. After the generation of an AST, we review its structure and add brackets to certain subtrees if necessary. Now we wonder: how should we set the weights for the Mask-Down action to gain a nice AST? In other words, we would like to figure out how the distribution of concrete nodes affects the scale of the resulting ASTs.

With the above simplification, we can formalize the problem as:

A binary tree contains two types of nodes: NT and T, where each NT node has exactly two offsprings and each T node is a leaf node. Now, with a given distribution of the occurrence of NT and T, how can we compute the probability distribution of the the resulting trees’ depths?

From the Branching Process Perspective

The above problem is actually a Galton-Watson branching process, which can be expressed as follows:

  1. The process starts with a single ancestor in the initial generation (root node of the binary tree

$$Z_0=1$$

  1. The number of offspring for any given individual, denoted by the random variable $X$, follows the specific probability mass function \(p_k=P(X=k)\):

$$\begin{cases} p_0 = P(X=0) = p, \\ p_2 = P(X=2) = 1-p, \\ p_k = P(X=k) = 0, \text{for} \ k \in \{1, 3, 4, \dots\} \end{cases}$$

The number of descendants are independent and identically distributed (i.i.d.) for all individuals in all generations.

  1. The size of the \((n+1)\)-th generation, \(Z_{n+1}\), is determined by the sum of the offspring produces by all individuals in the $n$-th generation, \(Z_n\). Formally, if we let \(X_{n,i}\) be the number of offspring produced by the $i$-th individual in generation \(Z_n\), then we have

$$Z_{n+1}=\sum_{i=1}^{Z_n}X_{n+1,i},$$

where \(\{X_{n+1,i}\}\) is a sequence of i.i.d. random variables with the common offspring distribution \(P(X=k)\) defined above, and they are independent of \(Z_0, Z_1, \dots, Z_n\).

  1. The mean number of offspring per individual, \(\mu\), is a crucial parameter for determining the long-term behavior of the process, where we have

$$\mu=E[X]=\sum_{k=0}^\infty{k \cdot p_k} = 0 \cdot p + 2 \cdot (1 - p) = 2 (1 - p).$$

This process is a discrete-time Markov chain on the state space of non-negative integers \(\mathbb{Z}_{\geq 0}\), with the state \(Z_n=0\) being an absorbing state (extinction).

Extinction Time aka Depth of Trees

Now we try to compute the extinction time of the branching process, i.e., the depth of the binary tree. The extinction of the Galton-Watson process is defined as the first generation $N$ where the population size drops to zero. Formally, we have that

$$N=\inf\{n \mid Z_n = 0\},$$

where \(\inf\emptyset = \infty\). If the population never becomes extinct, i.e., for all $n$ we have \(Z_n>0\), then \(N=\infty\).

The overall probability of ultimate extinction, \(\pi\), is the probability that the extinction time is finite:

$$\pi = P(N<\infty) = P(\exists n \geq 1 \ . Z_n = 0).$$

For the specific branching process under our settings, the value of \(\pi\) is the smallest non-negative root of the equation

$$s=G(s),$$

where $G(s)$ is the probability generating function (PGF) of the offspring distribution. More specifically, our PGF for the offspring is

$$G(s)=E[s^X]=\sum_{k=0}^\infty{s^kp_k}=s^0 \cdot p + s^2 \cdot (1 - p) = p + (1-p)s^2.$$

The extinction probability \(\pi\) is the smallest non-negative solution to

$$\pi = p + (1-p) \pi^2.$$

In general, calculating \(P(N=k)\) requires iterating the PGF. Let \(G_k(s)\) be the PGF for \(Z_k\), and we have

$$G_k(s)=E[s^{Z_k}]=G(G_{k-1}(s)),$$

for each $k>1$. And for \(k=1\), we have that \(Z_0=1\). Hence, \(G_0(s)=E[s^1]=s\).

With the definition of the PGFs for each generation, we have the probability of extinction by generation $k$:

$$P(N \leq k) = P(Z_k=0) = G_k(0).$$

And the probability of extinction at generation $k$ should be:

$$P(N=k)=P(Z_k=0) - P(Z_{k-1}=0) = G_k(0) - G_{k-1}(0).$$

We calculated each \(P(N=k)\) for each \(k \in [1, 40]\) under different distributions of $X$, where $p$ is set to one of the values in \(\{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9\}\). The figure below shows the distribution of \(P(N=k)\) and that of \(P(N \leq k)\), as well as the ultimate distinction probability \(\pi\).

From the figure, we can see that when $p<0.5$, i.e., \(\mu=2(1-p)>1\), it is possible that the extinction never occurs on the population. This corresponds to the supercritical case in a branching process. Hence, to always sample an AST with a finite depth, we should set the limit that \(p \geq 0.5\).

The probability mass function \(P(N=k)\) graph shows that for all considered values of $p$, the distribution is heavily concentrated towards the left, resulting in a rapidly decaying curve. Specifically, a very high proportion of the extinction events occur within the first few generations. This rapid decay means that the extinction time $N$ is typically too short, i.e., the sampled ASTs are too shallow. An ideal distribution where the probability mass is spread out more evenly over a moderate range of generations to achieve an intermediate extinction time, but the current process’s inherent tendency is to extinguish almost immediately or survive indefinitely. In summary, the distribution of $N$ is too steep, and we need a process that results in a slower decay of \(P(N=k)\) to shift the probability mass to higher generation numbers.

Going Controlled: A Non-Homogeneous Alternative

Okay. So how should we control the branching process? Ideally, we want the tree (population) to expand in width in the first few generations, and as it grows, we hope that the probability of expansion decreases. To achieve such desired behavior, we introduce a non-homogeneous Galton-Watson branching process. Instead of having a fixed offspring distribution \(P(X=k)\) for all generations, the probability of an individual leaving zero offspring changes over time. In our binary branching model, where an individual either produces 0 or 2 offspring, we define the probability of leaving zero offspring, \(p_k=P(X_k=0)\), as a smooth function of the generation number $k$. We propose the following functional form of the extinction probability:

$$p_k=\tanh(t \cdot k),$$

where $k$ is the current generation index and $t$ is a positive constant that acts as a tuning parameter that controls the rate at which \(p_k\) increases. This leads to the offspring probability PGF for generation $k$ turning into

$$G_k(s)=p_k + (1 - p_k)s^2.$$

And since now we are dealing with a non-homogeneous process, the iterative production of the PGF for the population size \(Z_n\) at the $n$-th generation is no longer a simple repetitive composition of the same PGF for each generation. Instead, we have the iterative relation between \(G^{(n)}(s)=E[s^{Z_n}]\) and \(G^{(n-1)}(s)=E[s^{Z_{n-1}}]\):

$$\begin{align} G^{(n)}(s) &= E[s^{Z_n}] = E[E[s^{Z_n}\mid Z_{n-1}]] = E[E[s^{\sum_{i=1}^{Z_{n-1}}{X_{n,i}}}]] \\ &= E[(E[s^{X_n}])^{Z_{n-1}}] = E[(G_n(s))^{Z_{n-1}}] = G^{(n-1)}(G_n(s)). \end{align}$$

Similarly to the homogeneous branching process, we also have \(G^{(0)}(s)=s\), since \(Z_0=1\). And for each \(k\geq1\), we have that

$$G^{(n)}(s) = G_1(G_2(G_3(\dots G_n(s)\dots))).$$

With the definition of \(G^{(n)}(s)\), we have the probability of extinction by generation $k$

$$P(N \leq k) = P(Z_k = 0) = G^{(n)}(0),$$

and the probability of extinction at generation $k$

$$P(N=k)=P(Z_k=0)-P(Z_{k-1}=0) = G^{(n)}(0) - G^{(n-1)}(0).$$

With the constant parameter $t$ selected from \(\{0.1, 0.2, 0.3, 0.4, 0.5, 1.0\}\), we compute the probability distribution of the extinction time $N$. See the following figure.

From the left graph, it can be seen that the distribution is more concentrated around a moderate number of generations when $t<0.5$. From the right graph, we observe that no matter what value is assigned to $t$, the ultimate extinction probability is always $1$. Hence, using \(\tanh\) in the offspring distribution is an appropriate choice for the binary branching process in our case.

Summary

In this blog, the problem of sampling random ASTs for a custom language was modeled as a branching process to control the resulting tree depth. The AST generation uses two actions: Mask-Down (growing the tree depth/width, corresponding to producing 2 offspring, NT) and Unmask (terminating a branch, corresponding to 0 offspring, T). The initial attempt used a homogeneous Galton-Watson branching process where the probability of terminating a branch, $p$, was constant across all generations. The offspring distribution was \(P(X=0)=p\) and \(P(X=2)=1-p\). The analysis showed that to guarantee a finite-depth AST, the mean number of offspring \(\mu = 2(1-p)\) must be \(\leq 1\), requiring \(p \geq 0.5\). However, this homogeneous process resulted in a tree depth probability distribution \(P(N=k)\) that decayed too rapidly, meaning most sampled ASTs were very shallow (small depth $N$). To achieve a more desirable, moderate AST depth, a non-homogeneous Galton-Watson branching process was introduced. In this model, the termination probability is no longer constant but changes with the generation index $k$: \(p_k = P(X_k=0) = \tanh(t \cdot k)\), where $t$ is a tuning parameter. This function ensures that in early generations (small $k$), \(p_k\) is low, encouraging the tree to grow wide, and as the tree gets deeper (large $k$), \(p_k\) increases, making termination more likely. The new process utilizes the iterative PGF composition \(G^{(n)}(s) = G^{(n-1)}(G_n(s))\) to calculate \(P(N=k) = G^{(n)}(0) - G^{(n-1)}(0)\). The results confirmed that the non-homogeneous process, particularly for small values of $t$,e.g., $t<0.5$, spreads the probability mass over a larger range of depths, thus sampling ASTs with more moderate depths. Furthermore, the \(\tanh\) function guarantees that the ultimate extinction probability is $1$, meaning all generated ASTs will eventually have a finite depth. And via this approach, we eventually arrive at a controlled branching process.