datafusion/dataframe/
parquet.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use crate::datasource::file_format::{
21    format_as_file_type, parquet::ParquetFormatFactory,
22};
23
24use super::{
25    DataFrame, DataFrameWriteOptions, DataFusionError, LogicalPlanBuilder, RecordBatch,
26};
27
28use datafusion_common::config::TableParquetOptions;
29use datafusion_common::not_impl_err;
30use datafusion_expr::dml::InsertOp;
31
32impl DataFrame {
33    /// Execute the `DataFrame` and write the results to Parquet file(s).
34    ///
35    /// # Example
36    /// ```
37    /// # use datafusion::prelude::*;
38    /// # use datafusion::error::Result;
39    /// # use std::fs;
40    /// # #[tokio::main]
41    /// # async fn main() -> Result<()> {
42    /// use datafusion::dataframe::DataFrameWriteOptions;
43    /// let ctx = SessionContext::new();
44    /// // Sort the data by column "b" and write it to a new location
45    /// ctx.read_csv("tests/data/example.csv", CsvReadOptions::new())
46    ///     .await?
47    ///     .sort(vec![col("b").sort(true, true)])? // sort by b asc, nulls first
48    ///     .write_parquet(
49    ///         "output.parquet",
50    ///         DataFrameWriteOptions::new(),
51    ///         None, // can also specify parquet writing options here
52    ///     )
53    ///     .await?;
54    /// # fs::remove_file("output.parquet")?;
55    /// # Ok(())
56    /// # }
57    /// ```
58    pub async fn write_parquet(
59        self,
60        path: &str,
61        options: DataFrameWriteOptions,
62        writer_options: Option<TableParquetOptions>,
63    ) -> Result<Vec<RecordBatch>, DataFusionError> {
64        if options.insert_op != InsertOp::Append {
65            return not_impl_err!(
66                "{} is not implemented for DataFrame::write_parquet.",
67                options.insert_op
68            );
69        }
70
71        let format = if let Some(parquet_opts) = writer_options {
72            Arc::new(ParquetFormatFactory::new_with_options(parquet_opts))
73        } else {
74            Arc::new(ParquetFormatFactory::new())
75        };
76
77        let file_type = format_as_file_type(format);
78
79        let plan = if options.sort_by.is_empty() {
80            self.plan
81        } else {
82            LogicalPlanBuilder::from(self.plan)
83                .sort(options.sort_by)?
84                .build()?
85        };
86
87        let plan = LogicalPlanBuilder::copy_to(
88            plan,
89            path.into(),
90            file_type,
91            Default::default(),
92            options.partition_by,
93        )?
94        .build()?;
95        DataFrame {
96            session_state: self.session_state,
97            plan,
98            projection_requires_validation: self.projection_requires_validation,
99        }
100        .collect()
101        .await
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use std::collections::HashMap;
108    use std::sync::Arc;
109
110    use super::super::Result;
111    use super::*;
112    use crate::arrow::util::pretty;
113    use crate::execution::context::SessionContext;
114    use crate::execution::options::ParquetReadOptions;
115    use crate::test_util::{self, register_aggregate_csv};
116
117    use datafusion_common::file_options::parquet_writer::parse_compression_string;
118    use datafusion_execution::config::SessionConfig;
119    use datafusion_expr::{col, lit};
120
121    #[cfg(feature = "parquet_encryption")]
122    use datafusion_common::config::ConfigFileEncryptionProperties;
123    use object_store::local::LocalFileSystem;
124    use parquet::file::reader::FileReader;
125    use tempfile::TempDir;
126    use url::Url;
127
128    #[tokio::test]
129    async fn filter_pushdown_dataframe() -> Result<()> {
130        let ctx = SessionContext::new();
131
132        ctx.register_parquet(
133            "test",
134            &format!(
135                "{}/alltypes_plain.snappy.parquet",
136                test_util::parquet_test_data()
137            ),
138            ParquetReadOptions::default(),
139        )
140        .await?;
141
142        ctx.register_table("t1", ctx.table("test").await?.into_view())?;
143
144        let df = ctx
145            .table("t1")
146            .await?
147            .filter(col("id").eq(lit(1)))?
148            .select_columns(&["bool_col", "int_col"])?;
149
150        let plan = df.explain(false, false)?.collect().await?;
151        // Filters all the way to Parquet
152        let formatted = pretty::pretty_format_batches(&plan)?.to_string();
153        assert!(formatted.contains("FilterExec: id@0 = 1"));
154
155        Ok(())
156    }
157
158    #[tokio::test]
159    async fn write_parquet_with_compression() -> Result<()> {
160        let test_df = test_util::test_table().await?;
161        let output_path = "file://local/test.parquet";
162        let test_compressions = vec![
163            "snappy",
164            "brotli(1)",
165            "lz4",
166            "lz4_raw",
167            "gzip(6)",
168            "zstd(1)",
169        ];
170        for compression in test_compressions.into_iter() {
171            let df = test_df.clone();
172            let tmp_dir = TempDir::new()?;
173            let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?);
174            let local_url = Url::parse("file://local").unwrap();
175            let ctx = &test_df.session_state;
176            ctx.runtime_env().register_object_store(&local_url, local);
177            let mut options = TableParquetOptions::default();
178            options.global.compression = Some(compression.to_string());
179            df.write_parquet(
180                output_path,
181                DataFrameWriteOptions::new().with_single_file_output(true),
182                Some(options),
183            )
184            .await?;
185
186            // Check that file actually used the specified compression
187            let file = std::fs::File::open(tmp_dir.path().join("test.parquet"))?;
188
189            let reader =
190                parquet::file::serialized_reader::SerializedFileReader::new(file)
191                    .unwrap();
192
193            let parquet_metadata = reader.metadata();
194
195            let written_compression =
196                parquet_metadata.row_group(0).column(0).compression();
197
198            assert_eq!(written_compression, parse_compression_string(compression)?);
199        }
200
201        Ok(())
202    }
203
204    #[tokio::test]
205    async fn write_parquet_with_small_rg_size() -> Result<()> {
206        // This test verifies writing a parquet file with small rg size
207        // relative to datafusion.execution.batch_size does not panic
208        let ctx = SessionContext::new_with_config(SessionConfig::from_string_hash_map(
209            &HashMap::from_iter(
210                [("datafusion.execution.batch_size", "10")]
211                    .iter()
212                    .map(|(s1, s2)| ((*s1).to_string(), (*s2).to_string())),
213            ),
214        )?);
215        register_aggregate_csv(&ctx, "aggregate_test_100").await?;
216        let test_df = ctx.table("aggregate_test_100").await?;
217
218        let output_path = "file://local/test.parquet";
219
220        for rg_size in 1..10 {
221            let df = test_df.clone();
222            let tmp_dir = TempDir::new()?;
223            let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?);
224            let local_url = Url::parse("file://local").unwrap();
225            let ctx = &test_df.session_state;
226            ctx.runtime_env().register_object_store(&local_url, local);
227            let mut options = TableParquetOptions::default();
228            options.global.max_row_group_size = rg_size;
229            options.global.allow_single_file_parallelism = true;
230            df.write_parquet(
231                output_path,
232                DataFrameWriteOptions::new().with_single_file_output(true),
233                Some(options),
234            )
235            .await?;
236
237            // Check that file actually used the correct rg size
238            let file = std::fs::File::open(tmp_dir.path().join("test.parquet"))?;
239
240            let reader =
241                parquet::file::serialized_reader::SerializedFileReader::new(file)
242                    .unwrap();
243
244            let parquet_metadata = reader.metadata();
245
246            let written_rows = parquet_metadata.row_group(0).num_rows();
247
248            assert_eq!(written_rows as usize, rg_size);
249        }
250
251        Ok(())
252    }
253
254    #[rstest::rstest]
255    #[cfg(feature = "parquet_encryption")]
256    #[tokio::test]
257    async fn roundtrip_parquet_with_encryption(
258        #[values(false, true)] allow_single_file_parallelism: bool,
259    ) -> Result<()> {
260        use parquet::encryption::decrypt::FileDecryptionProperties;
261        use parquet::encryption::encrypt::FileEncryptionProperties;
262
263        let test_df = test_util::test_table().await?;
264
265        let schema = test_df.schema();
266        let footer_key = b"0123456789012345".to_vec(); // 128bit/16
267        let column_key = b"1234567890123450".to_vec(); // 128bit/16
268
269        let mut encrypt = FileEncryptionProperties::builder(footer_key.clone());
270        let mut decrypt = FileDecryptionProperties::builder(footer_key.clone());
271
272        for field in schema.fields().iter() {
273            encrypt = encrypt.with_column_key(field.name().as_str(), column_key.clone());
274            decrypt = decrypt.with_column_key(field.name().as_str(), column_key.clone());
275        }
276
277        let encrypt = encrypt.build()?;
278        let decrypt = decrypt.build()?;
279
280        let df = test_df.clone();
281        let tmp_dir = TempDir::new()?;
282        let tempfile = tmp_dir.path().join("roundtrip.parquet");
283        let tempfile_str = tempfile.into_os_string().into_string().unwrap();
284
285        // Write encrypted parquet using write_parquet
286        let mut options = TableParquetOptions::default();
287        options.crypto.file_encryption =
288            Some(ConfigFileEncryptionProperties::from(&encrypt));
289        options.global.allow_single_file_parallelism = allow_single_file_parallelism;
290
291        df.write_parquet(
292            tempfile_str.as_str(),
293            DataFrameWriteOptions::new().with_single_file_output(true),
294            Some(options),
295        )
296        .await?;
297        let num_rows_written = test_df.count().await?;
298
299        // Read encrypted parquet
300        let ctx: SessionContext = SessionContext::new();
301        let read_options =
302            ParquetReadOptions::default().file_decryption_properties((&decrypt).into());
303
304        ctx.register_parquet("roundtrip_parquet", &tempfile_str, read_options.clone())
305            .await?;
306
307        let df_enc = ctx.sql("SELECT * FROM roundtrip_parquet").await?;
308        let num_rows_read = df_enc.count().await?;
309
310        assert_eq!(num_rows_read, num_rows_written);
311
312        // Read encrypted parquet and subset rows + columns
313        let encrypted_parquet_df = ctx.read_parquet(tempfile_str, read_options).await?;
314
315        // Select three columns and filter the results
316        // Test that the filter works as expected
317        let selected = encrypted_parquet_df
318            .clone()
319            .select_columns(&["c1", "c2", "c3"])?
320            .filter(col("c2").gt(lit(4)))?;
321
322        let num_rows_selected = selected.count().await?;
323        assert_eq!(num_rows_selected, 14);
324
325        Ok(())
326    }
327}