## Import Model and libraries

In [2]:
import torch
import pandas as pd
from torch.utils.data import DataLoader
from collections import Counter, defaultdict
from tqdm import trange
from sklearn.manifold import TSNE
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import spearmanr
from sklearn.preprocessing import StandardScaler
from collections import defaultdict

class SkipGramModel(nn.Module):
    def __init__(self, embedding_size, vocab_size):
        super(SkipGramModel, self).__init__()
        self.embeddings_input = nn.Embedding(vocab_size, embedding_size)
        self.linear = nn.Linear(embedding_size, vocab_size)  # linear layer

    def forward(self, input_word):
        embeds = self.embeddings_input(input_word)
        out = self.linear(embeds)  # nonlinear + projection
        log_probs = F.log_softmax(out, dim=1)  # softmax compute log probability
        return log_probs


class CBOWModel(nn.Module):
    def __init__(self, embedding_size, vocab_size):
        super(CBOWModel, self).__init__()
        self.embeddings_input = nn.Embedding(vocab_size, embedding_size)
        self.linear = nn.Linear(embedding_size, vocab_size)  # linear layer

    def forward(self, input_word):
        embeds = self.embeddings_input(input_word)
        embeds = torch.sum(embeds, dim=1)
        out = self.linear(embeds)  # nonlinear + projection
        log_probs = F.log_softmax(out, dim=1)  # softmax compute log probability
        return log_probs


class GloVeModel(nn.Module):
    def __init__(self, embedding_size, vocab_size, x_max = 100):
        super(GloVeModel, self).__init__()
        self.x_max = x_max
        self._focal_embeddings = nn.Embedding(vocab_size, embedding_size)
        self._context_embeddings = nn.Embedding(vocab_size, embedding_size)
        self._focal_biases = nn.Embedding(vocab_size, 1).type(torch.float64)
        self._context_biases = nn.Embedding(vocab_size, 1).type(torch.float64)

    def forward(self, focal_input, context_input, coocurrence_count):
        x_max = max(self.x_max, 1)
        focal_embed = self._focal_embeddings(focal_input) # embed layer
        context_embed = self._context_embeddings(context_input) # embed layer
        focal_bias = self._focal_biases(focal_input) # bias for each embedding
        context_bias = self._context_biases(context_input) # bias for each embedding

        # count weight factor
        weight_factor = torch.pow(coocurrence_count / x_max, 0.75)
        weight_factor[weight_factor > 1] = 1

        embedding_products = torch.sum(focal_embed * context_embed, dim=1)
        log_cooccurrences = torch.log(coocurrence_count)

        distance_expr = (
            embedding_products + focal_bias + context_bias + log_cooccurrences
        ) ** 2

        single_losses = weight_factor * distance_expr
        mean_loss = torch.mean(single_losses)
        return mean_loss

In [19]:
def get_emb(model, name, idx):
    if name == "glove":
        emb = model._focal_embeddings(torch.tensor(idx))
    else:
        emb = model.embeddings_input(torch.tensor(idx))
    return emb.cpu().detach().numpy()

def save_word_embeddings(model_name: str, neg: str):
    MODEL_DIR = "../model"
    saved_params = torch.load(f"{MODEL_DIR}/model_{model_name}_neg{neg}.pth", map_location=torch.device('cpu'))
    word_idx = saved_params["word_to_ix"]
    
    if model_name == "cbow":
        model = CBOWModel(100, len(word_idx))
    elif model_name == "skipgram":
        model = SkipGramModel(100, len(word_idx))
    else:
        model = GloVeModel(100, len(word_idx))
        
    model.load_state_dict(saved_params["model_state_dict"])
    model.eval()

    word_embeddings = defaultdict(list)
    for w in word_idx.keys():
        word_embeddings[w] = get_emb(model, model_name, neg)
    with open(f"{MODEL_DIR}/{model_name}_{neg}.embedding", "wb") as f:
        pickle.dump(word_embeddings, f, pickle.DEFAULT_PROTOCOL)
    print(f"{model_name} embeddings saved")

## Save the embeddings

In [20]:
for m in ["cbow", "skipgram", "glove"]:
    for n in [0, 10]:
        save_word_embeddings(m, n)

cbow embeddings saved
cbow embeddings saved
skipgram embeddings saved
skipgram embeddings saved
glove embeddings saved
glove embeddings saved
