#import bevy_sprite::mesh2d_view_bindings
#import bevy_sprite::mesh2d_bindings mesh

#import bevy_sprite::mesh2d_functions mesh2d_position_local_to_world
#import bevy_sprite::mesh2d_functions mesh2d_position_world_to_clip

struct StarMaterial {
    star_count: i32,
    _padding1: i32,
    _padding1: i32,
    _padding1: i32,
};

@group(1) @binding(0)
var brightness_texture: texture_2d<f32>;
@group(1) @binding(1)
var brightness_sampler: sampler;
@group(1) @binding(2)
var colorscheme: texture_2d<f32>;
@group(1) @binding(3)
var colorscheme_sampler: sampler;
@group(1) @binding(4)
var<uniform> material: StarMaterial;

struct Vertex {
    @location(0) position: vec3<f32>,
    @location(1) uv: vec2<f32>,
    @location(2) color_idx: i32,
    @location(3) star_idx: i32,
};

struct VertexOutput {
    @builtin(position) clip_position: vec4<f32>,
    @location(0) world_position: vec4<f32>,
    @location(1) world_normal: vec3<f32>,
    @location(2) uv: vec2<f32>,
    @location(3) color_idx: i32,
};

@vertex
fn vertex(vertex: Vertex) -> VertexOutput {
    var out: VertexOutput;

    var model = mesh.model;

    var dims = vec2<f32>(textureDimensions(brightness_texture));

    var extent = vec3<f32>((vertex.uv - vec2(0.5)) * vec2(dims.y), 0.0);

    out.world_position = mesh2d_position_local_to_world(mesh.model, vec4<f32>(vertex.position + extent, 1.0));
    out.clip_position = mesh2d_position_world_to_clip(out.world_position);

    out.uv = vec2<f32>((f32(vertex.star_idx) + vertex.uv.x) / f32(material.star_count), vertex.uv.y);
    out.color_idx = vertex.color_idx;

    return out;
}

struct FragmentInput {
    @location(0) world_position: vec4<f32>,
    @location(1) world_normal: vec3<f32>,
    @location(2) uv: vec2<f32>,
    @location(3) color_idx: i32,
}

@fragment
fn fragment(
    in: FragmentInput
) -> @location(0) vec4<f32> {

    var num_colors = f32(textureDimensions(colorscheme).x);
    var col = textureSample(brightness_texture, brightness_sampler, in.uv);

    var replace_col = textureSample(colorscheme, colorscheme_sampler, vec2(round(col.r * num_colors) / num_colors)).rgb;
    var modul_col = textureSample(colorscheme, colorscheme_sampler, vec2(f32(in.color_idx) / num_colors)).rgb;

    if in.color_idx == 0 {
        return vec4(replace_col, col.a);
    } else {
        return vec4(col.rgb * modul_col, col.a);
    }
}



