Source code for concepts.benchmark.logic_induction.boolean_normal_form_dataset
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : boolean_normal_form_dataset.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 03/18/2018
#
# This file is part of Project Concepts.
# Distributed under terms of the MIT license.
import numpy as np
from torch.utils.data.dataset import Dataset
[docs]
class TruthTableDataset(Dataset):
"""Learning a truth table."""
[docs]
def __init__(self, nr_variables, table):
assert nr_variables <= 16
assert 1 << nr_variables == len(table)
self.nr_variables = nr_variables
self.table = table
[docs]
def __getitem__(self, item):
assigns = _binary_decomposition(item, self.nr_variables)
result = self.table[item]
return dict(input=np.array(assigns, dtype=np.float32), label=float(result))
[docs]
def __len__(self):
return len(self.table)
[docs]
class ParityDataset(Dataset):
"""Learning the parity function."""
[docs]
def __init__(self, nr_variables):
assert nr_variables <= 16
self.nr_variables = nr_variables
[docs]
def __getitem__(self, item):
assigns = _binary_decomposition(item, self.nr_variables)
result = sum(assigns) % 2
return dict(input=np.array(assigns, dtype=np.float32), label=float(result))
[docs]
def __len__(self):
return 1 << self.nr_variables
def _binary_decomposition(v, n):
return [bool(v & (1 << i)) for i in range(n)]