2024-11-14 20:42:12 +08:00
|
|
|
# flake8: noqa: E501
|
|
|
|
"""WARNING (or more like an aggressive note).
|
|
|
|
|
|
|
|
A lot of functionality was implemented here for earlier experiments. Most of which is not used. We have left it here
|
|
|
|
for backwards compatibility with the current dataset as well as because why not.
|
|
|
|
|
|
|
|
ALSO NOTE:
|
|
|
|
|
|
|
|
This file was created to have no dependencies on anything in the repo for a reason. You can copy this file into your
|
|
|
|
own project and use the classes to parse/visualize/edit the logic trees in the dataset or create your own.
|
|
|
|
|
|
|
|
FINAL NOTE:
|
|
|
|
|
|
|
|
See examples of how to create LogicNodes and LogicTrees in the __main__ part of the file.
|
|
|
|
"""
|
|
|
|
|
|
|
|
import random
|
|
|
|
from copy import deepcopy
|
|
|
|
from enum import Enum
|
|
|
|
from typing import Any, Dict, List
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
class LogicNodeOperatorType:
|
|
|
|
"""How should the deduction combine the nodes (choose will randomly sample
|
|
|
|
and/or when populate is called)"""
|
|
|
|
AND = 'and'
|
|
|
|
OR = 'or'
|
|
|
|
CHOOSE = 'choose'
|
|
|
|
|
|
|
|
|
|
|
|
class LogicNodeFactType:
|
|
|
|
"""Is a node explicit (mentioned in the story) or commonsense knowledge
|
|
|
|
(left unsaid)"""
|
|
|
|
EXPLICIT = 'explicit'
|
|
|
|
COMMONSENSE = 'commonsense'
|
|
|
|
|
|
|
|
|
|
|
|
class LogicNodeConstraints:
|
2025-05-27 16:39:17 +08:00
|
|
|
"""Useful for things like children = ['X is the murderer', 'Y is the
|
|
|
|
murderer', 'Z is the murderer'], we no longer use this structure though."""
|
2024-11-14 20:42:12 +08:00
|
|
|
ONLY_ONE_CAN_BE_TRUE = 'Only one child can be true'
|
|
|
|
|
|
|
|
|
|
|
|
class LogicNodeDeductionType:
|
|
|
|
"""What type of deduction should be used here (not used currently)"""
|
|
|
|
SYLLOGISM = 'syllogism'
|
|
|
|
TEMPORAL = 'temporal'
|
|
|
|
SPATIAL = 'spatial'
|
|
|
|
CHOOSE = 'choose'
|
|
|
|
|
|
|
|
|
|
|
|
class LogicNode:
|
|
|
|
"""A LogicNode is a tree primitive.
|
|
|
|
|
|
|
|
It is either a deduction or a leaf fact. Leaf facts are the ones that we
|
|
|
|
use in story generation (if they are explicit facts and not commonsense).
|
|
|
|
"""
|
|
|
|
value: str
|
|
|
|
children: List['LogicNode']
|
|
|
|
fact_type: str
|
|
|
|
operator: str
|
|
|
|
constraints: List[str]
|
|
|
|
deduction_type: str
|
|
|
|
prunable: bool
|
|
|
|
can_be_leaf: bool
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
value: str = '',
|
|
|
|
children: List['LogicNode'] = None,
|
|
|
|
operator: str = LogicNodeOperatorType.OR,
|
|
|
|
fact_type: str = LogicNodeFactType.EXPLICIT,
|
|
|
|
constraints: List[str] = (),
|
|
|
|
deduction_type: str = None,
|
|
|
|
prunable: bool = True,
|
|
|
|
can_be_leaf: bool = False,
|
|
|
|
frozen: bool = False,
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
:param value: Content for this specific node (also the deduction of the children).
|
|
|
|
:param children: The children for this node.
|
|
|
|
:param operator: Should the children be "And"ed or "Or"ed to create the deduction (the content of this node).
|
|
|
|
:param fact_type: Explicit or commonsense
|
|
|
|
:param constraints: Not used anymore (see LogicNodeConstraints)
|
|
|
|
:param deduction_type: Not used anymore (see LogicNodeDeductionType)
|
|
|
|
:param prunable: Can this node be removed from the tree (we don't prune in our datasets)
|
|
|
|
:param can_be_leaf: Can this node be a leaf node (usually false for nodes that you are injecting manually)
|
|
|
|
:param frozen: Should we add/prune children in the populate function (if frozen, no children will be added or removed, but the children may have children appended/pruned from them).
|
|
|
|
"""
|
|
|
|
self.value = value
|
|
|
|
if children is None:
|
|
|
|
children = []
|
|
|
|
self.children = children
|
|
|
|
self.operator = operator
|
|
|
|
self.fact_type = fact_type
|
|
|
|
self.constraints = constraints
|
|
|
|
self.deduction_type = deduction_type
|
|
|
|
self.prunable = prunable
|
|
|
|
self.can_be_leaf = can_be_leaf
|
|
|
|
self.frozen = frozen
|
|
|
|
self.parent = None
|
|
|
|
|
|
|
|
@property
|
|
|
|
def children(self):
|
|
|
|
return self._children
|
|
|
|
|
|
|
|
@children.setter
|
|
|
|
def children(self, children: List['LogicNode']):
|
|
|
|
self._children = children
|
|
|
|
for c in self.children:
|
|
|
|
c.parent = self
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
line = []
|
|
|
|
cnsts = ', '.join([str(x.value) for x in self.constraints])
|
|
|
|
|
|
|
|
if self.value and self.value != '':
|
|
|
|
line.append(self.value)
|
|
|
|
if len(self.children) > 0:
|
|
|
|
line.append(self.operator)
|
|
|
|
else:
|
|
|
|
line.append(self.fact_type)
|
|
|
|
|
|
|
|
if self.deduction_type:
|
|
|
|
line.append(self.deduction_type)
|
|
|
|
|
|
|
|
if len(self.constraints) > 0:
|
|
|
|
line.append(cnsts)
|
|
|
|
|
|
|
|
if len(self.children) > 0:
|
|
|
|
line.append(f'children: {len(self.children)}')
|
|
|
|
|
|
|
|
return ' | '.join(line)
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return str(self)
|
|
|
|
|
|
|
|
def to_json(self):
|
|
|
|
return {
|
|
|
|
'value': self.value,
|
|
|
|
'children': [x.to_json() for x in self.children],
|
|
|
|
'fact_type': self.fact_type,
|
|
|
|
'operator': self.operator,
|
|
|
|
'constraints': self.constraints,
|
|
|
|
'deduction_type': self.deduction_type,
|
|
|
|
'prunable': self.prunable,
|
|
|
|
'can_be_leaf': self.can_be_leaf
|
|
|
|
}
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_json(cls, js):
|
|
|
|
js['children'] = [LogicNode.from_json(x) for x in js['children']]
|
|
|
|
return cls(**js)
|
|
|
|
|
|
|
|
|
|
|
|
class LogicTree:
|
|
|
|
"""Main datastructure used when creating a MuSR example.
|
|
|
|
|
|
|
|
It's basically a standard tree with some parameters controlling the shape.
|
|
|
|
"""
|
|
|
|
|
|
|
|
nodes: List[LogicNode]
|
|
|
|
|
|
|
|
chance_of_or: float
|
|
|
|
chance_of_cs_fact: float
|
|
|
|
depth: int
|
|
|
|
chance_to_prune: float
|
|
|
|
chance_to_prune_all: float
|
|
|
|
bf_factor: Dict[int, float]
|
|
|
|
deduction_type_sample_rate: Dict[LogicNodeDeductionType, float]
|
|
|
|
root_structure: List[List[LogicNode]] = ()
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
chance_of_or: float = 0.3,
|
|
|
|
chance_of_cs_fact: float = 0.1,
|
|
|
|
depth: int = 2,
|
|
|
|
chance_to_prune: float = 0.6,
|
|
|
|
chance_to_prune_all: float = 0.2,
|
|
|
|
bf_factor: Dict[int, float] = None,
|
|
|
|
deduction_type_sample_rate: Dict[LogicNodeDeductionType,
|
|
|
|
float] = None,
|
|
|
|
enforce_cs_fact_per_level: bool = False,
|
|
|
|
root_structure: List[Any] = (),
|
|
|
|
nodes: List[LogicNode] = (),
|
|
|
|
populate: bool = True,
|
|
|
|
prune: bool = True):
|
|
|
|
"""
|
|
|
|
:param chance_of_or: (not used) how often should a node with children be an OR
|
|
|
|
:param chance_of_cs_fact: (not used) how often should there be a commonsense node
|
|
|
|
:param depth: How deep should a tree go
|
|
|
|
:param chance_to_prune: Percentage chance of pruning a node
|
|
|
|
:param chance_to_prune_all: Percentage chance of pruning all children from a node.
|
|
|
|
:param bf_factor: Branching factor (dictionary of percentages {1: 0.33, 2:0.33, 3:0.33} for example.
|
|
|
|
:param deduction_type_sample_rate: (not used, see bf_factor and LogicNodeDeductionType)
|
|
|
|
:param enforce_cs_fact_per_level: Enforce 1 commonsense fact per level in the tree (we use this instead of chance_of_cs_fact)
|
|
|
|
:param root_structure: List of LogicNodes to build off of.
|
|
|
|
:param nodes: List of LogicNodes to define the LogicTree on (we will not populate/prune the tree if this is filled)
|
|
|
|
:param populate: Should we populate children for the tree according to the other parameters?
|
|
|
|
:param prune: Should we prune the children for the tree according to the other parameters?
|
|
|
|
"""
|
|
|
|
self.chance_of_or = chance_of_or
|
|
|
|
self.chance_of_cs_fact = chance_of_cs_fact
|
|
|
|
self.depth = depth
|
|
|
|
self.chance_to_prune = chance_to_prune
|
|
|
|
self.chance_to_prune_all = chance_to_prune_all
|
|
|
|
self.bf_factor = bf_factor
|
|
|
|
self.enforce_cs_fact_per_level = enforce_cs_fact_per_level
|
|
|
|
|
|
|
|
if not bf_factor:
|
|
|
|
self.bf_factor = {2: 0.8, 3: 0.2}
|
|
|
|
if not deduction_type_sample_rate:
|
|
|
|
deduction_type_sample_rate = {
|
|
|
|
LogicNodeDeductionType.SYLLOGISM: 1.0
|
|
|
|
}
|
|
|
|
|
|
|
|
self.deduction_type_sample_rate = deduction_type_sample_rate
|
|
|
|
self.root_structure = root_structure
|
|
|
|
|
|
|
|
if len(nodes) > 0:
|
|
|
|
self.nodes = nodes
|
|
|
|
else:
|
|
|
|
|
|
|
|
if root_structure is not None and len(root_structure) > 0:
|
|
|
|
self.nodes = root_structure
|
|
|
|
else:
|
|
|
|
self.nodes = [
|
|
|
|
LogicNode('root', operator=LogicNodeOperatorType.AND)
|
|
|
|
]
|
|
|
|
|
|
|
|
if populate:
|
|
|
|
[self.populate(x, 1) for x in self.nodes]
|
|
|
|
if prune:
|
|
|
|
[self.prune(x, 1) for x in self.nodes]
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
return self.print_tree()
|
|
|
|
|
|
|
|
def get_facts(self,
|
|
|
|
include_cs: bool = False,
|
|
|
|
include_deductions_from_level: int = -1,
|
|
|
|
no_facts_after_depth: int = -1):
|
|
|
|
"""Get a list of LogicNodes from the tree. By default, you will get the
|
|
|
|
explicit leaf nodes.
|
|
|
|
|
|
|
|
:param include_cs: Include the commonsense nodes from all levels.
|
2025-05-27 16:39:17 +08:00
|
|
|
:param include_deductions_from_level: Include any intermediate
|
|
|
|
deduction nodes from the specified level and deeper.
|
|
|
|
:param no_facts_after_depth: Essentially tree the deductions at the
|
|
|
|
specified depth as leaf nodes.
|
2024-11-14 20:42:12 +08:00
|
|
|
"""
|
|
|
|
|
|
|
|
def recurse_facts(_node: LogicNode, depth: int = 0) -> List[str]:
|
|
|
|
node = deepcopy(_node)
|
|
|
|
if depth >= no_facts_after_depth and no_facts_after_depth > -1:
|
|
|
|
node.children = []
|
|
|
|
|
|
|
|
facts = []
|
|
|
|
|
|
|
|
if node.fact_type == LogicNodeFactType.EXPLICIT and len(
|
|
|
|
node.children) == 0:
|
|
|
|
facts.append(node)
|
|
|
|
if node.fact_type == LogicNodeFactType.COMMONSENSE and include_cs and len(
|
|
|
|
node.children) == 0:
|
|
|
|
facts.append(node)
|
|
|
|
if len(
|
|
|
|
node.children
|
|
|
|
) > 0 and include_deductions_from_level <= depth and include_deductions_from_level > -1:
|
|
|
|
facts.append(node)
|
|
|
|
|
|
|
|
for child in node.children:
|
|
|
|
facts.extend(recurse_facts(child, depth + 1))
|
|
|
|
return list(set(facts))
|
|
|
|
|
|
|
|
facts = []
|
|
|
|
for n in self.nodes:
|
|
|
|
facts.extend(recurse_facts(n))
|
|
|
|
return facts
|
|
|
|
|
|
|
|
def print_tree(self, node=None, level=0):
|
|
|
|
"""Deprecated (not used)"""
|
|
|
|
if node is None:
|
|
|
|
node = self.nodes[0]
|
|
|
|
line = '-' * level * 4 + str(node) + (' | ' + str(node.operator) if
|
|
|
|
len(node.children) > 0 else '')
|
|
|
|
|
|
|
|
for child in node.children:
|
|
|
|
line += '\n' + self.print_tree(child, level + 1)
|
|
|
|
|
|
|
|
return line
|
|
|
|
|
|
|
|
def print_for_gpt(self,
|
|
|
|
node=None,
|
|
|
|
level=0,
|
|
|
|
pad_char=' ',
|
|
|
|
pad_space=4,
|
|
|
|
print_forward=True,
|
|
|
|
print_conjection_types: bool = False,
|
|
|
|
print_reasoning_types: bool = False,
|
|
|
|
ignore_value_after_depth: int = -1,
|
|
|
|
print_only_nodes_with_value: bool = False):
|
|
|
|
"""Complex print function. We often use it as
|
|
|
|
print_for_gpt(pad_space=1, pad_char='> ')
|
|
|
|
|
|
|
|
However, more complex arguments can be used to control what is printed.
|
|
|
|
|
2025-05-27 16:39:17 +08:00
|
|
|
This returns a string that must be printed (don't be confused by the
|
|
|
|
method name.)
|
2024-11-14 20:42:12 +08:00
|
|
|
|
|
|
|
:param node: Start at a specific node.
|
2025-05-27 16:39:17 +08:00
|
|
|
:param level: Controls how much tabbing is done when printing the
|
|
|
|
current node.
|
|
|
|
:param pad_char: Char to use that specifies depth ('> ' at depth 3 will
|
|
|
|
look like '> > > ' if you have pad_space equal to 1 for example)
|
2024-11-14 20:42:12 +08:00
|
|
|
:param pad_space: How many spaces to include between pad_chars
|
|
|
|
:param print_forward: Print the tree with parent nodes first.
|
2025-05-27 16:39:17 +08:00
|
|
|
:param print_conjection_types: Print the Ands and Ors per deduction
|
|
|
|
(not used)
|
2024-11-14 20:42:12 +08:00
|
|
|
:param print_reasoning_types: Print the deduction types (not used)
|
2025-05-27 16:39:17 +08:00
|
|
|
:param ignore_value_after_depth: Ignore content of the nodes once a
|
|
|
|
depth is met
|
2024-11-14 20:42:12 +08:00
|
|
|
:param print_only_nodes_with_value: Ignore nodes without content.
|
|
|
|
"""
|
|
|
|
|
|
|
|
line = ''
|
|
|
|
|
|
|
|
if node is None:
|
|
|
|
node = self.nodes[0]
|
|
|
|
|
|
|
|
if not print_forward:
|
|
|
|
for child in node.children:
|
|
|
|
v = self.print_for_gpt(
|
|
|
|
child,
|
|
|
|
level + 1,
|
|
|
|
pad_char=pad_char,
|
|
|
|
pad_space=pad_space,
|
|
|
|
print_forward=print_forward,
|
|
|
|
ignore_value_after_depth=ignore_value_after_depth,
|
|
|
|
print_only_nodes_with_value=print_only_nodes_with_value)
|
|
|
|
if v != '':
|
|
|
|
line += v + '\n'
|
|
|
|
|
|
|
|
ignore_val = ignore_value_after_depth > -1 and ignore_value_after_depth < level
|
|
|
|
ignore_line = print_only_nodes_with_value and node.value == ''
|
|
|
|
|
|
|
|
if ignore_line:
|
|
|
|
line_val = ''
|
|
|
|
else:
|
|
|
|
line_val = (node.value + ' | ' if node.value != '' and not ignore_val else '') + (
|
|
|
|
('Fact From Story' if node.fact_type == LogicNodeFactType.EXPLICIT else 'Commonsense Knowledge') \
|
|
|
|
if len(node.children) == 0 else 'Deduced Fact')
|
|
|
|
|
|
|
|
if level == 0:
|
|
|
|
line_val = (node.value + ' | ' if node.value != '' else
|
|
|
|
'') + 'Deduced Root Conclusion'
|
|
|
|
|
|
|
|
if len(node.children) > 0 and (print_conjection_types
|
|
|
|
or print_reasoning_types):
|
|
|
|
if print_conjection_types:
|
|
|
|
line_val += f' ({node.operator}'
|
|
|
|
else:
|
|
|
|
line_val += f'('
|
|
|
|
if node.deduction_type and print_reasoning_types:
|
|
|
|
line_val += f' | {node.deduction_type})'
|
|
|
|
else:
|
|
|
|
line_val += ')'
|
|
|
|
|
|
|
|
if len(node.constraints) > 0:
|
|
|
|
cnsts = ', '.join([str(x) for x in node.constraints])
|
|
|
|
line_val += f' constraints: [{cnsts}]'
|
|
|
|
|
|
|
|
line += pad_char * level * pad_space + line_val
|
|
|
|
|
|
|
|
if print_forward:
|
|
|
|
for child in node.children:
|
|
|
|
v = self.print_for_gpt(
|
|
|
|
child,
|
|
|
|
level + 1,
|
|
|
|
pad_char=pad_char,
|
|
|
|
pad_space=pad_space,
|
|
|
|
print_forward=print_forward,
|
|
|
|
ignore_value_after_depth=ignore_value_after_depth,
|
|
|
|
print_only_nodes_with_value=print_only_nodes_with_value)
|
|
|
|
if v != '':
|
|
|
|
line += '\n' + v
|
|
|
|
|
|
|
|
return line
|
|
|
|
|
|
|
|
def populate(self, node: LogicNode, current_depth: int = 1):
|
|
|
|
if node.operator == LogicNodeOperatorType.CHOOSE:
|
|
|
|
node.operator = LogicNodeOperatorType.OR \
|
|
|
|
if random.random() < self.chance_of_or else LogicNodeOperatorType.AND
|
|
|
|
if node.deduction_type == LogicNodeDeductionType.CHOOSE:
|
|
|
|
if node.operator != LogicNodeOperatorType.AND:
|
|
|
|
node.deduction_type = None
|
|
|
|
else:
|
|
|
|
node.deduction_type = random.choices(
|
|
|
|
list(self.deduction_type_sample_rate.keys()),
|
|
|
|
list(self.deduction_type_sample_rate.values()),
|
|
|
|
k=1)[0]
|
|
|
|
|
|
|
|
if not node.frozen:
|
|
|
|
|
|
|
|
bf = max(
|
|
|
|
0,
|
|
|
|
random.choices(list(self.bf_factor.keys()),
|
|
|
|
list(self.bf_factor.values()),
|
|
|
|
k=1)[0] - len(node.children))
|
|
|
|
|
|
|
|
if bf > 0:
|
|
|
|
|
|
|
|
new_nodes = []
|
|
|
|
one_fact_is_cs = False
|
|
|
|
for idx in range(bf):
|
|
|
|
roll_for_or = random.random()
|
|
|
|
fact_type = LogicNodeFactType.COMMONSENSE \
|
|
|
|
if random.random() < self.chance_of_cs_fact and not one_fact_is_cs else \
|
|
|
|
LogicNodeFactType.EXPLICIT
|
|
|
|
|
|
|
|
if roll_for_or > self.chance_of_or and\
|
|
|
|
current_depth < self.depth and\
|
|
|
|
not fact_type == LogicNodeFactType.COMMONSENSE:
|
|
|
|
new_nodes.append(
|
|
|
|
LogicNode(
|
|
|
|
f'',
|
|
|
|
operator=LogicNodeOperatorType.AND,
|
|
|
|
fact_type=fact_type,
|
|
|
|
deduction_type=random.choices(
|
|
|
|
list(self.deduction_type_sample_rate.keys(
|
|
|
|
)),
|
|
|
|
list(self.deduction_type_sample_rate.
|
|
|
|
values()),
|
|
|
|
k=1)[0],
|
|
|
|
prunable=True,
|
|
|
|
can_be_leaf=True,
|
|
|
|
))
|
|
|
|
else:
|
|
|
|
new_nodes.append(
|
|
|
|
LogicNode(f'',
|
|
|
|
operator=LogicNodeOperatorType.OR,
|
|
|
|
fact_type=fact_type,
|
|
|
|
prunable=True,
|
|
|
|
can_be_leaf=True))
|
|
|
|
|
|
|
|
if fact_type == LogicNodeFactType.COMMONSENSE:
|
|
|
|
node.operator = LogicNodeOperatorType.AND
|
|
|
|
if not node.deduction_type:
|
|
|
|
node.deduction_type = random.choices(
|
|
|
|
list(self.deduction_type_sample_rate.keys()),
|
|
|
|
list(self.deduction_type_sample_rate.values()),
|
|
|
|
k=1)[0]
|
|
|
|
one_fact_is_cs = True
|
|
|
|
|
|
|
|
if not one_fact_is_cs and self.enforce_cs_fact_per_level:
|
|
|
|
new_nodes.append(
|
|
|
|
LogicNode(f'',
|
|
|
|
operator=LogicNodeOperatorType.OR,
|
|
|
|
fact_type=LogicNodeFactType.COMMONSENSE,
|
|
|
|
prunable=False,
|
|
|
|
can_be_leaf=True))
|
|
|
|
|
|
|
|
node.children.extend(new_nodes)
|
|
|
|
|
|
|
|
if current_depth < self.depth:
|
|
|
|
for node in node.children:
|
|
|
|
if node.fact_type == LogicNodeFactType.COMMONSENSE:
|
|
|
|
continue
|
|
|
|
self.populate(node, current_depth + 1)
|
|
|
|
|
|
|
|
def prune(self, node: LogicNode, current_depth: int = 1):
|
|
|
|
to_prune = []
|
|
|
|
|
|
|
|
if current_depth > 1 and node.can_be_leaf:
|
|
|
|
if random.random() < self.chance_to_prune_all:
|
|
|
|
node.children = []
|
|
|
|
return
|
|
|
|
|
|
|
|
prunable = [x for x in node.children if x.prunable]
|
|
|
|
if (len(prunable) > 1 and node.operator == LogicNodeOperatorType.OR or\
|
|
|
|
len(prunable) > 2 and node.operator == LogicNodeOperatorType.AND) and\
|
|
|
|
current_depth <= self.depth:
|
|
|
|
|
|
|
|
if node.prunable:
|
|
|
|
for n in random.sample(
|
|
|
|
prunable,
|
|
|
|
len(prunable) -
|
|
|
|
(1 if node.operator == LogicNodeOperatorType.OR else 2)):
|
|
|
|
roll_to_prune = random.random()
|
|
|
|
if roll_to_prune < self.chance_to_prune:
|
|
|
|
to_prune.append(n)
|
|
|
|
|
|
|
|
node.children = [x for x in node.children if x not in to_prune]
|
|
|
|
for n in node.children:
|
|
|
|
self.prune(n, current_depth + 1)
|
|
|
|
|
|
|
|
def to_json(self):
|
|
|
|
args = {
|
|
|
|
'chance_of_or': self.chance_of_or,
|
|
|
|
'depth': self.depth,
|
|
|
|
'chance_to_prune': self.chance_to_prune,
|
|
|
|
'chance_to_prune_all': self.chance_to_prune_all,
|
|
|
|
'bf_factor': self.bf_factor,
|
|
|
|
'deduction_type_sample_rate': self.deduction_type_sample_rate,
|
|
|
|
'root_structure': [x.to_json() for x in self.root_structure],
|
|
|
|
'nodes': [x.to_json() for x in self.nodes]
|
|
|
|
}
|
|
|
|
return args
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_json(cls, _js):
|
|
|
|
js = deepcopy(_js)
|
|
|
|
js['nodes'] = [LogicNode.from_json(x) for x in js['nodes']]
|
|
|
|
js['root_structure'] = [
|
|
|
|
LogicNode.from_json(x) for x in js['root_structure']
|
|
|
|
]
|
|
|
|
return cls(**js)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
"""EXAMPLE USES."""
|
|
|
|
|
|
|
|
def tv_scene_ex():
|
|
|
|
root_structure = [
|
|
|
|
LogicNode('A good drama tv scene',
|
|
|
|
operator=LogicNodeOperatorType.OR,
|
|
|
|
prunable=False,
|
|
|
|
can_be_leaf=False,
|
|
|
|
frozen=True)
|
|
|
|
]
|
|
|
|
|
|
|
|
root_structure[0].children = [
|
|
|
|
LogicNode('Bob is sad.',
|
|
|
|
operator=LogicNodeOperatorType.CHOOSE,
|
|
|
|
prunable=True,
|
|
|
|
can_be_leaf=False),
|
|
|
|
LogicNode('John now hates Bob.',
|
|
|
|
operator=LogicNodeOperatorType.CHOOSE,
|
|
|
|
prunable=True,
|
|
|
|
can_be_leaf=False),
|
|
|
|
LogicNode('Bob bought a car.',
|
|
|
|
operator=LogicNodeOperatorType.CHOOSE,
|
|
|
|
prunable=True,
|
|
|
|
can_be_leaf=False),
|
|
|
|
LogicNode('Bob wanted to be happy.',
|
|
|
|
operator=LogicNodeOperatorType.CHOOSE,
|
|
|
|
prunable=True,
|
|
|
|
can_be_leaf=False),
|
|
|
|
]
|
|
|
|
|
|
|
|
tree = LogicTree(depth=4,
|
|
|
|
root_structure=root_structure,
|
|
|
|
bf_factor={
|
|
|
|
1: 0.5,
|
|
|
|
2: 0.5
|
|
|
|
},
|
|
|
|
chance_of_or=0.0,
|
|
|
|
chance_of_cs_fact=0.0,
|
|
|
|
chance_to_prune_all=0.5,
|
|
|
|
chance_to_prune=0.5,
|
|
|
|
enforce_cs_fact_per_level=True)
|
|
|
|
|
|
|
|
rep = tree.print_for_gpt(pad_space=1, pad_char='- ')
|
|
|
|
print(rep)
|
|
|
|
|
|
|
|
def eb_ex():
|
|
|
|
root_structure = [
|
|
|
|
LogicNode('',
|
|
|
|
operator=LogicNodeOperatorType.CHOOSE,
|
|
|
|
prunable=False,
|
|
|
|
can_be_leaf=False)
|
|
|
|
]
|
|
|
|
|
|
|
|
n = LogicNode('Eruptions block sunlight.',
|
|
|
|
operator=LogicNodeOperatorType.CHOOSE,
|
|
|
|
prunable=False,
|
|
|
|
can_be_leaf=False,
|
|
|
|
frozen=True)
|
|
|
|
n.children = [
|
|
|
|
LogicNode('Eruptions produce ash clouds.',
|
|
|
|
operator=LogicNodeOperatorType.CHOOSE,
|
|
|
|
prunable=False,
|
|
|
|
can_be_leaf=True,
|
|
|
|
frozen=True),
|
|
|
|
LogicNode('Ash blocks sunlight.',
|
|
|
|
operator=LogicNodeOperatorType.CHOOSE,
|
|
|
|
prunable=False,
|
|
|
|
can_be_leaf=True,
|
|
|
|
frozen=True),
|
|
|
|
]
|
|
|
|
|
|
|
|
g = LogicNode('Eruptions can cause plants to die.',
|
|
|
|
operator=LogicNodeOperatorType.CHOOSE,
|
|
|
|
prunable=True,
|
|
|
|
can_be_leaf=False,
|
|
|
|
frozen=True)
|
|
|
|
|
|
|
|
g.children = [
|
|
|
|
n,
|
|
|
|
LogicNode('Producers will die without sunlight.',
|
|
|
|
operator=LogicNodeOperatorType.CHOOSE,
|
|
|
|
prunable=False,
|
|
|
|
can_be_leaf=True,
|
|
|
|
frozen=True)
|
|
|
|
]
|
|
|
|
|
|
|
|
l = LogicNode('',
|
|
|
|
operator=LogicNodeOperatorType.AND,
|
|
|
|
prunable=False,
|
|
|
|
can_be_leaf=False)
|
|
|
|
l.children = [g]
|
|
|
|
|
|
|
|
root_structure[0].children = [l]
|
|
|
|
|
|
|
|
tree = LogicTree(depth=5,
|
|
|
|
root_structure=root_structure,
|
|
|
|
bf_factor={
|
|
|
|
1: 0.3,
|
|
|
|
2: 0.7
|
|
|
|
},
|
|
|
|
chance_of_or=0.0,
|
|
|
|
chance_of_cs_fact=0.0,
|
|
|
|
chance_to_prune_all=0.0,
|
|
|
|
chance_to_prune=0.0,
|
|
|
|
enforce_cs_fact_per_level=True)
|
|
|
|
|
|
|
|
rep = tree.print_for_gpt(pad_space=1, pad_char='- ')
|
|
|
|
print(rep)
|
|
|
|
|
|
|
|
def murder_mystery_ex():
|
|
|
|
root_structure = [
|
|
|
|
LogicNode('Killer',
|
|
|
|
operator=LogicNodeOperatorType.OR,
|
|
|
|
constraints=[LogicNodeConstraints.ONLY_ONE_CAN_BE_TRUE],
|
|
|
|
prunable=False,
|
|
|
|
can_be_leaf=False,
|
|
|
|
frozen=True)
|
|
|
|
]
|
|
|
|
|
|
|
|
suspect_nodes = [
|
|
|
|
LogicNode(f'Murderer Suspect {idx + 1}',
|
|
|
|
operator=LogicNodeOperatorType.AND,
|
|
|
|
prunable=False,
|
|
|
|
can_be_leaf=False,
|
|
|
|
frozen=True) for idx in range(1)
|
|
|
|
]
|
|
|
|
for s in suspect_nodes:
|
|
|
|
s.children = [
|
|
|
|
LogicNode('Suspect has means',
|
|
|
|
operator=LogicNodeOperatorType.CHOOSE,
|
|
|
|
prunable=True,
|
|
|
|
can_be_leaf=False),
|
|
|
|
LogicNode('Suspect has motive',
|
|
|
|
operator=LogicNodeOperatorType.CHOOSE,
|
|
|
|
prunable=True,
|
|
|
|
can_be_leaf=False),
|
|
|
|
LogicNode('Suspect has opportunity',
|
|
|
|
operator=LogicNodeOperatorType.CHOOSE,
|
|
|
|
prunable=True,
|
|
|
|
can_be_leaf=False)
|
|
|
|
]
|
|
|
|
root_structure[0].children = suspect_nodes
|
|
|
|
|
|
|
|
tree = LogicTree(depth=4,
|
|
|
|
root_structure=root_structure,
|
|
|
|
bf_factor={
|
|
|
|
1: 0.5,
|
|
|
|
2: 0.5
|
|
|
|
},
|
|
|
|
chance_of_or=0.0,
|
|
|
|
chance_of_cs_fact=0.0,
|
|
|
|
chance_to_prune_all=0.5,
|
|
|
|
chance_to_prune=0.5,
|
|
|
|
enforce_cs_fact_per_level=True)
|
|
|
|
|
|
|
|
rep = tree.print_for_gpt(pad_space=1, pad_char='> ')
|
|
|
|
print(rep)
|
|
|
|
|
|
|
|
def action_ex():
|
|
|
|
root_structure = [
|
|
|
|
LogicNode('Take an action',
|
|
|
|
operator=LogicNodeOperatorType.OR,
|
|
|
|
prunable=False,
|
|
|
|
can_be_leaf=False,
|
|
|
|
frozen=True)
|
|
|
|
]
|
|
|
|
|
|
|
|
root_structure[0].children = [
|
|
|
|
LogicNode('Run away',
|
|
|
|
operator=LogicNodeOperatorType.CHOOSE,
|
|
|
|
prunable=False,
|
|
|
|
can_be_leaf=False,
|
|
|
|
frozen=True),
|
|
|
|
LogicNode('Fight back',
|
|
|
|
operator=LogicNodeOperatorType.CHOOSE,
|
|
|
|
prunable=False,
|
|
|
|
can_be_leaf=False,
|
|
|
|
frozen=True),
|
|
|
|
LogicNode('Hide',
|
|
|
|
operator=LogicNodeOperatorType.CHOOSE,
|
|
|
|
prunable=False,
|
|
|
|
can_be_leaf=False,
|
|
|
|
frozen=True),
|
|
|
|
]
|
|
|
|
|
|
|
|
for cidx, c in enumerate(root_structure[0].children):
|
|
|
|
nfacts = random.randint(2, 4)
|
|
|
|
|
|
|
|
for n in range(nfacts):
|
|
|
|
fact = LogicNode('',
|
|
|
|
operator=LogicNodeOperatorType.CHOOSE,
|
|
|
|
prunable=False,
|
|
|
|
can_be_leaf=False,
|
|
|
|
frozen=True)
|
|
|
|
fact.children = [
|
|
|
|
LogicNode('Pro (supporting the parent action)',
|
|
|
|
operator=LogicNodeOperatorType.CHOOSE,
|
|
|
|
prunable=True,
|
|
|
|
can_be_leaf=False,
|
|
|
|
frozen=False),
|
|
|
|
LogicNode('Con (counters the sibling Pro only)',
|
|
|
|
operator=LogicNodeOperatorType.CHOOSE,
|
|
|
|
prunable=True,
|
|
|
|
can_be_leaf=False,
|
|
|
|
frozen=False)
|
|
|
|
]
|
|
|
|
root_structure[0].children[cidx].children.append(fact)
|
|
|
|
|
|
|
|
tree = LogicTree(depth=4,
|
|
|
|
root_structure=root_structure,
|
|
|
|
bf_factor={
|
|
|
|
1: 0.25,
|
|
|
|
2: 0.5,
|
|
|
|
3: 0.25
|
|
|
|
},
|
|
|
|
chance_of_or=0.0,
|
|
|
|
chance_of_cs_fact=0.0,
|
|
|
|
chance_to_prune_all=0.5,
|
|
|
|
chance_to_prune=0.75,
|
|
|
|
enforce_cs_fact_per_level=True)
|
|
|
|
|
|
|
|
rep = tree.print_for_gpt(pad_space=1, pad_char='- ')
|
|
|
|
print(rep)
|
|
|
|
|
|
|
|
tv_scene_ex()
|
|
|
|
eb_ex()
|
|
|
|
action_ex()
|