#!/usr/bin/env python
## display paths in wordnet

# standard imports
import string
import re
import glob
import sys
if sys.version_info < (3,):
    range = xrange
from ppretty import ppretty

# other imports
from nltk.corpus import wordnet as wn
import numpy as np
from itertools import product
import math
import networkx as nx

# plots
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d import proj3d
import matplotlib.pyplot as plt
from matplotlib.path import Path
import matplotlib.patches as patches


############### FUNCTIONS ################

# WordNet: pathtoroot
def wn_path2root(graph, start):
    
    graph.depth[node.name] = depth
    for child in node.hyponyms():
        graph.add_edge(node.name, child.name)
        wn_traverse(graph, start, child, depth+1)

# WordNet: hyponym graph (from NLTK book, p.170)
def wn_hyponym_graph(start):
    G = nx.Graph()
    G.depth = {}
    wn_traverse(G, start, start, 0)
    return G

# WordNet: draw the graph (from NLTK book, p.170)
def wn_graph_draw(graph, savefilename):
    plt.figure()
    pos = nx.spring_layout(graph)
    nx.draw(graph, pos, 
            node_size = [16*graph.degree(n) for n in graph],
            node_color = [graph.depth[n] for n in graph], with_labels=False)
    labels = {i: str(graph.node.keys()[i].im_self._name.split(".")[0]) for i in range(len(graph))}
    offset = 0.01 * pos.values()[0]
    labelpos = [np.array(p) + offset for p in pos.values()]
    nx.draw_networkx_labels(graph, labelpos, labels, font_size=10)    
    plt.axis('off')
    plt.savefig(savefilename)
    plt.show()

########################################

if len(sys.argv) < 3:
    exit('need two common words on cmd line')

# read words from cmd line and make them into WordNet terms
word1 = filter(str.isalpha, sys.argv[1])
word2 = filter(str.isalpha, sys.argv[2])
wnword1 = word1 + ".n.01"
wnword2 = word2 + ".n.01"
w1 = wn.synset(wnword1)
w2 = wn.synset(wnword2)

# find all paths from words to root 'entity'
hp1 = w1.hypernym_paths()
hp2 = w2.hypernym_paths()
paths1 = [[str(synset.name().split('.')[0]) for synset in path] for path in hp1]
paths2 = [[str(synset.name().split('.')[0]) for synset in path] for path in hp2]

# find lowest common subsumers
wnlcs = w1.lowest_common_hypernyms(w2)
lcsdepth = {str(s.name().split('.')[0]): min(len(sp) for sp in s.hypernym_paths()) for s in wnlcs}
lcs = [str(synset.name().split('.')[0]) for synset in wnlcs]

# for each lc subsumer identify paths containing it and combine them
paths = {}
for s in lcs:
    paths[s] = list()
    pt1 = [p for p in paths1 if s in p]
    pt2 = [p for p in paths2 if s in p]
    for p in pt1:
        for q in pt2:
            pathbegin = p[[i for i,t in enumerate(p) if t == s][0]+1:]
            pathbegin.reverse()
            pathend = q[[i for i,t in enumerate(q) if t == s][0]:]
            paths[s].append(pathbegin + pathend)

# output
for s in paths.keys():
    print "depth(", s, ") =", lcsdepth[s]
    for path in paths[s]:
        print len(path),
        for word in path:
            print "->", word,
        print
        wupdist = 2.0*float(lcsdepth[s]) / float(len(path) + 2.0*lcsdepth[s])
        print "Wu-Palmer distance =", wupdist

################### OBLIVION ######################
quit()

