#import "Basic";
#import "String";
#import "File";
#import "Hash_Table";
#import "Math";

solve_day20 :: (test: bool) {

    contents := read_entire_file(ifx test then "inputs/day20_test.txt" else "inputs/day20.txt");
    lines    := split(contents, "\n");
    modules: Table(string, Module);

    broadcast_targets: []string;

    for lines {
        if it.count == 0  continue;
        parts := split(it, " -> ");
        outputs := split(parts[1], ", ");

        if parts[0] == "broadcaster" {
            broadcast_targets = outputs;
        } else {
            table_add(*modules, slice(parts[0], 1, parts[0].count - 1), .{
                type = parts[0][0],
                name = slice(parts[0], 1, parts[0].count - 1),
                outputs = outputs,
            });
        }
    }

    for module, name: *modules {
        for output: module.outputs {
            output_module := table_find_pointer(*modules, output);

            if output_module && output_module.type == #char "&" {
                table_set(*output_module.memory_conj, name, .LOW); 
            }

        }
    }

    solve(broadcast_targets, modules);

}

solve :: (broadcast_targets: []string, modules: Table(string, Module)) {
    high := 0;
    low  := 0;
    deque: Deque(Queue_Entry);

    feed: Module;

    for v,k: modules {
        for v.outputs {
            if it == "rx" {
                feed = v;
                break v;
            }
        }
    }

    seen: Table(string, s64);
    cycle_lengths: Table(string, s64);

    for v,k: feed.memory_conj {
        table_add(*seen, k, 0);

    }

    part_2 := 0;

    for presses: 1..10000 {
        if presses == 1001 {
            print("Part 1: %\n", low * high);
        }

        low += 1;

        for broadcast_targets {
            deque_add_last(*deque, .{
                origin = "boardcaster",
                target = it,
                pulse = .LOW,
            });
        }

        while deque.count > 0 {
            entry := deque_remove_first(*deque);

            if entry.pulse {
                high += 1;
            } else {
                low += 1;
            }

            if !table_contains(*modules, entry.target)  continue;

            target_module := table_find_pointer(*modules, entry.target);

            if entry.target == feed.name && entry.pulse == .HIGH {
                seen_ptr := table_find_pointer(*seen, entry.origin);
                << seen_ptr += 1;

                cycle_ptr := table_find_pointer(*cycle_lengths, entry.origin);

                if !cycle_ptr {
                    table_add(*cycle_lengths, entry.origin, presses);
                }

                all_seen := true;

                for seen_val, _: seen {
                    all_seen = all_seen && seen_val > 0;
                }

                if all_seen {
                    part_2 = 1;
                    for cycle, g: cycle_lengths {
                        part_2 = lcm(part_2, cycle);
                    }

                    break presses;
                }
            }


            if target_module.type == #char "%" {
                if entry.pulse == .LOW {
                    target_module.memory = ifx target_module.memory == .ON then .OFF else .ON;
                    outgoing := ifx target_module.memory == Memory_States.ON then Pulse_States.HIGH else Pulse_States.LOW;

                    for target_module.outputs {
                        deque_add_last(*deque, .{
                            origin = target_module.name,
                            target = it,
                            pulse = outgoing,
                        });
                    }
                }
            } else {
                table_set(*target_module.memory_conj, entry.origin, entry.pulse);
                outgoing := true;

                for v, k: target_module.memory_conj {
                    outgoing = outgoing && v == .HIGH;
                }

                for target_module.outputs {
                    deque_add_last(*deque, .{
                        origin = target_module.name,
                        target = it,
                        pulse = ifx outgoing then Pulse_States.LOW else Pulse_States.HIGH,
                    });
                }
            }
        }
    }

    print("Part 2: %\n", part_2);
}

reset_modules :: (modules: Table(string, Module)) {
    for v,k: *modules {
        if v.type == #char "%" {
            v.memory = .OFF;
        } else {
            for v2, k2: *v.memory_conj {
                v2 = .LOW;
            }
        }
    }
}

Memory_States :: enum u32 {
    OFF :: 0;
    ON  :: 1;
}

Pulse_States :: enum u32 {
    LOW  :: 0;
    HIGH :: 1;
}

Queue_Entry :: struct {
    origin: string;
    target: string;
    pulse: Pulse_States;
}

Module :: struct {
    name: string;
    type: u8;
    memory: Memory_States = .OFF;
    outputs: []string;
    memory_conj: Table(string, Pulse_States);
};