diff --git a/pkg/notification/notification.go b/pkg/notification/notification.go index cedc418d24cb4e7e1e4ee8497977629c112ec242..abe2a71a099f17b9363f9de112495749763b94f2 100644 --- a/pkg/notification/notification.go +++ b/pkg/notification/notification.go @@ -1,15 +1,26 @@ package notification import ( - "crypto/hmac" + "crypto" + "crypto/rand" + "crypto/rsa" "crypto/sha256" - "encoding/hex" + "crypto/x509" + "encoding/base64" "encoding/json" + "encoding/pem" "errors" + "fmt" ) type OperationType int +var ( + ErrInvalidPEMFormat = errors.New("invalid key PEM format") + ErrMessageEmpty = errors.New("message is empty") + ErrSignEmpty = errors.New("sign is empty") +) + const ( Transfer OperationType = iota CommissionPayment @@ -44,42 +55,91 @@ func (m *Message) serializeWithoutSign() ([]byte, error) { return json.Marshal(tempMessage) } -func (m *Message) SignMessage(privateKey string) error { +func ParsePrivateKey(privateKeyPEM string) (*rsa.PrivateKey, error) { + block, _ := pem.Decode([]byte(privateKeyPEM)) + if block == nil || block.Type != "RSA PRIVATE KEY" { + return nil, ErrInvalidPEMFormat + } + + privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %w", err) + } + + return privateKey, nil +} + +func ParsePublicKey(publicKeyPEM string) (*rsa.PublicKey, error) { + block, _ := pem.Decode([]byte(publicKeyPEM)) + if block == nil || block.Type != "RSA PUBLIC KEY" { + return nil, ErrInvalidPEMFormat + } + + publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse public key: %w", err) + } + + return publicKey, nil +} + +func (m *Message) SignMessage(privateKeyPEM string) error { if m == nil { - return errors.New("cannot sign: message is nil") + return ErrMessageEmpty + } + + privateKey, err := ParsePrivateKey(privateKeyPEM) + if err != nil { + return fmt.Errorf("failed to parse private key: %w", err) } messageBytes, err := m.serializeWithoutSign() if err != nil { - return err + return fmt.Errorf("failed to serialize message: %w", err) } - h := hmac.New(sha256.New, []byte(privateKey)) - h.Write(messageBytes) + hashed := sha256.Sum256(messageBytes) - signature := hex.EncodeToString(h.Sum(nil)) - m.Sign = &signature + signature, err := rsa.SignPKCS1v15(rand.Reader, privateKey, crypto.SHA256, hashed[:]) + if err != nil { + return fmt.Errorf("failed to sign message: %w", err) + } + + signatureBase64 := base64.StdEncoding.EncodeToString(signature) + m.Sign = &signatureBase64 return nil } -func (m *Message) VerifySign(privateKey string) (bool, error) { +func (m *Message) VerifySign(publicKeyPEM string) (bool, error) { if m == nil { - return false, errors.New("cannot verify signature: message is nil") + return false, ErrMessageEmpty } if m.Sign == nil { - return false, errors.New("signature is missing") + return false, ErrSignEmpty + } + + publicKey, err := ParsePublicKey(publicKeyPEM) + if err != nil { + return false, fmt.Errorf("failed to parse public key: %v", err) } messageBytes, err := m.serializeWithoutSign() if err != nil { - return false, err + return false, fmt.Errorf("failed to serialize message: %v", err) } - h := hmac.New(sha256.New, []byte(privateKey)) - h.Write(messageBytes) + hashed := sha256.Sum256(messageBytes) - expectedSignature := hex.EncodeToString(h.Sum(nil)) + signature, err := base64.StdEncoding.DecodeString(*m.Sign) + if err != nil { + return false, fmt.Errorf("failed to decode signature: %v", err) + } + + err = rsa.VerifyPKCS1v15(publicKey, crypto.SHA256, hashed[:], signature) + if err != nil { + return false, fmt.Errorf("failed to verify signature: %v", err) + } - return hmac.Equal([]byte(expectedSignature), []byte(*m.Sign)), nil + return true, nil } diff --git a/pkg/notification/notification_test.go b/pkg/notification/notification_test.go new file mode 100644 index 0000000000000000000000000000000000000000..4e840d473313c8621bb72a6f20e4208677c5aa8e --- /dev/null +++ b/pkg/notification/notification_test.go @@ -0,0 +1,120 @@ +package notification_test + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "testing" + + "github.com/stretchr/testify/assert" + + "git.ptb.bet/public-group/shared/pkg/notification" +) + +func generateTestKeys() (string, string, error) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return "", "", err + } + + privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKey) + privateKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: privateKeyBytes, + }) + + publicKeyBytes := x509.MarshalPKCS1PublicKey(&privateKey.PublicKey) + publicKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PUBLIC KEY", + Bytes: publicKeyBytes, + }) + + return string(privateKeyPEM), string(publicKeyPEM), nil +} + +func TestSignMessage(t *testing.T) { + privateKeyPEM, publicKeyPEM, err := generateTestKeys() + assert.NoError(t, err) + + msg := ¬ification.Message{ + Amount: "1000000", + TokenName: "usdt", + FromAddress: "asfdqf", + ToAddress: "ADdada", + TransactionHash: "12321eqwd", + Status: "confirmed", + ErrorNote: nil, + Operation: notification.Transfer, + } + + err = msg.SignMessage(privateKeyPEM) + assert.NoError(t, err) + assert.NotNil(t, msg.Sign) + + verified, err := msg.VerifySign(publicKeyPEM) + assert.NoError(t, err) + assert.True(t, verified) +} + +func TestVerifySignInvalid(t *testing.T) { + privateKeyPEM, publicKeyPEM, err := generateTestKeys() + assert.NoError(t, err) + + msg := ¬ification.Message{ + Amount: "1000000", + TokenName: "trx", + FromAddress: "asd", + ToAddress: "trdqwdq", + TransactionHash: "wdqdqwd12e", + Status: "success", + Operation: notification.Transfer, + } + + err = msg.SignMessage(privateKeyPEM) + assert.NoError(t, err) + assert.NotNil(t, msg.Sign) + + msg.Amount = "2000000" + + verified, err := msg.VerifySign(publicKeyPEM) + assert.Error(t, err) + assert.False(t, verified) +} + +func TestInvalidPEMFormat(t *testing.T) { + invalidPEM := "INVALID PEM" + + _, err := notification.ParsePrivateKey(invalidPEM) + assert.ErrorIs(t, err, notification.ErrInvalidPEMFormat) + + _, err = notification.ParsePublicKey(invalidPEM) + assert.ErrorIs(t, err, notification.ErrInvalidPEMFormat) +} + +func TestEmptyMessage(t *testing.T) { + var msg *notification.Message + + err := msg.SignMessage("private_key") + assert.ErrorIs(t, err, notification.ErrMessageEmpty) + + verified, err := msg.VerifySign("public_key") + assert.False(t, verified) + assert.ErrorIs(t, err, notification.ErrMessageEmpty) +} + +func TestEmptySignature(t *testing.T) { + msg := ¬ification.Message{ + Amount: "100.00", + TokenName: "ETH", + FromAddress: "0x123", + ToAddress: "0x456", + TransactionHash: "0xabc", + Status: "success", + Operation: notification.Transfer, + } + + verified, err := msg.VerifySign("public_key") + assert.False(t, verified) + assert.ErrorIs(t, err, notification.ErrSignEmpty) +}