{-# LANGUAGE TypeFamilies, MultiWayIf #-}
import qualified Data.Vector as V
import qualified Data.IntMap as IntMap
import Control.Monad
import Control.Monad.State
import Data.List
import System.IO
import Data.Tuple.Homogenous
import Data.Foldable
import Text.Printf (printf)

-- | Number of states
dim :: Int
dim = 4

type Vec = V.Vector LaurentPolynomial

newtype LaurentPolynomial = LaurentPolynomial (IntMap.IntMap Integer)
  deriving (Show)

instance Num LaurentPolynomial where
  (LaurentPolynomial a) + (LaurentPolynomial b) = LaurentPolynomial $ IntMap.unionWith (+) a b
  (LaurentPolynomial a) * (LaurentPolynomial b) = LaurentPolynomial $ IntMap.fromListWith (+)
    [ (i + j, a IntMap.! i * b IntMap.! j) | i <- IntMap.keys a, j <- IntMap.keys b ]
  negate (LaurentPolynomial a) = LaurentPolynomial $ IntMap.map negate a
  fromInteger 0 = LaurentPolynomial IntMap.empty
  fromInteger n = LaurentPolynomial $ IntMap.singleton 0 n
  abs = undefined
  signum = undefined

normalize :: LaurentPolynomial -> LaurentPolynomial
normalize (LaurentPolynomial a) = LaurentPolynomial $ IntMap.filter (/= 0) a

prettyPolynomial :: LaurentPolynomial -> String
prettyPolynomial p =
  let LaurentPolynomial a = normalize p
  in if IntMap.null a
      then "0"
      else
        intercalate " + "
        [ (if | i == 0 -> show c
              | c == 1 -> "x^" ++ show i
              | c == -1 -> "-x^" ++ show i
              | otherwise -> show c ++ " x^" ++ show i
          )
        | (i, c) <- IntMap.toAscList a, c /= 0]

-- | A dim x dim matrix stored in a row-major order.
newtype Matrix = Matrix Vec
newtype RowVector = RowVector Vec
newtype ColumnVector = ColumnVector Vec

getRow :: Matrix -> Int -> RowVector
getRow (Matrix a) i = RowVector $ V.slice (i * dim) dim a

getColumn :: Matrix -> Int -> ColumnVector
getColumn (Matrix a) i = ColumnVector $ V.generate dim $ \j -> a V.! (j * dim + i)

prettyMatrix :: Matrix -> String
prettyMatrix (Matrix a) =
  unlines
  [ (if i == 0 then "[ " else ", ") ++
    intercalate ", " [ prettyPolynomial $ a V.! (i * dim + j) | j <- [0..dim-1] ]
  | i <- [0..dim-1]
  ]
  ++ "]"

class DotProduct a b where
  type Result a b
  (<.>) :: a -> b -> Result a b
  infixl 7 <.>
instance DotProduct RowVector ColumnVector where
  type Result RowVector ColumnVector = LaurentPolynomial
  (RowVector a) <.> (ColumnVector b) = V.sum $ V.zipWith (*) a b
instance DotProduct RowVector Matrix where
  type Result RowVector Matrix = RowVector
  a <.> (Matrix b) = RowVector $
    V.generate dim $ \i -> a <.> getColumn (Matrix b) i
instance DotProduct Matrix Matrix where
  type Result Matrix Matrix = Matrix
  (Matrix a) <.> (Matrix b) = Matrix $
    V.generate (dim * dim) $ \i ->
      let (q, r) = i `divMod` dim
      in getRow (Matrix a) q <.> getColumn (Matrix b) r

extractOutcomes :: LaurentPolynomial -> Tuple3 Integer
extractOutcomes (LaurentPolynomial a) =
  let l = IntMap.toList a
      conds = Tuple3 ((> 0), (< 0), (== 0))
  in
      (\cond -> sum [ c | (i, c) <- l, cond i ]) <$> conds

x :: LaurentPolynomial
x = LaurentPolynomial $ IntMap.singleton 1 1

x_inv :: LaurentPolynomial
x_inv = LaurentPolynomial $ IntMap.singleton (-1) 1

transition_matrix :: Matrix
transition_matrix = Matrix $ V.fromList
  [ 1, 1, 0, 0
  , 0, 0, x, x_inv
  , 0, 0, x, x_inv
  , 1, 1, 0, 0
  ]

row_vec :: RowVector
row_vec = RowVector $ V.fromList [1, 0, 0, 0]

col_vec :: ColumnVector
col_vec = ColumnVector $ V.fromList [1, 1, 1, 1]

names :: Tuple3 String
names = tuple3 "Alice" "Bob" "Tie"

main = do
  withFile "alice.csv" WriteMode $ \csv -> do
    hPutStrLn csv "n,name,prob"
    execStateT (loop csv) (transition_matrix, 1)
loop csv = do
  (a, n :: Int) <- get
  when (n <= 100) $ do
    let p = row_vec <.> a <.> col_vec
        outcomes = extractOutcomes p
        outcome_probs :: Tuple3 Double
        outcome_probs = flip fmap outcomes $ \w -> fromIntegral w / 2 ** (fromIntegral n)

    lift $ sequence_ $
      (\name prob -> hPutStrLn csv $ intercalate "," [show n, name, show prob]) <$> names <*> outcome_probs

    when (n == 100) $ do
      lift $ putStrLn $ prettyPolynomial p
      lift . sequence_ $ do
        (\name win win_approx ->
          printf "%s with probability %d/2^%d ≈ %.4f\n"
            name win n win_approx
          ) <$> names <*> outcomes <*> outcome_probs

    put (a <.> transition_matrix, n + 1)
    loop csv