1 /**
2   Crypto for one-to-one aka Olm
3   */
4 module matrix.session;
5 
6 import matrix.olm;
7 
8 import std.file : read;
9 import std.experimental.allocator : processAllocator;
10 import std.exception : assumeUnique;
11 
12 import std.stdio; // TODO debug only!
13 
14 immutable(char)[] cstr2dstr(inout(char)* cstr)
15 {
16 	import core.stdc.string: strlen;
17 	return cstr ? cstr[0 .. strlen(cstr)].idup : cstr[0 .. 0].idup;
18 }
19 
20 class Account {
21 	OlmAccount* account;
22 	private this() {
23 		const len = olm_account_size();
24 		auto mem = processAllocator.allocate(len);
25 		this.account = olm_account(mem.ptr);
26 	}
27 	/// Create a fresh account, generate keys, etc
28 	static public Account create() {
29 		auto a = new Account();
30 		const rnd_len = olm_create_account_random_length(a.account);
31 		auto rnd_mem = read_random(rnd_len);
32 		olm_create_account(a.account, rnd_mem.ptr, rnd_len);
33 		return a;
34 	}
35 	/// serialize account data, locked by key
36 	public string pickle(string key) {
37 		char[] ret;
38 		ret.length = olm_pickle_account_length(this.account);
39 		const r = olm_pickle_account(this.account,
40 			key.ptr, key.length, ret.ptr, ret.length);
41 		error_check(r);
42 		return assumeUnique(ret);
43 	}
44 	/// deserialize account data, unlocked by key
45 	static public Account unpickle(string key, string pickle) {
46 		auto a = new Account();
47 		char[] p = pickle.dup; // p is destroyed!
48 		const r = olm_unpickle_account(a.account,
49 			key.ptr, key.length, p.ptr, p.length);
50 		a.error_check(r);
51 		return a;
52 	}
53 	/// returns a JSON string of identity keys
54 	public @property string identity_keys() {
55 		char[] ret;
56 		ret.length = olm_account_identity_keys_length(this.account);
57 		const r = olm_account_identity_keys(this.account,
58 			 ret.ptr, ret.length);
59 		error_check(r);
60 		return assumeUnique(ret);
61 	}
62 	/// sign a message
63 	public string sign(string msg) {
64 		char[] ret;
65 		ret.length = olm_account_signature_length(this.account);
66 		const r = olm_account_sign(this.account,
67 			msg.ptr, msg.length, ret.ptr, ret.length);
68 		error_check(r);
69 		return assumeUnique(ret);
70 	}
71 	/// returns a JSON string of one time keys (pre keys)
72 	public @property string one_time_keys() {
73 		char[] ret;
74 		ret.length = olm_account_one_time_keys_length(this.account);
75 		const r = olm_account_one_time_keys(this.account,
76 			ret.ptr, ret.length);
77 		error_check(r);
78 		return assumeUnique(ret);
79 	}
80 	public void mark_keys_as_published() {
81 		olm_account_mark_keys_as_published(this.account);
82 	}
83 	public size_t max_number_of_one_time_keys() {
84 		return olm_account_max_number_of_one_time_keys(this.account);
85 	}
86 	public void generate_one_time_keys(size_t count) {
87 		const rnd_len = olm_account_generate_one_time_keys_random_length(this.account, count);
88 		auto rnd_mem = read_random(rnd_len);
89 		const r = olm_account_generate_one_time_keys(this.account,
90 			count, rnd_mem.ptr, rnd_mem.length);
91 	}
92 	private void error_check(size_t x) {
93 		if (x == olm_error()) {
94 			auto errmsg = olm_account_last_error(this.account);
95 			throw new Exception(cstr2dstr(errmsg));
96 		}
97 	}
98 }
99 
100 unittest {
101 	import std.json : parseJSON;
102 	auto a = Account.create();
103 	auto key = "foobar";
104 	auto p = a.pickle(key);
105 	auto a2 = Account.unpickle(key, p);
106 	auto id_keys = parseJSON(a.identity_keys());
107 	assert ("curve25519" in id_keys);
108 	assert ("ed25519" in id_keys);
109 	auto msg = "Hello World!";
110 	auto sig_msg = a.sign(msg);
111 	// TODO test signature
112 
113 	auto max = a.max_number_of_one_time_keys();
114 	assert (max > 10);
115 	auto otks = a.one_time_keys();
116 	// none generated so far
117 	assert (otks == "{\"curve25519\":{}}");
118 	const key_count = 11;
119 	a.generate_one_time_keys(key_count);
120 	auto j_otks = parseJSON(a.one_time_keys());
121 	assert ("curve25519" in j_otks);
122 	assert(j_otks["curve25519"].object.length == key_count);
123 	auto j_otks2 = parseJSON(a.one_time_keys());
124 	assert (j_otks == j_otks2);
125 	a.mark_keys_as_published();
126 	auto j_otks3 = parseJSON(a.one_time_keys());
127 	assert ("curve25519" in j_otks3);
128 	assert(j_otks3["curve25519"].object.length == 0);
129 }
130 
131 class Session {
132 	OlmSession* session;
133 	public this() {
134 		const len = olm_session_size();
135 		auto mem = processAllocator.allocate(len);
136 		this.session = olm_session(mem.ptr);
137 	}
138 	/// serialize session data, locked by key
139 	public string pickle(string key) {
140 		char[] ret;
141 		ret.length = olm_pickle_session_length(this.session);
142 		const r = olm_pickle_session(this.session,
143 			key.ptr, key.length, ret.ptr, ret.length);
144 		error_check(r);
145 		return assumeUnique(ret);
146 	}
147 	/// deserialize session data, unlocked by key
148 	static public Session unpickle(string key, string pickle) {
149 		auto a = new Session();
150 		char[] p = pickle.dup; // p is destroyed!
151 		const r = olm_unpickle_session(a.session,
152 			key.ptr, key.length, p.ptr, p.length);
153 		a.error_check(r);
154 		return a;
155 	}
156     static public Session create_outbound(Account a, string identity_key, string one_time_key) {
157 		auto s = new Session();
158 		const rnd_len = olm_create_outbound_session_random_length(s.session);
159 		auto rnd_mem = read_random(rnd_len);
160 		const r = olm_create_outbound_session(s.session,
161 			a.account, identity_key.ptr, identity_key.length,
162 			one_time_key.ptr, one_time_key.length,
163 			rnd_mem.ptr, rnd_mem.length);
164 		s.error_check(r);
165 		return s;
166 	}
167 	static public Session create_inbound(Account a, string one_time_key_msg) {
168 		auto s = new Session();
169 		char[] msg = one_time_key_msg.dup; // msg is destroyed!
170 		const r = olm_create_inbound_session(s.session, a.account,
171 			msg.ptr, msg.length);
172 		s.error_check(r);
173 		return s;
174 	}
175 	static public Session create_inbound_from(Account a, string identity_key, string one_time_key_msg) {
176 		auto s = new Session();
177 		char[] msg = one_time_key_msg.dup; // msg is destroyed!
178 		const r = olm_create_inbound_session_from(s.session, a.account,
179 			identity_key.ptr, identity_key.length,
180 			msg.ptr, msg.length);
181 		s.error_check(r);
182 		return s;
183 	}
184 	public @property string id() {
185 		char[] ret;
186 		ret.length = olm_session_id_length(this.session);
187 		const r = olm_session_id(this.session, ret.ptr, ret.length);
188 		error_check(r);
189 		return assumeUnique(ret);
190 	}
191 	public bool matches_inbound(string one_time_key_msg) {
192 		char[] msg = one_time_key_msg.dup; // msg is destroyed!
193 		const r = olm_matches_inbound_session(this.session,
194 			msg.ptr, msg.length);
195 		error_check(r);
196 		return r == 1;
197 	}
198 	public bool matches_inbound_from(string identity_key, string one_time_key_msg) {
199 		char[] msg = one_time_key_msg.dup; // msg is destroyed!
200 		const r = olm_matches_inbound_session_from(this.session,
201 			identity_key.ptr, identity_key.length,
202 			msg.ptr, msg.length);
203 		error_check(r);
204 		return r == 1;
205 	}
206 	public string encrypt(string plaintext, out size_t msg_type) {
207 		const rnd_len = olm_encrypt_random_length(this.session);
208 		auto rnd_mem = read_random(rnd_len);
209 		msg_type = olm_encrypt_message_type(this.session);
210 		// TODO use enum for msg_type?
211 		error_check(msg_type);
212 		auto msg_len = olm_encrypt_message_length(this.session, plaintext.length);
213 		char[] ret;
214 		ret.length = msg_len;
215 		const r = olm_encrypt(this.session,
216 			plaintext.ptr, plaintext.length,
217 			rnd_mem.ptr, rnd_mem.length,
218 			ret.ptr, ret.length);
219 		error_check(r);
220 		return assumeUnique(ret);
221 	}
222 	public string decrypt(size_t msg_type, string cypher) {
223 		char[] c = cypher.dup; // c is destroyed!
224 		const max_plain_len = olm_decrypt_max_plaintext_length(this.session,
225 			msg_type, c.ptr, c.length);
226 		error_check(max_plain_len);
227 		char[] ret;
228 		ret.length = max_plain_len;
229 		c = cypher.dup; // c is destroyed!
230 		const r = olm_decrypt(this.session, msg_type,
231 			c.ptr, c.length,
232 			ret.ptr, ret.length);
233 		error_check(r);
234 		return assumeUnique(ret[0..r]);
235 	}
236 
237 	private void error_check(size_t x) {
238 		if (x == olm_error()) {
239 			auto errmsg = olm_session_last_error(this.session);
240 			throw new Exception(cstr2dstr(errmsg));
241 		}
242 	}
243 }
244 
245 unittest {
246 	auto s = new Session();
247 	auto key = "foobar";
248 	auto p = s.pickle(key);
249 	auto s2 = Session.unpickle(key, p);
250 	assert(s.id == s2.id);
251 
252 	size_t msg_type;
253 	auto plain = "Hello World!";
254 	auto cypher = s.encrypt(plain, msg_type);
255 	// TODO text decrypt
256 }
257 
258 unittest {
259 	import std.json : parseJSON;
260 
261 	// Setup account 1
262 	auto a1 = Account.create();
263 	string a1_id_key = parseJSON(a1.identity_keys)["curve25519"].str;
264 	a1.generate_one_time_keys(3);
265 	string a1_otk;
266 	foreach (k,v; parseJSON(a1.one_time_keys)["curve25519"].object) {
267 		a1_otk = v.str;
268 		break;
269 	}
270 	a1.mark_keys_as_published();
271 
272 	// Setup account 2
273 	auto a2 = Account.create();
274 	string a2_id_key = parseJSON(a2.identity_keys)["curve25519"].str;
275 	a2.generate_one_time_keys(3);
276 	string a2_otk;
277 	foreach (k,v; parseJSON(a2.one_time_keys)["curve25519"].object) {
278 		a2_otk = v.str;
279 		break;
280 	}
281 	a2.mark_keys_as_published();
282 
283 	/** Now a2 publishes his identity and one time keys,
284       * such that a1 can encrypt a message for a2. */
285 
286 	// exchange
287 	auto s1_out = Session.create_outbound(a1, a2_id_key, a2_otk);
288 	auto msg = "Hello World!";
289 	size_t msg_type;
290 	auto cipher = s1_out.encrypt(msg, msg_type);
291 	auto s2_in = Session.create_inbound(a2, cipher);
292     assert(s1_out.id == s2_in.id);
293 	auto plain = s2_in.decrypt(msg_type, cipher);
294 	assert(plain == msg);
295 }