Source code for eva.eva_interactive

#!/usr/bin/env python

# (C) Copyright 2023 United States Government as represented by the Administrator of the
# National Aeronautics and Space Administration. All Rights Reserved.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.


# --------------------------------------------------------------------------------------------------


import os
import netCDF4 as nc
import xarray as xr
import re
import numpy as np

from eva.data.data_collections import DataCollections
from eva.utilities.logger import Logger
from eva.utilities.timing import Timing
from eva.data.eva_dataset_base import EvaDatasetFactory
from eva.transforms.arithmetic import arithmetic, generate_arithmetic_config
from eva.transforms.accept_where import accept_where, generate_accept_where_config

import eva.plotting.hvplot.interactive_plot_tools as plot


# --------------------------------------------------------------------------------------------------


[docs]class EvaInteractive(): """ A class for interactive data manipulation and visualization using the EVA framework. This class provides methods to load data collections, perform data transformations, and generate various types of plots. Attributes: logger (Logger): An instance of the Logger class for logging messages. timer (Timing): An instance of the Timing class for measuring execution time. dc_dict (dict): A dictionary containing data collections indexed by collection names. fn_dict (dict): A dictionary containing filenames associated with data collections. ch_required_dict (dict): A dictionary indicating whether channel requirements are needed for each data collection. var_cache (list): A list to cache variable names. """ def __init__(self): """ Initialize the EvaInteractive instance with necessary attributes. """ self.logger = Logger('EvaInteractive') self.timer = Timing() self.dc_dict = {} self.fn_dict = {} self.ch_required_dict = {} self.var_cache = [] # ----------------------------------------------------------------------------------------------
[docs] def load_collection(self, collection_name, filenames, eva_class_name, control_file=None): """ Load a data collection into the EvaInteractive instance. Args: collection_name (str): Name for the loaded data collection. filenames (str or list): Filename(s) containing the data. eva_class_name (str): Name of the EVA class for creating the data collection. control_file (str, optional): Path to the control file for configuring data collection. Default is None. """ # Handle filenames input if isinstance(filenames, str): filenames = [filenames] creator = EvaDatasetFactory() data_collection = DataCollections() eva_object = creator.create_eva_object(eva_class_name, 'data', self.logger, self.timer) if control_file: config = eva_object.generate_default_config(filenames, collection_name, control_file) else: config = eva_object.generate_default_config(filenames, collection_name) eva_object.execute(config, data_collection, self.timer) self.dc_dict[collection_name] = data_collection self.fn_dict[collection_name] = filenames[0] no_ch_dataspaces = ['JediLog', 'MonDataSpace'] # Open up file to find channel requirements if eva_class_name not in no_ch_dataspaces: ds = nc.Dataset(filenames[0]) if 'Channel' in ds.dimensions.keys(): self.ch_required_dict[collection_name] = True else: self.ch_required_dict[collection_name] = False else: self.ch_required_dict[collection_name] = False
# ----------------------------------------------------------------------------------------------
[docs] def get_data_collection(self, collection_name): """ Retrieve a data collection by its name. Args: collection_name (str): Name of the data collection to retrieve. Returns: DataCollections: The retrieved data collection instance. Raises: Exception: If the specified collection does not exist. """ if collection_name in self.dc_dict.keys(): return self.dc_dict[collection_name] else: self.logger.abort(f'Collection name \'{collection_name}\' does not exist. ')
# ----------------------------------------------------------------------------------------------
[docs] def print_data_collection(self, collection_name): """ Print the content of a data collection. Args: collection_name (str): Name of the data collection to print. """ if collection_name in self.dc_dict.keys(): self.dc_dict[collection_name].display_collections() else: self.logger.abort(f'Collection name \'{collection_name}\' does not exist. ')
# ----------------------------------------------------------------------------------------------
[docs] def retrieve_var_list(self, collection, group): """ Retrieve a list of variable names from a data collection's group. Args: collection (str): Name of the data collection. group (str): Name of the group within the data collection. Returns: list: List of variable names within the specified group. """ ds = nc.Dataset(self.fn_dict[collection]) if group in list(ds.groups): var_list = list(ds[group].variables) self.var_cache = var_list return self.var_cache
# ----------------------------------------------------------------------------------------------
[docs] def arithmetic(self, new_name, expression, collection, var_list=[]): """ Apply an arithmetic transformation to the data collection. Args: new_name (str): Name of the new variable to be created. expression (str): Arithmetic expression to be evaluated. collection (str): Name of the data collection. var_list (list, optional): List of variable names to be used in the expression. Default is an empty list. """ # Ensure var_list is not empty if not var_list: group = re.split(r'\(|\)|-|\*|\+|\/', expression)[0] var_list = self.retrieve_var_list(collection, group) # Generate default config for transform arithmetic_config = generate_arithmetic_config(new_name, expression, collection, var_list) # Execute transform arithmetic(arithmetic_config, self.dc_dict[collection]) self.logger.info(f'Added \'{new_name}\' to data collection \'{collection}\'.')
# ----------------------------------------------------------------------------------------------
[docs] def accept_where(self, new_name, starting_field, where, collection, var_list=[]): """ Apply an 'accept_where' transformation to the data collection based on specified conditions. Args: new_name (str): Name of the new variable to be created. starting_field (str): Field to which the 'where' conditions are applied. where (list): List of expressions specifying conditions for accepting data. collection (str): Name of the data collection. var_list (list, optional): List of variable names to be used in the expressions. Default is an empty list. """ # Make sure all expressions are in correct format for expression in where: try: group, _, _ = expression.split(' ') except Exception: self.logger.abort(f'Failed to split \'{expression}\'. Check that ' + 'it has the correct format') # Set var_list if empty if not var_list: var_list = self.retrieve_var_list(collection, group) # Generate default config for transform accept_where_config = generate_accept_where_config(new_name, starting_field, where, collection, var_list) # Execute transform accept_where(accept_where_config, self.dc_dict[collection]) self.logger.info(f'Added \'{new_name}\' to data collection \'{collection}\'.')
# ----------------------------------------------------------------------------------------------
[docs] def print_statistics(self, df): """ Print statistics for each column in the given DataFrame. Args: df (DataFrame): The DataFrame containing data. """ # for each column, print statistics nobs = str(len(df)) for column in df: col = df[column] print("name: " + column + "\n\t minimum: " + str(col.min()) + "\n\t maximum: " + str(col.max()) + "\n\t std: " + str(col.std()))
# ----------------------------------------------------------------------------------------------
[docs] def map_gridded(self): """ Placeholder method for generating a gridded map plot. """ print('map gridded')
# ----------------------------------------------------------------------------------------------
[docs] def line_plot(self, plot_list): """ Generate a line plot using specified data collections and variables. Args: plot_list (list): List of dictionaries specifying the plot configuration. Returns: holoviews.plotting.ElementPlot: The generated line plot. """ return plot.hvplot_line_plot(self.dc_dict, plot_list, self.ch_required_dict, self.logger)
# ----------------------------------------------------------------------------------------------
[docs] def histogram(self, plot_list): """ Generate a histogram plot using specified data collections and variables. Args: plot_list (list): List of dictionaries specifying the plot configuration. Returns: holoviews.plotting.ElementPlot: The generated histogram plot. """ return plot.hvplot_histogram(self.dc_dict, plot_list, self.ch_required_dict, self.logger)
# ----------------------------------------------------------------------------------------------
[docs] def map_scatter(self, plot_entry): """ Generate a scatter plot on a map using specified data collection and variables. Args: plot_entry (dict): Dictionary specifying the plot configuration. Returns: holoviews.plotting.ElementPlot: The generated scatter plot on a map. """ return plot.hvplot_map_scatter(self.dc_dict, plot_entry, self.logger)
# ----------------------------------------------------------------------------------------------
[docs] def density_plot(self, plot_list): """ Generate a density plot using specified data collections and variables. Args: plot_list (list): List of dictionaries specifying the plot configuration. Returns: holoviews.plotting.ElementPlot: The generated density plot. """ return plot.hvplot_density_plot(self.dc_dict, plot_list, self.ch_required_dict, self.logger)
# ----------------------------------------------------------------------------------------------
[docs] def scatter(self, x, y): """ Generate a scatter plot using specified data collections, variables, and coordinates. Args: x (str): Name of the variable for the x-coordinate. y (str): Name of the variable for the y-coordinate. Returns: holoviews.plotting.ElementPlot: The generated scatter plot. """ return plot.hvplot_scatter(self.dc_dict, x, y, self.ch_required_dict, self.logger)
# ----------------------------------------------------------------------------------------------