🐛 Fixed the DecryptToBytes verification issue
This commit is contained in:
51
xcipher.go
51
xcipher.go
@@ -1524,44 +1524,35 @@ func (x *XCipher) EncryptToBytes(plainReader io.Reader, additionalData []byte) (
|
|||||||
|
|
||||||
// DecryptToBytes Stream decryption and return []byte
|
// DecryptToBytes Stream decryption and return []byte
|
||||||
func (x *XCipher) DecryptToBytes(encryptedReader io.Reader, additionalData []byte) ([]byte, error) {
|
func (x *XCipher) DecryptToBytes(encryptedReader io.Reader, additionalData []byte) ([]byte, error) {
|
||||||
// Read and verify format header first
|
// 读取整个加密数据
|
||||||
header := make([]byte, headerSize)
|
encryptedData, err := io.ReadAll(encryptedReader)
|
||||||
if _, err := io.ReadFull(encryptedReader, header); err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("read header failed: %w", err)
|
return nil, fmt.Errorf("read data failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify magic number and version
|
if len(encryptedData) < headerSize {
|
||||||
if err := verifyDataFormat(header); err != nil {
|
return nil, ErrInvalidFormat
|
||||||
return nil, err
|
}
|
||||||
|
|
||||||
|
magic := binary.BigEndian.Uint32(encryptedData[0:4])
|
||||||
|
if magic != magicNumber {
|
||||||
|
return nil, ErrInvalidMagicNumber
|
||||||
|
}
|
||||||
|
|
||||||
|
version := binary.BigEndian.Uint32(encryptedData[4:8])
|
||||||
|
if version != currentVersion {
|
||||||
|
return nil, ErrUnsupportedVersion
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a memory buffer
|
|
||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
|
|
||||||
// Streaming decryption is performed, and the results are written to the BUFF
|
if err := x.DecryptStream(
|
||||||
if err := x.DecryptStream(encryptedReader, buf, additionalData); err != nil {
|
bytes.NewReader(encryptedData[headerSize:]),
|
||||||
|
buf,
|
||||||
|
additionalData,
|
||||||
|
); err != nil {
|
||||||
return nil, fmt.Errorf("decrypt failed: %w", err)
|
return nil, fmt.Errorf("decrypt failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns bytes directly in the buffer
|
|
||||||
return buf.Bytes(), nil
|
return buf.Bytes(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add helper function to verify data format integrity
|
|
||||||
func verifyDataFormat(data []byte) error {
|
|
||||||
if len(data) < headerSize+minCiphertextSize {
|
|
||||||
return ErrInvalidFormat
|
|
||||||
}
|
|
||||||
|
|
||||||
magic := binary.BigEndian.Uint32(data[0:4])
|
|
||||||
if magic != magicNumber {
|
|
||||||
return ErrInvalidMagicNumber
|
|
||||||
}
|
|
||||||
|
|
||||||
version := binary.BigEndian.Uint32(data[4:8])
|
|
||||||
if version != currentVersion {
|
|
||||||
return ErrUnsupportedVersion
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
260
xcipher_test.go
260
xcipher_test.go
@@ -7,7 +7,9 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -642,3 +644,261 @@ func TestAutoParallelDecision(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestNetworkImageStreamProcessing tests encrypting and decrypting an image from network stream
|
||||||
|
func TestNetworkImageStreamProcessing(t *testing.T) {
|
||||||
|
// Generate random key
|
||||||
|
key, err := generateRandomKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate key: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize cipher
|
||||||
|
xcipher := NewXCipher(key)
|
||||||
|
|
||||||
|
// Create output directory if not exists
|
||||||
|
outputDir := "testdata"
|
||||||
|
if err := os.MkdirAll(outputDir, 0755); err != nil {
|
||||||
|
t.Fatalf("Failed to create output directory: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Define file paths in local directory
|
||||||
|
originalPath := filepath.Join(outputDir, "original.jpg")
|
||||||
|
encryptedPath := filepath.Join(outputDir, "encrypted.bin")
|
||||||
|
decryptedPath := filepath.Join(outputDir, "decrypted.jpg")
|
||||||
|
|
||||||
|
// Download image from URL
|
||||||
|
imageURL := "https://cdn.picui.cn/vip/2025/03/20/67dbc6154b20f.jpg"
|
||||||
|
resp, err := http.Get(imageURL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to download image: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// Save original image
|
||||||
|
originalFile, err := os.Create(originalPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create original file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a TeeReader to save original image while reading
|
||||||
|
imageReader := io.TeeReader(resp.Body, originalFile)
|
||||||
|
|
||||||
|
// Create encrypted output file
|
||||||
|
encryptedFile, err := os.Create(encryptedPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create encrypted file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用简单的 EncryptStream 方法加密
|
||||||
|
err = xcipher.EncryptStream(imageReader, encryptedFile, []byte("123456"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to encrypt image stream: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close files
|
||||||
|
originalFile.Close()
|
||||||
|
encryptedFile.Close()
|
||||||
|
|
||||||
|
t.Logf("Original image saved to: %s", originalPath)
|
||||||
|
t.Logf("Encrypted file saved to: %s", encryptedPath)
|
||||||
|
|
||||||
|
// Open encrypted file for reading
|
||||||
|
encryptedFile, err = os.Open(encryptedPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to open encrypted file: %v", err)
|
||||||
|
}
|
||||||
|
defer encryptedFile.Close()
|
||||||
|
|
||||||
|
// Create decrypted output file
|
||||||
|
decryptedFile, err := os.Create(decryptedPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create decrypted file: %v", err)
|
||||||
|
}
|
||||||
|
defer decryptedFile.Close()
|
||||||
|
|
||||||
|
err = xcipher.DecryptStream(encryptedFile, decryptedFile, []byte("123456"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to decrypt image stream: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("Decrypted file saved to: %s", decryptedPath)
|
||||||
|
|
||||||
|
// Get file sizes
|
||||||
|
originalInfo, err := os.Stat(originalPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to stat original file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
encryptedInfo, err := os.Stat(encryptedPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to stat encrypted file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
decryptedInfo, err := os.Stat(decryptedPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to stat decrypted file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print file sizes
|
||||||
|
t.Logf("File sizes:")
|
||||||
|
t.Logf("- Original: %d bytes", originalInfo.Size())
|
||||||
|
t.Logf("- Encrypted: %d bytes", encryptedInfo.Size())
|
||||||
|
t.Logf("- Decrypted: %d bytes", decryptedInfo.Size())
|
||||||
|
|
||||||
|
// Verify the decrypted file matches the original
|
||||||
|
originalData, err := os.ReadFile(originalPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read original file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
decryptedData, err := os.ReadFile(decryptedPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read decrypted file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(originalData, decryptedData) {
|
||||||
|
t.Fatal("Decrypted file does not match original file")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if it's a valid JPEG file by checking signature
|
||||||
|
if len(decryptedData) < 2 || decryptedData[0] != 0xFF || decryptedData[1] != 0xD8 {
|
||||||
|
t.Fatal("Decrypted file is not a valid JPEG image")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log("Successfully verified: decrypted file matches original and is a valid JPEG image")
|
||||||
|
|
||||||
|
encryptedData, err := os.ReadFile(encryptedPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read encrypted file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var decryptedBuffer bytes.Buffer
|
||||||
|
|
||||||
|
encryptedReader := bytes.NewReader(encryptedData)
|
||||||
|
|
||||||
|
err = xcipher.DecryptStream(encryptedReader, &decryptedBuffer, []byte("123456"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to decrypt image stream: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
decryptedBytes := decryptedBuffer.Bytes()
|
||||||
|
|
||||||
|
previewLen := 100
|
||||||
|
if len(decryptedBytes) < previewLen {
|
||||||
|
previewLen = len(decryptedBytes)
|
||||||
|
}
|
||||||
|
t.Logf("Decrypted data preview (first %d bytes): %v", previewLen, decryptedBytes[:previewLen])
|
||||||
|
t.Logf("Total decrypted data length: %d bytes", len(decryptedBytes))
|
||||||
|
|
||||||
|
if len(decryptedBytes) < 2 || decryptedBytes[0] != 0xFF || decryptedBytes[1] != 0xD8 {
|
||||||
|
t.Fatal("Decrypted data is not a valid JPEG image")
|
||||||
|
}
|
||||||
|
|
||||||
|
originalData, err = os.ReadFile(originalPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read original file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(originalData, decryptedBytes) {
|
||||||
|
t.Fatal("Decrypted data does not match original file")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log("Successfully verified: decrypted data matches original and is a valid JPEG image")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncryptAndDecryptToBytes(t *testing.T) {
|
||||||
|
|
||||||
|
key := make([]byte, chacha20poly1305.KeySize)
|
||||||
|
copy(key, "this-is-32-byte-testing-key-data!")
|
||||||
|
|
||||||
|
cipher := NewXCipher(key)
|
||||||
|
|
||||||
|
testData := []byte("Hello, this is test data for encryption!")
|
||||||
|
|
||||||
|
t.Run("First Encryption Test", func(t *testing.T) {
|
||||||
|
reader := bytes.NewReader(testData)
|
||||||
|
encrypted, err := cipher.EncryptToBytes(reader, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("First encryption failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(encrypted) <= headerSize {
|
||||||
|
t.Fatal("First encrypted data too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解密并验证
|
||||||
|
decReader := bytes.NewReader(encrypted)
|
||||||
|
decrypted, err := cipher.DecryptToBytes(decReader, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("First decryption failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(decrypted, testData) {
|
||||||
|
t.Fatal("First decrypted data doesn't match original")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Second Encryption Test", func(t *testing.T) {
|
||||||
|
reader := bytes.NewReader(testData)
|
||||||
|
encrypted2, err := cipher.EncryptToBytes(reader, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Second encryption failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(encrypted2) <= headerSize {
|
||||||
|
t.Fatal("Second encrypted data too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解密并验证
|
||||||
|
decReader := bytes.NewReader(encrypted2)
|
||||||
|
decrypted2, err := cipher.DecryptToBytes(decReader, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Second decryption failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(decrypted2, testData) {
|
||||||
|
t.Fatal("Second decrypted data doesn't match original")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncryptToBytesWithDifferentSizes(t *testing.T) {
|
||||||
|
|
||||||
|
key := make([]byte, chacha20poly1305.KeySize)
|
||||||
|
copy(key, "this-is-32-byte-testing-key-data!")
|
||||||
|
|
||||||
|
cipher := NewXCipher(key)
|
||||||
|
|
||||||
|
testSizes := []int{
|
||||||
|
10, // 10 bytes
|
||||||
|
1024, // 1KB
|
||||||
|
64 * 1024, // 64KB
|
||||||
|
1024 * 1024, // 1MB
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, size := range testSizes {
|
||||||
|
t.Run(fmt.Sprintf("Size_%d", size), func(t *testing.T) {
|
||||||
|
|
||||||
|
testData := make([]byte, size)
|
||||||
|
_, err := rand.Read(testData)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate test data: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
reader := bytes.NewReader(testData)
|
||||||
|
encrypted, err := cipher.EncryptToBytes(reader, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Encryption failed for size %d: %v", size, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
decReader := bytes.NewReader(encrypted)
|
||||||
|
decrypted, err := cipher.DecryptToBytes(decReader, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Decryption failed for size %d: %v", size, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(decrypted, testData) {
|
||||||
|
t.Fatalf("Data mismatch for size %d", size)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user