datafusion/dataframe/
parquet.rs1use std::sync::Arc;
19
20use crate::datasource::file_format::{
21 format_as_file_type, parquet::ParquetFormatFactory,
22};
23
24use super::{
25 DataFrame, DataFrameWriteOptions, DataFusionError, LogicalPlanBuilder, RecordBatch,
26};
27
28use datafusion_common::config::TableParquetOptions;
29use datafusion_common::not_impl_err;
30use datafusion_expr::dml::InsertOp;
31
32impl DataFrame {
33 pub async fn write_parquet(
59 self,
60 path: &str,
61 options: DataFrameWriteOptions,
62 writer_options: Option<TableParquetOptions>,
63 ) -> Result<Vec<RecordBatch>, DataFusionError> {
64 if options.insert_op != InsertOp::Append {
65 return not_impl_err!(
66 "{} is not implemented for DataFrame::write_parquet.",
67 options.insert_op
68 );
69 }
70
71 let format = if let Some(parquet_opts) = writer_options {
72 Arc::new(ParquetFormatFactory::new_with_options(parquet_opts))
73 } else {
74 Arc::new(ParquetFormatFactory::new())
75 };
76
77 let file_type = format_as_file_type(format);
78
79 let plan = if options.sort_by.is_empty() {
80 self.plan
81 } else {
82 LogicalPlanBuilder::from(self.plan)
83 .sort(options.sort_by)?
84 .build()?
85 };
86
87 let plan = LogicalPlanBuilder::copy_to(
88 plan,
89 path.into(),
90 file_type,
91 Default::default(),
92 options.partition_by,
93 )?
94 .build()?;
95 DataFrame {
96 session_state: self.session_state,
97 plan,
98 projection_requires_validation: self.projection_requires_validation,
99 }
100 .collect()
101 .await
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use std::collections::HashMap;
108 use std::sync::Arc;
109
110 use super::super::Result;
111 use super::*;
112 use crate::arrow::util::pretty;
113 use crate::execution::context::SessionContext;
114 use crate::execution::options::ParquetReadOptions;
115 use crate::test_util::{self, register_aggregate_csv};
116
117 use datafusion_common::file_options::parquet_writer::parse_compression_string;
118 use datafusion_execution::config::SessionConfig;
119 use datafusion_expr::{col, lit};
120
121 #[cfg(feature = "parquet_encryption")]
122 use datafusion_common::config::ConfigFileEncryptionProperties;
123 use object_store::local::LocalFileSystem;
124 use parquet::file::reader::FileReader;
125 use tempfile::TempDir;
126 use url::Url;
127
128 #[tokio::test]
129 async fn filter_pushdown_dataframe() -> Result<()> {
130 let ctx = SessionContext::new();
131
132 ctx.register_parquet(
133 "test",
134 &format!(
135 "{}/alltypes_plain.snappy.parquet",
136 test_util::parquet_test_data()
137 ),
138 ParquetReadOptions::default(),
139 )
140 .await?;
141
142 ctx.register_table("t1", ctx.table("test").await?.into_view())?;
143
144 let df = ctx
145 .table("t1")
146 .await?
147 .filter(col("id").eq(lit(1)))?
148 .select_columns(&["bool_col", "int_col"])?;
149
150 let plan = df.explain(false, false)?.collect().await?;
151 let formatted = pretty::pretty_format_batches(&plan)?.to_string();
153 assert!(formatted.contains("FilterExec: id@0 = 1"));
154
155 Ok(())
156 }
157
158 #[tokio::test]
159 async fn write_parquet_with_compression() -> Result<()> {
160 let test_df = test_util::test_table().await?;
161 let output_path = "file://local/test.parquet";
162 let test_compressions = vec![
163 "snappy",
164 "brotli(1)",
165 "lz4",
166 "lz4_raw",
167 "gzip(6)",
168 "zstd(1)",
169 ];
170 for compression in test_compressions.into_iter() {
171 let df = test_df.clone();
172 let tmp_dir = TempDir::new()?;
173 let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?);
174 let local_url = Url::parse("file://local").unwrap();
175 let ctx = &test_df.session_state;
176 ctx.runtime_env().register_object_store(&local_url, local);
177 let mut options = TableParquetOptions::default();
178 options.global.compression = Some(compression.to_string());
179 df.write_parquet(
180 output_path,
181 DataFrameWriteOptions::new().with_single_file_output(true),
182 Some(options),
183 )
184 .await?;
185
186 let file = std::fs::File::open(tmp_dir.path().join("test.parquet"))?;
188
189 let reader =
190 parquet::file::serialized_reader::SerializedFileReader::new(file)
191 .unwrap();
192
193 let parquet_metadata = reader.metadata();
194
195 let written_compression =
196 parquet_metadata.row_group(0).column(0).compression();
197
198 assert_eq!(written_compression, parse_compression_string(compression)?);
199 }
200
201 Ok(())
202 }
203
204 #[tokio::test]
205 async fn write_parquet_with_small_rg_size() -> Result<()> {
206 let ctx = SessionContext::new_with_config(SessionConfig::from_string_hash_map(
209 &HashMap::from_iter(
210 [("datafusion.execution.batch_size", "10")]
211 .iter()
212 .map(|(s1, s2)| ((*s1).to_string(), (*s2).to_string())),
213 ),
214 )?);
215 register_aggregate_csv(&ctx, "aggregate_test_100").await?;
216 let test_df = ctx.table("aggregate_test_100").await?;
217
218 let output_path = "file://local/test.parquet";
219
220 for rg_size in 1..10 {
221 let df = test_df.clone();
222 let tmp_dir = TempDir::new()?;
223 let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?);
224 let local_url = Url::parse("file://local").unwrap();
225 let ctx = &test_df.session_state;
226 ctx.runtime_env().register_object_store(&local_url, local);
227 let mut options = TableParquetOptions::default();
228 options.global.max_row_group_size = rg_size;
229 options.global.allow_single_file_parallelism = true;
230 df.write_parquet(
231 output_path,
232 DataFrameWriteOptions::new().with_single_file_output(true),
233 Some(options),
234 )
235 .await?;
236
237 let file = std::fs::File::open(tmp_dir.path().join("test.parquet"))?;
239
240 let reader =
241 parquet::file::serialized_reader::SerializedFileReader::new(file)
242 .unwrap();
243
244 let parquet_metadata = reader.metadata();
245
246 let written_rows = parquet_metadata.row_group(0).num_rows();
247
248 assert_eq!(written_rows as usize, rg_size);
249 }
250
251 Ok(())
252 }
253
254 #[rstest::rstest]
255 #[cfg(feature = "parquet_encryption")]
256 #[tokio::test]
257 async fn roundtrip_parquet_with_encryption(
258 #[values(false, true)] allow_single_file_parallelism: bool,
259 ) -> Result<()> {
260 use parquet::encryption::decrypt::FileDecryptionProperties;
261 use parquet::encryption::encrypt::FileEncryptionProperties;
262
263 let test_df = test_util::test_table().await?;
264
265 let schema = test_df.schema();
266 let footer_key = b"0123456789012345".to_vec(); let column_key = b"1234567890123450".to_vec(); let mut encrypt = FileEncryptionProperties::builder(footer_key.clone());
270 let mut decrypt = FileDecryptionProperties::builder(footer_key.clone());
271
272 for field in schema.fields().iter() {
273 encrypt = encrypt.with_column_key(field.name().as_str(), column_key.clone());
274 decrypt = decrypt.with_column_key(field.name().as_str(), column_key.clone());
275 }
276
277 let encrypt = encrypt.build()?;
278 let decrypt = decrypt.build()?;
279
280 let df = test_df.clone();
281 let tmp_dir = TempDir::new()?;
282 let tempfile = tmp_dir.path().join("roundtrip.parquet");
283 let tempfile_str = tempfile.into_os_string().into_string().unwrap();
284
285 let mut options = TableParquetOptions::default();
287 options.crypto.file_encryption =
288 Some(ConfigFileEncryptionProperties::from(&encrypt));
289 options.global.allow_single_file_parallelism = allow_single_file_parallelism;
290
291 df.write_parquet(
292 tempfile_str.as_str(),
293 DataFrameWriteOptions::new().with_single_file_output(true),
294 Some(options),
295 )
296 .await?;
297 let num_rows_written = test_df.count().await?;
298
299 let ctx: SessionContext = SessionContext::new();
301 let read_options =
302 ParquetReadOptions::default().file_decryption_properties((&decrypt).into());
303
304 ctx.register_parquet("roundtrip_parquet", &tempfile_str, read_options.clone())
305 .await?;
306
307 let df_enc = ctx.sql("SELECT * FROM roundtrip_parquet").await?;
308 let num_rows_read = df_enc.count().await?;
309
310 assert_eq!(num_rows_read, num_rows_written);
311
312 let encrypted_parquet_df = ctx.read_parquet(tempfile_str, read_options).await?;
314
315 let selected = encrypted_parquet_df
318 .clone()
319 .select_columns(&["c1", "c2", "c3"])?
320 .filter(col("c2").gt(lit(4)))?;
321
322 let num_rows_selected = selected.count().await?;
323 assert_eq!(num_rows_selected, 14);
324
325 Ok(())
326 }
327}