aboutsummaryrefslogtreecommitdiff
path: root/lib/sig.ml
blob: b4f49d7ed9f8b42491f59d802ce9eac642d49dd6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
(* Taken from https://gopiandcode.uk/logs/log-writing-activitypub.html
   License is AGPL-v3
*)
open Containers
module StringMap = Map.Make (String)

let drop_quotes str = String.sub str 1 (String.length str - 2)

let body_digest body =
  Mirage_crypto.Hash.SHA256.digest (Cstruct.of_string body) |> Cstruct.to_string
  |> fun hash -> "SHA-256=" ^ Base64.encode_string hash

let req_headers headers = Cohttp.Header.to_list headers |> StringMap.of_list

let split_equals str =
  match String.index_opt str '=' with
  | Some ind ->
      let key = String.sub str 0 ind in
      let data = String.sub str (ind + 1) (String.length str - ind - 1) in
      Some (key, data)
  | _ -> None

(* constructs a signed string *)
let build_signed_string ~signed_headers ~meth ~path ~headers ~body_digest =
  (* (request-target) user-agent host date digest content-type *)
  String.split_on_char ' ' signed_headers
  |> List.map (function
       | "(request-target)" ->
           "(request-target): " ^ String.lowercase_ascii meth ^ " " ^ path
       | "digest" -> "digest: " ^ body_digest
       | header ->
           header ^ ": "
           ^ (StringMap.find_opt header headers |> Option.value ~default:""))
  |> String.concat "\n"

let parse_signature signature =
  String.split_on_char ',' signature
  |> List.filter_map split_equals
  |> List.map (Pair.map_snd drop_quotes)
  |> StringMap.of_list

let verify ~signed_string ~signature pubkey =
  let result =
    X509.Public_key.verify `SHA256 ~scheme:`RSA_PKCS1
      ~signature:(Cstruct.of_string signature)
      pubkey
      (`Message (Cstruct.of_string signed_string))
  in
  match result with
  | Ok () -> true
  | Error (`Msg e) ->
      Dream.log
        "error while verifying: %s\n\nsigned_string is:%s\n\nsignature is:%s\n"
        e signed_string signature;
      false

let encrypt (privkey : X509.Private_key.t) str =
  Base64.encode
    (X509.Private_key.sign `SHA256 ~scheme:`RSA_PKCS1 privkey
       (`Message (Cstruct.of_string str))
    |> Result.get_exn |> Cstruct.to_string)

let time_now () =
  CalendarLib.Calendar.now ()
  |> CalendarLib.Calendar.to_unixfloat |> Ptime.of_float_s
  |> Option.get_exn_or "invalid date"

let verify_request taken_public_key (req : Dream.request) =
  let ( let+ ) x f =
    match x with None -> Lwt.return (Ok false) | Some v -> f v
  in
  let ( let@ ) x f = Lwt.bind x f in
  let meth =
    Dream.method_ req |> Dream.method_to_string |> String.lowercase_ascii
  in
  let path = Dream.target req in
  let headers =
    Dream.all_headers req
    |> List.map (Pair.map_fst String.lowercase_ascii)
    |> StringMap.of_list
  in
  let+ signature = Dream.header req "Signature" in
  let signed_headers = parse_signature signature in
  (* 1. build signed string *)
  let@ body = Dream.body req in
  let body_digest = body_digest body in
  let+ public_key = taken_public_key in
  (* signed headers *)
  let+ headers_in_signed_string = StringMap.find_opt "headers" signed_headers in
  (* signed string *)
  let signed_string =
    build_signed_string ~signed_headers:headers_in_signed_string ~meth ~path
      ~headers ~body_digest
  in
  (* 2. retrieve signature *)
  let+ signature = StringMap.find_opt "signature" signed_headers in
  let+ signature = Base64.decode signature |> Result.to_opt in
  (* verify signature against signed string with public key *)
  Lwt_result.return @@ verify ~signed_string ~signature public_key

let build_signed_headers ~priv_key ~key_id ~headers ?body_str ~current_time
    ~method_ ~uri () =
  let signed_headers =
    match body_str with
    | Some _ -> "(request-target) content-length host date digest"
    | None -> "(request-target) host date"
  in

  let body_str_len = Option.map Fun.(Int.to_string % String.length) body_str in
  let body_digest = Option.map body_digest body_str in

  let date = Http_date.to_utc_string current_time in
  let host = uri |> Uri.host |> Option.get_exn_or "no host for request" in

  let signature_string =
    let opt name vl =
      match vl with None -> Fun.id | Some vl -> StringMap.add name vl
    in
    let to_be_signed =
      build_signed_string ~signed_headers
        ~meth:(method_ |> String.lowercase_ascii)
        ~path:(Uri.path uri)
        ~headers:
          (opt "content-length" body_str_len
          @@ StringMap.add "date" date @@ StringMap.add "host" host @@ headers)
        ~body_digest:(Option.value body_digest ~default:"")
    in

    let signed_string = encrypt priv_key to_be_signed |> Result.get_exn in
    Printf.sprintf
      {|keyId="%s",algorithm="rsa-sha256",headers="%s",signature="%s"|} key_id
      signed_headers signed_string
  in
  List.fold_left
    (fun map (k, v) ->
      match v with None -> map | Some v -> StringMap.add k v map)
    headers
    [
      ("Digest", body_digest);
      ("Date", Some date);
      ("Host", Some host);
      ("Signature", Some signature_string);
      ("Content-Length", body_str_len);
    ]
  |> StringMap.to_list

let sign_headers ~priv_key ~key_id ?(body : Cohttp_lwt.Body.t option)
    ~(headers : Cohttp.Header.t) ~uri ~method_ () =
  let ( let* ) x f = Lwt.bind x f in

  let* body_str =
    match body with
    | None -> Lwt.return None
    | Some body -> Lwt.map Option.some (Cohttp_lwt.Body.to_string body)
  in
  let current_time = time_now () in

  let headers =
    List.fold_left
      (fun header (key, vl) -> Cohttp.Header.add header key vl)
      headers
      (build_signed_headers ~priv_key ~key_id ~headers:(req_headers headers)
         ?body_str ~current_time
         ~method_:(Cohttp.Code.string_of_method method_)
         ~uri ())
  in
  Lwt.return headers