diff --git a/xcipher.go b/xcipher.go index 7396c20..610c705 100644 --- a/xcipher.go +++ b/xcipher.go @@ -1524,44 +1524,35 @@ func (x *XCipher) EncryptToBytes(plainReader io.Reader, additionalData []byte) ( // DecryptToBytes Stream decryption and return []byte func (x *XCipher) DecryptToBytes(encryptedReader io.Reader, additionalData []byte) ([]byte, error) { - // Read and verify format header first - header := make([]byte, headerSize) - if _, err := io.ReadFull(encryptedReader, header); err != nil { - return nil, fmt.Errorf("read header failed: %w", err) + // 读取整个加密数据 + encryptedData, err := io.ReadAll(encryptedReader) + if err != nil { + return nil, fmt.Errorf("read data failed: %w", err) } - // Verify magic number and version - if err := verifyDataFormat(header); err != nil { - return nil, err + if len(encryptedData) < headerSize { + return nil, ErrInvalidFormat + } + + magic := binary.BigEndian.Uint32(encryptedData[0:4]) + if magic != magicNumber { + return nil, ErrInvalidMagicNumber + } + + version := binary.BigEndian.Uint32(encryptedData[4:8]) + if version != currentVersion { + return nil, ErrUnsupportedVersion } - // Create a memory buffer buf := new(bytes.Buffer) - // Streaming decryption is performed, and the results are written to the BUFF - if err := x.DecryptStream(encryptedReader, buf, additionalData); err != nil { + if err := x.DecryptStream( + bytes.NewReader(encryptedData[headerSize:]), + buf, + additionalData, + ); err != nil { return nil, fmt.Errorf("decrypt failed: %w", err) } - // Returns bytes directly in the buffer return buf.Bytes(), nil } - -// Add helper function to verify data format integrity -func verifyDataFormat(data []byte) error { - if len(data) < headerSize+minCiphertextSize { - return ErrInvalidFormat - } - - magic := binary.BigEndian.Uint32(data[0:4]) - if magic != magicNumber { - return ErrInvalidMagicNumber - } - - version := binary.BigEndian.Uint32(data[4:8]) - if version != currentVersion { - return ErrUnsupportedVersion - } - - return nil -} diff --git a/xcipher_test.go b/xcipher_test.go index e30f161..071bb9a 100644 --- a/xcipher_test.go +++ b/xcipher_test.go @@ -7,7 +7,9 @@ import ( "errors" "fmt" "golang.org/x/crypto/chacha20poly1305" + "io" "io/ioutil" + "net/http" "os" "path/filepath" "testing" @@ -642,3 +644,261 @@ func TestAutoParallelDecision(t *testing.T) { }) } } + +// TestNetworkImageStreamProcessing tests encrypting and decrypting an image from network stream +func TestNetworkImageStreamProcessing(t *testing.T) { + // Generate random key + key, err := generateRandomKey() + if err != nil { + t.Fatalf("Failed to generate key: %v", err) + } + + // Initialize cipher + xcipher := NewXCipher(key) + + // Create output directory if not exists + outputDir := "testdata" + if err := os.MkdirAll(outputDir, 0755); err != nil { + t.Fatalf("Failed to create output directory: %v", err) + } + + // Define file paths in local directory + originalPath := filepath.Join(outputDir, "original.jpg") + encryptedPath := filepath.Join(outputDir, "encrypted.bin") + decryptedPath := filepath.Join(outputDir, "decrypted.jpg") + + // Download image from URL + imageURL := "https://cdn.picui.cn/vip/2025/03/20/67dbc6154b20f.jpg" + resp, err := http.Get(imageURL) + if err != nil { + t.Fatalf("Failed to download image: %v", err) + } + defer resp.Body.Close() + + // Save original image + originalFile, err := os.Create(originalPath) + if err != nil { + t.Fatalf("Failed to create original file: %v", err) + } + + // Create a TeeReader to save original image while reading + imageReader := io.TeeReader(resp.Body, originalFile) + + // Create encrypted output file + encryptedFile, err := os.Create(encryptedPath) + if err != nil { + t.Fatalf("Failed to create encrypted file: %v", err) + } + + // 使用简单的 EncryptStream 方法加密 + err = xcipher.EncryptStream(imageReader, encryptedFile, []byte("123456")) + if err != nil { + t.Fatalf("Failed to encrypt image stream: %v", err) + } + + // Close files + originalFile.Close() + encryptedFile.Close() + + t.Logf("Original image saved to: %s", originalPath) + t.Logf("Encrypted file saved to: %s", encryptedPath) + + // Open encrypted file for reading + encryptedFile, err = os.Open(encryptedPath) + if err != nil { + t.Fatalf("Failed to open encrypted file: %v", err) + } + defer encryptedFile.Close() + + // Create decrypted output file + decryptedFile, err := os.Create(decryptedPath) + if err != nil { + t.Fatalf("Failed to create decrypted file: %v", err) + } + defer decryptedFile.Close() + + err = xcipher.DecryptStream(encryptedFile, decryptedFile, []byte("123456")) + if err != nil { + t.Fatalf("Failed to decrypt image stream: %v", err) + } + + t.Logf("Decrypted file saved to: %s", decryptedPath) + + // Get file sizes + originalInfo, err := os.Stat(originalPath) + if err != nil { + t.Fatalf("Failed to stat original file: %v", err) + } + + encryptedInfo, err := os.Stat(encryptedPath) + if err != nil { + t.Fatalf("Failed to stat encrypted file: %v", err) + } + + decryptedInfo, err := os.Stat(decryptedPath) + if err != nil { + t.Fatalf("Failed to stat decrypted file: %v", err) + } + + // Print file sizes + t.Logf("File sizes:") + t.Logf("- Original: %d bytes", originalInfo.Size()) + t.Logf("- Encrypted: %d bytes", encryptedInfo.Size()) + t.Logf("- Decrypted: %d bytes", decryptedInfo.Size()) + + // Verify the decrypted file matches the original + originalData, err := os.ReadFile(originalPath) + if err != nil { + t.Fatalf("Failed to read original file: %v", err) + } + + decryptedData, err := os.ReadFile(decryptedPath) + if err != nil { + t.Fatalf("Failed to read decrypted file: %v", err) + } + + if !bytes.Equal(originalData, decryptedData) { + t.Fatal("Decrypted file does not match original file") + } + + // Check if it's a valid JPEG file by checking signature + if len(decryptedData) < 2 || decryptedData[0] != 0xFF || decryptedData[1] != 0xD8 { + t.Fatal("Decrypted file is not a valid JPEG image") + } + + t.Log("Successfully verified: decrypted file matches original and is a valid JPEG image") + + encryptedData, err := os.ReadFile(encryptedPath) + if err != nil { + t.Fatalf("Failed to read encrypted file: %v", err) + } + + var decryptedBuffer bytes.Buffer + + encryptedReader := bytes.NewReader(encryptedData) + + err = xcipher.DecryptStream(encryptedReader, &decryptedBuffer, []byte("123456")) + if err != nil { + t.Fatalf("Failed to decrypt image stream: %v", err) + } + + decryptedBytes := decryptedBuffer.Bytes() + + previewLen := 100 + if len(decryptedBytes) < previewLen { + previewLen = len(decryptedBytes) + } + t.Logf("Decrypted data preview (first %d bytes): %v", previewLen, decryptedBytes[:previewLen]) + t.Logf("Total decrypted data length: %d bytes", len(decryptedBytes)) + + if len(decryptedBytes) < 2 || decryptedBytes[0] != 0xFF || decryptedBytes[1] != 0xD8 { + t.Fatal("Decrypted data is not a valid JPEG image") + } + + originalData, err = os.ReadFile(originalPath) + if err != nil { + t.Fatalf("Failed to read original file: %v", err) + } + + if !bytes.Equal(originalData, decryptedBytes) { + t.Fatal("Decrypted data does not match original file") + } + + t.Log("Successfully verified: decrypted data matches original and is a valid JPEG image") +} + +func TestEncryptAndDecryptToBytes(t *testing.T) { + + key := make([]byte, chacha20poly1305.KeySize) + copy(key, "this-is-32-byte-testing-key-data!") + + cipher := NewXCipher(key) + + testData := []byte("Hello, this is test data for encryption!") + + t.Run("First Encryption Test", func(t *testing.T) { + reader := bytes.NewReader(testData) + encrypted, err := cipher.EncryptToBytes(reader, nil) + if err != nil { + t.Fatalf("First encryption failed: %v", err) + } + if len(encrypted) <= headerSize { + t.Fatal("First encrypted data too short") + } + + // 解密并验证 + decReader := bytes.NewReader(encrypted) + decrypted, err := cipher.DecryptToBytes(decReader, nil) + if err != nil { + t.Fatalf("First decryption failed: %v", err) + } + + if !bytes.Equal(decrypted, testData) { + t.Fatal("First decrypted data doesn't match original") + } + }) + + t.Run("Second Encryption Test", func(t *testing.T) { + reader := bytes.NewReader(testData) + encrypted2, err := cipher.EncryptToBytes(reader, nil) + if err != nil { + t.Fatalf("Second encryption failed: %v", err) + } + if len(encrypted2) <= headerSize { + t.Fatal("Second encrypted data too short") + } + + // 解密并验证 + decReader := bytes.NewReader(encrypted2) + decrypted2, err := cipher.DecryptToBytes(decReader, nil) + if err != nil { + t.Fatalf("Second decryption failed: %v", err) + } + + if !bytes.Equal(decrypted2, testData) { + t.Fatal("Second decrypted data doesn't match original") + } + }) +} + +func TestEncryptToBytesWithDifferentSizes(t *testing.T) { + + key := make([]byte, chacha20poly1305.KeySize) + copy(key, "this-is-32-byte-testing-key-data!") + + cipher := NewXCipher(key) + + testSizes := []int{ + 10, // 10 bytes + 1024, // 1KB + 64 * 1024, // 64KB + 1024 * 1024, // 1MB + } + + for _, size := range testSizes { + t.Run(fmt.Sprintf("Size_%d", size), func(t *testing.T) { + + testData := make([]byte, size) + _, err := rand.Read(testData) + if err != nil { + t.Fatalf("Failed to generate test data: %v", err) + } + + reader := bytes.NewReader(testData) + encrypted, err := cipher.EncryptToBytes(reader, nil) + if err != nil { + t.Fatalf("Encryption failed for size %d: %v", size, err) + } + + decReader := bytes.NewReader(encrypted) + decrypted, err := cipher.DecryptToBytes(decReader, nil) + if err != nil { + t.Fatalf("Decryption failed for size %d: %v", size, err) + } + + if !bytes.Equal(decrypted, testData) { + t.Fatalf("Data mismatch for size %d", size) + } + }) + } +}