shenikan/tools/esh.py

200 lines
5.3 KiB
Python
Executable file

#!/usr/bin/env python3
"""
This losely does similar behaviour to asking ChatGPT the following:
Can you help me find words in your embedding space? I want to give you a basic
arithmetic expression involving words to find relationships between words in
your embedding model. For example king minus man plus woman should probably be
something like queen. Please give me 10 options each time. Are you ready?
"""
import cmd
import re
import os
from gensim import downloader
from thefuzz import process
EMBEDDING_TOKENS = [
('NUMBER', r'\d+(\.\d*)?'), # an integer or decimal number
('WORD', r'\w+'), # a word
('PAREN', r'[()]'), # a parenthesis
('OP', r'[+\-*/~]'), # an arithmetic operator
('COMMA', r','), # a comma
('WS', r'\s+'), # whitespace
('ERROR', r'.'), # anything else
]
EMBEDDING_TOKENIZATION_RE = re.compile('|'.join(
f'(?P<{x[0]}>{x[1]})' for x in EMBEDDING_TOKENS
))
def tokenize_embedding_expr(expr):
""" Generates (token_kind, token) for each token in expr. """
for mo in EMBEDDING_TOKENIZATION_RE.finditer(expr):
yield (mo.lastgroup, mo.group())
def token_precedence(token):
"""
Returns the precedence of the token.
Negative precedences are right-associative
"""
if token in {'+', '-', '~'}:
return 1
if token in {'*', '/'}:
return 2
return 0
def _goes_first(a, b):
ap = token_precedence(a)
bp = token_precedence(b)
aap = abs(ap)
abp = abs(bp)
if aap > abp:
return True
if aap == abp and bp > 0:
return True
return False
def shunt_embedding_tokens(tokens):
"""
Tokens are (kind, value) where kind is:
w - word to be looked up in model and converted to embedding vector
s - scalar value
o - operator
"""
stack = [] # operator stack, just the op itself!
for (kind, tok) in tokens:
if kind == 'WORD':
yield ('w', tok)
elif kind == 'NUMBER':
yield ('s', tok)
elif kind == 'OP':
while stack and stack[-1] != '(' and _goes_first(stack[-1], tok):
yield ('o', stack.pop())
stack.append(tok)
elif kind == 'PAREN':
if tok == '(':
stack.append(tok)
else:
while stack and stack[-1] != '(':
yield ('o', stack.pop())
if stack:
stack.pop() # remove the '('
while stack:
yield ('o', stack.pop())
def evaluate_embedding_shunt(shunt, model):
""" Evaluates shunt using model. """
stack = []
for (kind, x) in shunt:
if kind == 'w':
if x[0] == '_':
if x[1:] in model:
stack.append(-model[x[1:]])
else:
most_similar = process.extractOne(x[1:], model.key_to_index.keys())[0]
stack.append(-model[most_similar])
if x in model:
stack.append(model[x])
else:
most_similar = process.extractOne(x, model.key_to_index.keys())[0]
stack.append(model[most_similar])
elif kind == 's':
stack.append(float(x))
elif kind == 'o':
if x == '+':
a = stack.pop()
b = stack.pop()
stack.append(a + b)
elif x == '-':
a = stack.pop()
b = stack.pop()
stack.append(b - a)
elif x == '*':
a = stack.pop()
b = stack.pop()
stack.append(a * b)
elif x == '/':
a = stack.pop()
b = stack.pop()
stack.append(b / a)
elif x == '~':
a = stack.pop()
b = stack.pop()
stack.append((a + b) / 2)
return stack[-1]
class EmbeddingShell(cmd.Cmd):
""" Actual embedding shell wrapper. """
intro = 'Welcome to the embedding shell. Enter words in an equation to see similar embeddings. Type :help for more information'
prompt = '(Ʃ) '
def __init__(self, *args, model='glove-wiki-gigaword-300', **kwargs):
super().__init__(completekey='tab', stdin=None, stdout=None, *args, **kwargs)
print('Loading model...', end='', flush=True)
self._model = downloader.load(model)
self._keys = self._model.key_to_index.keys()
print(' DONE')
def do_exec(self, arg):
""" Test """
try:
result = evaluate_embedding_shunt(shunt_embedding_tokens(tokenize_embedding_expr(arg)), self._model)
for (word, sim) in self._model.most_similar(result, restrict_vocab=10000):
(w, _) = os.get_terminal_size()
bar = '-' * int((w - 20) * sim)
print(f'{word:10} {bar}')
except Exception as e:
print("Could not evaluate expression:", e)
def do_shunt(self, arg):
for x in shunt_embedding_tokens(tokenize_embedding_expr(arg)):
print(x)
def do_quit(self, arg):
""" Exit the embedding shell. """
return True
def precmd(self, line):
if not line:
return line
if line[0] == ':':
return line[1:]
return 'exec ' + line
if __name__ == '__main__':
EmbeddingShell().cmdloop()