#import "Basic";
#import "String";
#import "File";
#import "Hash_Table";
#import "Hash";
#import "Math";

dummy_void: void;

solve_day21 :: (test: bool) {
    contents := read_entire_file("inputs/day21.txt");
    lines    := split(contents, "\n");

    max_steps := 26501365;

    assert(lines.count == lines[0].count);
    assert(max_steps % lines.count == lines.count / 2);

    part_1 := walk(lines, .{x = 65, y = 65, steps = 64});

    print("Part 1: %\n", part_1);

    num_cells: s64 = max_steps / lines.count - 1;

    odd_cells  := pow(num_cells / 2 * 2 + 1, 2); 
    even_cells := pow((num_cells + 1) / 2 * 2, 2); 

    odd_points  := walk(lines, .{ x = 65, y = 65, steps = lines.count * 2 + 1 });
    even_points := walk(lines, .{ x = 65, y = 65, steps = lines.count * 2 });

    top_point    := walk(lines, .{ x = lines.count - 1, y = 65,                 steps = lines.count - 1}); 
    bottom_point := walk(lines, .{ x = 0,               y = 65,                 steps = lines.count - 1}); 
    left_point   := walk(lines, .{ x = 65,              y = lines[0].count - 1, steps = lines.count - 1}); 
    right_point  := walk(lines, .{ x = 65,              y = 0,                  steps = lines.count - 1}); 

    top_right_first    := walk(lines, .{ x = lines.count - 1, y = 0,               steps = (lines.count - 1) / 2 - 1 });
    top_left_first     := walk(lines, .{ x = lines.count - 1, y = lines.count - 1, steps = (lines.count - 1) / 2 - 1 });
    bottom_right_first := walk(lines, .{ x = 0,               y = 0,               steps = (lines.count - 1) / 2 - 1 });
    bottom_left_first  := walk(lines, .{ x = 0,               y = lines.count - 1, steps = (lines.count - 1) / 2 - 1 });

    top_right_second    := walk(lines, .{ x = lines.count - 1, y = 0,               steps = lines.count * 3 / 2 - 1 });
    top_left_second     := walk(lines, .{ x = lines.count - 1, y = lines.count - 1, steps = lines.count * 3 / 2 - 1 });
    bottom_right_second := walk(lines, .{ x = 0,               y = 0,               steps = lines.count * 3 / 2 - 1 });
    bottom_left_second  := walk(lines, .{ x = 0,               y = lines.count - 1, steps = lines.count * 3 / 2 - 1 });

    part_2 := odd_cells * odd_points +
              even_cells * even_points +
              top_point + bottom_point + left_point + right_point +
              (num_cells + 1) * (top_right_first + top_left_first + bottom_left_first + bottom_right_first) +
              num_cells * (top_right_second + top_left_second + bottom_left_second + bottom_right_second);


    print("Part 2: %\n", part_2);
             
}

pow :: (base: s64, exp: s64) -> s64 {
    result: s64 = 1;

    for 0..exp - 1{
        result = result * base;
    }

    return result;

}

walk :: (grid: []string, start_pos: Coord) -> s64 {
    Coord_Table :: Table(Coord, bool, hash_tt, (c1, c2) => c1 == c2);
    result: Coord_Table;
    seen: Coord_Table;
    queue: [..]Coord;

    t: *int = null;

    table_add(*seen, start_pos, false);

    array_add(*queue, start_pos);

    while queue.count > 0 {
        item := queue[0];
        array_ordered_remove_by_index(*queue, 0);

        if item.steps % 2 == 0  table_add(*result, item, false);

        if item.steps == 0  continue;

        for coord: Coord.[.{x = item.x + 1, y = item.y}, .{x = item.x - 1, y = item.y}, .{x = item.x, y = item.y + 1}, .{x = item.x, y = item.y - 1}] {

            if coord.y < 0 || coord.y >= grid.count || coord.x < 0 || coord.x >= grid[0].count || grid[coord.y][coord.x] == #char "#" || table_contains(*seen, coord) {


                continue;
            }

            new := coord;
            new.steps = item.steps - 1;
            table_add(*seen, new, false);

            array_add(*queue, new);
       }
    }

    return result.count;
}

operator == :: (c1: Coord, c2: Coord) -> bool {
    return c1.x == c2.x && c1.y == c2.y;
}

operator + :: (c1: Coord, c2: Coord) -> Coord {
    return .{ x = c1.x + c2.x, y = c1.y + c2.y };
}

hash_tt :: (c: Coord) -> u32 {
    return get_hash(c.x) ^ get_hash(c.y);
}

to_s64 :: (c: Coord) -> s64 {
    return c.x + c.y * 256;
}

print_coord :: (c: Coord) {
    print("%, %, %\n", c.x, c.y, c.steps);
}

#scope_file
Coord :: struct {
    x: s64;
    y: s64;
    steps: s64 = 0;
}