import io
import pkgutil
import os
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import matplotlib.pyplot as plt
DATA_SOURCES: dict[str, str] = {
"DoOR_adult": "data/DoOR/processed_door_adult.csv",
"DoOR_adult_sfr_subtracted": "data/DoOR/processed_door_adult_sfr_subtracted.csv",
"Dweck_adult_chem": "data/Dweck2018/adult_chem2glom.csv",
"Dweck_adult_fruit": "data/Dweck2018/adult_fruit2glom.csv",
"Dweck_larva_chem": "data/Dweck2018/larva_chem2or.csv",
"Dweck_larva_fruit": "data/Dweck2018/larva_fruit2or.csv",
"Nern2024": "data/Nern2024/ME-columnar-cells-hex-location.csv",
"Matsliah2024": "data/Matsliah2024/fafb_right_vis_cols.csv",
"Badel2016_PN": "data/Badel2016/Badel2016.csv",
"Zhao2024": "data/Zhao2024/ucl_hex_right_20240701_tomale.csv",
}
[docs]
def load_dataset(dataset: str) -> pd.DataFrame:
"""
Load the dataset from the package data folder. These datasets have been
preprocessed to work with connectomics data. The preprocessing scripts are
in this repository: https://github.com/YijieYin/interpret_connectome.
Args:
dataset: (str) The name of the dataset to load. Options are:
- 'DoOR_adult': mapping from glomeruli to chemicals, from Munch and Galizia DoOR dataset (https://www.nature.com/articles/srep21841).
- 'DoOR_adult_sfr_subtracted': mapping from glomeruli to chemicals, with spontaneous firing rate subtracted. There are therefore negative values.
- 'Dweck_adult_chem': mapping from glomeruli to chemicals extracted from fruits, from Dweck et al. 2018 (https://www.cell.com/cell-reports/abstract/S2211-1247(18)30663-6). Firing rates normalised to between 0 and 1.
- 'Dweck_adult_fruit': mapping from glomeruli to fruits, from Dweck et al. 2018. Number of responses normalised to between 0 and 1.
- 'Dweck_larva_chem': mapping from olfactory receptors to chemicals, from Dweck et al. 2018. Firing rates normalised to between 0 and 1.
- 'Dweck_larva_fruit': mapping from olfactory receptors to fruits from Dweck et al. 2018. Number of responses normalised to between 0 and 1.
- 'Nern2024': columnar coordinates of individual cells from a collection of columnar cell types within the medulla of the right optic lobe, from Nern et al. 2024 (https://www.biorxiv.org/content/10.1101/2024.04.16.589741v2).
- 'Matsliah2024': columnar coordinates of individual cells from a collection of columnar cell types in the right optic lobe from FAFB, from Matsliah et al. 2024 (https://www.nature.com/articles/s41586-024-07981-1).
- 'Badel2016_PN': mapping from olfactory projection neurons to odours, from Badel et al. 2016 (https://www.cell.com/neuron/fulltext/S0896-6273(16)30201-X).
- 'Zhao2024': mapping from hexagonal coordinates to 3D coordinates, update from Zhao et al. 2022 (https://www.biorxiv.org/content/10.1101/2022.12.14.520178v1).
Returns:
pd.DataFrame: The dataset as a pandas DataFrame. For the adult, the glomeruli
are in the rows. For the larva, receptors are in the rows.
"""
try:
data = pkgutil.get_data("connectome_interpreter", DATA_SOURCES[dataset])
except KeyError as exc:
raise ValueError(
"Dataset not recognized. Please choose from {}".format(
list(DATA_SOURCES.keys())
)
) from exc
return pd.read_csv(io.BytesIO(data), index_col=0)
[docs]
def map_to_experiment(df, dataset=None, custom_experiment=None):
"""
Map the connectomics data to experimental data. For example, if odour1
excites neuron1 0.5, and neuron2 0.6; both neuron1 and neuron2 output to
neuron3 (0.7 and 0.8 respectively), then the output of neuron3 to odour1
is 0.5*0.7 + 0.6*0.8 = 0.83. The result would only be 1 if a stimulus
excites neurons 100%, and those neurons constitue 100% of the downstream
neuron's input.
Args:
df (pd.DataFrame): The connectivity data. Standardised input (e.g. glomeruli,
receptors) in rows, observations (target neurons) in columns.
dataset (str): The name of the dataset to load. Options are:
- 'DoOR_adult': mapping from glomeruli to chemicals, from Munch and Galizia DoOR dataset (https://www.nature.com/articles/srep21841).
- 'DoOR_adult_sfr_subtracted': mapping from glomeruli to chemicals, with spontaneous firing rate subtracted. There are therefore negative values.
- 'Dweck_adult_chem': mapping from glomeruli to chemicals extracted from fruits, from Dweck et al. 2018 (https://www.cell.com/cell-reports/abstract/S2211-1247(18)30663-6). Firing rates normalised to between 0 and 1.
- 'Dweck_adult_fruit': mapping from glomeruli to fruits, from Dweck et al. 2018. Number of responses normalised to between 0 and 1.
- 'Dweck_larva_chem': mapping from olfactory receptors to chemicals, from Dweck et al. 2018. Firing rates normalised to between 0 and 1.
- 'Dweck_larva_fruit': mapping from olfactory receptors to fruits, from Dweck et al. 2018. Number of responses normalised to between 0 and 1.
- 'Nern2024': columnar coordinates of individual cells from a collection of columnar cell types within the medulla of the right optic lobe, from Nern et al. 2024.
- 'Badel2016_PN': mapping from olfactory projection neurons to odours, from Badel et al. 2016 (https://www.cell.com/neuron/fulltext/S0896-6273(16)30201-X).
custom_experiment (pd.DataFrame): A custom experimental dataset to compare the
connectomics data to. The row indices of this dataframe must match the row
indices of df. They are the units of comparison (e.g. glomeruli).
Returns:
pd.DataFrame: The similarity between the connectomics data and the experimental
data. Rows are neurons, columns are external stimulus.
"""
# try:
# from sklearn.metrics.pairwise import cosine_similarity
# except ImportError as e:
# raise ImportError(
# "To use this function, please install scikit-learn. You can
# install it with 'pip install scikit-learn'.") from e
if dataset is not None and custom_experiment is not None:
raise ValueError(
"Please provide either a dataset or a custom_experiment, not both."
)
if dataset is None and custom_experiment is None:
raise ValueError("Please provide either a dataset or a custom_experiment.")
if dataset is not None:
data = load_dataset(dataset)
else:
data = custom_experiment
# take the intersection of glomeruli
data = data[data.index.isin(df.index)]
df_intersect = df[df.index.isin(data.index)]
df_intersect = df_intersect.reindex(data.index)
# multiply the correpsonding values using matmul
target2chem = np.dot(df_intersect.values.T, data.values)
# Assign appropriate column names
target2chem = pd.DataFrame(
target2chem, index=df_intersect.columns, columns=data.columns
)
return target2chem
[docs]
def hex_heatmap(
df: pd.Series | pd.DataFrame,
style: dict | None = None,
sizing: dict | None = None,
dpi: int = 72,
custom_colorscale: list | None = None,
global_min: float | None = None,
global_max: float | None = None,
dataset: str | None = "mcns_right",
) -> go.Figure:
"""
Generate a hexagonal heat map plot of the data. The index of the data
should be formatted as strings of the form '-12,34', where the first
number is the x-coordinate and the second number is the y-coordinate.
Args:
df : pd.Series | pd.DataFrame
The data to plot. Each column will generate a separate frame in
the plot.
style : dict, default=None
Dict containing styling formatting variables. Possible keys are:
- 'font_type': str, default='arial'
- 'linecolor': str, default='black'
- 'papercolor': str, default='rgba(255,255,255,255)' (white)
sizing : dict, default=None
Dict containing size formatting variables. Possible keys are:
- 'fig_width': int, default=260 (mm)
- 'fig_height': int, default=220 (mm)
- 'fig_margin': int, default=0 (mm)
- 'fsize_ticks_pt': int, default=20 (points)
- 'fsize_title_pt': int, default=20 (points)
- 'markersize': int, default=18 if dataset='mcns_right', 20 if dataset='fafb_right'
- 'ticklen': int, default=15
- 'tickwidth': int, default=5
- 'axislinewidth': int, default=3
- 'markerlinewidth': int, default=0.9
- 'cbar_thickness': int, default=20
- 'cbar_len': float, default=0.75
dpi : int, default=72
Dots per inch for the output figure. Standard is 72 for screen/SVG/PDF.
Use higher values (e.g., 300) for print-quality output.
custom_colorscale : list, default=None
Custom colorscale for the heatmap. If None, defaults to white-to-blue
colorscale [[0, "rgb(255, 255, 255)"], [1, "rgb(0, 20, 200)"]].
global_min : float, default=None
Global minimum value for the color scale.
If None, the minimum value of the data is used but if that is negative, use 0.
global_max : float, default=None
Global maximum value for the color scale.
If None, the maximum value of the data is used.
dataset : str, default='mcns_right'
The dataset to use for the hexagon locations. Options are:
- 'mcns_right': columnar coordinates of individual cells from columnar cell types: L1, L2, L3, L5, Mi1, Mi4, Mi9, C2, C3, Tm1, Tm2, Tm4, Tm9, Tm20, T1, within the medulla of the right optic lobe, from Nern et al. 2024.
- 'fafb_right': columnar coordinates of individual cells from columnar cell types, in the right optic lobe of FAFB, from Matsliah et al. 2024.
Returns:
fig : go.Figure
"""
def bg_hex():
"""
Generate a scatter plot of the background hexagons."
"""
goscatter = go.Scatter(
x=background_hex["x"],
y=background_hex["y"],
mode="markers",
marker_symbol=symbol_number,
marker={
"size": sizing["markersize"],
"color": "white",
"line": {
"width": sizing["markerlinewidth"],
"color": "lightgrey",
},
},
showlegend=False,
)
return goscatter
def data_hex(aseries):
"""
Generate a scatter plot of the data hexagons.""
"""
goscatter = go.Scatter(
x=x_vals,
y=y_vals,
mode="markers",
marker_symbol=symbol_number,
customdata=np.stack([x_vals, y_vals, aseries.values], axis=-1),
hovertemplate="x: %{customdata[0]}<br>y: %{customdata[1]}<br>value: %{customdata[2]}",
marker={
"cmin": global_min,
"cmax": global_max,
"size": sizing["markersize"],
"color": aseries.values,
"line": {
"width": sizing["markerlinewidth"],
"color": "lightgrey",
},
"colorbar": {
"orientation": "v",
"outlinecolor": style["linecolor"],
"outlinewidth": sizing["axislinewidth"],
"thickness": sizing["cbar_thickness"],
"len": sizing["cbar_len"],
"tickmode": "array",
"ticklen": sizing["ticklen"],
"tickwidth": sizing["tickwidth"],
"tickcolor": style["linecolor"],
"tickfont": {
"size": fsize_ticks_px,
"family": style["font_type"],
"color": style["linecolor"],
},
"tickformat": ".5f",
"title": {
"font": {
"family": style["font_type"],
"size": fsize_title_px,
"color": style["linecolor"],
},
"side": "right",
},
},
"colorscale": custom_colorscale,
},
showlegend=False,
)
return goscatter
# Default styling and sizing parameters to use if not specified.
default_style = {
"font_type": "arial",
"markerlinecolor": "rgba(0,0,0,0)", # transparent
"linecolor": "black",
"papercolor": "rgba(255,255,255,255)",
}
if dataset == "mcns_right":
markersize = 18
elif dataset == "fafb_right":
markersize = 20
else:
# raise error
raise ValueError(
"Dataset not recognized. Currently available datasets are 'mcns_right', "
"'fafb_right'."
)
default_sizing = {
"fig_width": 260, # units = mm
"fig_height": 220, # units = mm
"fig_margin": 0,
"fsize_ticks_pt": 20,
"fsize_title_pt": 20,
"markersize": markersize,
"ticklen": 15,
"tickwidth": 5,
"axislinewidth": 3,
"markerlinewidth": 0.5, # 0.9,
"cbar_thickness": 20,
"cbar_len": 0.75,
}
# If style is provided, update default_style with user values
if style is not None:
default_style.update(style)
style = default_style
if sizing is not None:
default_sizing.update(sizing)
sizing = default_sizing
# Constants for unit conversion
POINTS_PER_INCH = 72 # Typography standard: 1 point = 1/72 inch
MM_PER_INCH = 25.4 # Standard conversion: 1 inch = 25.4 mm
# sizing of the figure and font
pixelsperinch = dpi # Use the provided DPI value
pixelspermm = pixelsperinch / MM_PER_INCH
# Default colorscale
if custom_colorscale is None:
custom_colorscale = [[0, "rgb(255, 255, 255)"], [1, "rgb(0, 20, 200)"]]
area_width = (sizing["fig_width"] - sizing["fig_margin"]) * pixelspermm
area_height = (sizing["fig_height"] - sizing["fig_margin"]) * pixelspermm
fsize_ticks_px = sizing["fsize_ticks_pt"] * (1 / POINTS_PER_INCH) * pixelsperinch
fsize_title_px = sizing["fsize_title_pt"] * (1 / POINTS_PER_INCH) * pixelsperinch
# Get global min and max for consistent color scale
# minimum of 0 and df.values.min()
vals = df.to_numpy()
if global_min is None:
global_min = min(0, vals.min())
if global_max is None:
global_max = vals.max()
# Symbol number to choose to plot hexagons
symbol_number = 15
# load all hex coordinates
if dataset == "mcns_right":
background_hex = load_dataset("Nern2024")
elif dataset == "fafb_right":
background_hex = load_dataset("Matsliah2024")
else:
# raise error
raise ValueError(
"Dataset not recognized. Currently available datasets are 'mcns_right', "
"'fafb_right'."
)
# only get the unique combination of 'x' and 'y' columns
background_hex = background_hex.drop_duplicates(subset=["x", "y"])
# initiate plot
fig = go.Figure()
fig.update_layout(
autosize=False,
height=area_height,
width=area_width,
margin={"l": 0, "r": 0, "b": 0, "t": 0, "pad": 0},
paper_bgcolor=style["papercolor"],
plot_bgcolor=style["papercolor"],
)
fig.update_xaxes(
showgrid=False, showticklabels=False, showline=False, visible=False
)
fig.update_yaxes(
showgrid=False, showticklabels=False, showline=False, visible=False
)
# Convert index values (formatted as '-12,34') into separate x and y coordinates
df = df[(df.index != "nan") & (~df.index.isnull())]
coords = [tuple(map(float, idx.split(","))) for idx in df.index]
x_vals, y_vals = zip(*coords) # Separate into x and y lists
if isinstance(df, pd.Series) or len(df.columns) == 1:
if isinstance(df, pd.DataFrame):
df = df.iloc[:, 0]
fig.add_trace(bg_hex())
fig.add_trace(data_hex(df))
elif isinstance(df, pd.DataFrame):
# Adjust figure size - add extra height for slider
slider_height = 100 # pixels
area_height += slider_height
# Create frames for slider
frames = []
slider_steps = []
# Add base layout
fig.update_layout(
autosize=False,
height=area_height,
width=area_width,
margin={
"l": 0,
"r": 0,
"b": slider_height,
"t": 0,
"pad": 0,
}, # Add bottom margin for slider
paper_bgcolor=style["papercolor"],
plot_bgcolor=style["papercolor"],
sliders=[
{
"active": 0,
"currentvalue": {
"font": {"size": 16},
"visible": True,
"xanchor": "right",
},
"pad": {"b": 10, "t": 0}, # Adjusted padding
"len": 0.9,
"x": 0.1,
"y": 0, # Move slider below plot
"steps": [],
}
],
)
# Create frames for each column
for i, col_name in enumerate(df.columns):
series = df[col_name]
frame_data = [
bg_hex(),
data_hex(series),
]
frames.append(go.Frame(data=frame_data, name=str(i)))
# Add to slider
slider_steps.append(
{
"args": [
[str(i)],
{"frame": {"duration": 0, "redraw": True}, "mode": "immediate"},
],
"label": col_name,
"method": "animate",
}
)
# Set initial display to first column
if i == 0:
fig.add_traces(frame_data)
# Update slider with all steps
fig.layout.sliders[0].steps = slider_steps
fig.frames = frames
# Update axes
fig.update_xaxes(
showgrid=False, showticklabels=False, showline=False, visible=False
)
fig.update_yaxes(
showgrid=False, showticklabels=False, showline=False, visible=False
)
else:
# raise error
raise ValueError("df must be a pd.Series or pd.DataFrame")
return fig
[docs]
def looming_stimulus(start_coords, all_coords, n_time=4):
"""
Generate a list of lists of coordinates for a looming stimulus. The stimulus starts
at the start_coords and expands outwards in a hexagonal pattern. The stimulus
expands for n_time steps. Currently the expansion happens one layer at a time.
Args:
start_coords (list): List of strings of the form 'x,y' where x and y are the
coordinates of the starting hexes for the stimulus.
all_coords (list): List of strings of the form 'x,y' where x and y are the
coordinates of all hexes in the grid.
n_time (int): Default=4. Number of time steps for the stimulus to expand.
Returns:
stim_str (list): List of lists of strings of the form 'x,y' where x and y are
the coordinates of the hexes that are stimulated at each time step.
"""
coords = [tuple(map(float, idx.split(","))) for idx in all_coords]
x_vals, y_vals = zip(*coords) # Separate into x and y lists
# sort and rank x_vals
x_sorted = sorted(list(set(x_vals)))
x_to_rank = {x: rank for rank, x in enumerate(x_sorted)}
rank_to_x = {rank: x for rank, x in enumerate(x_sorted)}
y_sorted = sorted(list(set(y_vals)))
y_to_rank = {y: rank for rank, y in enumerate(y_sorted)}
rank_to_y = {rank: y for rank, y in enumerate(y_sorted)}
start = [tuple(map(float, idx.split(","))) for idx in start_coords]
stimulus = []
stimulus.append(start)
for atime in range(n_time):
for x, y in start:
start_copy = start.copy()
# hexes above and below x
if y_to_rank[y] + 2 in rank_to_y:
start_copy.append((x, rank_to_y[y_to_rank[y] + 2]))
if y_to_rank[y] - 2 in rank_to_y:
start_copy.append((x, rank_to_y[y_to_rank[y] - 2]))
# hexes to the left
if x_to_rank[x] + 1 in rank_to_x:
if y_to_rank[y] + 1 in rank_to_y:
start_copy.append(
(rank_to_x[x_to_rank[x] + 1], rank_to_y[y_to_rank[y] + 1])
)
if y_to_rank[y] - 1 in rank_to_y:
start_copy.append(
(rank_to_x[x_to_rank[x] + 1], rank_to_y[y_to_rank[y] - 1])
)
# hexes to the right
if x_to_rank[x] - 1 in rank_to_x:
if y_to_rank[y] + 1 in rank_to_y:
start_copy.append(
(rank_to_x[x_to_rank[x] - 1], rank_to_y[y_to_rank[y] + 1])
)
if y_to_rank[y] - 1 in rank_to_y:
start_copy.append(
(rank_to_x[x_to_rank[x] - 1], rank_to_y[y_to_rank[y] - 1])
)
start = list(set(start_copy))
stimulus.append(start)
stim_str = []
for atime in range(n_time):
stim_atime = []
for x, y in stimulus[atime]:
# Format x and y to remove .0 if they're integers
x_str = str(int(x)) if x == int(x) else str(x)
y_str = str(int(y)) if y == int(y) else str(y)
stim_atime.append(f"{x_str},{y_str}")
stim_str.append(stim_atime)
return stim_str
[docs]
def make_sine_stim(phase=0, amplitude=1, n=8):
"""
Generate a dictionary of values representing a sine wave stimulus with a given phase
and amplitude. The sine wave is defined over n points, starting from the given phase.
Args:
phase (int): Phase of the sine wave in degrees. Default is 0.
amplitude (float): Amplitude of the sine wave. Default is 1.
n (int): Number of points in the sine wave. Default is 8.
Returns:
dict: A dictionary where keys are indices from 1 to n, and values are the
corresponding sine wave values.
"""
x = (phase % 180) / 180 * np.pi
x = np.linspace(x, x + np.pi, n)
y = amplitude * abs(np.sin(x))
return dict(zip(range(1, n + 1), y))
[docs]
def plot_mollweide_projection(
data: pd.Series | pd.DataFrame,
fig_size: tuple = (900, 700),
custom_colorscale: str = "Viridis",
global_min: float | None = None,
global_max: float | None = None,
dataset: str = "Zhao2024",
marker_size: int = 8,
) -> go.Figure:
"""
Generates a heatmap to visualize the value of column features per column using the
mollweide projection.
Args:
data (pd.Series | pd.DataFrame): Data with index formatted as strings of the
form '-12,34', where the first number is the x-coordinate and the second
number is the y-coordinate. The data to plot. Each column will generate a
separate frame in the plot.
fig_size (tuple): Size of the figure in pixels (width, height).
custom_colorscale (str): Name of the Plotly colorscale to use.
global_min (float | None): Global minimum value for the color scale. If this
minumum is >0, 0 is used.
global_max (float | None): Global maximum value for the color scale. If None,
the maximum value of the data is used.
dataset (str): The dataset to use for the hexagon locations. Options are:
- 'Zhao2024': mapping from hexagonal coordinates to 3D coordinates, update from Zhao et al. 2022 (https://www.biorxiv.org/content/10.1101/2022.12.14.520178v1).
marker_size (int): Size of markers in the plot.
Returns:
go.Figure: A Plotly figure object containing the mollweide projection heatmap.
"""
def cart2sph(xyz: np.array) -> np.array:
"""
Convert Cartesian to spherical coordinates.
Theta is polar angle (from +z), phi is angle from +x to +y.
"""
r = np.sqrt((xyz**2).sum(1))
theta = np.arccos(xyz[:, 2])
phi = np.arctan2(xyz[:, 1], xyz[:, 0])
phi[phi < 0] = phi[phi < 0] + 2 * np.pi
return np.stack((r, theta, phi), axis=1)
def sph2Mollweide(thetaphi: np.array) -> np.array:
"""
Spherical (viewed from outside) to Mollweide,
cf. https://mathworld.wolfram.com/MollweideProjection.html
"""
azim = thetaphi[:, 1]
azim[azim > np.pi] = azim[azim > np.pi] - 2 * np.pi # longitude/azimuth
elev = np.pi / 2 - thetaphi[:, 0] # lattitude/elevation in radian
N = len(azim) # number of points
xy = np.zeros((N, 2)) # output
for i in range(N):
theta = np.arcsin(2 * elev[i] / np.pi)
if np.abs(np.abs(theta) - np.pi / 2) < 0.001:
xy[i,] = [
2 * np.sqrt(2) / np.pi * azim[i] * np.cos(theta),
np.sqrt(2) * np.sin(theta),
]
else:
# to calculate theta
dtheta = 1
while dtheta > 1e-3:
theta_new = theta - (
2 * theta + np.sin(2 * theta) - np.pi * np.sin(elev[i])
) / (2 + 2 * np.cos(2 * theta))
dtheta = np.abs(theta_new - theta)
theta = theta_new
xy[i,] = [
2 * np.sqrt(2) / np.pi * azim[i] * np.cos(theta),
np.sqrt(2) * np.sin(theta),
]
return xy
def create_mollweide_guidelines():
"""
Create Mollweide projection guidelines as plotly traces
"""
traces = []
# Create meridians
ww = np.stack((np.linspace(0, 180, 19), np.repeat(-180, 19)), axis=1)
w = np.stack((np.linspace(180, 0, 19), np.repeat(-90, 19)), axis=1)
m = np.stack((np.linspace(0, 180, 19), np.repeat(0, 19)), axis=1)
e = np.stack((np.linspace(180, 0, 19), np.repeat(90, 19)), axis=1)
ee = np.stack((np.linspace(0, 180, 19), np.repeat(180, 19)), axis=1)
pts = np.vstack((ww, w, m, e, ee))
rtp = np.insert(pts / 180 * np.pi, 0, np.repeat(1, pts.shape[0]), axis=1)
meridians_xy = sph2Mollweide(rtp[:, 1:3])
traces.append(
go.Scatter(
x=meridians_xy[:, 0],
y=meridians_xy[:, 1],
mode="lines",
line=dict(color="lightgrey", width=0.5),
showlegend=False,
hoverinfo="skip",
)
)
# Create parallels
for lat in [45, 90, 135]:
pts = np.stack((np.repeat(lat, 37), np.linspace(-180, 180, 37)), axis=1)
rtp = np.insert(pts / 180 * np.pi, 0, np.repeat(1, pts.shape[0]), axis=1)
parallel_xy = sph2Mollweide(rtp[:, 1:3])
traces.append(
go.Scatter(
x=parallel_xy[:, 0],
y=parallel_xy[:, 1],
mode="lines",
line=dict(color="lightgrey", width=0.5),
showlegend=False,
hoverinfo="skip",
)
)
return traces
def create_data_scatter(series_data, x_coords, y_coords, column_name=None):
"""Create scatter plot for data points"""
return go.Scatter(
x=x_coords,
y=y_coords,
mode="markers",
marker=dict(
color=series_data.values,
colorscale=custom_colorscale,
cmin=global_min,
cmax=global_max,
size=marker_size,
colorbar=dict(
title=dict(text=column_name if column_name else "Value", side="right"),
),
),
customdata=np.stack([x_coords, y_coords, series_data.values], axis=-1),
hovertemplate="x: %{customdata[0]}<br>y: %{customdata[1]}<br>value: %{customdata[2]}<extra></extra>",
showlegend=False,
)
# Clean data - remove NaN indices
data = data[(data.index != "nan") & (~data.index.isnull())]
# Convert string indices to coordinate arrays
coords = [tuple(map(float, idx.split(","))) for idx in data.index]
coord_array = np.array(coords)
# Get global min and max for consistent color scale
vals = data.to_numpy() if isinstance(data, pd.DataFrame) else data.values
if global_min is None:
global_min = min(0, vals.min())
if global_max is None:
global_max = vals.max()
# Load eyemap data and convert coordinates
ucl_hex = load_dataset(dataset)
rtp2 = cart2sph(ucl_hex[["x", "y", "z"]].values)
xy = sph2Mollweide(rtp2[:, 1:3])
xy[:, 0] = -xy[:, 0] # flip x axis
xypq_moll = np.concatenate((xy, ucl_hex[["p", "q"]].values), axis=1)
xypq_moll = pd.DataFrame(xypq_moll, columns=["x", "y", "p", "q"])
xypq_moll[["p", "q"]] = xypq_moll[["p", "q"]].astype(int)
# Convert data coordinates to Mollweide
hex1_id = (coord_array[:, 1] - coord_array[:, 0]) / 2
hex2_id = (coord_array[:, 1] + coord_array[:, 0]) / 2
coord_df = pd.DataFrame({"hex1_id": hex1_id, "hex2_id": hex2_id}, index=data.index)
merged_coords = coord_df.merge(
xypq_moll, left_on=["hex1_id", "hex2_id"], right_on=["q", "p"], how="left"
)
x_mollweide = merged_coords["x"].values
y_mollweide = merged_coords["y"].values
# Create figure
fig = go.Figure()
# Add guidelines
guidelines = create_mollweide_guidelines()
for trace in guidelines:
fig.add_trace(trace)
# Handle single series vs DataFrame
if isinstance(data, pd.Series) or (
isinstance(data, pd.DataFrame) and len(data.columns) == 1
):
if isinstance(data, pd.DataFrame):
data = data.iloc[:, 0]
# Single plot
fig.add_trace(create_data_scatter(data, x_mollweide, y_mollweide))
elif isinstance(data, pd.DataFrame):
# Multiple columns - create frames for slider
frames = []
slider_steps = []
# Create frames for each column
for i, col_name in enumerate(data.columns):
series = data[col_name]
# Create frame data (guidelines + data scatter)
frame_traces = guidelines + [
create_data_scatter(series, x_mollweide, y_mollweide, col_name)
]
frames.append(go.Frame(data=frame_traces, name=str(i)))
# Add slider step
slider_steps.append(
{
"args": [
[str(i)],
{"frame": {"duration": 0, "redraw": True}, "mode": "immediate"},
],
"label": col_name,
"method": "animate",
}
)
# Set initial display to first column
if i == 0:
fig.add_trace(
create_data_scatter(series, x_mollweide, y_mollweide, col_name)
)
# Add frames to figure
fig.frames = frames
# Add slider
fig.update_layout(
sliders=[
{
"active": 0,
"currentvalue": {
"font": {"size": 16},
"visible": True,
"xanchor": "right",
},
"pad": {"b": 10, "t": 50},
"len": 0.9,
"x": 0.1,
"y": 0,
"steps": slider_steps,
}
]
)
# Update layout
fig.update_layout(
width=fig_size[0],
height=fig_size[1],
xaxis=dict(
range=[-np.pi, np.pi],
scaleanchor="y",
scaleratio=1,
showgrid=False,
showticklabels=False,
showline=False,
visible=False,
),
yaxis=dict(
range=[-np.pi / 2, np.pi / 2],
showgrid=False,
showticklabels=False,
showline=False,
visible=False,
),
plot_bgcolor="white",
paper_bgcolor="white",
margin=dict(
l=0,
r=0,
t=50,
b=50 if isinstance(data, pd.DataFrame) and len(data.columns) > 1 else 0,
),
)
return fig