Given a simple feature selection code below, I want to know the selected columns after the feature selection (The dataset includes a header V1 ... V20)
import pandas as pd
from sklearn.feature_selection import SelectFromModel, SelectKBest, f_regression
def feature_selection(data):
y = data['Class']
X = data.drop(['Class'], axis=1)
fs = SelectKBest(score_func=f_regression, k=10)
# Applying feature selection
X_selected = fs.fit_transform(X, y)
# TODO: determine the columns being selected
return X_selected
data = pd.read_csv("../dataset.csv")
new_data = feature_selection(data)
I appreciate any help.

I have used the iris dataset for my example but you can probably easily modify your code to match your use case.
The SelectKBest method has the scores_ attribute I used to sort the features.
Feel free to ask for any clarifications.
import pandas as pd
import numpy as np
from sklearn.feature_selection import SelectFromModel, SelectKBest, f_regression
from sklearn.datasets import load_iris
def feature_selection(data):
y = data[1]
X = data[0]
column_names = ["A", "B", "C", "D"] # Here you should use your dataframe's column names
k = 2
fs = SelectKBest(score_func=f_regression, k=k)
# Applying feature selection
X_selected = fs.fit_transform(X, y)
# Find top features
# I create a list like [[ColumnName1, Score1] , [ColumnName2, Score2], ...]
# Then I sort in descending order on the score
top_features = sorted(zip(column_names, fs.scores_), key=lambda x: x[1], reverse=True)
return X_selected
data = load_iris(return_X_y=True)
new_data = feature_selection(data)

I don't know the in-build method, but it can be easily coded.
n_columns_selected = X_new.shape[0]
new_columns = list(sorted(zip(fs.scores_, X.columns))[-n_columns_selected:])
# new_columns order is perturbed, we need to restore it. We use the names of the columns of X as a reference
new_columns = list(sorted(cols_new, key=lambda x: list(X.columns).index(x)))


Saving an sklearn.svm.SVR model as JSON instead of pickling

I have a trained SVR model which needs to be saved in a JSON format instead of pickling.
The idea behind JSONifying the trained model is to simply capture the state of the weights and other 'fitted' attributes. Then, I can set these attributes later to make predictions. Here is an implementation of it I did:
# assume SVR has been trained
regressor = SVR(), y_train)
# saving the regressor params in a JSON file for later retrieval
with open(f'saved_regressor_params.json', 'w', encoding='utf-8') as outfile:
json.dump(regressor.get_params(), outfile)
# finding the fitted attributes of SVR()
# if an attribute is trailed by '_', it's a fitted attribute
attrs = [i for i in dir(regressor) if i.endswith('_') and not i.endswith('__')]
remove_list = ['coef_', '_repr_html_', '_repr_mimebundle_'] # unnecessary attributes
for attr in remove_list:
if attr in attrs:
# deserialize NumPy arrays and save trained attribute values into JSON file
attr_dict = {i: getattr(regressor, i) for i in attrs}
for k in attr_dict:
if isinstance(attr_dict[k], np.ndarray):
attr_dict[k] = attr_dict[k].tolist()
# dump JSON for prediction
with open(f'saved_regressor_{index}.json', 'w', encoding='utf-8') as outfile:
separators=(',', ':'),
This would create two separate json files. One file called saved_regressor_params.json which saves certain required parameters for SVR and another is called saved_regressor.json which stores attributes and their trained values as objects. Example (saved_regressor.json):
Later, I can create a new SVR() model and simply set these parameters and attributes into it by calling them from the existing JSON files we just created. Then, call in the predict() method to predict. Like so (in a new file):
predict_svr = SVR()
#load the json from the files
obj_text ='saved_regressor_params.json', 'r', encoding='utf-8').read()
params = json.loads(obj_text)
obj_text ='saved_regressor.json', 'r', encoding='utf-8').read()
attributes = json.loads(obj_text)
#setting params
# setting attributes
for k in attributes:
if isinstance(attributes[k], list):
setattr(predict_svr, k, np.array(attributes[k]))
setattr(predict_svr, k, attributes[k])
However, during this process, a particular attribute called: n_support_ cannot be set due to some reason. And even if I ignore n_support_ attribute, it creates additional errors. (Is my logic wrong or am I missing something here?)
Therefore, I am looking for different ways or ingenious methods to save an SVR model into JSON.
I have tried the existing third party helper libraries like: sklearn_json. These libraries tend to export perfectly for linear models but not for support vectors.
Making a reproducible example missing in the OP, based on the docs (version 1.1.2)
from sklearn.svm import SVR
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
import numpy as np
n_samples, n_features = 10, 5
rng = np.random.RandomState(0)
y = rng.randn(n_samples)
X = rng.randn(n_samples, n_features)
regressor = SVR(C=1.0, epsilon=0.2), y)
Then a sketch of the a JSON serialization/deserialization
import json
# serialize
serialized = json.dumps({
k: v.tolist() if isinstance(v, np.ndarray) else v
for k, v in regressor.__dict__.items()
# deserialize
regressor2 = SVR()
regressor2.__dict__ = {
k: np.asarray(v) if isinstance(v, list) else v
for k, v in json.loads(serialized).items()
# test
assert np.all(regressor.predict(X) == regressor2.predict(X))
EDIT: Serialization preserving data type
A not so elegant solution to address the first issue mentioned in a comment is to save the data type together with the data.
import json
# serialize
serialized = json.dumps({
k: [v.tolist(), 'np.ndarray', str(v.dtype)] if isinstance(v, np.ndarray) else v
for k, v in regressor.__dict__.items()
# deserialize
regressor2 = SVR()
regressor2.__dict__ = {
k: np.asarray(v[0], dtype=v[2]) if isinstance(v, list) and v[1] == 'np.ndarray' else v
for k, v in json.loads(serialized).items()
# test
assert np.all(regressor.predict(X) == regressor2.predict(X))

How to handle alphanumeric values in machine learning

I am trying to the find the best algorithm for my claims data. The claims data include some diagnosis code which are alphanumeric like 'EA43454' . when i run the below code to evaluate the models
models.append(('LR', LogisticRegression()))
models.append(('LDA', LinearDiscriminantAnalysis()))
models.append(('KNN', KNeighborsClassifier()))
models.append(('CART', DecisionTreeClassifier()))
models.append(('NB', GaussianNB()))
models.append(('SVM', SVC()))
# evaluate each model in turn
results = []
names = []
scoring = 'accuracy'
for name, model in models:
kfold = model_selection.KFold(n_splits=10, random_state=None)
cv_results = model_selection.cross_val_score(model, X, y, cv=kfold, scoring=scoring)
msg = "%s: %f (%f)" % (name, cv_results.mean(), cv_results.std())
i get the error
ValueError: could not convert string to float: 'U0003'
How to handle these alphanumeric values?
You need to convert your strings to an indicator variable (dummy variables). Each value of the string variable has to be associated with a number so that the models can train on that data.
Scikit-learn has several preprocessors to help you with this such as OneHotEncoder. You can also use pandas.get_dummies, but using sklearn's own classes is more composable - for example, you can use them as part of a pipeline.
import pandas as pd
import numpy as np
from sklearn.preprocessing import OneHotEncoder
rng = np.random.default_rng()
animals = pd.DataFrame({"animal": rng.choice(["cat", "dog"], size=10),
"age": rng.integers(1, 20, size=10)})
animals_ohe = OneHotEncoder().fit_transform(animals.drop(columns=["age"]))

Outcput of Chi2 is showing as in dataframe fromat:

I am trying loop the ChiSquare test and outcome is not shown as required, that is in Dataframe.
All columns are coming one row..
Please help
# Import the function
from scipy.stats import chi2_contingency
chi2_check = []
for i in df_clean.select_dtypes(['object']):
if chi2_contingency(pd.crosstab(df_clean['Final_Comments'], df_clean[i]))[1] < 0.05:
chi2_check.append('Reject Null Hypothesis')
chi2_check.append('Fail to Reject Null Hypothesis')
res = pd.DataFrame(data = [df_clean.select_dtypes(['object']), chi2_check]
res.columns = ['Column', 'Hypothesis']

Error while trying to transpose the matrix

Code for a single raster file:
import geopandas as gpd
#import os
import rasterio
import scipy.sparse as sparse
import pandas as pd
import numpy as np
# Create an empty pandas dataframe called 'table'
table = pd.DataFrame(index = np.arange(0,1))
# Read the points shapefile using GeoPandas
stations = gpd.read_file(r'E:/anakonda/Shape files/AAQ_st1/AAQ_ST1.shp')
stations['lon'] = stations['geometry'].x
stations['lat'] = stations['geometry'].y
Matrix = pd.DataFrame()
# Iterate through the rasters and save the data as individual arrays to a Matrix
dataset ='E:/anakonda/LST_day/MOD11A1.006_LST_Day_1km_doy2019082_aid0001.tif')
data_array =
data_array_sparse = sparse.coo_matrix(data_array, shape = (351, 545))
for records_date in Matrix.columns.tolist():
a = Matrix
LST_day_value = a.loc[int(row)][int(col)]
table[records_date] = LST_day_value
transpose_mat = table.T
transpose_mat.rename(columns = {0: 'LST_Day(Kel)'}, inplace = True)
Error code lines:
LST_day_value = a.loc[int(row)][int(col)]
Errors Shown:
Undefined Name 'row' (pyflakes E)
Undefined Name 'col' (pyflakes E)
NameError: name 'transpose_mat' is not defined
I'm using the above code for creating a Raster Time-series for Modis LST data. the code ran well till 'transposing the matrix'. the error shown is mentioned below the code. Im new to python, so kindly help me with this issue.
import os
import rasterio
import scipy.sparse as sparse
import pandas as pd
import numpy as np
# Create an empty pandas dataframe called 'table'
table = pd.DataFrame(index = np.arange(0,1))
# Read the points shapefile using GeoPandas
stations = gpd.read_file(r'E:/anakonda/Shape files/AAQ_st1/AAQ_ST1.shp')
stations['lon'] = stations['geometry'].x
stations['lat'] = stations['geometry'].y
Matrix = pd.DataFrame()
# Iterate through the rasters and save the data as individual arrays to a Matrix
for files in os.listdir(r'E:/anakonda/LST_Night'):
if files[-4: ] == '.tif':
dataset ='E:/anakonda/LST_Night'+'\\'+files)
data_array =
data_array_sparse = sparse.coo_matrix(data_array, shape = (351,545))
data = files[ :-20]
Matrix[data] = data_array_sparse.toarray().tolist()
print('Processing is done for the raster: '+ files[:-20])
# Iterate through the stations and get the corresponding row and column for the related x, y coordinates
for index, row in stations.iterrows():
station_name = str(row['Station'])
lon = float(row['lon'])
lat = float(row['lat'])
x,y = (lon, lat)
row, col = dataset.index(x, y)
print('Processing: '+ station_name)
# Pick the LST value from each stored raster array and record it into the previously created 'table'
for records_date in Matrix.columns.tolist():
a = Matrix[records_date]
LST_Night_value = a.loc[int(row)][int(col)]
table[records_date] = LST_Night_value
transpose_mat = table.T
transpose_mat.rename(columns = {0: 'LstNight(Kel)'}, inplace = True)
This is the error shown:
```File "C:\Anaconda\envs\timeseries\lib\site-packages\pandas\core\indexes\", line 357, in get_loc
raise KeyError(key) from err
KeyError: 2278```

Where can I get the pretrained word embeddinngs for BERT?

I know that BERT has total vocabulary size of 30522 which contains some words and subwords. I want to get the initial input embeddings of BERT. So, my requirement is to get the table of size [30522, 768] to which I can index by token id to get its embeddings. Where can I get this table?
The BertModels have get_input_embeddings():
import torch
from transformers import BertModel, BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert = BertModel.from_pretrained('bert-base-uncased')
token_embedding = {token: bert.get_input_embeddings()(torch.tensor(id)) for token, id in tokenizer.get_vocab().items()}
tensor([ 1.3630e-02, -2.6490e-02, -2.3503e-02, -7.7876e-03, 8.5892e-03,
-7.6645e-03, -9.8808e-03, 6.0184e-03, 4.6921e-03, -3.0984e-02,
1.8883e-02, -6.0093e-03, -1.6652e-02, 1.1684e-02, -3.6245e-02,
8.3482e-03, -1.2112e-03, 1.0322e-02, 1.6692e-02, -3.0354e-02,
-1.2372e-02, -2.5173e-02, -8.9602e-03, 8.1994e-03, -2.0011e-02,
-1.5901e-02, -3.8394e-03, 1.4241e-03, 7.0500e-03, 1.6092e-03,
-2.7764e-03, 9.4931e-03, -2.2768e-02, 1.9317e-02, -1.3442e-02,
-2.3763e-02, -1.4617e-02, 9.7735e-03, -2.2428e-03, 3.0642e-02,
6.7829e-03, -2.6471e-03, -1.8553e-02, -1.2363e-02, 7.6489e-03,
-2.5461e-03, -3.1498e-01, 6.3761e-03, 4.8914e-02, -7.7636e-03,
6.0919e-02, 2.1346e-02, -3.9741e-02, 2.2853e-01, 2.6502e-02,
-1.0144e-03, -7.8480e-03, -1.9995e-03, 1.7057e-02, -3.3270e-02,
4.5421e-03, 6.1751e-03, -1.0077e-01, -2.0973e-02, -1.4512e-04,
-9.6657e-03, 1.0871e-02, -1.4786e-02, 2.6437e-04, 2.1166e-02,
1.6492e-02, -5.1928e-03, -1.1857e-02, -9.9159e-03, -1.4363e-02,
-1.2405e-02, -1.2973e-02, 2.6778e-02, -1.0986e-02, 1.0572e-02,
-2.5566e-02, 5.2494e-03, 1.5890e-02, -5.1504e-03, -7.5859e-03,
2.0259e-02, -7.0155e-03, 1.6359e-02, 1.7487e-02, 5.4297e-03,
-8.6403e-03, 2.8821e-02, -7.8964e-03, 1.9259e-02, 2.3868e-02,
-4.3472e-03, 5.5662e-02, -2.1940e-02, 4.1779e-03, -5.7216e-03,
2.6712e-02, -5.0371e-03, 2.4923e-02, -1.3429e-02, -8.4337e-03,
9.8188e-02, -1.2940e-03, 1.2865e-02, -1.5930e-03, 3.6437e-03,
1.5569e-02, 1.8620e-02, -9.0643e-03, -1.9740e-02, 1.0530e-02,
-2.7359e-03, -7.5283e-03, 1.1492e-03, 2.6162e-03, -6.2757e-03,
-8.6096e-03, 6.6221e-01, -3.2235e-03, -4.1309e-02, 3.3047e-03,
-2.5040e-03, 1.2838e-04, -6.8073e-03, 6.0291e-03, -9.8468e-03,
8.0641e-03, -1.9815e-03, 2.5801e-02, 5.7429e-03, -1.0712e-02,
2.9176e-02, 5.9414e-03, 2.4795e-02, -1.7887e-02, 7.3183e-01,
1.0964e-02, 5.9942e-03, -4.6157e-02, 4.0131e-02, -9.7481e-03,
-8.9496e-01, 1.6385e-02, -1.9816e-03, 1.4691e-02, -1.9837e-02,
-1.7611e-02, -4.5263e-04, -1.8605e-02, -1.5660e-02, -1.0709e-02,
1.8016e-02, -3.4149e-03, -1.2632e-02, 4.2877e-03, -3.9169e-01,
1.0016e-02, -1.0955e-02, 4.5133e-03, -5.1150e-03, 4.9968e-03,
1.7852e-02, 1.1313e-02, 2.6519e-03, 3.3658e-01, -1.8168e-02,
1.3170e-02, 7.3927e-03, 5.2521e-03, -9.6230e-03, 1.2844e-02,
4.1554e-01, -9.7247e-03, -4.2439e-03, 5.5287e-04, 1.8271e-02,
-1.3889e-03, -2.0502e-03, -8.1946e-03, -6.5979e-06, -7.2764e-04,
-1.4625e-03, -6.9872e-03, -6.9633e-03, -8.0701e-03, 1.9936e-02,
4.8370e-03, 8.6883e-03, -4.9246e-02, -2.0028e-02, 1.4124e-03,
1.0444e-02, -1.1236e-02, -4.4654e-03, -2.0491e-02, -2.7654e-02,
-3.7079e-02, 1.3215e-02, 6.9498e-02, -3.1109e-02, 7.0562e-03,
1.0887e-02, -7.8090e-03, -1.0501e-02, -4.8735e-03, -6.8399e-04,
1.4717e-02, 4.4342e-03, 1.6012e-02, -1.0427e-02, -2.5767e-02,
-2.2699e-01, 8.6569e-02, 2.3453e-02, 4.6362e-02, 3.5609e-03,
2.1353e-02, 2.3703e-02, -2.0252e-02, 2.1580e-02, 7.2652e-03,
2.0933e-01, 1.2108e-02, 1.0869e-02, 7.0568e-03, -3.1132e-02,
2.0505e-02, 3.2248e-03, -2.2724e-03, 5.5342e-03, 3.0563e-03,
1.9542e-02, 1.2827e-03, 1.5952e-02, -1.5458e-02, -3.8455e-03,
-4.9417e-03, -1.0446e-02, 7.0516e-03, 2.2467e-03, -9.3643e-03,
1.9163e-02, 1.4239e-02, -1.5816e-02, 8.7413e-03, 2.4737e-02,
-7.3777e-03, -4.0975e-02, 9.4948e-03, 1.4700e-02, 2.6819e-02,
1.0706e-02, 1.0621e-02, -7.1816e-03, -8.5402e-03, 1.2261e-02,
-4.8679e-03, -9.6136e-03, 7.8765e-04, 3.8504e-02, -7.7485e-03,
-6.5018e-03, 3.4352e-03, 2.2931e-04, 5.7456e-03, -4.8441e-03,
-9.0898e-03, 8.6298e-03, 5.4740e-03, 2.2274e-02, -2.1218e-02,
-2.6795e-02, -3.5337e-03, 1.0785e-02, 1.2475e-02, -6.1160e-03,
1.0729e-02, -9.7955e-03, 1.8543e-02, -6.0488e-03, -4.5744e-03,
2.7089e-03, 1.5632e-02, -1.2928e-02, -3.0778e-03, -1.0325e-02,
-7.9550e-03, -6.3065e-02, 2.1062e-02, -6.6717e-03, 8.4616e-03,
1.4475e-02, 1.1477e-01, -2.2838e-02, -3.7491e-02, -3.6218e-02,
-3.1994e-02, -8.9252e-03, 3.1720e-02, -1.1260e-02, -1.2980e-01,
-1.0315e-03, -4.7242e-03, -2.0092e-02, -9.4521e-01, -2.2178e-02,
-4.4297e-04, 1.9711e-02, 3.3402e-02, -1.0513e-02, 1.4492e-02,
-1.9697e-02, -9.8452e-03, -1.7347e-02, 2.3472e-02, 7.6570e-02,
1.9504e-02, 9.3617e-03, 8.2672e-03, -1.0471e-02, -1.9932e-03,
2.0000e-02, 2.0485e-02, 1.0977e-02, 1.7720e-02, 1.3532e-02,
7.3682e-03, 3.4906e-04, 1.8772e-03, 1.9976e-02, -3.2041e-02,
-8.9169e-03, 1.2900e-02, -1.3331e-02, 6.6207e-03, -5.7063e-03,
-1.1482e-02, 8.3907e-03, -6.4162e-03, 1.5816e-02, 7.8921e-03,
4.4177e-03, 2.2568e-02, 1.0239e-02, -3.0194e-04, 1.3294e-02,
-2.1606e-02, 3.8832e-03, 2.4475e-02, 4.3808e-02, -2.1031e-03,
-1.2163e-02, -4.0786e-02, 1.5565e-02, 1.4750e-02, 1.6645e-02,
2.8083e-02, 1.8920e-03, -1.4733e-04, -2.6208e-02, 2.3780e-02,
1.8657e-04, -2.2931e-03, 3.0334e-03, -1.7294e-02, -2.3001e-02,
8.6004e-03, -3.3497e-02, 2.5660e-02, -1.9225e-02, -2.7186e-02,
-2.1020e-02, -3.5213e-02, -1.8228e-03, -8.2840e-03, 1.1212e-02,
1.0387e-02, -3.4194e-01, -1.9705e-03, 1.1558e-02, 5.1976e-03,
7.4498e-03, 5.7142e-03, 2.8401e-02, -7.7551e-03, 1.0682e-02,
-1.2657e-02, -1.8065e-02, 2.6681e-03, 3.3947e-03, -4.5565e-02,
-2.1170e-02, -1.7830e-02, 3.4679e-03, -2.2051e-02, -5.4176e-03,
-1.1517e-02, -3.4155e-02, -3.0335e-03, -1.3915e-02, 6.2173e-03,
-1.1101e-02, -1.5308e-02, 9.2188e-03, -7.5665e-03, 6.5685e-03,
8.0935e-03, 3.1139e-03, -5.5047e-03, -3.1347e-02, 2.2140e-02,
1.0865e-02, -2.7849e-02, -4.9580e-03, 1.8804e-03, 1.0007e-01,
-1.8013e-03, -4.8792e-03, 1.5534e-02, -2.0179e-02, -1.2351e-02,
-1.3871e-02, 1.1439e-02, -9.0208e-03, 1.2580e-02, -2.5973e-02,
-2.0398e-02, -1.9464e-03, 4.3189e-03, 2.0707e-02, 5.0029e-03,
-1.0679e-02, 1.2298e-02, 1.0269e-02, 2.2228e-02, 2.9754e-02,
-2.6392e-03, 1.9286e-02, -1.5137e-02, 2.1914e-01, 1.3030e-02,
-7.4460e-03, -9.6818e-04, 2.9736e-02, 9.8722e-03, -5.6688e-03,
4.2518e-03, 1.8941e-02, -6.3909e-03, 8.0590e-03, -6.7893e-03,
6.0878e-03, -5.3970e-03, 7.5776e-04, 1.1374e-03, -5.0035e-03,
-1.6159e-03, 1.6764e-02, 9.1251e-03, 1.3020e-02, -1.0368e-02,
2.2141e-02, -2.5411e-03, -1.5227e-02, 2.3444e-02, 8.4076e-04,
-1.1465e-01, 2.7017e-03, -4.4961e-03, 2.9762e-04, -3.9612e-03,
8.9038e-05, 2.8683e-02, 5.0068e-03, 1.6509e-02, 7.8983e-04,
5.7728e-03, 3.2685e-02, -1.0457e-01, 1.2989e-02, 1.1278e-02,
1.1943e-02, 1.5258e-02, -6.2411e-04, 1.0682e-04, 1.2087e-02,
7.2984e-03, 2.7758e-02, 1.7572e-02, -6.0345e-03, 1.7211e-02,
1.4121e-02, 6.4663e-02, 9.1813e-03, 3.2555e-03, -3.2667e-02,
2.9132e-02, -1.7770e-02, 1.5302e-03, -2.9944e-02, -2.0706e-02,
-3.6528e-03, -1.5497e-02, 1.5223e-02, -1.4751e-02, -2.2381e-02,
6.9636e-03, -8.0838e-03, -2.4583e-03, -2.0677e-02, 8.8132e-03,
-6.9554e-04, 1.6965e-02, 1.8535e-01, 3.5843e-04, 1.0812e-02,
-4.2391e-03, 8.1779e-03, 3.4144e-02, -1.8996e-03, 2.9939e-03,
3.6898e-04, -1.0144e-02, -5.7416e-03, -5.7676e-03, 1.7565e-01,
-1.5793e-03, -2.6617e-02, -1.2572e-02, 3.0421e-04, -1.2132e-02,
-1.4168e-02, 1.2154e-02, 8.4700e-03, -1.6284e-02, 2.6983e-03,
-6.8554e-03, 2.7829e-01, 2.4060e-02, 1.1130e-02, 7.6095e-04,
3.1341e-01, 2.1668e-02, 1.0277e-02, -3.0065e-02, -8.3565e-03,
5.2488e-03, -1.1287e-02, -1.8266e-02, 1.1814e-02, 1.2662e-02,
2.9036e-04, 7.0254e-04, -1.4084e-02, 1.2925e-02, 3.9504e-03,
-7.9568e-03, 3.2794e-02, 7.3839e-03, 2.4609e-02, 9.6109e-03,
-8.7206e-03, 9.2571e-03, -3.5850e-03, -8.9996e-03, 2.3120e-03,
-1.8475e-02, -1.9610e-02, 1.1994e-02, 6.7156e-03, 1.9903e-02,
3.0703e-02, -4.9538e-03, -6.1673e-02, -6.4986e-03, -2.1317e-02,
-3.3650e-03, 2.3200e-03, -6.2224e-03, 3.7458e-03, 1.1542e-02,
-1.0181e-02, -8.4711e-03, 1.1603e-02, -5.6247e-03, -1.0220e-02,
-8.6501e-04, -1.2285e-02, -8.7487e-03, -1.1265e-02, 1.6322e-02,
1.5160e-02, 1.8882e-02, 5.1557e-03, -8.8616e-03, 4.2153e-03,
-1.9450e-02, -8.7365e-03, -9.7867e-03, 1.1667e-02, 5.0613e-03,
2.8221e-03, -7.1795e-03, 9.3306e-03, -4.9663e-02, 1.7708e-02,
-2.0959e-02, -3.3989e-02, 2.2581e-03, 5.1748e-03, -1.0133e-01,
2.1052e-03, 5.5644e-03, 1.3607e-03, 8.8388e-03, 1.0244e-02,
-3.8072e-03, 5.9209e-03, 6.7993e-03, 1.1594e-02, -1.1802e-02,
-2.4233e-03, -5.1504e-03, -1.1903e-02, 1.4075e-02, -4.0701e-03,
-2.9465e-02, -1.7579e-03, 4.3654e-03, 1.0429e-02, 3.7096e-02,
8.6493e-03, 1.5871e-02, 1.8034e-02, -3.2165e-03, -2.1941e-02,
2.6274e-02, -7.6941e-03, -5.9618e-03, -1.4179e-02, 8.0281e-03,
1.1293e-02, -6.6936e-05, 1.2899e-02, 1.0056e-02, -6.3919e-04,
2.0299e-02, 3.1528e-03, -4.8988e-03, 3.2754e-03, -1.1003e-01,
1.8414e-02, 2.2272e-03, -2.2185e-02, -4.8672e-03, 1.9643e-03,
3.0928e-02, -8.9599e-03, -1.1446e-02, -1.3794e-02, 7.1943e-03,
-5.8965e-03, 2.2605e-03, -2.6114e-02, -5.6616e-03, 6.5073e-03,
9.2219e-02, -6.7243e-03, 4.4427e-04, 7.2846e-03, -1.1021e-02,
7.8802e-04, -3.8878e-03, 1.0489e-02, 9.2883e-03, 1.8895e-02,
2.1808e-02, 6.2590e-04, -2.6519e-02, 7.0343e-04, -2.9067e-02,
-9.1515e-03, 1.0418e-03, 8.3222e-03, -8.7548e-03, -2.0637e-03,
-1.1450e-02, -8.8985e-04, -4.4062e-03, 2.3629e-02, -2.7221e-02,
3.2008e-02, 6.6325e-03, -1.1302e-02, -1.0138e-03, -1.6902e-01,
-8.4473e-03, 2.8536e-02, 1.4117e-03, -1.2136e-02, -1.4781e-02,
4.9960e-03, 3.3916e-02, 5.2710e-03, 1.7382e-02, -4.6315e-03,
1.1680e-02, -9.1395e-03, 1.8310e-02, 1.2321e-02, -2.4871e-02,
1.1535e-02, 5.0308e-03, 5.5028e-03, -7.2184e-03, -5.5210e-03,
1.7085e-02, 5.7236e-03, 1.7463e-03, 1.9969e-03, 6.1670e-03,
2.9347e-03, 1.3946e-02, -1.9984e-03, 1.0091e-02, 1.0388e-03,
-6.1902e-03, 3.0905e-02, 6.6038e-03, -9.1223e-02, -1.8411e-02,
5.4185e-03, 2.4396e-02, 1.5696e-02, -1.2742e-02, 1.8126e-02,
-2.6138e-02, 1.1170e-02, -1.3058e-02, -1.9386e-02, -5.9828e-03,
1.9176e-02, 1.9962e-03, -2.1538e-03, 3.3003e-02, 1.8407e-02,
-5.9498e-03, -3.2533e-03, -1.8917e-02, -1.5897e-02, -4.7057e-03,
5.4162e-03, -3.0037e-02, 8.6773e-03, -1.7942e-03, 6.6826e-03,
-1.1929e-02, -1.4076e-02, 1.6709e-02, 1.6860e-03, -3.3842e-03,
8.6805e-03, 7.1340e-03, 1.5147e-02], grad_fn=<EmbeddingBackward>)
To get context-sensitive word embedding for given input sentence/text, here is the code,
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
def get_word_idx(sent: str, word: str):
return sent.split(" ").index(word)
def get_hidden_states(encoded, token_ids_word, model, layers):
"""Push input IDs through model. Stack and sum `layers` (last four by default).
Select only those subword token outputs that belong to our word of interest
and average them."""
with torch.no_grad():
output = model(**encoded)
# Get all hidden states
states = output.hidden_states
# Stack and sum all requested layers
output = torch.stack([states[i] for i in layers]).sum(0).squeeze()
# Only select the tokens that constitute the requested word
word_tokens_output = output[token_ids_word]
return word_tokens_output.mean(dim=0)
def get_word_vector(sent, idx, tokenizer, model, layers):
"""Get a word vector by first tokenizing the input sentence, getting all token idxs
that make up the word of interest, and then `get_hidden_states`."""
encoded = tokenizer.encode_plus(sent, return_tensors="pt")
# get all token idxs that belong to the word of interest
token_ids_word = np.where(np.array(encoded.word_ids()) == idx)
return get_hidden_states(encoded, token_ids_word, model, layers)
def main(layers=None):
# Use last four layers by default
layers = [-4, -3, -2, -1] if layers is None else layers
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
model = AutoModel.from_pretrained("bert-base-cased", output_hidden_states=True)
sent = "I like cookies ."
idx = get_word_idx(sent, "cookies")
word_embedding = get_word_vector(sent, idx, tokenizer, model, layers)
return word_embedding
if __name__ == '__main__':
More details can be found here.
