#import "Basic";
#import "File";
#import "String";
#import "Math";

solve_day16 :: (test: bool) {

    contents := read_entire_file(ifx test then "inputs/day16_test.txt" else "inputs/day16.txt");
    lines := split(contents, "\n"[0]);

    builder: String_Builder;

    for lines {
        print_to_builder(*builder, it);
    }

    grid: Grid = .{
        width = lines[0].count,
        height = lines.count,
        grid = builder_to_string(*builder),
    };

    walked: []u8 = NewArray(grid.grid.count, u8);

    assert(grid.height == grid.width, "Grid not square. Width: %, height: %", grid.width, grid.height);

    walk_beam(grid, walked, .{ x = 1, y = 0 }, .{ x = 0, y = 0 });

    part1: u64 = 0;
    part2: u64 = 0;

    for walked {
        if it > 0  part1 += 1;
    }

    memset(walked.data, 0, walked.count);

    for start_pos_x: 0..grid.width {
        walk_beam(grid, walked, .{ x = 0, y = 1 }, .{ x = xx start_pos_x, y = 0 });

        h: u64 = 0;

        for walked {
            if it > 0  h += 1;
        }
        memset(walked.data, 0, walked.count);
        part2 = max(part2, h);
    }
    for start_pos_x: 0..grid.width {
        walk_beam(grid, walked, .{ x = 0, y = -1 }, .{ x = xx start_pos_x, y = xx (grid.height - 1) });

        h: u64 = 0;

        for walked {
            if it > 0  h += 1;
        }
        memset(walked.data, 0, walked.count);
        part2 = max(part2, h);
    }
    for start_pos_y: 0..grid.height {
        walk_beam(grid, walked, .{ x = -1, y = 0 }, .{ x = xx (grid.width - 1), y = xx start_pos_y });

        h: u64 = 0;

        for walked {
            if it > 0  h += 1;
        }
        memset(walked.data, 0, walked.count);
        part2 = max(part2, h);
    }
    for start_pos_y: 0..grid.height {
        walk_beam(grid, walked, .{ x = 1, y = 0 }, .{ x = 0, y = xx start_pos_y });

        h: u64 = 0;

        for walked {
            if it > 0  h += 1;
        }
        memset(walked.data, 0, walked.count);
        part2 = max(part2, h);
    }

    print("Part 1: %\n", part1);
    print("Part 2: %\n", part2);

}

walk_beam :: (grid: Grid, walked: []u8, direction: Direction, position: Direction) {
    walk_pos := position;
    walk_dir := direction;

    while walk_pos.x < grid.width && walk_pos.x >= 0 && walk_pos.y < grid.height && walk_pos.y >= 0 {
        wx := walk_pos.x;
        wy := walk_pos.y;

        flag := dir_to_bit_flag(walk_dir);

        if walked[wy * grid.width + wx] & flag == flag {
            break;
        }

        walked[wy * grid.width + wx] |= dir_to_bit_flag(walk_dir);
        
        cur := grid.grid[wy * grid.width + wx];

        if cur == #char "/" {
            walk_dir = forward_mirror(walk_dir);
        } else if cur == #char "\\" {
            walk_dir = backward_mirror(walk_dir);
        }

        if cur == #char "-" {
            p := Direction.{ x = walk_pos.x - 1, y = walk_pos.y };
            walk_beam(grid, walked, .{ x = -1, y = 0 }, p);
            p.x += 2;
            walk_beam(grid, walked, .{ x = 1, y = 0 }, p);
            break;
        } else if cur == #char "|" {
            p := Direction.{ x = walk_pos.x, y = walk_pos.y - 1 };
            walk_beam(grid, walked, .{ x = 0, y = -1 }, p);
            p.y += 2;
            walk_beam(grid, walked, .{ x = 0, y = 1 }, p);
            break;
        }
        walk_pos.x += walk_dir.x;
        walk_pos.y += walk_dir.y;
    }
}

forward_mirror :: (dir: Direction) -> Direction {
    if dir.x == 0 && dir.y == 1 {
        return .{ x = -1, y = 0 };
    } else if dir.x == 0 && dir.y == -1 {
        return .{ x = 1, y = 0 };
    } else if dir.x == 1 && dir.y == 0 {
        return .{ x = 0, y = -1 };
    } else if dir.x == -1 && dir.y == 0 {
        return .{ x = 0, y = 1 };
    }

    assert(false, "Bad direction %", dir);

    return .{};
}

backward_mirror :: (dir: Direction) -> Direction {
    if dir.x == 0 && dir.y == 1 {
        return .{ x = 1, y = 0 };
    } else if dir.x == 0 && dir.y == -1 {
        return .{ x = -1, y = 0 };
    } else if dir.x == 1 && dir.y == 0 {
        return .{ x = 0, y = 1 };
    } else if dir.x == -1 && dir.y == 0 {
        return .{ x = 0, y = -1 };
    }

    assert(false, "Bad direction %", dir);

    return .{};
}

dir_to_bit_flag :: (dir: Direction) -> u8 {
    if dir.x == 1 && dir.y == 0 {
        return 2;
    } else if dir.x == -1 && dir.y == 0 {
        return 4;
    } else if dir.x == 0 && dir.y == 1 {
        return 8;
    } else if dir.x == 0 && dir.y == -1 {
        return 16;
    }

    return 0;
}