t-SNE sklearn

Created
TagsBasic Concepts

t-Distributed Stochastic Neighbor Embedding (t-SNE) is an unsupervised, non-linear technique primarily used for data exploration and visualizing high-dimensional data.

In simpler terms, t-SNE shows how the data is arranged in a high-dimensional space.

t-Distributed Stochastic Neighbor Embedding (t-SNE) Overview

t-Distributed Stochastic Neighbor Embedding (t-SNE) is a non-linear dimensionality reduction technique particularly well suited for the visualization of high-dimensional datasets. It converts similarities between data points to joint probabilities and tries to minimize the Kullback-Leibler divergence between the joint probabilities of the low-dimensional embedding and the high-dimensional data. t-SNE has a cost function that is not convex, meaning it can have different results each time it's run, especially on the layout of the map (though the overall patterns and clusters should remain consistent).

How t-SNE Works

  1. Similarity Computation in High-Dimensional Space: t-SNE starts by converting the Euclidean distances between data points in the high-dimensional space into conditional probabilities that represent similarities. The similarity of datapoint \(x_j\) to datapoint \(x_i\) is the conditional probability \(p_{j|i}\) that \(x_i\) would pick \(x_j\) as its neighbor if neighbors were picked in proportion to their probability density under a Gaussian centered at \(x_i\).
  1. Similarity Computation in Low-Dimensional Space: It then defines a similar probability distribution in the low-dimensional space for the map points, using a Student t-distribution to compute similarities \(q_{j|i}\).
  1. Minimization of Kullback-Leibler Divergence: The Kullback-Leibler divergence between the two distributions \(P\) (in the high-dimensional space) and \(Q\) (in the low-dimensional space) is minimized by gradient descent. This divergence is a measure of how one probability distribution diverges from a second, expected probability distribution.

Key Features of t-SNE

Why t-SNE?

t-SNE is particularly useful because it can capture both the local and global structure of the data, meaning that it can group data points that are similar while also separating dissimilar groups far apart in the low-dimensional representation. This makes it especially useful for visualizing high-dimensional data in 2D or 3D.

Applications of t-SNE

Implementing t-SNE in Python with Scikit-Learn

Here's a basic example of using t-SNE with Scikit-Learn:

from sklearn.manifold import TSNE
from sklearn.datasets import load_digits
import matplotlib.pyplot as plt

# Load a dataset
digits = load_digits()
X = digits.data
y = digits.target

# Initialize and fit t-SNE
tsne = TSNE(n_components=2, perplexity=30.0, n_iter=1000, random_state=42)
X_tsne = tsne.fit_transform(X)

# Plotting the result
plt.figure(figsize=(10, 6))
plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=y, cmap='viridis', alpha=0.5)
plt.colorbar()
plt.title('t-SNE visualization of digit data')
plt.show()

In this code snippet, t-SNE is applied to the digits dataset (a set of 8x8 images of digits) to reduce its dimensionality for visualization purposes. The result is a 2D map where similar digits are located close to each other, demonstrating t-SNE's ability to capture the inherent structure of the data.

3D

from sklearn.manifold import TSNE
from sklearn.datasets import load_digits
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# Load a dataset
digits = load_digits()
X = digits.data
y = digits.target

# Reduce dimensions with t-SNE
tsne = TSNE(n_components=3, perplexity=30.0, n_iter=1000, random_state=42)
X_tsne = tsne.fit_transform(X)

# Plotting in 3D
fig = plt.figure(figsize=(10, 7))
ax = fig.add_subplot(111, projection='3d')

# Scatter plot
scatter = ax.scatter(X_tsne[:, 0], X_tsne[:, 1], X_tsne[:, 2], c=y, cmap='viridis', alpha=0.5)

# Legend
legend1 = ax.legend(*scatter.legend_elements(), title="Digits")
ax.add_artist(legend1)

# Title and labels
ax.set_title('3D t-SNE visualization of digit data')
ax.set_xlabel('Component 1')
ax.set_ylabel('Component 2')
ax.set_zlabel('Component 3')

plt.show()