OpenCompass/opencompass/datasets/musr/tree.py
2025-05-27 08:39:17 +00:00

748 lines
28 KiB
Python

# flake8: noqa: E501
"""WARNING (or more like an aggressive note).
A lot of functionality was implemented here for earlier experiments. Most of which is not used. We have left it here
for backwards compatibility with the current dataset as well as because why not.
ALSO NOTE:
This file was created to have no dependencies on anything in the repo for a reason. You can copy this file into your
own project and use the classes to parse/visualize/edit the logic trees in the dataset or create your own.
FINAL NOTE:
See examples of how to create LogicNodes and LogicTrees in the __main__ part of the file.
"""
import random
from copy import deepcopy
from enum import Enum
from typing import Any, Dict, List
import numpy as np
class LogicNodeOperatorType:
"""How should the deduction combine the nodes (choose will randomly sample
and/or when populate is called)"""
AND = 'and'
OR = 'or'
CHOOSE = 'choose'
class LogicNodeFactType:
"""Is a node explicit (mentioned in the story) or commonsense knowledge
(left unsaid)"""
EXPLICIT = 'explicit'
COMMONSENSE = 'commonsense'
class LogicNodeConstraints:
"""Useful for things like children = ['X is the murderer', 'Y is the
murderer', 'Z is the murderer'], we no longer use this structure though."""
ONLY_ONE_CAN_BE_TRUE = 'Only one child can be true'
class LogicNodeDeductionType:
"""What type of deduction should be used here (not used currently)"""
SYLLOGISM = 'syllogism'
TEMPORAL = 'temporal'
SPATIAL = 'spatial'
CHOOSE = 'choose'
class LogicNode:
"""A LogicNode is a tree primitive.
It is either a deduction or a leaf fact. Leaf facts are the ones that we
use in story generation (if they are explicit facts and not commonsense).
"""
value: str
children: List['LogicNode']
fact_type: str
operator: str
constraints: List[str]
deduction_type: str
prunable: bool
can_be_leaf: bool
def __init__(
self,
value: str = '',
children: List['LogicNode'] = None,
operator: str = LogicNodeOperatorType.OR,
fact_type: str = LogicNodeFactType.EXPLICIT,
constraints: List[str] = (),
deduction_type: str = None,
prunable: bool = True,
can_be_leaf: bool = False,
frozen: bool = False,
):
"""
:param value: Content for this specific node (also the deduction of the children).
:param children: The children for this node.
:param operator: Should the children be "And"ed or "Or"ed to create the deduction (the content of this node).
:param fact_type: Explicit or commonsense
:param constraints: Not used anymore (see LogicNodeConstraints)
:param deduction_type: Not used anymore (see LogicNodeDeductionType)
:param prunable: Can this node be removed from the tree (we don't prune in our datasets)
:param can_be_leaf: Can this node be a leaf node (usually false for nodes that you are injecting manually)
:param frozen: Should we add/prune children in the populate function (if frozen, no children will be added or removed, but the children may have children appended/pruned from them).
"""
self.value = value
if children is None:
children = []
self.children = children
self.operator = operator
self.fact_type = fact_type
self.constraints = constraints
self.deduction_type = deduction_type
self.prunable = prunable
self.can_be_leaf = can_be_leaf
self.frozen = frozen
self.parent = None
@property
def children(self):
return self._children
@children.setter
def children(self, children: List['LogicNode']):
self._children = children
for c in self.children:
c.parent = self
def __str__(self):
line = []
cnsts = ', '.join([str(x.value) for x in self.constraints])
if self.value and self.value != '':
line.append(self.value)
if len(self.children) > 0:
line.append(self.operator)
else:
line.append(self.fact_type)
if self.deduction_type:
line.append(self.deduction_type)
if len(self.constraints) > 0:
line.append(cnsts)
if len(self.children) > 0:
line.append(f'children: {len(self.children)}')
return ' | '.join(line)
def __repr__(self):
return str(self)
def to_json(self):
return {
'value': self.value,
'children': [x.to_json() for x in self.children],
'fact_type': self.fact_type,
'operator': self.operator,
'constraints': self.constraints,
'deduction_type': self.deduction_type,
'prunable': self.prunable,
'can_be_leaf': self.can_be_leaf
}
@classmethod
def from_json(cls, js):
js['children'] = [LogicNode.from_json(x) for x in js['children']]
return cls(**js)
class LogicTree:
"""Main datastructure used when creating a MuSR example.
It's basically a standard tree with some parameters controlling the shape.
"""
nodes: List[LogicNode]
chance_of_or: float
chance_of_cs_fact: float
depth: int
chance_to_prune: float
chance_to_prune_all: float
bf_factor: Dict[int, float]
deduction_type_sample_rate: Dict[LogicNodeDeductionType, float]
root_structure: List[List[LogicNode]] = ()
def __init__(self,
chance_of_or: float = 0.3,
chance_of_cs_fact: float = 0.1,
depth: int = 2,
chance_to_prune: float = 0.6,
chance_to_prune_all: float = 0.2,
bf_factor: Dict[int, float] = None,
deduction_type_sample_rate: Dict[LogicNodeDeductionType,
float] = None,
enforce_cs_fact_per_level: bool = False,
root_structure: List[Any] = (),
nodes: List[LogicNode] = (),
populate: bool = True,
prune: bool = True):
"""
:param chance_of_or: (not used) how often should a node with children be an OR
:param chance_of_cs_fact: (not used) how often should there be a commonsense node
:param depth: How deep should a tree go
:param chance_to_prune: Percentage chance of pruning a node
:param chance_to_prune_all: Percentage chance of pruning all children from a node.
:param bf_factor: Branching factor (dictionary of percentages {1: 0.33, 2:0.33, 3:0.33} for example.
:param deduction_type_sample_rate: (not used, see bf_factor and LogicNodeDeductionType)
:param enforce_cs_fact_per_level: Enforce 1 commonsense fact per level in the tree (we use this instead of chance_of_cs_fact)
:param root_structure: List of LogicNodes to build off of.
:param nodes: List of LogicNodes to define the LogicTree on (we will not populate/prune the tree if this is filled)
:param populate: Should we populate children for the tree according to the other parameters?
:param prune: Should we prune the children for the tree according to the other parameters?
"""
self.chance_of_or = chance_of_or
self.chance_of_cs_fact = chance_of_cs_fact
self.depth = depth
self.chance_to_prune = chance_to_prune
self.chance_to_prune_all = chance_to_prune_all
self.bf_factor = bf_factor
self.enforce_cs_fact_per_level = enforce_cs_fact_per_level
if not bf_factor:
self.bf_factor = {2: 0.8, 3: 0.2}
if not deduction_type_sample_rate:
deduction_type_sample_rate = {
LogicNodeDeductionType.SYLLOGISM: 1.0
}
self.deduction_type_sample_rate = deduction_type_sample_rate
self.root_structure = root_structure
if len(nodes) > 0:
self.nodes = nodes
else:
if root_structure is not None and len(root_structure) > 0:
self.nodes = root_structure
else:
self.nodes = [
LogicNode('root', operator=LogicNodeOperatorType.AND)
]
if populate:
[self.populate(x, 1) for x in self.nodes]
if prune:
[self.prune(x, 1) for x in self.nodes]
def __str__(self):
return self.print_tree()
def get_facts(self,
include_cs: bool = False,
include_deductions_from_level: int = -1,
no_facts_after_depth: int = -1):
"""Get a list of LogicNodes from the tree. By default, you will get the
explicit leaf nodes.
:param include_cs: Include the commonsense nodes from all levels.
:param include_deductions_from_level: Include any intermediate
deduction nodes from the specified level and deeper.
:param no_facts_after_depth: Essentially tree the deductions at the
specified depth as leaf nodes.
"""
def recurse_facts(_node: LogicNode, depth: int = 0) -> List[str]:
node = deepcopy(_node)
if depth >= no_facts_after_depth and no_facts_after_depth > -1:
node.children = []
facts = []
if node.fact_type == LogicNodeFactType.EXPLICIT and len(
node.children) == 0:
facts.append(node)
if node.fact_type == LogicNodeFactType.COMMONSENSE and include_cs and len(
node.children) == 0:
facts.append(node)
if len(
node.children
) > 0 and include_deductions_from_level <= depth and include_deductions_from_level > -1:
facts.append(node)
for child in node.children:
facts.extend(recurse_facts(child, depth + 1))
return list(set(facts))
facts = []
for n in self.nodes:
facts.extend(recurse_facts(n))
return facts
def print_tree(self, node=None, level=0):
"""Deprecated (not used)"""
if node is None:
node = self.nodes[0]
line = '-' * level * 4 + str(node) + (' | ' + str(node.operator) if
len(node.children) > 0 else '')
for child in node.children:
line += '\n' + self.print_tree(child, level + 1)
return line
def print_for_gpt(self,
node=None,
level=0,
pad_char=' ',
pad_space=4,
print_forward=True,
print_conjection_types: bool = False,
print_reasoning_types: bool = False,
ignore_value_after_depth: int = -1,
print_only_nodes_with_value: bool = False):
"""Complex print function. We often use it as
print_for_gpt(pad_space=1, pad_char='> ')
However, more complex arguments can be used to control what is printed.
This returns a string that must be printed (don't be confused by the
method name.)
:param node: Start at a specific node.
:param level: Controls how much tabbing is done when printing the
current node.
:param pad_char: Char to use that specifies depth ('> ' at depth 3 will
look like '> > > ' if you have pad_space equal to 1 for example)
:param pad_space: How many spaces to include between pad_chars
:param print_forward: Print the tree with parent nodes first.
:param print_conjection_types: Print the Ands and Ors per deduction
(not used)
:param print_reasoning_types: Print the deduction types (not used)
:param ignore_value_after_depth: Ignore content of the nodes once a
depth is met
:param print_only_nodes_with_value: Ignore nodes without content.
"""
line = ''
if node is None:
node = self.nodes[0]
if not print_forward:
for child in node.children:
v = self.print_for_gpt(
child,
level + 1,
pad_char=pad_char,
pad_space=pad_space,
print_forward=print_forward,
ignore_value_after_depth=ignore_value_after_depth,
print_only_nodes_with_value=print_only_nodes_with_value)
if v != '':
line += v + '\n'
ignore_val = ignore_value_after_depth > -1 and ignore_value_after_depth < level
ignore_line = print_only_nodes_with_value and node.value == ''
if ignore_line:
line_val = ''
else:
line_val = (node.value + ' | ' if node.value != '' and not ignore_val else '') + (
('Fact From Story' if node.fact_type == LogicNodeFactType.EXPLICIT else 'Commonsense Knowledge') \
if len(node.children) == 0 else 'Deduced Fact')
if level == 0:
line_val = (node.value + ' | ' if node.value != '' else
'') + 'Deduced Root Conclusion'
if len(node.children) > 0 and (print_conjection_types
or print_reasoning_types):
if print_conjection_types:
line_val += f' ({node.operator}'
else:
line_val += f'('
if node.deduction_type and print_reasoning_types:
line_val += f' | {node.deduction_type})'
else:
line_val += ')'
if len(node.constraints) > 0:
cnsts = ', '.join([str(x) for x in node.constraints])
line_val += f' constraints: [{cnsts}]'
line += pad_char * level * pad_space + line_val
if print_forward:
for child in node.children:
v = self.print_for_gpt(
child,
level + 1,
pad_char=pad_char,
pad_space=pad_space,
print_forward=print_forward,
ignore_value_after_depth=ignore_value_after_depth,
print_only_nodes_with_value=print_only_nodes_with_value)
if v != '':
line += '\n' + v
return line
def populate(self, node: LogicNode, current_depth: int = 1):
if node.operator == LogicNodeOperatorType.CHOOSE:
node.operator = LogicNodeOperatorType.OR \
if random.random() < self.chance_of_or else LogicNodeOperatorType.AND
if node.deduction_type == LogicNodeDeductionType.CHOOSE:
if node.operator != LogicNodeOperatorType.AND:
node.deduction_type = None
else:
node.deduction_type = random.choices(
list(self.deduction_type_sample_rate.keys()),
list(self.deduction_type_sample_rate.values()),
k=1)[0]
if not node.frozen:
bf = max(
0,
random.choices(list(self.bf_factor.keys()),
list(self.bf_factor.values()),
k=1)[0] - len(node.children))
if bf > 0:
new_nodes = []
one_fact_is_cs = False
for idx in range(bf):
roll_for_or = random.random()
fact_type = LogicNodeFactType.COMMONSENSE \
if random.random() < self.chance_of_cs_fact and not one_fact_is_cs else \
LogicNodeFactType.EXPLICIT
if roll_for_or > self.chance_of_or and\
current_depth < self.depth and\
not fact_type == LogicNodeFactType.COMMONSENSE:
new_nodes.append(
LogicNode(
f'',
operator=LogicNodeOperatorType.AND,
fact_type=fact_type,
deduction_type=random.choices(
list(self.deduction_type_sample_rate.keys(
)),
list(self.deduction_type_sample_rate.
values()),
k=1)[0],
prunable=True,
can_be_leaf=True,
))
else:
new_nodes.append(
LogicNode(f'',
operator=LogicNodeOperatorType.OR,
fact_type=fact_type,
prunable=True,
can_be_leaf=True))
if fact_type == LogicNodeFactType.COMMONSENSE:
node.operator = LogicNodeOperatorType.AND
if not node.deduction_type:
node.deduction_type = random.choices(
list(self.deduction_type_sample_rate.keys()),
list(self.deduction_type_sample_rate.values()),
k=1)[0]
one_fact_is_cs = True
if not one_fact_is_cs and self.enforce_cs_fact_per_level:
new_nodes.append(
LogicNode(f'',
operator=LogicNodeOperatorType.OR,
fact_type=LogicNodeFactType.COMMONSENSE,
prunable=False,
can_be_leaf=True))
node.children.extend(new_nodes)
if current_depth < self.depth:
for node in node.children:
if node.fact_type == LogicNodeFactType.COMMONSENSE:
continue
self.populate(node, current_depth + 1)
def prune(self, node: LogicNode, current_depth: int = 1):
to_prune = []
if current_depth > 1 and node.can_be_leaf:
if random.random() < self.chance_to_prune_all:
node.children = []
return
prunable = [x for x in node.children if x.prunable]
if (len(prunable) > 1 and node.operator == LogicNodeOperatorType.OR or\
len(prunable) > 2 and node.operator == LogicNodeOperatorType.AND) and\
current_depth <= self.depth:
if node.prunable:
for n in random.sample(
prunable,
len(prunable) -
(1 if node.operator == LogicNodeOperatorType.OR else 2)):
roll_to_prune = random.random()
if roll_to_prune < self.chance_to_prune:
to_prune.append(n)
node.children = [x for x in node.children if x not in to_prune]
for n in node.children:
self.prune(n, current_depth + 1)
def to_json(self):
args = {
'chance_of_or': self.chance_of_or,
'depth': self.depth,
'chance_to_prune': self.chance_to_prune,
'chance_to_prune_all': self.chance_to_prune_all,
'bf_factor': self.bf_factor,
'deduction_type_sample_rate': self.deduction_type_sample_rate,
'root_structure': [x.to_json() for x in self.root_structure],
'nodes': [x.to_json() for x in self.nodes]
}
return args
@classmethod
def from_json(cls, _js):
js = deepcopy(_js)
js['nodes'] = [LogicNode.from_json(x) for x in js['nodes']]
js['root_structure'] = [
LogicNode.from_json(x) for x in js['root_structure']
]
return cls(**js)
if __name__ == '__main__':
"""EXAMPLE USES."""
def tv_scene_ex():
root_structure = [
LogicNode('A good drama tv scene',
operator=LogicNodeOperatorType.OR,
prunable=False,
can_be_leaf=False,
frozen=True)
]
root_structure[0].children = [
LogicNode('Bob is sad.',
operator=LogicNodeOperatorType.CHOOSE,
prunable=True,
can_be_leaf=False),
LogicNode('John now hates Bob.',
operator=LogicNodeOperatorType.CHOOSE,
prunable=True,
can_be_leaf=False),
LogicNode('Bob bought a car.',
operator=LogicNodeOperatorType.CHOOSE,
prunable=True,
can_be_leaf=False),
LogicNode('Bob wanted to be happy.',
operator=LogicNodeOperatorType.CHOOSE,
prunable=True,
can_be_leaf=False),
]
tree = LogicTree(depth=4,
root_structure=root_structure,
bf_factor={
1: 0.5,
2: 0.5
},
chance_of_or=0.0,
chance_of_cs_fact=0.0,
chance_to_prune_all=0.5,
chance_to_prune=0.5,
enforce_cs_fact_per_level=True)
rep = tree.print_for_gpt(pad_space=1, pad_char='- ')
print(rep)
def eb_ex():
root_structure = [
LogicNode('',
operator=LogicNodeOperatorType.CHOOSE,
prunable=False,
can_be_leaf=False)
]
n = LogicNode('Eruptions block sunlight.',
operator=LogicNodeOperatorType.CHOOSE,
prunable=False,
can_be_leaf=False,
frozen=True)
n.children = [
LogicNode('Eruptions produce ash clouds.',
operator=LogicNodeOperatorType.CHOOSE,
prunable=False,
can_be_leaf=True,
frozen=True),
LogicNode('Ash blocks sunlight.',
operator=LogicNodeOperatorType.CHOOSE,
prunable=False,
can_be_leaf=True,
frozen=True),
]
g = LogicNode('Eruptions can cause plants to die.',
operator=LogicNodeOperatorType.CHOOSE,
prunable=True,
can_be_leaf=False,
frozen=True)
g.children = [
n,
LogicNode('Producers will die without sunlight.',
operator=LogicNodeOperatorType.CHOOSE,
prunable=False,
can_be_leaf=True,
frozen=True)
]
l = LogicNode('',
operator=LogicNodeOperatorType.AND,
prunable=False,
can_be_leaf=False)
l.children = [g]
root_structure[0].children = [l]
tree = LogicTree(depth=5,
root_structure=root_structure,
bf_factor={
1: 0.3,
2: 0.7
},
chance_of_or=0.0,
chance_of_cs_fact=0.0,
chance_to_prune_all=0.0,
chance_to_prune=0.0,
enforce_cs_fact_per_level=True)
rep = tree.print_for_gpt(pad_space=1, pad_char='- ')
print(rep)
def murder_mystery_ex():
root_structure = [
LogicNode('Killer',
operator=LogicNodeOperatorType.OR,
constraints=[LogicNodeConstraints.ONLY_ONE_CAN_BE_TRUE],
prunable=False,
can_be_leaf=False,
frozen=True)
]
suspect_nodes = [
LogicNode(f'Murderer Suspect {idx + 1}',
operator=LogicNodeOperatorType.AND,
prunable=False,
can_be_leaf=False,
frozen=True) for idx in range(1)
]
for s in suspect_nodes:
s.children = [
LogicNode('Suspect has means',
operator=LogicNodeOperatorType.CHOOSE,
prunable=True,
can_be_leaf=False),
LogicNode('Suspect has motive',
operator=LogicNodeOperatorType.CHOOSE,
prunable=True,
can_be_leaf=False),
LogicNode('Suspect has opportunity',
operator=LogicNodeOperatorType.CHOOSE,
prunable=True,
can_be_leaf=False)
]
root_structure[0].children = suspect_nodes
tree = LogicTree(depth=4,
root_structure=root_structure,
bf_factor={
1: 0.5,
2: 0.5
},
chance_of_or=0.0,
chance_of_cs_fact=0.0,
chance_to_prune_all=0.5,
chance_to_prune=0.5,
enforce_cs_fact_per_level=True)
rep = tree.print_for_gpt(pad_space=1, pad_char='> ')
print(rep)
def action_ex():
root_structure = [
LogicNode('Take an action',
operator=LogicNodeOperatorType.OR,
prunable=False,
can_be_leaf=False,
frozen=True)
]
root_structure[0].children = [
LogicNode('Run away',
operator=LogicNodeOperatorType.CHOOSE,
prunable=False,
can_be_leaf=False,
frozen=True),
LogicNode('Fight back',
operator=LogicNodeOperatorType.CHOOSE,
prunable=False,
can_be_leaf=False,
frozen=True),
LogicNode('Hide',
operator=LogicNodeOperatorType.CHOOSE,
prunable=False,
can_be_leaf=False,
frozen=True),
]
for cidx, c in enumerate(root_structure[0].children):
nfacts = random.randint(2, 4)
for n in range(nfacts):
fact = LogicNode('',
operator=LogicNodeOperatorType.CHOOSE,
prunable=False,
can_be_leaf=False,
frozen=True)
fact.children = [
LogicNode('Pro (supporting the parent action)',
operator=LogicNodeOperatorType.CHOOSE,
prunable=True,
can_be_leaf=False,
frozen=False),
LogicNode('Con (counters the sibling Pro only)',
operator=LogicNodeOperatorType.CHOOSE,
prunable=True,
can_be_leaf=False,
frozen=False)
]
root_structure[0].children[cidx].children.append(fact)
tree = LogicTree(depth=4,
root_structure=root_structure,
bf_factor={
1: 0.25,
2: 0.5,
3: 0.25
},
chance_of_or=0.0,
chance_of_cs_fact=0.0,
chance_to_prune_all=0.5,
chance_to_prune=0.75,
enforce_cs_fact_per_level=True)
rep = tree.print_for_gpt(pad_space=1, pad_char='- ')
print(rep)
tv_scene_ex()
eb_ex()
action_ex()