/*
 * Decompiled with CFR 0.152.
 */
package com.expedient.adventofcodejade.solutions.year2024;

import com.expedient.adventofcodejade.BaseSolution;
import com.expedient.adventofcodejade.Triplet;
import com.expedient.adventofcodejade.common.Coordinate;
import com.expedient.adventofcodejade.common.Direction;
import com.expedient.adventofcodejade.common.Grid;
import com.expedient.adventofcodejade.common.Pair;
import com.expedient.adventofcodejade.common.PuzzleInput;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;

public class SolutionDay16
extends BaseSolution {
    public SolutionDay16(PuzzleInput input, PuzzleInput sampleInputOne, PuzzleInput sampleInputTwo) {
        super(input, sampleInputOne, sampleInputTwo);
    }

    public static boolean isTraversable(Character c) {
        return c.charValue() == '.' || c.charValue() == 'S' || c.charValue() == 'E';
    }

    public static List<Pair<Coordinate, Integer>> getNeighbors(Grid<Character> grid, Coordinate current, Direction currentDirection, Set<Pair<Coordinate, Direction>> visited) {
        List<Coordinate> directNeighbors = grid.matchNeighbors(current, SolutionDay16::isTraversable, true);
        directNeighbors = directNeighbors.stream().filter(c -> !visited.contains(new Pair<Coordinate, Direction>((Coordinate)c, current.directionToCoordinate((Coordinate)c)))).toList();
        ArrayList<Pair<Coordinate, Integer>> weightedNeighbors = new ArrayList<Pair<Coordinate, Integer>>();
        for (Coordinate neighbor : directNeighbors) {
            int cost = 2;
            if ((currentDirection == Direction.RIGHT || currentDirection == Direction.LEFT) && neighbor.row() != current.row() || (currentDirection == Direction.UP || currentDirection == Direction.DOWN) && neighbor.col() != current.col()) {
                cost += 1000;
            }
            weightedNeighbors.add(new Pair<Coordinate, Integer>(neighbor, cost));
        }
        return weightedNeighbors;
    }

    public Map<Pair<Coordinate, Direction>, List<Pair<Coordinate, Direction>>> findAllMazeSolutions(Grid<Character> grid, Coordinate startPoint) {
        HashSet<Pair<Coordinate, Direction>> visited = new HashSet<Pair<Coordinate, Direction>>();
        PriorityQueue<Triplet<Coordinate, Integer, Direction>> queue = new PriorityQueue<Triplet<Coordinate, Integer, Direction>>(new NeighborComparator());
        HashMap<Pair<Coordinate, Direction>, Integer> distance = new HashMap<Pair<Coordinate, Direction>, Integer>();
        List<Coordinate> mazeCoords = grid.matchCoordinates(SolutionDay16::isTraversable);
        for (Coordinate coordinate : mazeCoords) {
            if (coordinate.equals(startPoint)) {
                distance.put(new Pair<Coordinate, Direction>(startPoint, Direction.RIGHT), 0);
                continue;
            }
            for (Direction direction : Direction.all()) {
                distance.put(new Pair<Coordinate, Direction>(coordinate, direction), 0x3FFFFFFF);
            }
        }
        HashMap<Pair<Coordinate, Direction>, List<Pair<Coordinate, Direction>>> previous = new HashMap<Pair<Coordinate, Direction>, List<Pair<Coordinate, Direction>>>();
        for (Coordinate c : mazeCoords) {
            for (Direction d : Direction.all()) {
                previous.put(new Pair<Coordinate, Direction>(c, d), new ArrayList());
            }
        }
        queue.add(new Triplet<Coordinate, Integer, Direction>(startPoint, 0, Direction.RIGHT));
        while (!queue.isEmpty()) {
            Triplet triplet = (Triplet)queue.poll();
            Coordinate coordinate = (Coordinate)triplet.one();
            Direction currentDirection = (Direction)((Object)triplet.three());
            int queuedDistance = (Integer)triplet.two();
            if (queuedDistance > (Integer)distance.get(new Pair<Coordinate, Direction>(coordinate, currentDirection))) continue;
            List<Pair<Coordinate, Integer>> neighbors = SolutionDay16.getNeighbors(grid, coordinate, currentDirection, visited);
            for (Pair<Coordinate, Integer> neighbor : neighbors) {
                if (neighbor.one().equals(startPoint)) continue;
                Direction newDirection = coordinate.directionToCoordinate(neighbor.one());
                int newDistance = (Integer)distance.get(new Pair<Coordinate, Direction>(coordinate, currentDirection)) + neighbor.two();
                if (newDistance > (Integer)distance.get(new Pair<Coordinate, Direction>(neighbor.one(), newDirection))) continue;
                if (newDistance < (Integer)distance.get(new Pair<Coordinate, Direction>(neighbor.one(), newDirection))) {
                    distance.put(new Pair<Coordinate, Direction>(neighbor.one(), newDirection), newDistance);
                    previous.put(new Pair<Coordinate, Direction>(neighbor.one(), newDirection), new ArrayList());
                }
                List list = (List)previous.get(new Pair<Coordinate, Direction>(neighbor.one(), newDirection));
                list.add(new Pair<Coordinate, Direction>(coordinate, currentDirection));
                queue.add(new Triplet<Coordinate, Integer, Direction>(neighbor.one(), newDistance, newDirection));
            }
            visited.add(new Pair<Coordinate, Direction>(coordinate, currentDirection));
        }
        return previous;
    }

    @Override
    public Integer partOne(PuzzleInput input) {
        SolutionDay16Input in = SolutionDay16Input.fromPuzzleInput(input);
        Map<Pair<Coordinate, Direction>, List<Pair<Coordinate, Direction>>> previous = this.findAllMazeSolutions(in.grid(), in.startPoint());
        Coordinate currentPosition = in.endPoint();
        ArrayList<Pair<Coordinate, Direction>> pathTaken = new ArrayList<Pair<Coordinate, Direction>>();
        for (Direction d : Direction.all()) {
            try {
                Direction currentDirection = d;
                while (!currentPosition.equals(in.startPoint())) {
                    List<Pair<Coordinate, Direction>> current = previous.get(new Pair<Coordinate, Direction>(currentPosition, currentDirection));
                    pathTaken.add(new Pair<Coordinate, Direction>(currentPosition, currentDirection));
                    currentPosition = current.getFirst().one();
                    currentDirection = current.getFirst().two();
                }
                break;
            }
            catch (NullPointerException nullPointerException) {
            }
        }
        Direction cd = Direction.RIGHT;
        int total = 0;
        for (Pair pair : pathTaken) {
            if (pair.two() != cd) {
                total += 1000;
                cd = (Direction)((Object)pair.two());
            }
            ++total;
        }
        return total;
    }

    @Override
    public Integer partTwo(PuzzleInput input) {
        SolutionDay16Input in = SolutionDay16Input.fromPuzzleInput(input);
        Map<Pair<Coordinate, Direction>, List<Pair<Coordinate, Direction>>> previous = this.findAllMazeSolutions(in.grid(), in.startPoint());
        HashSet<Coordinate> seats = new HashSet<Coordinate>();
        Coordinate currentPosition = in.endPoint();
        for (Direction d : Direction.all()) {
            LinkedList<Pair<Coordinate, Direction>> toCheck = new LinkedList<Pair<Coordinate, Direction>>();
            HashSet<Pair> alreadyChecked = new HashSet<Pair>();
            toCheck.add(new Pair<Coordinate, Direction>(currentPosition, d));
            while (!toCheck.isEmpty()) {
                Pair c = (Pair)toCheck.poll();
                if (alreadyChecked.contains(c)) continue;
                alreadyChecked.add(c);
                currentPosition = (Coordinate)c.one();
                Direction currentDirection = (Direction)((Object)c.two());
                seats.add(currentPosition);
                if (currentPosition == in.startPoint()) continue;
                List<Pair<Coordinate, Direction>> stepList = previous.get(new Pair<Coordinate, Direction>(currentPosition, currentDirection));
                for (Pair<Coordinate, Direction> step : stepList) {
                    toCheck.add(new Pair<Coordinate, Direction>(step.one(), step.two()));
                }
            }
        }
        return seats.size();
    }

    public static class NeighborComparator
    implements Comparator<Triplet<Coordinate, Integer, Direction>> {
        @Override
        public int compare(Triplet<Coordinate, Integer, Direction> o1, Triplet<Coordinate, Integer, Direction> o2) {
            if (o1.two() < o2.two()) {
                return -1;
            }
            if (o1.two() > o2.two()) {
                return 1;
            }
            return 0;
        }
    }

    public record SolutionDay16Input(Grid<Character> grid, Coordinate startPoint, Coordinate endPoint) {
        public static SolutionDay16Input fromPuzzleInput(PuzzleInput input) {
            Grid<Character> grid = Grid.fromStringList(input.getLines());
            Coordinate endPoint = grid.matchCoordinates(c -> c.charValue() == 'E').getFirst();
            Coordinate startPoint = grid.matchCoordinates(c -> c.charValue() == 'S').getFirst();
            return new SolutionDay16Input(grid, startPoint, endPoint);
        }
    }
}

