plotters icon indicating copy to clipboard operation
plotters copied to clipboard

heatmap

Open wangjiawen2013 opened this issue 2 years ago • 2 comments

Hi, Anyone who can tell how to plot heatmap using plotters ?

wangjiawen2013 avatar Jul 21 '23 05:07 wangjiawen2013

Here is an example, feel free to use any of it:

wave

use std::{error::Error, ops::Neg};

use colorous::{Gradient, BLUES};
use ndarray::{prelude::*, stack, Zip};
use plotters::{coord::types::RangedCoordf64, prelude::*};

#[derive(Clone, Copy)]
struct Point {
    x: f64,
    y: f64,
}

impl Neg for Point {
    type Output = Self;

    fn neg(self) -> Self::Output {
        Point {
            x: -self.x,
            y: -self.y,
        }
    }
}

impl Point {
    fn norm(&self) -> f64 {
        self.y.hypot(self.x)
    }
}

struct Domain {
    x_min: f64,
    x_max: f64,
    y_min: f64,
    y_max: f64,
}

impl Domain {
    fn new(point_1: Point, point_2: Point) -> Self {
        Domain {
            x_min: point_1.x.min(point_2.x),
            x_max: point_1.x.max(point_2.x),
            y_min: point_1.y.min(point_2.y),
            y_max: point_1.y.max(point_2.y),
        }
    }

    fn width(&self) -> f64 {
        self.x_max - self.x_min
    }

    fn height(&self) -> f64 {
        self.y_max - self.y_min
    }

    fn area(&self) -> f64 {
        self.width() * self.height()
    }
}

fn filled_style<C: Into<RGBAColor>>(color: C) -> ShapeStyle {
    ShapeStyle {
        color: color.into(),
        filled: true,
        stroke_width: 0,
    }
}

struct Colorbar {
    min: f64,
    max: f64,
    gradient: Gradient,
}

impl Colorbar {
    fn color(&self, value: f64) -> RGBColor {
        let &Self {
            min,
            max,
            gradient: colormap,
        } = self;
        let value = value.max(min).min(max);
        let (r, g, b) = colormap
            .eval_continuous((value - min) / (max - min))
            .as_tuple();
        RGBColor(r, g, b)
    }

    fn draw<DB: DrawingBackend>(&self, text_color: RGBAColor, mut chart_builder: ChartBuilder<DB>) {
        let &Self { min, max, .. } = self;
        let step = (max - min) / 256.0;
        let mut chart_context = chart_builder
            .margin_top(10)
            .x_label_area_size(25)
            .y_label_area_size(40)
            .build_cartesian_2d(0.0..1.0, min..max)
            .unwrap();
        chart_context
            .configure_mesh()
            .set_all_tick_mark_size(5)
            .disable_x_axis()
            .disable_x_mesh()
            .disable_y_mesh()
            .axis_style(&text_color)
            .label_style("sans-serif".into_font().color(&text_color))
            .draw()
            .unwrap();
        let plotting_area = chart_context.plotting_area();
        for value in Array1::range(min, max + step, step) {
            let color = self.color(value);
            let rectangle = Rectangle::new(
                [(0.0, value - step / 2.0), (1.0, value + step / 2.0)],
                filled_style(color),
            );
            plotting_area.draw(&rectangle).unwrap();
        }
    }
}

fn meshgrid<A: Clone>(x: &Array1<A>, y: &Array1<A>) -> (Array2<A>, Array2<A>) {
    (
        stack(Axis(0), &vec![x.view(); y.len()]).unwrap(),
        stack(Axis(1), &vec![y.view(); x.len()]).unwrap(),
    )
}

fn heatmap<'a, 'b, 'c, F: Fn(Point) -> f64, DB: DrawingBackend>(
    function: F,
    domain: &'a Domain,
    step: f64,
    colorbar: &Colorbar,
    text_color: RGBAColor,
    mut chart_builder: ChartBuilder<'b, 'c, DB>,
) -> Result<ChartContext<'b, DB, Cartesian2d<RangedCoordf64, RangedCoordf64>>, Box<dyn Error>> {
    if !domain.area().is_normal() || !step.is_normal() {
        return Err("invalid domain or step".into());
    }

    let step = step.abs();
    let x_values = Array1::range(domain.x_min, domain.x_max + step, step);
    let y_values = Array1::range(domain.y_min, domain.y_max + step, step);
    let (x_grid, y_grid) = meshgrid(&x_values, &y_values);

    let mut chart_context = chart_builder
        .margin_top(10)
        .x_label_area_size(25)
        .y_label_area_size(40)
        .build_cartesian_2d(domain.x_min..domain.x_max, domain.y_min..domain.y_max)
        .unwrap();

    chart_context
        .configure_mesh()
        .disable_x_mesh()
        .disable_y_mesh()
        .set_all_tick_mark_size(5)
        .axis_style(&text_color)
        .label_style("sans-serif".into_font().color(&text_color))
        .draw()
        .unwrap();

    let plotting_area = chart_context.plotting_area();

    Zip::from(&x_grid).and(&y_grid).for_each(|&x, &y| {
        let rectangle = Rectangle::new(
            [
                (x - step / 2.0, y - step / 2.0),
                (x + step / 2.0, y + step / 2.0),
            ],
            filled_style(colorbar.color(function(Point { x, y }))),
        );
        plotting_area.draw(&rectangle).unwrap();
    });

    Ok(chart_context)
}

fn main() {
    let drawing_area = BitMapBackend::new("wave.png", (512, 512)).into_drawing_area();
    drawing_area.fill(&WHITE).unwrap();

    let pixel_width = drawing_area.dim_in_pixel().0;
    let (left, right) = drawing_area.split_horizontally(pixel_width - 60);
    let colorbar = Colorbar {
        min: -2.0,
        max: 2.0,
        gradient: BLUES,
    };
    colorbar.draw(BLACK.into(), ChartBuilder::on(&right));

    let mut chart_builder = ChartBuilder::on(&left);
    chart_builder.margin(10);
    let corner = Point { x: 25.0, y: 25.0 };
    let domain = Domain::new(-corner, corner);
    heatmap(
        |p| p.norm().cos(),
        &domain,
        0.1,
        &colorbar,
        BLACK.into(),
        chart_builder,
    )
    .expect("failure while drawing heatmap");

    drawing_area.present().expect("failure while writing file");
}

qoheniac avatar Jul 28 '23 09:07 qoheniac

Thanks a lot, I 'll try it!

wangjiawen2013 avatar Aug 17 '23 07:08 wangjiawen2013