Tutorial 2.4: Learning Lexicon Weights of NCCGs#

[1]:
import jacinle
from tabulate import tabulate
[2]:
from concepts.dsl.dsl_types import INT64
from concepts.dsl.dsl_functions import Function, FunctionTyping
from concepts.dsl.function_domain import FunctionDomain

math_domain = FunctionDomain()
math_domain.define_function(Function('add', FunctionTyping[INT64](INT64, INT64)))
math_domain.define_function(Function('minus', FunctionTyping[INT64](INT64, INT64)))
[2]:
Function<minus(#0: int64, #1: int64) -> int64>
[3]:
for i in range(10):
    math_domain.define_function(Function(f'int{i}', FunctionTyping[INT64]()))
[4]:
math_domain.print_summary()
TypeSystem: FunctionDomain
  Types:
  Constants:
  Functions:
    add(#0: int64, #1: int64) -> int64
    minus(#0: int64, #1: int64) -> int64
    int0() -> int64
    int1() -> int64
    int2() -> int64
    int3() -> int64
    int4() -> int64
    int5() -> int64
    int6() -> int64
    int7() -> int64
    int8() -> int64
    int9() -> int64
[5]:
import torch
from concepts.dsl.executors.function_domain_executor import FunctionDomainExecutor

class Executor(FunctionDomainExecutor):
    def int0(self): return torch.tensor(0, dtype=torch.float32)
    def int1(self): return torch.tensor(1, dtype=torch.float32)
    def int2(self): return torch.tensor(2, dtype=torch.float32)
    def int3(self): return torch.tensor(3, dtype=torch.float32)
    def int4(self): return torch.tensor(4, dtype=torch.float32)
    def int5(self): return torch.tensor(5, dtype=torch.float32)
    def int6(self): return torch.tensor(6, dtype=torch.float32)
    def int7(self): return torch.tensor(7, dtype=torch.float32)
    def int8(self): return torch.tensor(8, dtype=torch.float32)
    def int9(self): return torch.tensor(9, dtype=torch.float32)
    def add(self, x, y): return x + y
    def minus(self, x, y): return x - y

executor = Executor(math_domain)
11 22:40:35 Function add automatically registered.
11 22:40:35 Function minus automatically registered.
11 22:40:35 Function int0 automatically registered.
11 22:40:35 Function int1 automatically registered.
11 22:40:35 Function int2 automatically registered.
11 22:40:35 Function int3 automatically registered.
11 22:40:35 Function int4 automatically registered.
11 22:40:35 Function int5 automatically registered.
11 22:40:35 Function int6 automatically registered.
11 22:40:35 Function int7 automatically registered.
11 22:40:35 Function int8 automatically registered.
11 22:40:35 Function int9 automatically registered.
[6]:
executor.execute(math_domain.lam(lambda: math_domain.f_add(math_domain.f_int3(), math_domain.f_int4()))())
[6]:
V(7.0, dtype=int64)
[7]:
from concepts.dsl.learning.function_domain_search import FunctionDomainExpressionEnumerativeSearcher

expression_searcher = FunctionDomainExpressionEnumerativeSearcher(math_domain)
candidate_expressions = expression_searcher.gen(max_depth=1)
candidate_expressions
[7]:
[FunctionDomainExpressionSearchResult(expression=Function<def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)>, depth=1, nr_constant_arguments=0, nr_variable_arguments=2, nr_function_arguments=0),
 FunctionDomainExpressionSearchResult(expression=Function<def __lambda__(#0: int64, #1: int64): return minus(V::#0, V::#1)>, depth=1, nr_constant_arguments=0, nr_variable_arguments=2, nr_function_arguments=0),
 FunctionDomainExpressionSearchResult(expression=Function<def __lambda__(): return int0()>, depth=1, nr_constant_arguments=0, nr_variable_arguments=0, nr_function_arguments=0),
 FunctionDomainExpressionSearchResult(expression=Function<def __lambda__(): return int1()>, depth=1, nr_constant_arguments=0, nr_variable_arguments=0, nr_function_arguments=0),
 FunctionDomainExpressionSearchResult(expression=Function<def __lambda__(): return int2()>, depth=1, nr_constant_arguments=0, nr_variable_arguments=0, nr_function_arguments=0),
 FunctionDomainExpressionSearchResult(expression=Function<def __lambda__(): return int3()>, depth=1, nr_constant_arguments=0, nr_variable_arguments=0, nr_function_arguments=0),
 FunctionDomainExpressionSearchResult(expression=Function<def __lambda__(): return int4()>, depth=1, nr_constant_arguments=0, nr_variable_arguments=0, nr_function_arguments=0),
 FunctionDomainExpressionSearchResult(expression=Function<def __lambda__(): return int5()>, depth=1, nr_constant_arguments=0, nr_variable_arguments=0, nr_function_arguments=0),
 FunctionDomainExpressionSearchResult(expression=Function<def __lambda__(): return int6()>, depth=1, nr_constant_arguments=0, nr_variable_arguments=0, nr_function_arguments=0),
 FunctionDomainExpressionSearchResult(expression=Function<def __lambda__(): return int7()>, depth=1, nr_constant_arguments=0, nr_variable_arguments=0, nr_function_arguments=0),
 FunctionDomainExpressionSearchResult(expression=Function<def __lambda__(): return int8()>, depth=1, nr_constant_arguments=0, nr_variable_arguments=0, nr_function_arguments=0),
 FunctionDomainExpressionSearchResult(expression=Function<def __lambda__(): return int9()>, depth=1, nr_constant_arguments=0, nr_variable_arguments=0, nr_function_arguments=0)]
[8]:
from concepts.language.neural_ccg.search import NeuralCCGLexiconEnumerativeSearcher

lexicon_searcher = NeuralCCGLexiconEnumerativeSearcher(candidate_expressions, executor)
candidate_lexicon_entries = lexicon_searcher.gen()
candidate_lexicon_entries_table = list()
for result in candidate_lexicon_entries[:20]:
    candidate_lexicon_entries_table.append((str(result.syntax), str(result.semantics)))

print(tabulate(candidate_lexicon_entries_table, headers=['syntax', 'semantics']))
print(f'In total: {len(candidate_lexicon_entries)} lexicon entries.')
syntax             semantics
-----------------  ----------------------------------------------------------------
int64/int64/int64  def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
int64\int64/int64  def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
int64\int64\int64  def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
int64/int64/int64  def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)
int64\int64/int64  def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)
int64\int64\int64  def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)
int64/int64/int64  def __lambda__(#0: int64, #1: int64): return minus(V::#1, V::#0)
int64\int64/int64  def __lambda__(#0: int64, #1: int64): return minus(V::#1, V::#0)
int64\int64\int64  def __lambda__(#0: int64, #1: int64): return minus(V::#1, V::#0)
int64/int64/int64  def __lambda__(#0: int64, #1: int64): return minus(V::#0, V::#1)
int64\int64/int64  def __lambda__(#0: int64, #1: int64): return minus(V::#0, V::#1)
int64\int64\int64  def __lambda__(#0: int64, #1: int64): return minus(V::#0, V::#1)
int64              def __lambda__(): return int0()
int64              def __lambda__(): return int1()
int64              def __lambda__(): return int2()
int64              def __lambda__(): return int3()
int64              def __lambda__(): return int4()
int64              def __lambda__(): return int5()
int64              def __lambda__(): return int6()
int64              def __lambda__(): return int7()
In total: 22 lexicon entries.
[9]:
from concepts.language.neural_ccg.grammar import NeuralCCG
[10]:
ccg = NeuralCCG(math_domain, executor, candidate_lexicon_entries)
[11]:
import torch.nn as nn
import torch.nn.functional as F
[12]:
lexicon_weights = nn.Parameter(torch.zeros((3, 22), dtype=torch.float32))
lexicon_weights.data[0, 13] = 1e9
lexicon_weights.data[2, 14] = 1e9
[13]:
results = ccg.parse("one plus two", F.log_softmax(lexicon_weights, dim=-1))
[14]:
result_table = list()
for result in results[:20]:
    result_table.append((str(result.syntax), str(result.semantics.value.execute()), str(result.execution_result), str(result.weight.item())))
[15]:
print(tabulate(result_table, headers=['syntax', 'semantics', 'grounded value', 'weight']))
print(f'In total: {len(results)} parsing trees.')
syntax    semantics              grounded value          weight
--------  ---------------------  --------------------  --------
int64     add(int1(), int2())    V(3.0, dtype=int64)   -3.09104
int64     add(int2(), int1())    V(3.0, dtype=int64)   -3.09104
int64     minus(int1(), int2())  V(-1.0, dtype=int64)  -3.09104
int64     minus(int2(), int1())  V(1.0, dtype=int64)   -3.09104
int64     add(int0(), int2())    V(2.0, dtype=int64)   -1e+09
int64     add(int2(), int0())    V(2.0, dtype=int64)   -1e+09
int64     minus(int0(), int2())  V(-2.0, dtype=int64)  -1e+09
int64     minus(int2(), int0())  V(2.0, dtype=int64)   -1e+09
int64     add(int1(), int0())    V(1.0, dtype=int64)   -1e+09
int64     add(int1(), int1())    V(2.0, dtype=int64)   -1e+09
int64     add(int1(), int3())    V(4.0, dtype=int64)   -1e+09
int64     add(int1(), int4())    V(5.0, dtype=int64)   -1e+09
int64     add(int1(), int5())    V(6.0, dtype=int64)   -1e+09
int64     add(int1(), int6())    V(7.0, dtype=int64)   -1e+09
int64     add(int1(), int7())    V(8.0, dtype=int64)   -1e+09
int64     add(int1(), int8())    V(9.0, dtype=int64)   -1e+09
int64     add(int1(), int9())    V(10.0, dtype=int64)  -1e+09
int64     add(int0(), int1())    V(1.0, dtype=int64)   -1e+09
int64     add(int1(), int1())    V(2.0, dtype=int64)   -1e+09
int64     add(int3(), int1())    V(4.0, dtype=int64)   -1e+09
In total: 1200 parsing trees.
[16]:
print(ccg.format_lexicon_table_sentence("one plus two".split(), lexicon_weights))
        i    weight  w_grad    syntax             semantics
----  ---  --------  --------  -----------------  ----------------------------------------------------------------
one    13    1       g: None   int64              def __lambda__(): return int1()
        0    0       g: None   int64/int64/int64  def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
        1    0       g: None   int64\int64/int64  def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
        2    0       g: None   int64\int64\int64  def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
        3    0       g: None   int64/int64/int64  def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)
        4    0       g: None   int64\int64/int64  def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)
        5    0       g: None   int64\int64\int64  def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)
        6    0       g: None   int64/int64/int64  def __lambda__(#0: int64, #1: int64): return minus(V::#1, V::#0)
        7    0       g: None   int64\int64/int64  def __lambda__(#0: int64, #1: int64): return minus(V::#1, V::#0)
        8    0       g: None   int64\int64\int64  def __lambda__(#0: int64, #1: int64): return minus(V::#1, V::#0)
        9    0       g: None   int64/int64/int64  def __lambda__(#0: int64, #1: int64): return minus(V::#0, V::#1)
       10    0       g: None   int64\int64/int64  def __lambda__(#0: int64, #1: int64): return minus(V::#0, V::#1)
       11    0       g: None   int64\int64\int64  def __lambda__(#0: int64, #1: int64): return minus(V::#0, V::#1)
       12    0       g: None   int64              def __lambda__(): return int0()
       14    0       g: None   int64              def __lambda__(): return int2()
       15    0       g: None   int64              def __lambda__(): return int3()
       16    0       g: None   int64              def __lambda__(): return int4()
       17    0       g: None   int64              def __lambda__(): return int5()
       18    0       g: None   int64              def __lambda__(): return int6()
       19    0       g: None   int64              def __lambda__(): return int7()
       20    0       g: None   int64              def __lambda__(): return int8()
       21    0       g: None   int64              def __lambda__(): return int9()
plus    0    0.0455  g: None   int64/int64/int64  def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
        1    0.0455  g: None   int64\int64/int64  def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
        2    0.0455  g: None   int64\int64\int64  def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
        3    0.0455  g: None   int64/int64/int64  def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)
        4    0.0455  g: None   int64\int64/int64  def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)
        5    0.0455  g: None   int64\int64\int64  def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)
        6    0.0455  g: None   int64/int64/int64  def __lambda__(#0: int64, #1: int64): return minus(V::#1, V::#0)
        7    0.0455  g: None   int64\int64/int64  def __lambda__(#0: int64, #1: int64): return minus(V::#1, V::#0)
        8    0.0455  g: None   int64\int64\int64  def __lambda__(#0: int64, #1: int64): return minus(V::#1, V::#0)
        9    0.0455  g: None   int64/int64/int64  def __lambda__(#0: int64, #1: int64): return minus(V::#0, V::#1)
       10    0.0455  g: None   int64\int64/int64  def __lambda__(#0: int64, #1: int64): return minus(V::#0, V::#1)
       11    0.0455  g: None   int64\int64\int64  def __lambda__(#0: int64, #1: int64): return minus(V::#0, V::#1)
       12    0.0455  g: None   int64              def __lambda__(): return int0()
       13    0.0455  g: None   int64              def __lambda__(): return int1()
       14    0.0455  g: None   int64              def __lambda__(): return int2()
       15    0.0455  g: None   int64              def __lambda__(): return int3()
       16    0.0455  g: None   int64              def __lambda__(): return int4()
       17    0.0455  g: None   int64              def __lambda__(): return int5()
       18    0.0455  g: None   int64              def __lambda__(): return int6()
       19    0.0455  g: None   int64              def __lambda__(): return int7()
       20    0.0455  g: None   int64              def __lambda__(): return int8()
       21    0.0455  g: None   int64              def __lambda__(): return int9()
two    14    1       g: None   int64              def __lambda__(): return int2()
        0    0       g: None   int64/int64/int64  def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
        1    0       g: None   int64\int64/int64  def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
        2    0       g: None   int64\int64\int64  def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
        3    0       g: None   int64/int64/int64  def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)
        4    0       g: None   int64\int64/int64  def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)
        5    0       g: None   int64\int64\int64  def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)
        6    0       g: None   int64/int64/int64  def __lambda__(#0: int64, #1: int64): return minus(V::#1, V::#0)
        7    0       g: None   int64\int64/int64  def __lambda__(#0: int64, #1: int64): return minus(V::#1, V::#0)
        8    0       g: None   int64\int64\int64  def __lambda__(#0: int64, #1: int64): return minus(V::#1, V::#0)
        9    0       g: None   int64/int64/int64  def __lambda__(#0: int64, #1: int64): return minus(V::#0, V::#1)
       10    0       g: None   int64\int64/int64  def __lambda__(#0: int64, #1: int64): return minus(V::#0, V::#1)
       11    0       g: None   int64\int64\int64  def __lambda__(#0: int64, #1: int64): return minus(V::#0, V::#1)
       12    0       g: None   int64              def __lambda__(): return int0()
       13    0       g: None   int64              def __lambda__(): return int1()
       15    0       g: None   int64              def __lambda__(): return int3()
       16    0       g: None   int64              def __lambda__(): return int4()
       17    0       g: None   int64              def __lambda__(): return int5()
       18    0       g: None   int64              def __lambda__(): return int6()
       19    0       g: None   int64              def __lambda__(): return int7()
       20    0       g: None   int64              def __lambda__(): return int8()
       21    0       g: None   int64              def __lambda__(): return int9()
[17]:
optimizer = torch.optim.Adam([lexicon_weights], lr=1)
[18]:
for i in range(20):
    results = ccg.parse("one plus two", F.log_softmax(lexicon_weights, dim=-1))
    weights = F.log_softmax(torch.stack([node.weight for node in results], dim=0), dim=0)
    weights_softmax = F.softmax(weights, dim=0)

    log_likelihood = 0
    for i, node in enumerate(results):
        if node.execution_result.value.item() == 3.0:
            log_likelihood -= weights_softmax[i].detach() * weights[i]

    print(f'log_likelihood: {log_likelihood.item()}')

    optimizer.zero_grad()
    log_likelihood.backward()
    optimizer.step()
log_likelihood: 0.6931471824645996
log_likelihood: 0.7223198413848877
log_likelihood: 0.6995475888252258
log_likelihood: 0.6946471333503723
log_likelihood: 0.6935825943946838
log_likelihood: 0.6932985782623291
log_likelihood: 0.6932079195976257
log_likelihood: 0.6931744813919067
log_likelihood: 0.6931606531143188
log_likelihood: 0.6931544542312622
log_likelihood: 0.6931512951850891
log_likelihood: 0.6931496858596802
log_likelihood: 0.6931487917900085
log_likelihood: 0.6931482553482056
log_likelihood: 0.6931478977203369
log_likelihood: 0.6931477189064026
log_likelihood: 0.6931475400924683
log_likelihood: 0.6931474804878235
log_likelihood: 0.6931474208831787
log_likelihood: 0.6931474208831787
[19]:
print(ccg.format_lexicon_table_sentence("one plus two".split(), lexicon_weights, max_entries=5))
        i    weight  w_grad      syntax             semantics
----  ---  --------  ----------  -----------------  --------------------------------------------------------------
one    13    1       g: 0.0000   int64              def __lambda__(): return int1()
        0    0       g: 0.0000   int64/int64/int64  def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
        1    0       g: 0.0000   int64\int64/int64  def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
        2    0       g: 0.0000   int64\int64\int64  def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
        3    0       g: 0.0000   int64/int64/int64  def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)
plus    1    0.497   g: -0.0000  int64\int64/int64  def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
        4    0.497   g: -0.0000  int64\int64/int64  def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)
        0    0.0003  g: 0.0000   int64/int64/int64  def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
        2    0.0003  g: 0.0000   int64\int64\int64  def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
        3    0.0003  g: 0.0000   int64/int64/int64  def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)
two    14    1       g: 0.0000   int64              def __lambda__(): return int2()
        0    0       g: 0.0000   int64/int64/int64  def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
        1    0       g: 0.0000   int64\int64/int64  def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
        2    0       g: 0.0000   int64\int64\int64  def __lambda__(#0: int64, #1: int64): return add(V::#1, V::#0)
        3    0       g: 0.0000   int64/int64/int64  def __lambda__(#0: int64, #1: int64): return add(V::#0, V::#1)