#import "Basic";
#import "File";
#import "String";
#import "Math";
#import "Sort";
#import "Hash";
#import "Hash_Table";

cmp_path :: (a: Path, b: Path) -> s64 {
    return cast(s64)b.heat_loss - cast(s64)a.heat_loss;
}

void_dummy: void;

solve_day17 :: (test: bool) {
    contents := read_entire_file(ifx test then "inputs/day17_test.txt" else "inputs/day17.txt");
    lines := split(contents, "\n");

    assert(lines.count == lines[0].count);

    walked: Table(Stepped_On, bool, hash, cmp_stepped_on);

    builder: String_Builder;

    queue: Priority_Queue(Path, (p1, p2)  => cast(s64)p1.heat_loss - cast(s64)p2.heat_loss);

    add(*queue, .{ x = 0, y = 0, d_x = 1, d_y = 0, heat_loss = 0, steps_left = 2});
    add(*queue, .{ x = 0, y = 0, d_x = 0, d_y = 1, heat_loss = 0, steps_left = 2});

    final_heat_loss: u32 = 0;

    for lines {
        print_to_builder(*builder, it);
    }

    grid := Grid.{
        width = lines[0].count,
        height = lines.count,
        grid = builder_to_string(*builder),
    };

    while queue.count > 0 {
        removed, item := remove_top(*queue);
        d, success := walk(item, *queue, grid, *walked, false, false);
        if success {
            if d.heat_loss >= 0 {
                final_heat_loss = d.heat_loss;
                break;
            }
        }
    }

    print("Part 1: %\n", final_heat_loss);

    deinit(*queue);

    table_reset(*walked);

    final_heat_loss = 0;

    add(*queue, .{ x = 0, y = 0, d_x = 1, d_y = 0, heat_loss = 0, steps_left = 9});
    add(*queue, .{ x = 0, y = 0, d_x = 0, d_y = 1, heat_loss = 0, steps_left = 9});

    while queue.count > 0 {
        removed, item := remove_top(*queue);
        d, success := walk(item, *queue, grid, *walked, false, true);
        if success {
            if d.heat_loss >= 0 {
                final_heat_loss = d.heat_loss;
                break;
            }
        }
    }
    print("Part 2: %\n", final_heat_loss);

}

walk :: (path: Path, queue: *Priority_Queue, grid: Grid, walked: *Table, history: bool, part2: bool) -> (path: Path, success: bool) {
    required_steps_left :: 6;

    dummy: Path = ---;

    if path.x + path.d_x >= grid.width || path.x + path.d_x < 0  return dummy, false;
    if path.y + path.d_y >= grid.height || path.y + path.d_y < 0  return dummy, false;

    new_pos := path;

    new_pos.x += new_pos.d_x;
    new_pos.y += new_pos.d_y;

    //if history {
    //    array_add(*new_post, .{ x = new_pos.x, y = new_pos.y, new_pos.d_x = 
    //}

    new_pos.heat_loss += grid.grid[new_pos.y * grid.width + new_pos.x] - #char "0";

    map_key := Stepped_On.{ x = new_pos.x, y = new_pos.y, d_x = new_pos.d_x, d_y = new_pos.d_y };

    if new_pos.steps_left > 0 {
        new_pos.steps_left -= 1;
        add(queue, new_pos);
    }

    if new_pos.x == grid.width - 1 && new_pos.y == grid.height - 1 && !(part2 && path.steps_left > required_steps_left)  return new_pos, true;

    if !table_contains(walked, map_key) {

        if new_pos.steps_left < required_steps_left {
            table_add(walked, map_key, true);
            new_pos.steps_left = xx ifx part2 then 9 else 2;
            new_pos.d_x = path.d_y;
            new_pos.d_y = path.d_x;

            add(queue, new_pos);
            new_pos.d_x *= -1;
            new_pos.d_y *= -1;
            add(queue, new_pos);
        }
    }

    return dummy, false;
}

hash :: (s: Stepped_On) -> u32 {
    return get_hash(s.x) ^ get_hash(s.y) ^ get_hash(s.d_x) ^ get_hash(s.d_y);
}

cmp_stepped_on :: (s1: Stepped_On, s2: Stepped_On) -> bool {
    return s1.x == s2.x && s1.y == s2.y && s1.d_x == s2.d_x && s1.d_y == s2.d_y;
}

Stepped_On :: struct {
    x: s32;
    y: s32;
    d_x: s32;
    d_y: s32;
}

Path :: struct {
    x: s32;
    y: s32;
    d_x: s32;
    d_y: s32;
    heat_loss: u32;
    steps_left: u32;
    history: [..][5]s32;
}