Tutorial 4.1: Relational Representations and Inductive Learning#
[1]:
import jacinle
from concepts.benchmark.logic_induction.family import random_generate_family
[2]:
family = random_generate_family(10)
[3]:
jacinle.stprint(family.father)
np.ndarray(shape=(10, 10), dtype=float64)
[4]:
jacinle.stprint(family.mother)
np.ndarray(shape=(10, 10), dtype=float64)
[5]:
jacinle.stprint(family.get_parents())
np.ndarray(shape=(10, 10), dtype=float64)
[6]:
from concepts.benchmark.logic_induction.graph_dataset import FamilyTreeDataset
[7]:
dataset = FamilyTreeDataset(10, epoch_size=8192, task='parents')
[8]:
dataset
[8]:
<concepts.benchmark.inductive_reasoning.graph_dataset.FamilyTreeDataset at 0x15c75ebb0>
[9]:
jacinle.stprint(dataset[0])
dict{
n: 5
relations: np.ndarray(shape=(5, 5, 4), dtype=float64)
target: np.ndarray(shape=(5, 5), dtype=float64){[[0. 0. 0. 0. 0.]
[1. 0. 0. 0. 1.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0.]]}
}
[10]:
from jactorch.data.dataloader import JacDataLoader
[11]:
loader = JacDataLoader(dataset, batch_size=1, shuffle=True, drop_last=True)
[12]:
for i, batch in enumerate(loader):
if i >= 16:
break
print(i, batch['relations'].shape, batch['target'].shape)
0 torch.Size([1, 6, 6, 4]) torch.Size([1, 6, 6])
1 torch.Size([1, 8, 8, 4]) torch.Size([1, 8, 8])
2 torch.Size([1, 9, 9, 4]) torch.Size([1, 9, 9])
3 torch.Size([1, 9, 9, 4]) torch.Size([1, 9, 9])
4 torch.Size([1, 5, 5, 4]) torch.Size([1, 5, 5])
5 torch.Size([1, 9, 9, 4]) torch.Size([1, 9, 9])
6 torch.Size([1, 8, 8, 4]) torch.Size([1, 8, 8])
7 torch.Size([1, 8, 8, 4]) torch.Size([1, 8, 8])
8 torch.Size([1, 10, 10, 4]) torch.Size([1, 10, 10])
9 torch.Size([1, 8, 8, 4]) torch.Size([1, 8, 8])
10 torch.Size([1, 8, 8, 4]) torch.Size([1, 8, 8])
11 torch.Size([1, 7, 7, 4]) torch.Size([1, 7, 7])
12 torch.Size([1, 10, 10, 4]) torch.Size([1, 10, 10])
13 torch.Size([1, 8, 8, 4]) torch.Size([1, 8, 8])
14 torch.Size([1, 8, 8, 4]) torch.Size([1, 8, 8])
15 torch.Size([1, 10, 10, 4]) torch.Size([1, 10, 10])
[13]:
from jactorch.data.collate.collate_v2 import VarLengthCollateV2
[14]:
loader = JacDataLoader(dataset, batch_size=32, shuffle=True, drop_last=True, collate_fn=VarLengthCollateV2({
'relations': 'pad2d',
'target': 'pad2d'
}))
[15]:
for i, batch in enumerate(loader):
if i >= 2:
break
print(i, batch['relations'].shape, batch['target'].shape)
0 torch.Size([32, 10, 10, 4]) torch.Size([32, 10, 10])
1 torch.Size([32, 10, 10, 4]) torch.Size([32, 10, 10])
[16]:
import torch.nn as nn
import torch.optim as optim
from jactorch.nn.neural_logic.layer import NeuralLogicMachine
from jactorch.nn.losses.losses import PNBalancedBinaryCrossEntropyLossWithProbs
class Model(nn.Module):
def __init__(self):
super().__init__()
self.nlm = NeuralLogicMachine(3, 3, [0, 0, 4, 0], [16, 16, 16, 16], 'mlp', logic_hidden_dim=[], io_residual=True)
self.predict = nn.Linear(self.nlm.output_dims[2], 1)
self.loss = PNBalancedBinaryCrossEntropyLossWithProbs()
def forward(self, feed_dict):
feature = self.nlm([None, None, feed_dict['relations'].float(), None])[2]
pred = self.predict(feature).squeeze(-1).sigmoid()
if self.training:
loss = self.loss(pred, feed_dict['target'].float())
return loss, {'pred': pred}
return {'pred': pred}
model = Model()
[17]:
model.train()
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
for i, batch in enumerate(loader):
loss, output_dict = model(batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % 50 == 0:
print(i, len(loader), loss.item())
0 256 0.7050229907035828
50 256 0.5954307913780212
100 256 0.18965472280979156
150 256 0.04814121127128601
200 256 0.023047110065817833
250 256 0.013977136462926865
[18]:
test_dataset = FamilyTreeDataset(10, epoch_size=1024, task='parents')
test_dataloader = JacDataLoader(test_dataset, batch_size=32, shuffle=True, drop_last=True, collate_fn=VarLengthCollateV2({
'relations': 'pad2d',
'target': 'pad2d'
}))
[19]:
from jactorch.train.monitor import binary_classification_accuracy
model.eval()
for i, batch in enumerate(test_dataloader):
output_dict = model(batch)
accuracy = binary_classification_accuracy(output_dict['pred'], batch['target'], saturation=False)
print(i, len(test_dataloader), accuracy)
0 32 {'accuracy': 1.0}
1 32 {'accuracy': 1.0}
2 32 {'accuracy': 1.0}
3 32 {'accuracy': 1.0}
4 32 {'accuracy': 1.0}
5 32 {'accuracy': 1.0}
6 32 {'accuracy': 1.0}
7 32 {'accuracy': 1.0}
8 32 {'accuracy': 1.0}
9 32 {'accuracy': 1.0}
10 32 {'accuracy': 1.0}
11 32 {'accuracy': 1.0}
12 32 {'accuracy': 1.0}
13 32 {'accuracy': 1.0}
14 32 {'accuracy': 1.0}
15 32 {'accuracy': 1.0}
16 32 {'accuracy': 1.0}
17 32 {'accuracy': 1.0}
18 32 {'accuracy': 1.0}
19 32 {'accuracy': 1.0}
20 32 {'accuracy': 1.0}
21 32 {'accuracy': 1.0}
22 32 {'accuracy': 1.0}
23 32 {'accuracy': 1.0}
24 32 {'accuracy': 1.0}
25 32 {'accuracy': 1.0}
26 32 {'accuracy': 1.0}
27 32 {'accuracy': 1.0}
28 32 {'accuracy': 1.0}
29 32 {'accuracy': 1.0}
30 32 {'accuracy': 1.0}
31 32 {'accuracy': 1.0}