flex_model.traverse.unflatten

flex_model.traverse.unflatten(root_node: InternalNode | LeafNode | Any, leaves: List[Tensor | None]) Any

Repack a tree definition and list of leaves into the original python object.

Parameters:
  • root_node (Union[InternalNode, LeafNode, ScalarNode], leaves: List[Optional[Tensor]]) – Root node which defines the tree definition of the python object.

  • leaves (List[Optional[Tensor]]) – List of leaf nodes.

Returns:

The reconstructed python objects.

Return type:

Any