aboutsummaryrefslogtreecommitdiffstats
path: root/ecc.hs
blob: 19fe1a04edb16d70b251783684b1159a252d57e5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE FlexibleInstances   #-}
{-# LANGUAGE KindSignatures      #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators       #-}

import           Data.Bits
import           Data.Proxy
import           GHC.TypeLits
import           Text.Printf  (PrintfArg, printf)

-- FiniteFields
--https://stackoverflow.com/questions/39823408/prime-finite-field-z-pz-in-haskell-with-operator-overloading
newtype FieldElement (n :: Nat) = FieldElement Integer deriving Eq

instance KnownNat n => Num (FieldElement n) where
  FieldElement x + FieldElement y = fromInteger $ x + y
  FieldElement x * FieldElement y = fromInteger $ x * y
  abs x = x
  signum _ = 1
  negate (FieldElement x) = fromInteger $ negate x
  fromInteger a = FieldElement (mod a n) where n = natVal (Proxy :: Proxy n)

instance KnownNat n => Fractional (FieldElement n) where
  recip a = a ^ (n - 2) where n = natVal (Proxy :: Proxy n)
  fromRational r = error "cant transform" -- fromInteger (numerator r) / fromInteger (denominator r)

instance KnownNat n => Show (FieldElement n) where
  show (FieldElement a) | n == (2 ^ 256 - 2 ^ 32 - 977) = printf "0x%064x" a
                        | otherwise = "FieldElement_" ++ show n ++ " " ++ show a
    where n = natVal (Proxy :: Proxy n)


assert :: Bool -> Bool
assert False = error "WRONG"
assert x     = x

aa =
  let a = FieldElement 2 :: FieldElement 31
      b = FieldElement 15
  in  (a + b == FieldElement 17, a /= b, a - b == FieldElement 18)

bb =
  let a = FieldElement 19 :: FieldElement 31
      b = FieldElement 24
  in  a * b

-- Elliptic curve
data ECPoint a
  = Infinity
  | ECPoint
      { x :: a
      , y :: a
      , a :: a
      , b :: a
      }
  deriving (Eq)


instance {-# OVERLAPPABLE #-} (PrintfArg a, Num a) => Show (ECPoint a) where
  show Infinity = "ECPoint(Infinity)"
  show p        = printf "ECPoint(%f, %f)_%f_%f" (x p) (y p) (a p) (b p)

instance {-# OVERLAPPING  #-} KnownNat n => Show (ECPoint (FieldElement n)) where
  show Infinity = "ECPoint(Infinity)"
  show p | n == (2 ^ 256 - 2 ^ 32 - 977) = "S256Point" ++ points
         | otherwise = "ECPoint_" ++ show n ++ points ++ params
   where
    n      = natVal (Proxy :: Proxy n)
    points = "(" ++ si (x p) ++ ", " ++ si (y p) ++ ")"
    params = "a_" ++ si (a p) ++ "|b_" ++ si (b p)
    si (FieldElement r) | n == (2 ^ 256 - 2 ^ 32 - 977) = printf "0x%064x" r
                        | otherwise                     = show r

validECPoint :: (Eq a, Num a) => ECPoint a -> Bool
validECPoint Infinity          = True
validECPoint (ECPoint x y a b) = y ^ 2 == x ^ 3 + a * x + b

add :: (Eq a, Fractional a) => ECPoint a -> ECPoint a -> ECPoint a
add Infinity p        = p
add p        Infinity = p
add p q | a p /= a q || b p /= b q = error "point not on same curve"
        | x p == x q && y p /= y q = Infinity
        | x p /= x q               = new_point $ (y q - y p) / (x q - x p)
        | x p == x q && y p == 0   = Infinity
        | p == q                   = new_point $ (3 * x p ^ 2 + a p) / (2 * y p)
        | otherwise                = error "Unexpected case of points"
 where
  new_point slope =
    let new_x = slope ^ 2 - x p - x q
        new_y = slope * (x p - new_x) - y p
    in  ECPoint new_x new_y (a p) (b p)


binaryExpansion :: (Semigroup a) => Integer -> a -> a -> a
binaryExpansion m value result
  | m == 0    = result
  | otherwise = binaryExpansion (m `shiftR` 1) (value <> value) accumulator
  where accumulator = if m .&. 1 == 1 then result <> value else result

scalarProduct :: (Eq a, Fractional a) => Integer -> ECPoint a -> ECPoint a
scalarProduct m ec = binaryExpansion m ec Infinity

instance (Eq a, Fractional a) => Semigroup (ECPoint a) where
  (<>) = add

instance (Eq a, Fractional a) => Monoid (ECPoint a) where
  mempty = Infinity

tre = FieldElement 3 :: FieldElement 31
cc =
  let a = ECPoint tre (-7) 5 7
      b = ECPoint 18 77 5 7
      c = ECPoint (-1) (-1) 5 7
  in  ( validECPoint a
      , validECPoint b
      , validECPoint c
      , a /= b
      , a == a
      , add Infinity          a
      , add a                 (ECPoint 3 7 5 7)
      , add (ECPoint 3 7 5 7) c
      , add c                 c
      )

dd =
  let a = FieldElement 0 :: FieldElement 223
      b = FieldElement 7
      x = FieldElement 192
      y = FieldElement 105
  in  ECPoint x y a b
ee = ECPoint 192 105 (FieldElement 0 :: FieldElement 223) 7
ff = ECPoint 192 105 0 7 :: ECPoint (FieldElement 223)

aPoint = ECPoint 192 105 0 7 :: ECPoint (FieldElement 223)
total = add aPoint $ add aPoint $ add aPoint $ add aPoint aPoint

totalfold = foldr add Infinity $ replicate 5 aPoint
totalmconcat = mconcat $ replicate 5 aPoint

type S256Field = FieldElement (2 ^ 256- 2^ 32 - 977)
type NField
  = FieldElement
      0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141
type S256Point = ECPoint S256Field
s256point :: S256Field -> S256Field -> S256Point
s256point x y =
  let p = ECPoint x y 0 7
  in  if validECPoint p then p else error "Invalid point"
li :: S256Field
li = 12
ll :: ECPoint (FieldElement 31)
ll = Infinity
ri = ECPoint 3 7 5 7 :: S256Point


ncons = 0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141
gcons = s256point
  0x79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798
  0x483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8

asInt :: KnownNat n => FieldElement n -> Integer
asInt (FieldElement n) = n

-- z = 0xbc62d4b80d9e36da29c16c5d4d9f11731f36052c72401a76c23c0fb5a9b74423
-- r = 0x37206a0610995c58074999cb9767b87af4c4978db68c06e8e6e81d282047a7c6
-- s = 0x8ca63759c1157ebeaec0d03cecca119fc9a75bf8e6d0fa65c841c8e2738cdaec ::NField
-- px = 0x04519fac3d910ca7e7138f7013706f619fa8f033e6ec6e09370ea38cee6a7574
-- py = 0x82b51eab8c27c66e26c858a079bcdf4f1ada34cec420cafc7eac1a42216fb6c4
-- point = s256point px py
-- u = z / s
-- v = r / s
-- signa = scalarProduct (asInt u) gcons <> scalarProduct (asInt v) point

pub = s256point
  0x887387e452b8eacc4acfde10d9aaf7f6d9a0f975aabb10d006e4da568744d06c
  0x61de6d95231cd89026e286df3b6ae4a894a3378e393e93a0f45b666329a0ae34

z1 = 0xec208baa0fc1c19f708a9ca96fdeff3ac3f230bb4a7ba4aede4942ad003c0f60
r1 = 0xac8d1c87e51d0d441be8b3dd5b05c8795b48875dffe00b7ffcfac23010d3a395
s1 =
  0x68342ceff8935ededd102dd876ffd6ba72d6a427a3edb13d26eb0781cb423c4 :: NField

signa1 =
  scalarProduct (asInt $ z1 / s1) gcons <> scalarProduct (asInt $ r1 / s1) pub

z2 = 0x7c076ff316692a3d7eb3c3bb0f8b1488cf72e1afcd929e29307032997a838a3d::NField
r2 = 0xeff69ef2b1bd93a66ed5219add4fb51e11a840f404876325a1e8ffe0529a2c::NField
s2 =
  0xc7207fee197d27c618aea621406f6bf5ef6fca38681d82b2f06fddbdce6feab6 :: NField


data Signature = Signature
  { r :: NField
  , s :: NField
  }

verifySignanture :: NField -> Signature -> S256Point -> Bool
verifySignanture z (Signature r s) pub = asInt (x target) == asInt r
 where
  target =
    scalarProduct (asInt $ z / s) gcons <> scalarProduct (asInt $ r / s) pub