GPTree#
Module contents#
A decision tree classifier employing LLMs for dynamic feature generation.
Each node uses language models to generate contextual questions and evaluate answers, enabling adaptive tree construction for classification tasks.
- class GPTree(qgen_llmc, critic_llmc, qgen_instr_llmc, qanswer_llmc=None, qgen_temperature=0.0, critic_temperature=0.0, qgen_instr_gen_temperature=0.0, qanswer_temperature=0.0, criterion='gini', max_depth=None, max_node_width=3, min_samples_leaf=1, llm_semaphore_limit=5, min_question_candidates=3, max_question_candidates=10, expert_advice=None, n_samples_as_context=30, class_ratio='balanced', use_critic=False, save_path=None, name=None, random_state=None)#
Bases:
object
LLM based decision tree classifier.
Note that GPTree auto saves the tree after each node is built.
- Parameters:
qgen_llmc (
List
[Union
[AnthropicChoice
,GoogleChoice
,OpenAIChoice
,XAIChoice
,AnthropicChoiceDict
,GoogleChoiceDict
,OpenAIChoiceDict
,XAIChoiceDict
]]) – LLMs to use for question generation, in priority order.critic_llmc (
List
[Union
[AnthropicChoice
,GoogleChoice
,OpenAIChoice
,XAIChoice
,AnthropicChoiceDict
,GoogleChoiceDict
,OpenAIChoiceDict
,XAIChoiceDict
]]) – LLMs to use for question critique, in priority order.qgen_instr_llmc (
List
[Union
[AnthropicChoice
,GoogleChoice
,OpenAIChoice
,XAIChoice
,AnthropicChoiceDict
,GoogleChoiceDict
,OpenAIChoiceDict
,XAIChoiceDict
]]) – LLMs for generating instructions.qanswer_llmc (
Optional
[List
[Union
[AnthropicChoice
,GoogleChoice
,OpenAIChoice
,XAIChoice
,AnthropicChoiceDict
,GoogleChoiceDict
,OpenAIChoiceDict
,XAIChoiceDict
]]]) – LLMs to use for answering questions, in priority order. If None, use qgen_llmc.qgen_temperature (
float
) – Sampling temperature for question generation.critic_temperature (
float
) – Sampling temperature for critique.qgen_instr_gen_temperature (
float
) – Sampling temperature for generating instructions.qanswer_temperature (
float
) – Sampling temperature for answering questions.criterion (
Literal
['gini'
]) – Splitting criterion. Currently only “gini”.max_depth (
int
|None
) – Maximum tree depth. If None, grow until pure/min samples.max_node_width (
int
) – Maximum children per node.min_samples_leaf (
int
) – Minimum samples per leaf.llm_semaphore_limit (
int
) – Max concurrent LLM calls.min_question_candidates (
int
) – Min number of questions per node.max_question_candidates (
int
) – Max number of questions per node. Max 15expert_advice (
str
|None
) – Human-provided hints for generation.n_samples_as_context (
int
) – Number of samples used as context in generation.class_ratio (
Union
[Dict
[str
,int
],Literal
['balanced'
]]) – Strategy for class sampling (“balanced” or dict of ratios).use_critic (
bool
) – Whether to critique generated questions.save_path (
str
|PathLike
[str
] |None
) – Directory to save checkpoints/models.
- classmethod load(path)#
Load a GPTree from saved state.
- advice(advice)#
Set context/advice for question generations.
- async fit(X=None, y=None, *, copy_data=True, reset=False)#
Train or resume tree construction as an async generator.
- Parameters:
- Yields:
Node
– Updated nodes during tree construction.- Raises:
ValueError – If data requirements aren’t met or invalid reset usage.
- Return type:
- async predict(samples)#
Predict labels for samples with concurrent processing.
- Parameters:
samples (
DataFrame
) – DataFrame with single column matching training data format.- Yields:
Tuple of (sample_index, question, answer, node_id, token_usage)
- Return type:
AsyncGenerator
[Tuple
[int
,str
,str
,int
,TokenCounter
],None
]
- prune_tree(node_id)#
Prune the tree from the node with the given ID.
- Parameters:
node_id (
int
) – The ID of the node to prune.- Raises:
ValueError – If the node with the given ID is not found on the tree.
ValueError – If the node with the given ID is a leaf node.
- Return type:
- async resume_fit(node_id)#
Enqueue a node to resume (re)building its subtree from current data.
- Typical usage:
Call prune_tree(node_id) to clear the subtree
await resume_fit(node_id) to continue building from that node
- Parameters:
node_id (
int
) – The ID of the node to resume building from.- Yields:
Node
– Updated nodes during tree construction.- Raises:
ValueError – If the node with the given ID is not found on the tree.
ValueError – If the tree has no training data loaded.
- Return type:
- save(dir_path=None, for_production=False)#
Save model config to JSON and dataframes to parquet in a directory.
If dir_path is None, uses <self.save_path>/<self.name>. If for_production is True, does not save the training dataframe.
- async set_tasks(instructions_template=None, task_description=None)#
Initialize question generation instructions template.
This sets the task description for the tree. Either sets a custom template or generates one from task description using LLM. For most users, LLM generation is recommended over custom templates.
- Parameters:
- Return type:
- Returns:
The question generation instructions template.
- Raises:
ValueError – If template missing required tag or generation fails.
AssertionError – If both parameters are None.
- view_node(node_id, format='png', add_all_questions=False, truncate_length=140)#
Render subtree rooted at node_id as PNG/SVG bytes.
- Parameters:
node_id (
int
) – Root node ID for the subtree visualization.format (
Literal
['png'
,'svg'
]) – Output image format (‘png’ or ‘svg’).add_all_questions (
bool
) – Include all generated questions in node display.truncate_length (
int
|None
) – Maximum text length before truncation. None disables truncation.
- Return type:
- Returns:
Rendered subtree image data as bytes.
- Raises:
ValueError – If node_id doesn’t exist in the tree.
ImportError – If graphviz package is not installed.
- property question_gen_instructions_template: str | None#
Get the question generation instructions template.
- property token_usage: TokenCounter#
Get the token counter for the GPTree.
- class Node(id, label, question=None, questions=<factory>, cumulative_memory=None, split_ratios=None, gini=0.0, class_distribution=<factory>, children=<factory>, parent_id=None)#
Bases:
object
A Node represents a decision point in GPTree.
- Parameters:
- classmethod from_dict(d)#
Convert a dictionary to a node.
-
question:
NodeQuestion
|None
#
-
questions:
List
[NodeQuestion
]#