#import "Basic";
#import "String";
#import "File";
#import "Hash_Table";
#import "Math";

solve_day23 :: (test: bool) {
    solve1();
    //solve2_slow();
}

solve2 :: (lines: []string) {
    dimension := Coord.{
        x = lines[0].count,
        y = lines.count,
    };

    assert(dimension.x < 256);
    assert(dimension.y < 256);

    junctions: [64]Junction;
    junction_positions: Table(Coord, u8, hash_coord, comp_coord);

    table_add(*junction_positions, Coord.{ x = dimension.x - 2, y = dimension.y - 1}, 1);
    table_add(*junction_positions, Coord.{ x = 1, y = 0 }, 0);

    find_junctions(lines, 1, 0, dimension, .South, junctions, *junction_positions);

    result := 0;

    {
        second_last_check: u64 = 0;
        for 0..3 {
            junction := junctions[junctions[1].indices[3]];
            if junction.distances[it] && junction.indices[it] != 1 {
                second_last_check |= cast(u64) 1 << cast(u64) junction.indices[it];
            }
        }

        result = find_longest_path2(0, junctions[1].indices[3], 0, junctions, 0, second_last_check);
    }

    print("Part 2: %\n", result + junctions[1].distances[3]);
}

find_longest_path2 :: (current_index: s64, end_index: s64, start_distance: s64, junctions: []Junction, visited: u64, second_last_check: u64) -> s64 {
    if current_index == end_index  return start_distance;
    if (visited >> current_\   index) & 1  return 0;
    if (visited >> second_last_check) & 1  return 0;

    new_visited := visited | cast(u64) 1 << cast(u64) current_index;

    junction := junctions[current_index];

    result := 0;

    for 0..3 {
        if (new_visited >> junction.indices[it]) & 1  continue;

        if junction.distances[it] {
            new_result := find_longest_path2(junction.indices[it], end_index, start_distance + junction.distances[it], junctions, new_visited, second_last_check);
            result = max(result, new_result);
        }
    }

    return result;
}

find_junctions :: (lines: []string, x: s64, y: s64, dimension: Coord, dir: Direction, junctions: []Junction, junction_positions: *Table) {
    distance := 1;
    last_x   := x;
    last_y   := y;
    start_x  := x;
    start_y  := y;

    if dir == {
        case .West;
            dir = .East;
            x -= 1;
        case .East;
            dir = .West;
            x += 1;
        case .North;
            dir = .South;
            y -= 1;
        case .South;
            dir = .North;
            y += 1;
        case .None;
            assert(false);
    }

    while true {
        valid_dirs: u32 = 0;
        dirs           := 0;
        tile_east      := get_tile(lines, x + 1, y,     dimension.x);
        tile_west      := get_tile(lines, x - 1, y,     dimension.x);
        tile_north     := get_tile(lines, x    , y - 1, dimension.x);
        tile_south     := get_tile(lines, x    , y + 1, dimension.x);

        if tile_east != #char "#" {
            if !(last_x == x + 1 && last_y == y) {
                valid_dirs |= cast(u32) Direction.East;
            }
            dirs += 1;
        }
        if tile_west != #char "#" {
            if !(last_x == x - 1 && last_y == y) {
                valid_dirs |= cast(u32) Direction.West;
            }
            dirs += 1;
        }
        if tile_south != #char "#" {
            if !(last_x == x && last_y == y + 1) {
                valid_dirs |= cast(u32) Direction.South;
            }
            dirs += 1;
        }
        if tile_north != #char "#" {
            if !(last_x == x && last_y == y - 1) {
                valid_dirs |= cast(u32) Direction.North;
            }
            dirs += 1;
        }

        assert(dirs > 1 || (x == dimension.x - 2 && y == dimension.y - 1));
        if dirs > 2 || (x == dimension.x - 2 && y == dimension.y - 1) {
            prev_dir_index := 0;

            if last_x + 1 == x && last_y == y        prev_dir_index = 0;
            else if last_x - 1 == x && last_y == y   prev_dir_index = 1;
            else if last_x == x && last_y - 1 == y   prev_dir_index = 2;
            else if last_x == x && last_y + 1 == y   prev_dir_index = 3;
            else assert(false);

            pos := Coord.{
                x = x,
                y = y,
            };
            added := false;

            j, found_pos := table_find(junction_positions, pos);

            if !found_pos {
                amount := cast(u8) junction_positions.count;
                table_add(junction_positions, pos, amount);
                added = true;
            }

            old_pos := Coord.{
                x = start_x,
                y = start_y,
            };

            old_index, found_old_index := table_find(junction_positions, old_pos);

            assert(found_old_index);

            old_junction_dir_index := 0;

            if dir == {
                case .East;
                    old_junction_dir_index = 0;
                case .West;
                    old_junction_dir_index = 1;
                case .North;
                    old_junction_dir_index = 2;
                case .South;
                    old_junction_dir_index = 3;
                case;
                    assert(false);
            }

            index, found_index := table_find(junction_positions, pos);
            assert(found_index);

            old_junction := *junctions[old_index];
            old_junction.indices[old_junction_dir_index]   = index;
            old_junction.distances[old_junction_dir_index] = cast(u16) distance;

            junctions[index].indices[prev_dir_index]   = old_index;
            junctions[index].distances[prev_dir_index] = cast(u16) distance;

            if added {
                assert(distance < 65536);
                if valid_dirs & cast(u32) Direction.East {
                    find_junctions(lines, x, y, dimension, .East, junctions, junction_positions);
                }
                if valid_dirs & cast(u32) Direction.West {
                    find_junctions(lines, x, y, dimension, .West, junctions, junction_positions);
                }
                if valid_dirs & cast(u32) Direction.North {
                    find_junctions(lines, x, y, dimension, .North, junctions, junction_positions);
                }
                if valid_dirs & cast(u32) Direction.South {
                    find_junctions(lines, x, y, dimension, .South , junctions, junction_positions);
                }
            }
            return;
        } else if dirs == 2 {
            last_x = x;
            last_y = y;
            distance += 1;

            if cast(Direction) valid_dirs == {
                case Direction.East;
                    x += 1;
                case Direction.West;
                    x -= 1;
                case Direction.North;
                    y -= 1;
                case Direction.South;
                    y += 1;
                case;
                    assert(false);
            }
        }
    }
}

operator == :: (c1: Coord, c2: Coord) -> bool {
    return c1.x == c2.x && c1.y == c2.y;
}

comp_coord :: (c1: Coord, c2: Coord) -> bool {
    return c1 == c2;
}

hash_coord :: (c: Coord) -> u32 {
    return (cast(u32) c.x) << 8 | cast(u32) c.y;
}

solve2_slow :: () {
    width = 0;
    height = 0;
    input, success := read_entire_file("inputs/day23.txt");

    if !success {
        return;
    }

    lines := split(input, "\n");
    
    for line: lines {
        for i: 0..line.count - 1 {
            if line[i] == #char "#" {
                path_array[height][i].path = line[i];
                path_array[height][i].visited = 1;
            } else {
                path_array[height][i].path = #char ".";
                path_array[height][i].visited = 0;
            }

            width = i + 1;
        }

        height += 1;

    }

    width = lines[0].count - 1;
    height = lines.count - 1;
    start = .{ x = 1, y = 0 };
    end = .{ x = width - 1, y = height };
    path_array[start.y][start.x].visited = 1;

    find_longest_path(start.x, start.y + 1, .South, 1);

    print("Part 1: %\n", longest_path);
}

solve1 :: () {
    input, success := read_entire_file("inputs/day23.txt");

    if !success {
        return;
    }

    lines := split(input, "\n");
    
    for line: lines {
        for i: 0..line.count - 1 {
            if line[i] == #char "#" {
                path_array[height][i].path = line[i];
                path_array[height][i].visited = 1;
            } else {
                path_array[height][i].path = line[i];
                path_array[height][i].visited = 0;
            }

            width = i + 1;
        }

        height += 1;

    }

    width = lines[0].count - 1;
    height = lines.count - 1;
    start = .{ x = 1, y = 0 };
    end = .{ x = width - 1, y = height };
    path_array[start.y][start.x].visited = 1;

    find_longest_path(start.x, start.y + 1, .South, 1);

    print("Part 1: %\n", longest_path);
    solve2(lines);
}

get_tile :: (lines: []string, x: s64, y: s64, x_max: s64) -> u8 {
    assert(x >= 0);
    assert(x <= x_max);
    assert(y >= 0);
    assert(y <= x_max);

    if (y == x_max) return 0;
    if (x == x_max) return 0;
    return lines[y][x];
}

find_longest_path :: (x: s64, y: s64, direction: Direction, visited: s64) {
    path_array[y][x].visited = 1;

    if y == end.y && x == end.x {
        if visited > longest_path {
            longest_path = visited;
        }

        path_array[y][x].visited = 0;

        return;
    }

    if is_valid(x, y - 1, .North) {
        find_longest_path(x, y - 1, .North, visited + 1);
    }

    if is_valid(x, y + 1, .South) {
        find_longest_path(x, y + 1, .South, visited + 1);
    }
    
    if is_valid(x + 1, y, .East) {
        find_longest_path(x + 1, y, .East, visited + 1);
    }

    if is_valid(x - 1, y, .West) {
        find_longest_path(x - 1, y, .West, visited + 1);
    }

    path_array[y][x].visited = 0;
}

is_valid :: (x: s64, y: s64, direction: Direction) -> bool {
    if path_array[y][x].visited == 1  return false;

    if path_array[y][x].path == {
        case #char ">";
            return direction == .East;
        case #char "<";
            return direction == .West;
        case #char "^";
            return direction == .North;
        case #char "v";
            return direction == .South;
        case #char ".";
            return true;
    }

    return true;
}

#scope_file

Path :: struct {
    visited: s64;
    path: u8;
}

Coord :: struct {
    x: s64;
    y: s64;
}

Direction :: enum u32 {
    None  :: 0;
    North :: 8;
    South :: 16;
    East  :: 2;
    West  :: 1;
}

Junction :: struct {
    distances: [4]u16;
    indices: [4]u8;
}

path_array : [200][200]Path;

start: Coord;
end:   Coord;

width        := 0;
height       := 0;
longest_path := 0;

//#import "Compiler";
//build :: (file: string) {
//    #if #run get_current_workspace() > 2 return;
//    print("%\n", get_current_workspace());
//
//    set_build_options_dc(.{do_output = false});
//    build_options := get_build_options();
//    build_options.output_type = .NO_OUTPUT;
//    set_optimization(*build_options, .VERY_OPTIMIZED);
//    build_options.output_type = .EXECUTABLE;
//    build_options.output_executable_name = "day23";
//    build_options.output_path = ".";
//
//    workspace := compiler_create_workspace();
//    #if OS == .WINDOWS {
//        build_options.additional_linker_arguments = .[tprint("/STACK:%", 1024 * 1024 * 8)];
//    }
//    set_build_options(build_options, workspace);
//
//
//    add_build_file(file, workspace);
//
//}
//#run build(#file);