Tutorial 2.4: Learning Lexicon Weights of NCCGs#
[1]:
import jacinle
from tabulate import tabulate
[2]:
from concepts.dsl.dsl_types import INT64
from concepts.dsl.dsl_functions import Function, FunctionTyping
from concepts.dsl.function_domain import FunctionDomain
math_domain = FunctionDomain()
math_domain.define_function(Function('add', FunctionTyping[INT64](INT64, INT64)))
math_domain.define_function(Function('minus', FunctionTyping[INT64](INT64, INT64)))
[2]:
Function<minus(#0: int64, #1: int64) -> int64>
[3]:
for i in range(10):
math_domain.define_function(Function(f'int{i}', FunctionTyping[INT64]()))
[4]:
math_domain.print_summary()
TypeSystem: FunctionDomain
Types:
Constants:
Functions:
add(#0: int64, #1: int64) -> int64
minus(#0: int64, #1: int64) -> int64
int0() -> int64
int1() -> int64
int2() -> int64
int3() -> int64
int4() -> int64
int5() -> int64
int6() -> int64
int7() -> int64
int8() -> int64
int9() -> int64
[5]:
import torch
from concepts.dsl.executors.function_domain_executor import FunctionDomainExecutor
class Executor(FunctionDomainExecutor):
def int0(self): return torch.tensor(0, dtype=torch.float32)
def int1(self): return torch.tensor(1, dtype=torch.float32)
def int2(self): return torch.tensor(2, dtype=torch.float32)
def int3(self): return torch.tensor(3, dtype=torch.float32)
def int4(self): return torch.tensor(4, dtype=torch.float32)
def int5(self): return torch.tensor(5, dtype=torch.float32)
def int6(self): return torch.tensor(6, dtype=torch.float32)
def int7(self): return torch.tensor(7, dtype=torch.float32)
def int8(self): return torch.tensor(8, dtype=torch.float32)
def int9(self): return torch.tensor(9, dtype=torch.float32)
def add(self, x, y): return x + y
def minus(self, x, y): return x - y
executor = Executor(math_domain)
11 22:40:35 Function add automatically registered.
11 22:40:35 Function minus automatically registered.
11 22:40:35 Function int0 automatically registered.
11 22:40:35 Function int1 automatically registered.
11 22:40:35 Function int2 automatically registered.
11 22:40:35 Function int3 automatically registered.
11 22:40:35 Function int4 automatically registered.
11 22:40:35 Function int5 automatically registered.
11 22:40:35 Function int6 automatically registered.
11 22:40:35 Function int7 automatically registered.
11 22:40:35 Function int8 automatically registered.
11 22:40:35 Function int9 automatically registered.
[6]:
executor.execute(math_domain.lam(lambda: math_domain.f_add(math_domain.f_int3(), math_domain.f_int4()))())
[6]:
V(7.0, dtype=int64)
[7]:
from concepts.dsl.learning.function_domain_search import FunctionDomainExpressionEnumerativeSearcher
expression_searcher = FunctionDomainExpressionEnumerativeSearcher(math_domain)
candidate_expressions = expression_searcher.gen(max_depth=1)
candidate_expressions
[7]:
[FunctionDomainExpressionSearchResult(expression=Function<def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)>, depth=1, nr_constant_arguments=0, nr_variable_arguments=2, nr_function_arguments=0),
FunctionDomainExpressionSearchResult(expression=Function<def __lambda__(#0: int64, #1: int64): return minus(V::#0, V::#1)>, depth=1, nr_constant_arguments=0, nr_variable_arguments=2, nr_function_arguments=0),
FunctionDomainExpressionSearchResult(expression=Function<def __lambda__(): return int0()>, depth=1, nr_constant_arguments=0, nr_variable_arguments=0, nr_function_arguments=0),
FunctionDomainExpressionSearchResult(expression=Function<def __lambda__(): return int1()>, depth=1, nr_constant_arguments=0, nr_variable_arguments=0, nr_function_arguments=0),
FunctionDomainExpressionSearchResult(expression=Function<def __lambda__(): return int2()>, depth=1, nr_constant_arguments=0, nr_variable_arguments=0, nr_function_arguments=0),
FunctionDomainExpressionSearchResult(expression=Function<def __lambda__(): return int3()>, depth=1, nr_constant_arguments=0, nr_variable_arguments=0, nr_function_arguments=0),
FunctionDomainExpressionSearchResult(expression=Function<def __lambda__(): return int4()>, depth=1, nr_constant_arguments=0, nr_variable_arguments=0, nr_function_arguments=0),
FunctionDomainExpressionSearchResult(expression=Function<def __lambda__(): return int5()>, depth=1, nr_constant_arguments=0, nr_variable_arguments=0, nr_function_arguments=0),
FunctionDomainExpressionSearchResult(expression=Function<def __lambda__(): return int6()>, depth=1, nr_constant_arguments=0, nr_variable_arguments=0, nr_function_arguments=0),
FunctionDomainExpressionSearchResult(expression=Function<def __lambda__(): return int7()>, depth=1, nr_constant_arguments=0, nr_variable_arguments=0, nr_function_arguments=0),
FunctionDomainExpressionSearchResult(expression=Function<def __lambda__(): return int8()>, depth=1, nr_constant_arguments=0, nr_variable_arguments=0, nr_function_arguments=0),
FunctionDomainExpressionSearchResult(expression=Function<def __lambda__(): return int9()>, depth=1, nr_constant_arguments=0, nr_variable_arguments=0, nr_function_arguments=0)]
[8]:
from concepts.language.neural_ccg.search import NeuralCCGLexiconEnumerativeSearcher
lexicon_searcher = NeuralCCGLexiconEnumerativeSearcher(candidate_expressions, executor)
candidate_lexicon_entries = lexicon_searcher.gen()
candidate_lexicon_entries_table = list()
for result in candidate_lexicon_entries[:20]:
candidate_lexicon_entries_table.append((str(result.syntax), str(result.semantics)))
print(tabulate(candidate_lexicon_entries_table, headers=['syntax', 'semantics']))
print(f'In total: {len(candidate_lexicon_entries)} lexicon entries.')
syntax semantics
----------------- ----------------------------------------------------------------
int64/int64/int64 def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
int64\int64/int64 def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
int64\int64\int64 def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
int64/int64/int64 def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)
int64\int64/int64 def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)
int64\int64\int64 def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)
int64/int64/int64 def __lambda__(#0: int64, #1: int64): return minus(V::#1, V::#0)
int64\int64/int64 def __lambda__(#0: int64, #1: int64): return minus(V::#1, V::#0)
int64\int64\int64 def __lambda__(#0: int64, #1: int64): return minus(V::#1, V::#0)
int64/int64/int64 def __lambda__(#0: int64, #1: int64): return minus(V::#0, V::#1)
int64\int64/int64 def __lambda__(#0: int64, #1: int64): return minus(V::#0, V::#1)
int64\int64\int64 def __lambda__(#0: int64, #1: int64): return minus(V::#0, V::#1)
int64 def __lambda__(): return int0()
int64 def __lambda__(): return int1()
int64 def __lambda__(): return int2()
int64 def __lambda__(): return int3()
int64 def __lambda__(): return int4()
int64 def __lambda__(): return int5()
int64 def __lambda__(): return int6()
int64 def __lambda__(): return int7()
In total: 22 lexicon entries.
[9]:
from concepts.language.neural_ccg.grammar import NeuralCCG
[10]:
ccg = NeuralCCG(math_domain, executor, candidate_lexicon_entries)
[11]:
import torch.nn as nn
import torch.nn.functional as F
[12]:
lexicon_weights = nn.Parameter(torch.zeros((3, 22), dtype=torch.float32))
lexicon_weights.data[0, 13] = 1e9
lexicon_weights.data[2, 14] = 1e9
[13]:
results = ccg.parse("one plus two", F.log_softmax(lexicon_weights, dim=-1))
[14]:
result_table = list()
for result in results[:20]:
result_table.append((str(result.syntax), str(result.semantics.value.execute()), str(result.execution_result), str(result.weight.item())))
[15]:
print(tabulate(result_table, headers=['syntax', 'semantics', 'grounded value', 'weight']))
print(f'In total: {len(results)} parsing trees.')
syntax semantics grounded value weight
-------- --------------------- -------------------- --------
int64 add(int1(), int2()) V(3.0, dtype=int64) -3.09104
int64 add(int2(), int1()) V(3.0, dtype=int64) -3.09104
int64 minus(int1(), int2()) V(-1.0, dtype=int64) -3.09104
int64 minus(int2(), int1()) V(1.0, dtype=int64) -3.09104
int64 add(int0(), int2()) V(2.0, dtype=int64) -1e+09
int64 add(int2(), int0()) V(2.0, dtype=int64) -1e+09
int64 minus(int0(), int2()) V(-2.0, dtype=int64) -1e+09
int64 minus(int2(), int0()) V(2.0, dtype=int64) -1e+09
int64 add(int1(), int0()) V(1.0, dtype=int64) -1e+09
int64 add(int1(), int1()) V(2.0, dtype=int64) -1e+09
int64 add(int1(), int3()) V(4.0, dtype=int64) -1e+09
int64 add(int1(), int4()) V(5.0, dtype=int64) -1e+09
int64 add(int1(), int5()) V(6.0, dtype=int64) -1e+09
int64 add(int1(), int6()) V(7.0, dtype=int64) -1e+09
int64 add(int1(), int7()) V(8.0, dtype=int64) -1e+09
int64 add(int1(), int8()) V(9.0, dtype=int64) -1e+09
int64 add(int1(), int9()) V(10.0, dtype=int64) -1e+09
int64 add(int0(), int1()) V(1.0, dtype=int64) -1e+09
int64 add(int1(), int1()) V(2.0, dtype=int64) -1e+09
int64 add(int3(), int1()) V(4.0, dtype=int64) -1e+09
In total: 1200 parsing trees.
[16]:
print(ccg.format_lexicon_table_sentence("one plus two".split(), lexicon_weights))
i weight w_grad syntax semantics
---- --- -------- -------- ----------------- ----------------------------------------------------------------
one 13 1 g: None int64 def __lambda__(): return int1()
0 0 g: None int64/int64/int64 def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
1 0 g: None int64\int64/int64 def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
2 0 g: None int64\int64\int64 def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
3 0 g: None int64/int64/int64 def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)
4 0 g: None int64\int64/int64 def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)
5 0 g: None int64\int64\int64 def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)
6 0 g: None int64/int64/int64 def __lambda__(#0: int64, #1: int64): return minus(V::#1, V::#0)
7 0 g: None int64\int64/int64 def __lambda__(#0: int64, #1: int64): return minus(V::#1, V::#0)
8 0 g: None int64\int64\int64 def __lambda__(#0: int64, #1: int64): return minus(V::#1, V::#0)
9 0 g: None int64/int64/int64 def __lambda__(#0: int64, #1: int64): return minus(V::#0, V::#1)
10 0 g: None int64\int64/int64 def __lambda__(#0: int64, #1: int64): return minus(V::#0, V::#1)
11 0 g: None int64\int64\int64 def __lambda__(#0: int64, #1: int64): return minus(V::#0, V::#1)
12 0 g: None int64 def __lambda__(): return int0()
14 0 g: None int64 def __lambda__(): return int2()
15 0 g: None int64 def __lambda__(): return int3()
16 0 g: None int64 def __lambda__(): return int4()
17 0 g: None int64 def __lambda__(): return int5()
18 0 g: None int64 def __lambda__(): return int6()
19 0 g: None int64 def __lambda__(): return int7()
20 0 g: None int64 def __lambda__(): return int8()
21 0 g: None int64 def __lambda__(): return int9()
plus 0 0.0455 g: None int64/int64/int64 def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
1 0.0455 g: None int64\int64/int64 def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
2 0.0455 g: None int64\int64\int64 def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
3 0.0455 g: None int64/int64/int64 def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)
4 0.0455 g: None int64\int64/int64 def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)
5 0.0455 g: None int64\int64\int64 def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)
6 0.0455 g: None int64/int64/int64 def __lambda__(#0: int64, #1: int64): return minus(V::#1, V::#0)
7 0.0455 g: None int64\int64/int64 def __lambda__(#0: int64, #1: int64): return minus(V::#1, V::#0)
8 0.0455 g: None int64\int64\int64 def __lambda__(#0: int64, #1: int64): return minus(V::#1, V::#0)
9 0.0455 g: None int64/int64/int64 def __lambda__(#0: int64, #1: int64): return minus(V::#0, V::#1)
10 0.0455 g: None int64\int64/int64 def __lambda__(#0: int64, #1: int64): return minus(V::#0, V::#1)
11 0.0455 g: None int64\int64\int64 def __lambda__(#0: int64, #1: int64): return minus(V::#0, V::#1)
12 0.0455 g: None int64 def __lambda__(): return int0()
13 0.0455 g: None int64 def __lambda__(): return int1()
14 0.0455 g: None int64 def __lambda__(): return int2()
15 0.0455 g: None int64 def __lambda__(): return int3()
16 0.0455 g: None int64 def __lambda__(): return int4()
17 0.0455 g: None int64 def __lambda__(): return int5()
18 0.0455 g: None int64 def __lambda__(): return int6()
19 0.0455 g: None int64 def __lambda__(): return int7()
20 0.0455 g: None int64 def __lambda__(): return int8()
21 0.0455 g: None int64 def __lambda__(): return int9()
two 14 1 g: None int64 def __lambda__(): return int2()
0 0 g: None int64/int64/int64 def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
1 0 g: None int64\int64/int64 def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
2 0 g: None int64\int64\int64 def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
3 0 g: None int64/int64/int64 def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)
4 0 g: None int64\int64/int64 def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)
5 0 g: None int64\int64\int64 def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)
6 0 g: None int64/int64/int64 def __lambda__(#0: int64, #1: int64): return minus(V::#1, V::#0)
7 0 g: None int64\int64/int64 def __lambda__(#0: int64, #1: int64): return minus(V::#1, V::#0)
8 0 g: None int64\int64\int64 def __lambda__(#0: int64, #1: int64): return minus(V::#1, V::#0)
9 0 g: None int64/int64/int64 def __lambda__(#0: int64, #1: int64): return minus(V::#0, V::#1)
10 0 g: None int64\int64/int64 def __lambda__(#0: int64, #1: int64): return minus(V::#0, V::#1)
11 0 g: None int64\int64\int64 def __lambda__(#0: int64, #1: int64): return minus(V::#0, V::#1)
12 0 g: None int64 def __lambda__(): return int0()
13 0 g: None int64 def __lambda__(): return int1()
15 0 g: None int64 def __lambda__(): return int3()
16 0 g: None int64 def __lambda__(): return int4()
17 0 g: None int64 def __lambda__(): return int5()
18 0 g: None int64 def __lambda__(): return int6()
19 0 g: None int64 def __lambda__(): return int7()
20 0 g: None int64 def __lambda__(): return int8()
21 0 g: None int64 def __lambda__(): return int9()
[17]:
optimizer = torch.optim.Adam([lexicon_weights], lr=1)
[18]:
for i in range(20):
results = ccg.parse("one plus two", F.log_softmax(lexicon_weights, dim=-1))
weights = F.log_softmax(torch.stack([node.weight for node in results], dim=0), dim=0)
weights_softmax = F.softmax(weights, dim=0)
log_likelihood = 0
for i, node in enumerate(results):
if node.execution_result.value.item() == 3.0:
log_likelihood -= weights_softmax[i].detach() * weights[i]
print(f'log_likelihood: {log_likelihood.item()}')
optimizer.zero_grad()
log_likelihood.backward()
optimizer.step()
log_likelihood: 0.6931471824645996
log_likelihood: 0.7223198413848877
log_likelihood: 0.6995475888252258
log_likelihood: 0.6946471333503723
log_likelihood: 0.6935825943946838
log_likelihood: 0.6932985782623291
log_likelihood: 0.6932079195976257
log_likelihood: 0.6931744813919067
log_likelihood: 0.6931606531143188
log_likelihood: 0.6931544542312622
log_likelihood: 0.6931512951850891
log_likelihood: 0.6931496858596802
log_likelihood: 0.6931487917900085
log_likelihood: 0.6931482553482056
log_likelihood: 0.6931478977203369
log_likelihood: 0.6931477189064026
log_likelihood: 0.6931475400924683
log_likelihood: 0.6931474804878235
log_likelihood: 0.6931474208831787
log_likelihood: 0.6931474208831787
[19]:
print(ccg.format_lexicon_table_sentence("one plus two".split(), lexicon_weights, max_entries=5))
i weight w_grad syntax semantics
---- --- -------- ---------- ----------------- --------------------------------------------------------------
one 13 1 g: 0.0000 int64 def __lambda__(): return int1()
0 0 g: 0.0000 int64/int64/int64 def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
1 0 g: 0.0000 int64\int64/int64 def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
2 0 g: 0.0000 int64\int64\int64 def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
3 0 g: 0.0000 int64/int64/int64 def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)
plus 1 0.497 g: -0.0000 int64\int64/int64 def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
4 0.497 g: -0.0000 int64\int64/int64 def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)
0 0.0003 g: 0.0000 int64/int64/int64 def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
2 0.0003 g: 0.0000 int64\int64\int64 def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
3 0.0003 g: 0.0000 int64/int64/int64 def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)
two 14 1 g: 0.0000 int64 def __lambda__(): return int2()
0 0 g: 0.0000 int64/int64/int64 def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
1 0 g: 0.0000 int64\int64/int64 def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
2 0 g: 0.0000 int64\int64\int64 def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
3 0 g: 0.0000 int64/int64/int64 def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)