🐛 Fixed the DecryptToBytes verification issue
This commit is contained in:
51
xcipher.go
51
xcipher.go
@@ -1524,44 +1524,35 @@ func (x *XCipher) EncryptToBytes(plainReader io.Reader, additionalData []byte) (
|
||||
|
||||
// DecryptToBytes Stream decryption and return []byte
|
||||
func (x *XCipher) DecryptToBytes(encryptedReader io.Reader, additionalData []byte) ([]byte, error) {
|
||||
// Read and verify format header first
|
||||
header := make([]byte, headerSize)
|
||||
if _, err := io.ReadFull(encryptedReader, header); err != nil {
|
||||
return nil, fmt.Errorf("read header failed: %w", err)
|
||||
// 读取整个加密数据
|
||||
encryptedData, err := io.ReadAll(encryptedReader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read data failed: %w", err)
|
||||
}
|
||||
|
||||
// Verify magic number and version
|
||||
if err := verifyDataFormat(header); err != nil {
|
||||
return nil, err
|
||||
if len(encryptedData) < headerSize {
|
||||
return nil, ErrInvalidFormat
|
||||
}
|
||||
|
||||
magic := binary.BigEndian.Uint32(encryptedData[0:4])
|
||||
if magic != magicNumber {
|
||||
return nil, ErrInvalidMagicNumber
|
||||
}
|
||||
|
||||
version := binary.BigEndian.Uint32(encryptedData[4:8])
|
||||
if version != currentVersion {
|
||||
return nil, ErrUnsupportedVersion
|
||||
}
|
||||
|
||||
// Create a memory buffer
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
// Streaming decryption is performed, and the results are written to the BUFF
|
||||
if err := x.DecryptStream(encryptedReader, buf, additionalData); err != nil {
|
||||
if err := x.DecryptStream(
|
||||
bytes.NewReader(encryptedData[headerSize:]),
|
||||
buf,
|
||||
additionalData,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("decrypt failed: %w", err)
|
||||
}
|
||||
|
||||
// Returns bytes directly in the buffer
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// Add helper function to verify data format integrity
|
||||
func verifyDataFormat(data []byte) error {
|
||||
if len(data) < headerSize+minCiphertextSize {
|
||||
return ErrInvalidFormat
|
||||
}
|
||||
|
||||
magic := binary.BigEndian.Uint32(data[0:4])
|
||||
if magic != magicNumber {
|
||||
return ErrInvalidMagicNumber
|
||||
}
|
||||
|
||||
version := binary.BigEndian.Uint32(data[4:8])
|
||||
if version != currentVersion {
|
||||
return ErrUnsupportedVersion
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
260
xcipher_test.go
260
xcipher_test.go
@@ -7,7 +7,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -642,3 +644,261 @@ func TestAutoParallelDecision(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNetworkImageStreamProcessing tests encrypting and decrypting an image from network stream
|
||||
func TestNetworkImageStreamProcessing(t *testing.T) {
|
||||
// Generate random key
|
||||
key, err := generateRandomKey()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate key: %v", err)
|
||||
}
|
||||
|
||||
// Initialize cipher
|
||||
xcipher := NewXCipher(key)
|
||||
|
||||
// Create output directory if not exists
|
||||
outputDir := "testdata"
|
||||
if err := os.MkdirAll(outputDir, 0755); err != nil {
|
||||
t.Fatalf("Failed to create output directory: %v", err)
|
||||
}
|
||||
|
||||
// Define file paths in local directory
|
||||
originalPath := filepath.Join(outputDir, "original.jpg")
|
||||
encryptedPath := filepath.Join(outputDir, "encrypted.bin")
|
||||
decryptedPath := filepath.Join(outputDir, "decrypted.jpg")
|
||||
|
||||
// Download image from URL
|
||||
imageURL := "https://cdn.picui.cn/vip/2025/03/20/67dbc6154b20f.jpg"
|
||||
resp, err := http.Get(imageURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to download image: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Save original image
|
||||
originalFile, err := os.Create(originalPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create original file: %v", err)
|
||||
}
|
||||
|
||||
// Create a TeeReader to save original image while reading
|
||||
imageReader := io.TeeReader(resp.Body, originalFile)
|
||||
|
||||
// Create encrypted output file
|
||||
encryptedFile, err := os.Create(encryptedPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create encrypted file: %v", err)
|
||||
}
|
||||
|
||||
// 使用简单的 EncryptStream 方法加密
|
||||
err = xcipher.EncryptStream(imageReader, encryptedFile, []byte("123456"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to encrypt image stream: %v", err)
|
||||
}
|
||||
|
||||
// Close files
|
||||
originalFile.Close()
|
||||
encryptedFile.Close()
|
||||
|
||||
t.Logf("Original image saved to: %s", originalPath)
|
||||
t.Logf("Encrypted file saved to: %s", encryptedPath)
|
||||
|
||||
// Open encrypted file for reading
|
||||
encryptedFile, err = os.Open(encryptedPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open encrypted file: %v", err)
|
||||
}
|
||||
defer encryptedFile.Close()
|
||||
|
||||
// Create decrypted output file
|
||||
decryptedFile, err := os.Create(decryptedPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create decrypted file: %v", err)
|
||||
}
|
||||
defer decryptedFile.Close()
|
||||
|
||||
err = xcipher.DecryptStream(encryptedFile, decryptedFile, []byte("123456"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to decrypt image stream: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Decrypted file saved to: %s", decryptedPath)
|
||||
|
||||
// Get file sizes
|
||||
originalInfo, err := os.Stat(originalPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to stat original file: %v", err)
|
||||
}
|
||||
|
||||
encryptedInfo, err := os.Stat(encryptedPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to stat encrypted file: %v", err)
|
||||
}
|
||||
|
||||
decryptedInfo, err := os.Stat(decryptedPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to stat decrypted file: %v", err)
|
||||
}
|
||||
|
||||
// Print file sizes
|
||||
t.Logf("File sizes:")
|
||||
t.Logf("- Original: %d bytes", originalInfo.Size())
|
||||
t.Logf("- Encrypted: %d bytes", encryptedInfo.Size())
|
||||
t.Logf("- Decrypted: %d bytes", decryptedInfo.Size())
|
||||
|
||||
// Verify the decrypted file matches the original
|
||||
originalData, err := os.ReadFile(originalPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read original file: %v", err)
|
||||
}
|
||||
|
||||
decryptedData, err := os.ReadFile(decryptedPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read decrypted file: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(originalData, decryptedData) {
|
||||
t.Fatal("Decrypted file does not match original file")
|
||||
}
|
||||
|
||||
// Check if it's a valid JPEG file by checking signature
|
||||
if len(decryptedData) < 2 || decryptedData[0] != 0xFF || decryptedData[1] != 0xD8 {
|
||||
t.Fatal("Decrypted file is not a valid JPEG image")
|
||||
}
|
||||
|
||||
t.Log("Successfully verified: decrypted file matches original and is a valid JPEG image")
|
||||
|
||||
encryptedData, err := os.ReadFile(encryptedPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read encrypted file: %v", err)
|
||||
}
|
||||
|
||||
var decryptedBuffer bytes.Buffer
|
||||
|
||||
encryptedReader := bytes.NewReader(encryptedData)
|
||||
|
||||
err = xcipher.DecryptStream(encryptedReader, &decryptedBuffer, []byte("123456"))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to decrypt image stream: %v", err)
|
||||
}
|
||||
|
||||
decryptedBytes := decryptedBuffer.Bytes()
|
||||
|
||||
previewLen := 100
|
||||
if len(decryptedBytes) < previewLen {
|
||||
previewLen = len(decryptedBytes)
|
||||
}
|
||||
t.Logf("Decrypted data preview (first %d bytes): %v", previewLen, decryptedBytes[:previewLen])
|
||||
t.Logf("Total decrypted data length: %d bytes", len(decryptedBytes))
|
||||
|
||||
if len(decryptedBytes) < 2 || decryptedBytes[0] != 0xFF || decryptedBytes[1] != 0xD8 {
|
||||
t.Fatal("Decrypted data is not a valid JPEG image")
|
||||
}
|
||||
|
||||
originalData, err = os.ReadFile(originalPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read original file: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(originalData, decryptedBytes) {
|
||||
t.Fatal("Decrypted data does not match original file")
|
||||
}
|
||||
|
||||
t.Log("Successfully verified: decrypted data matches original and is a valid JPEG image")
|
||||
}
|
||||
|
||||
func TestEncryptAndDecryptToBytes(t *testing.T) {
|
||||
|
||||
key := make([]byte, chacha20poly1305.KeySize)
|
||||
copy(key, "this-is-32-byte-testing-key-data!")
|
||||
|
||||
cipher := NewXCipher(key)
|
||||
|
||||
testData := []byte("Hello, this is test data for encryption!")
|
||||
|
||||
t.Run("First Encryption Test", func(t *testing.T) {
|
||||
reader := bytes.NewReader(testData)
|
||||
encrypted, err := cipher.EncryptToBytes(reader, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("First encryption failed: %v", err)
|
||||
}
|
||||
if len(encrypted) <= headerSize {
|
||||
t.Fatal("First encrypted data too short")
|
||||
}
|
||||
|
||||
// 解密并验证
|
||||
decReader := bytes.NewReader(encrypted)
|
||||
decrypted, err := cipher.DecryptToBytes(decReader, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("First decryption failed: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(decrypted, testData) {
|
||||
t.Fatal("First decrypted data doesn't match original")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Second Encryption Test", func(t *testing.T) {
|
||||
reader := bytes.NewReader(testData)
|
||||
encrypted2, err := cipher.EncryptToBytes(reader, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Second encryption failed: %v", err)
|
||||
}
|
||||
if len(encrypted2) <= headerSize {
|
||||
t.Fatal("Second encrypted data too short")
|
||||
}
|
||||
|
||||
// 解密并验证
|
||||
decReader := bytes.NewReader(encrypted2)
|
||||
decrypted2, err := cipher.DecryptToBytes(decReader, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Second decryption failed: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(decrypted2, testData) {
|
||||
t.Fatal("Second decrypted data doesn't match original")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestEncryptToBytesWithDifferentSizes(t *testing.T) {
|
||||
|
||||
key := make([]byte, chacha20poly1305.KeySize)
|
||||
copy(key, "this-is-32-byte-testing-key-data!")
|
||||
|
||||
cipher := NewXCipher(key)
|
||||
|
||||
testSizes := []int{
|
||||
10, // 10 bytes
|
||||
1024, // 1KB
|
||||
64 * 1024, // 64KB
|
||||
1024 * 1024, // 1MB
|
||||
}
|
||||
|
||||
for _, size := range testSizes {
|
||||
t.Run(fmt.Sprintf("Size_%d", size), func(t *testing.T) {
|
||||
|
||||
testData := make([]byte, size)
|
||||
_, err := rand.Read(testData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate test data: %v", err)
|
||||
}
|
||||
|
||||
reader := bytes.NewReader(testData)
|
||||
encrypted, err := cipher.EncryptToBytes(reader, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Encryption failed for size %d: %v", size, err)
|
||||
}
|
||||
|
||||
decReader := bytes.NewReader(encrypted)
|
||||
decrypted, err := cipher.DecryptToBytes(decReader, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Decryption failed for size %d: %v", size, err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(decrypted, testData) {
|
||||
t.Fatalf("Data mismatch for size %d", size)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user