From 806b5e1880ae100c7cb71d51718fc063ca7f9496 Mon Sep 17 00:00:00 2001
From: mwiegand <mwiegand@seibert-media.net>
Date: Tue, 15 Feb 2022 09:36:57 +0100
Subject: [PATCH] ssh: dont set rendom bytes to zero

---
 libs/ssh.py | 19 +++++++++++--------
 1 file changed, 11 insertions(+), 8 deletions(-)

diff --git a/libs/ssh.py b/libs/ssh.py
index 1bbcbf6..6325c03 100644
--- a/libs/ssh.py
+++ b/libs/ssh.py
@@ -1,19 +1,22 @@
 from base64 import b64decode, b64encode
 from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
-from cryptography.hazmat.primitives import serialization
+from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, PublicFormat, NoEncryption
+from hashlib import sha3_256
 
 
 def generate_ed25519_key_pair(secret):
     privkey_bytes = Ed25519PrivateKey.from_private_bytes(secret)
     
     nondeterministic_privatekey = privkey_bytes.private_bytes(
-        encoding=serialization.Encoding.PEM,
-        format=serialization.PrivateFormat.OpenSSH,
-        encryption_algorithm=serialization.NoEncryption()
+        encoding=Encoding.PEM,
+        format=PrivateFormat.OpenSSH,
+        encryption_algorithm=NoEncryption()
     ).decode()
-    nondeterministic_bytes = b64decode(''.join(nondeterministic_privatekey.split('\n')[1:-2]))
+
     # handle random 32bit number, occuring twice in a row
-    deterministic_bytes = nondeterministic_bytes[:98] + b'00000000' + nondeterministic_bytes[106:]
+    nondeterministic_bytes = b64decode(''.join(nondeterministic_privatekey.split('\n')[1:-2]))
+    random_bytes = sha3_256(secret).digest()[0:4]
+    deterministic_bytes = nondeterministic_bytes[:98] + random_bytes + random_bytes + nondeterministic_bytes[106:]
     deterministic_privatekey = '\n'.join([
         '-----BEGIN OPENSSH PRIVATE KEY-----',
         b64encode(deterministic_bytes).decode(),
@@ -21,8 +24,8 @@ def generate_ed25519_key_pair(secret):
     ])
 
     public_key = privkey_bytes.public_key().public_bytes(
-        encoding=serialization.Encoding.OpenSSH,
-        format=serialization.PublicFormat.OpenSSH,
+        encoding=Encoding.OpenSSH,
+        format=PublicFormat.OpenSSH,
     ).decode()
     
     return (deterministic_privatekey, public_key)