200 lines
5.3 KiB
Python
Executable file
200 lines
5.3 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
|
|
"""
|
|
This losely does similar behaviour to asking ChatGPT the following:
|
|
|
|
Can you help me find words in your embedding space? I want to give you a basic
|
|
arithmetic expression involving words to find relationships between words in
|
|
your embedding model. For example king minus man plus woman should probably be
|
|
something like queen. Please give me 10 options each time. Are you ready?
|
|
"""
|
|
|
|
import cmd
|
|
import re
|
|
import os
|
|
|
|
from gensim import downloader
|
|
from thefuzz import process
|
|
|
|
EMBEDDING_TOKENS = [
|
|
('NUMBER', r'\d+(\.\d*)?'), # an integer or decimal number
|
|
('WORD', r'\w+'), # a word
|
|
('PAREN', r'[()]'), # a parenthesis
|
|
('OP', r'[+\-*/~]'), # an arithmetic operator
|
|
('COMMA', r','), # a comma
|
|
('WS', r'\s+'), # whitespace
|
|
('ERROR', r'.'), # anything else
|
|
]
|
|
EMBEDDING_TOKENIZATION_RE = re.compile('|'.join(
|
|
f'(?P<{x[0]}>{x[1]})' for x in EMBEDDING_TOKENS
|
|
))
|
|
|
|
|
|
def tokenize_embedding_expr(expr):
|
|
""" Generates (token_kind, token) for each token in expr. """
|
|
for mo in EMBEDDING_TOKENIZATION_RE.finditer(expr):
|
|
yield (mo.lastgroup, mo.group())
|
|
|
|
|
|
def token_precedence(token):
|
|
"""
|
|
Returns the precedence of the token.
|
|
Negative precedences are right-associative
|
|
"""
|
|
if token in {'+', '-', '~'}:
|
|
return 1
|
|
|
|
if token in {'*', '/'}:
|
|
return 2
|
|
|
|
return 0
|
|
|
|
|
|
def _goes_first(a, b):
|
|
ap = token_precedence(a)
|
|
bp = token_precedence(b)
|
|
aap = abs(ap)
|
|
abp = abs(bp)
|
|
|
|
if aap > abp:
|
|
return True
|
|
|
|
if aap == abp and bp > 0:
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def shunt_embedding_tokens(tokens):
|
|
"""
|
|
Tokens are (kind, value) where kind is:
|
|
|
|
w - word to be looked up in model and converted to embedding vector
|
|
s - scalar value
|
|
o - operator
|
|
"""
|
|
stack = [] # operator stack, just the op itself!
|
|
|
|
for (kind, tok) in tokens:
|
|
if kind == 'WORD':
|
|
yield ('w', tok)
|
|
|
|
elif kind == 'NUMBER':
|
|
yield ('s', tok)
|
|
|
|
elif kind == 'OP':
|
|
while stack and stack[-1] != '(' and _goes_first(stack[-1], tok):
|
|
yield ('o', stack.pop())
|
|
stack.append(tok)
|
|
|
|
elif kind == 'PAREN':
|
|
if tok == '(':
|
|
stack.append(tok)
|
|
else:
|
|
while stack and stack[-1] != '(':
|
|
yield ('o', stack.pop())
|
|
|
|
if stack:
|
|
stack.pop() # remove the '('
|
|
|
|
while stack:
|
|
yield ('o', stack.pop())
|
|
|
|
|
|
def evaluate_embedding_shunt(shunt, model):
|
|
""" Evaluates shunt using model. """
|
|
stack = []
|
|
|
|
for (kind, x) in shunt:
|
|
if kind == 'w':
|
|
if x[0] == '_':
|
|
if x[1:] in model:
|
|
stack.append(-model[x[1:]])
|
|
else:
|
|
most_similar = process.extractOne(x[1:], model.key_to_index.keys())[0]
|
|
stack.append(-model[most_similar])
|
|
|
|
if x in model:
|
|
stack.append(model[x])
|
|
else:
|
|
most_similar = process.extractOne(x, model.key_to_index.keys())[0]
|
|
stack.append(model[most_similar])
|
|
|
|
elif kind == 's':
|
|
stack.append(float(x))
|
|
|
|
elif kind == 'o':
|
|
if x == '+':
|
|
a = stack.pop()
|
|
b = stack.pop()
|
|
stack.append(a + b)
|
|
|
|
elif x == '-':
|
|
a = stack.pop()
|
|
b = stack.pop()
|
|
stack.append(b - a)
|
|
|
|
elif x == '*':
|
|
a = stack.pop()
|
|
b = stack.pop()
|
|
stack.append(a * b)
|
|
|
|
elif x == '/':
|
|
a = stack.pop()
|
|
b = stack.pop()
|
|
stack.append(b / a)
|
|
|
|
elif x == '~':
|
|
a = stack.pop()
|
|
b = stack.pop()
|
|
stack.append((a + b) / 2)
|
|
|
|
return stack[-1]
|
|
|
|
|
|
class EmbeddingShell(cmd.Cmd):
|
|
""" Actual embedding shell wrapper. """
|
|
intro = 'Welcome to the embedding shell. Enter words in an equation to see similar embeddings. Type :help for more information'
|
|
prompt = '(Ʃ) '
|
|
|
|
def __init__(self, *args, model='glove-wiki-gigaword-300', **kwargs):
|
|
super().__init__(completekey='tab', stdin=None, stdout=None, *args, **kwargs)
|
|
print('Loading model...', end='', flush=True)
|
|
self._model = downloader.load(model)
|
|
self._keys = self._model.key_to_index.keys()
|
|
print(' DONE')
|
|
|
|
def do_exec(self, arg):
|
|
""" Test """
|
|
try:
|
|
result = evaluate_embedding_shunt(shunt_embedding_tokens(tokenize_embedding_expr(arg)), self._model)
|
|
|
|
for (word, sim) in self._model.most_similar(result, restrict_vocab=10000):
|
|
(w, _) = os.get_terminal_size()
|
|
bar = '-' * int((w - 20) * sim)
|
|
print(f'{word:10} {bar}')
|
|
except Exception as e:
|
|
print("Could not evaluate expression:", e)
|
|
|
|
|
|
def do_shunt(self, arg):
|
|
for x in shunt_embedding_tokens(tokenize_embedding_expr(arg)):
|
|
print(x)
|
|
|
|
def do_quit(self, arg):
|
|
""" Exit the embedding shell. """
|
|
return True
|
|
|
|
def precmd(self, line):
|
|
if not line:
|
|
return line
|
|
if line[0] == ':':
|
|
return line[1:]
|
|
return 'exec ' + line
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
EmbeddingShell().cmdloop()
|