Source code for chunk_sky_model


import numpy as np
import sys
import os
from enum import Enum

from wodenpy.skymodel.woden_skymodel import Component_Type_Counter, CompTypes
    
NUM_FLUX_TYPES = 3
    
[docs] class Components_Map(object): """ Mapping information for a set of components, of either POINT, GAUSSIAN or SHAPELET type. :cvar bool power_orig_inds: Relative indexes w.r.t all power-law components in the original sky model for power components. :cvar bool curve_orig_inds: Relative indexes w.r.t all curved power-law components in the original sky model for curve components. :cvar bool list_orig_inds: Relative indexes w.r.t all list-type components in the original sky model for list components. :cvar bool power_shape_basis_inds: Index of a basis function entry relative to its component for power components. :cvar bool curve_shape_basis_inds: Index of a basis function entry relative to its component for curve components. :cvar bool list_shape_basis_inds: Index of a basis function entry relative to its component for list components. :cvar float lowest_file_num: The line in the original sky model file that each component appears in. Ignore all lines before the smallest line number for all components in this chunk, makes reading faster. :cvar int total_num_flux_entires: Number of flux list entries there are in total so we can allocate correct amount when reading in full information. :cvar int total_shape_coeffs: Number of shapelet basis functions that are in total so we can allocate correct amount when reading in full information. """ ##Mapping information for a set of components, of either POINT, ##GAUSSIAN or SHAPELET type def __init__(self): """ Setup required fields """ ##These are relative indexes w.r.t all components in the original ##sky model self.power_orig_inds = False self.curve_orig_inds = False self.list_orig_inds = False ##These are only used for a SHAPELET component ##They map the index of a basis function entry relative to it's component self.power_shape_basis_inds = False self.curve_shape_basis_inds = False self.list_shape_basis_inds = False ##when reading file back, quickest to stick into one single ##array with ##the line in the original sky model file that each component ##appears in. Ignore all lines before the smallest line number ##for all components in this chunk, makes reading faster self.lowest_file_num = np.nan ##use this to count how many flux list entries there are in total ##so we can allocate correct amount when reading in full information self.total_num_flux_entires = 0 #use this to count how many shapelet basis functions that are in total #so we can allocate correct amount when reading in full information self.total_shape_coeffs = 0
[docs] class Skymodel_Chunk_Map(object): """A class representing a chunk of a sky model, containing information about the number of point, Gaussian, and shape components, as well as their respective power-law, curved power-law, or list type flux info. This class also provides methods for consolidating component and flux types into one array of original component indexes, and for printing information about the chunk. :cvar int n_point_lists: Number of POINT list-type components. :cvar int n_point_powers: Number of POINT power-law components. :cvar int n_point_curves: Number of POINT curved power-law components. :cvar int n_gauss_lists: Number of GAUSSIAN list-type components. :cvar int n_gauss_powers: Number of GAUSSIAN power-law components. :cvar int n_gauss_curves: Number of GAUSSIAN curved power-law components. :cvar int n_shape_lists: Number of SHAPELET list-type components. :cvar int n_shape_powers: Number of SHAPELET power-law components. :cvar int n_shape_curves: Number of SHAPELET curved power-law components. :cvar int n_shape_coeffs: Number of SHAPELET coefficients. :cvar int n_points: Number of POINT components :cvar int n_gauss: Number of GAUSSIAN components :cvar int n_shapes: Number of SHAPELET components :cvar int n_comps: Number of all components :cvar Components_Map point_components: Mapping object for POINT components. :cvar Components_Map gauss_components: Mapping object for GAUSSIAN components. :cvar Components_Map shape_components: Mapping object for SHAPELET components. """ def __init__(self, n_point_powers = 0, n_point_curves = 0, n_point_lists = 0, n_gauss_powers = 0, n_gauss_curves = 0, n_gauss_lists = 0, n_shape_powers = 0, n_shape_curves = 0, n_shape_lists = 0, n_shape_coeffs = 0): """Setup everything with zeros as default""" self.n_point_lists = n_point_lists self.n_point_powers = n_point_powers self.n_point_curves = n_point_curves self.n_gauss_lists = n_gauss_lists self.n_gauss_powers = n_gauss_powers self.n_gauss_curves = n_gauss_curves self.n_shape_lists = n_shape_lists self.n_shape_powers = n_shape_powers self.n_shape_curves = n_shape_curves self.n_shape_coeffs = n_shape_coeffs self.n_points = n_point_lists + n_point_powers + n_point_curves self.n_gauss = n_gauss_lists + n_gauss_powers + n_gauss_curves self.n_shapes = n_shape_lists + n_shape_powers + n_shape_curves self.n_comps = self.n_points + self.n_gauss + self.n_shapes ##Setup the POINT, GAUSS, and SHAPE classes ##TODO set these up regardless of size as empty things take up ##small RAM?? if self.n_points > 0: self.point_components = Components_Map() if self.n_gauss > 0: self.gauss_components = Components_Map() if self.n_shapes > 0: self.shape_components = Components_Map() self.lowest_file_number = np.nan ##Used to count how many basis function values have already been ##added, gets updated by `use_libwoden.add_info_to_source_catalogue` ##when reading in the full model from the catalogue file self.current_shape_basis_index = 0
[docs] def make_all_orig_inds_array(self): """Look through all component and flux types and consolidate into one array `self.all_orig_inds` of original component indexes. Use this when reading in full information from the sky model""" self.all_orig_inds = np.empty(self.n_points + self.n_gauss + self.n_shape_coeffs, dtype=int) lowest_file_lines = [] if self.n_points > 0: lowest_file_lines.append(self.point_components.lowest_file_num) low_ind = 0 if self.n_point_powers > 0: self.all_orig_inds[low_ind:low_ind+self.n_point_powers] = self.point_components.power_orig_inds low_ind += self.n_point_powers if self.n_point_curves > 0: self.all_orig_inds[low_ind:low_ind+self.n_point_curves] = self.point_components.curve_orig_inds low_ind += self.n_point_curves if self.n_point_lists > 0: self.all_orig_inds[low_ind:low_ind+self.n_point_lists] = self.point_components.list_orig_inds low_ind += self.n_point_lists if self.n_gauss > 0: lowest_file_lines.append(self.gauss_components.lowest_file_num) low_ind = self.n_points if self.n_gauss_powers > 0: self.all_orig_inds[low_ind:low_ind+self.n_gauss_powers] = self.gauss_components.power_orig_inds low_ind += self.n_gauss_powers if self.n_gauss_curves > 0: self.all_orig_inds[low_ind:low_ind+self.n_gauss_curves] = self.gauss_components.curve_orig_inds low_ind += self.n_gauss_curves if self.n_gauss_lists > 0: self.all_orig_inds[low_ind:low_ind+self.n_gauss_lists] = self.gauss_components.list_orig_inds low_ind += self.n_gauss_lists if self.n_shapes > 0: lowest_file_lines.append(self.shape_components.lowest_file_num) low_ind = self.n_points + self.n_gauss if self.n_shape_powers > 0: power_indexes = self.shape_components.power_shape_orig_inds self.all_orig_inds[low_ind:low_ind+len(power_indexes)] = power_indexes low_ind += len(power_indexes) if self.n_shape_curves > 0: curve_indexes = self.shape_components.curve_shape_orig_inds self.all_orig_inds[low_ind:low_ind+len(curve_indexes)] = curve_indexes low_ind += len(curve_indexes) if self.n_shape_lists > 0: list_indexes = self.shape_components.list_shape_orig_inds self.all_orig_inds[low_ind:low_ind+len(list_indexes)] = list_indexes low_ind += len(list_indexes) self.lowest_file_number = min(lowest_file_lines)
[docs] def print_info(self): """ Prints information about the ChunkSkyModel object, including the number of points, Gaussians, and shapes, as well as the number of powers, curves, lists, and coefficients associated with each type of object. """ print("n_points", self.n_points) print("\tn_point_powers", self.n_point_powers) print("\tn_point_curves", self.n_point_curves) print("\tn_point_lists", self.n_point_lists) print("n_gauss", self.n_gauss) print("\tn_gauss_powers", self.n_gauss_powers) print("\tn_gauss_curves", self.n_gauss_curves) print("\tn_gauss_lists", self.n_gauss_lists) print("n_shapes", self.n_shapes) print("\tn_shape_powers", self.n_shape_powers) print("\tn_shape_curves", self.n_shape_curves) print("\tn_shape_lists", self.n_shape_lists) print("\tn_shape_coeffs", self.n_shape_coeffs)
[docs] def increment_flux_type_counters(power_iter : int, curve_iter : int, list_iter : int, num_chunk_power : int, num_chunk_curve : int, num_chunk_list : int, num_power : int, num_curve : int, num_list : int, comps_per_chunk : int, lower_comp_ind : int, upper_comp_ind : int): """ Here, given the overall lower and upper index in a given type of components, work out how many of each flux type we have and increment the counters as appropriate Always order things as POWER_LAW, CURVED_POWER_LAW, LIST Parameters ----------- power_iter : int The current iteration of the power law flux type. curve_iter : int The current iteration of the curved power law flux type. list_iter : int The current iteration of the list flux type. num_chunk_power : int The number of power law flux types in the current chunk. num_chunk_curve : int The number of curved power law flux types in the current chunk. num_chunk_list : int The number of list flux types in the current chunk. num_power : int The total number of power law flux types. num_curve : int The total number of curved power law flux types. num_list : int The total number of list flux types. comps_per_chunk : int The number of components per chunk. lower_comp_ind : int The lower index of the current chunk. upper_comp_ind : int The upper index of the current chunk. Returns -------- Tuple of integers The updated values of power_iter, curve_iter, list_iter, num_chunk_power, num_chunk_curve, num_chunk_list. """ remainder = 0 lower_flux_ind = 0 upper_flux_ind = 0 ## Enough POWER_LAW to fill the whole chunk if (num_power > upper_comp_ind): num_chunk_power = comps_per_chunk power_iter = lower_comp_ind num_chunk_curve = 0 num_chunk_list = 0 ##Not enough POWER_LAW to fill the whole chunk else: ##There are enough POWER_LAW to partially fill chunk if (num_power >= lower_comp_ind): num_chunk_power = num_power - lower_comp_ind power_iter = lower_comp_ind ##How much is left to fill in this chunk remainder = comps_per_chunk - num_chunk_power ##If there are enough CURVED_POWER_LAW to fill rest of the chunk if (num_curve >= remainder): num_chunk_curve = remainder curve_iter = 0 num_chunk_list = 0 ##Not enough CURVED_POWER_LAW to fill rest of the chunk else: ##There are some CURVED_POWER_LAW to add if (num_curve < remainder) and (num_curve != 0): num_chunk_curve = num_curve curve_iter = 0 remainder -= num_curve ##There are enough LIST to fill the rest of the chunk if (num_list >= remainder): num_chunk_list = remainder list_iter = 0 ##There are some LIST but not enough to fill the rest of the chunk elif (num_list != 0): num_chunk_list = num_list list_iter = 0 ##There aren't any POWER_LAW to put in this chunk ##We may well have already chunked up a number of POWER_LAW so take ##off the number of POWER_LAW from the lower_comp_ind, upper_comp_ind else: lower_flux_ind = lower_comp_ind - num_power upper_flux_ind = upper_comp_ind - num_power ##There are enough CURVED_POWER_LAW to fill the rest of the chunk if (num_curve >= upper_flux_ind): num_chunk_curve = comps_per_chunk curve_iter = lower_flux_ind num_chunk_list = 0 else: ##There are some CURVED_POWER_LAW to add if (num_curve > lower_flux_ind) and (num_curve != 0): num_chunk_curve = num_curve - lower_flux_ind curve_iter = lower_flux_ind remainder = comps_per_chunk - num_chunk_curve ##There are enough LIST to fill rest of chunk if (num_list >= remainder): num_chunk_list = remainder list_iter = 0 ##There aren't enough LIST to fill chunk but there are some else: num_chunk_list = num_list list_iter = 0 ##There are no POWER_LAW or CURVED_POWER_LAW to add else: lower_flux_ind = lower_comp_ind - num_power - num_curve upper_flux_ind = upper_comp_ind - num_power - num_curve ##There are enough LIST to fill the rest of the chunk if (num_list > upper_flux_ind): num_chunk_list = comps_per_chunk list_iter = lower_flux_ind ##There are some LIST but not enough to fill the rest of the chunk elif (num_list > lower_flux_ind): num_chunk_list = num_list - lower_flux_ind list_iter = lower_flux_ind return power_iter, curve_iter, list_iter, num_chunk_power, num_chunk_curve, num_chunk_list
[docs] def fill_chunk_component(comp_type : CompTypes, cropped_comp_counter : Component_Type_Counter, power_iter: int, num_chunk_power : int, curve_iter: int, num_chunk_curve : int, list_iter: int, num_chunk_list : int) -> Skymodel_Chunk_Map: """ Fills in the relevant fields inside a `Skymodel_Chunk_Map` based on the given component type and the number of components of each type that are required in the chunk. The function returns a `Skymodel_Chunk_Map` in prepartion to read a number of chunks from the skymodel Parameters ------------- comp_type : CompTypes The type of component to be filled in the chunk. cropped_comp_counter : Component_Type_Counter The counter object that contains the indices of the cropped components. power_iter : int The starting index of the power components in the cropped component counter. num_chunk_power : int The number of power components to be included in the chunk. curve_iter : int The starting index of the curve components in the cropped component counter. num_chunk_curve : int The number of curve components to be included in the chunk. list_iter : int The starting index of the list components in the cropped component counter. num_chunk_list : int The number of list components to be included in the chunk. Returns -------- chunk_map: Skymodel_Chunk_Map A `Skymodel_Chunk_Map` object that contains the filled-in `Components_Map`. """ # if comp_type == CompTypes.POINT: chunk_map = Skymodel_Chunk_Map(n_point_powers = num_chunk_power, n_point_curves = num_chunk_curve, n_point_lists = num_chunk_list) components = chunk_map.point_components power_inds = cropped_comp_counter.orig_point_power_inds curve_inds = cropped_comp_counter.orig_point_curve_inds list_inds = cropped_comp_counter.orig_point_list_inds cropped_power_inds = np.where(cropped_comp_counter.comp_types == CompTypes.POINT_POWER.value)[0] cropped_curve_inds = np.where(cropped_comp_counter.comp_types == CompTypes.POINT_CURVE.value)[0] cropped_list_inds = np.where(cropped_comp_counter.comp_types == CompTypes.POINT_LIST.value)[0] elif comp_type == CompTypes.GAUSSIAN: chunk_map = Skymodel_Chunk_Map(n_gauss_powers = num_chunk_power, n_gauss_curves = num_chunk_curve, n_gauss_lists = num_chunk_list) components = chunk_map.gauss_components power_inds = cropped_comp_counter.orig_gauss_power_inds curve_inds = cropped_comp_counter.orig_gauss_curve_inds list_inds = cropped_comp_counter.orig_gauss_list_inds cropped_power_inds = np.where(cropped_comp_counter.comp_types == CompTypes.GAUSS_POWER.value)[0] cropped_curve_inds = np.where(cropped_comp_counter.comp_types == CompTypes.GAUSS_CURVE.value)[0] cropped_list_inds = np.where(cropped_comp_counter.comp_types == CompTypes.GAUSS_LIST.value)[0] components.power_orig_inds = power_inds[power_iter:power_iter+num_chunk_power] components.curve_orig_inds = curve_inds[curve_iter:curve_iter+num_chunk_curve] components.list_orig_inds = list_inds[list_iter:list_iter+num_chunk_list] cropped_power_inds = cropped_power_inds[power_iter:power_iter+num_chunk_power] cropped_curve_inds = cropped_curve_inds[curve_iter:curve_iter+num_chunk_curve] cropped_list_inds = cropped_list_inds[list_iter:list_iter+num_chunk_list] components.total_num_flux_entires = np.sum(cropped_comp_counter.num_list_fluxes[cropped_list_inds]) min_comp_inds = [] if len(cropped_power_inds) > 0: min_comp_inds.append(components.power_orig_inds.min()) if len(cropped_curve_inds) > 0: min_comp_inds.append(components.curve_orig_inds.min()) if len(cropped_list_inds) > 0: min_comp_inds.append(components.list_orig_inds.min()) components.min_comp_ind = np.min(min_comp_inds) return chunk_map
[docs] def map_chunk_pointgauss(cropped_comp_counter : Component_Type_Counter, chunk_ind : int, comps_per_chunk : int, point_source = False, gaussian_source = False) -> Components_Map: """ For a given chunk index `chunk_ind`, work out how many of each type of COMPONENT to fit in the chunk, and then map specifics across to that one chunk Parameters ----------- cropped_comp_counter : Component_Type_Counter A cropped component counter object that contains information about the components to use in the chunk. chunk_ind : int The index of the chunk. comps_per_chunk : int The number of components per chunk. point_source : bool Whether to use point sources. Default is False. gaussian_source : bool Whether to use gaussian sources. Default is False. Returns --------- chunk_map : Components_Map A Components_Map object that contains mapping information about the components in the chunk. """ if not point_source and not gaussian_source: print("You must set either `point_source` or `gaussian_source` to True") return 1 elif point_source and gaussian_source: print("You must set one `point_source` or `gaussian_source` to True, " "but both are set to True") return 1 else: ##Splitting POINTs and GAUSSIANS into lovely chunks that our GPU can chew ##First we have to ascertain where in the chunking we are, and which type ##of component we have to include ##Lower and upper indexes of components covered in this chunk lower_comp_ind = chunk_ind * comps_per_chunk upper_comp_ind = (chunk_ind + 1) * comps_per_chunk ##These ints are used to do pointer arithmatic to grab the correct portions ##of arrays out of `cropped_src` and into `temp_cropped_src` power_iter = 0 curve_iter = 0 list_iter = 0 num_chunk_power = 0 num_chunk_curve = 0 num_chunk_list = 0 n_powers = 0 n_curves = 0 n_lists = 0 if point_source: n_powers = cropped_comp_counter.num_point_flux_powers n_curves = cropped_comp_counter.num_point_flux_curves n_lists = cropped_comp_counter.num_point_flux_lists elif gaussian_source: n_powers = cropped_comp_counter.num_gauss_flux_powers n_curves = cropped_comp_counter.num_gauss_flux_curves n_lists = cropped_comp_counter.num_gauss_flux_lists ##Given the information about either point of gaussian components, ##spit out numbers of where in all the components we have reached ##for this chunk (*_iter), and how many of them there are ##(num_chunk*) power_iter, curve_iter, list_iter, num_chunk_power, num_chunk_curve, num_chunk_list = increment_flux_type_counters(power_iter, curve_iter, list_iter, num_chunk_power, num_chunk_curve, num_chunk_list, n_powers, n_curves, n_lists, comps_per_chunk, lower_comp_ind, upper_comp_ind) ##using these numbers, populate the chunk_comp_counter, which is ##our map for reading in the full information from the sky model for ##this chunk if point_source: chunk_map = fill_chunk_component(CompTypes.POINT, cropped_comp_counter, power_iter, num_chunk_power, curve_iter, num_chunk_curve, list_iter, num_chunk_list) elif gaussian_source: chunk_map = fill_chunk_component(CompTypes.GAUSSIAN, cropped_comp_counter, power_iter, num_chunk_power, curve_iter, num_chunk_curve, list_iter, num_chunk_list) chunk_map.make_all_orig_inds_array() return chunk_map
[docs] def create_shape_basis_maps(cropped_comp_counter : Component_Type_Counter): """ Creates maps that associate each shape basis function with its corresponding original component index, component type, and parameter index. Parameters: ----------- cropped_comp_counter : Component_Type_Counter An instance of the Component_Type_Counter class that contains information about the number of shape coefficients for each component type and the corresponding indices of the components in the original component list. Returns: -------- shape_basis_to_orig_comp_index_map : numpy.ndarray An array that maps each shape basis function to its corresponding original component index. shape_basis_to_comp_type_map : numpy.ndarray An array that maps each shape basis function to its corresponding component type. shape_basis_param_index : numpy.ndarray An array that maps each shape basis function to its corresponding parameter index within its component. """ shape_basis_to_orig_comp_index_map = np.empty(cropped_comp_counter.total_shape_basis) shape_basis_to_comp_type_map = np.empty(cropped_comp_counter.total_shape_basis) ##this holds the index of each basis function within a shapelet component shape_basis_param_index = np.empty(cropped_comp_counter.total_shape_basis) coeff_iter = 0 ##Go through power law flux indexes first, then curved, then list for comp_ind, orig_comp_ind in zip(cropped_comp_counter.shape_power_inds, cropped_comp_counter.orig_shape_power_inds): num_basis = cropped_comp_counter.num_shape_coeffs[comp_ind] shape_basis_to_orig_comp_index_map[coeff_iter:coeff_iter+num_basis] = orig_comp_ind shape_basis_to_comp_type_map[coeff_iter:coeff_iter+num_basis] = CompTypes.SHAPE_POWER.value shape_basis_param_index[coeff_iter:coeff_iter+num_basis] = np.arange(num_basis) coeff_iter += num_basis for comp_ind, orig_comp_ind in zip(cropped_comp_counter.shape_curve_inds, cropped_comp_counter.orig_shape_curve_inds): num_basis = cropped_comp_counter.num_shape_coeffs[comp_ind] shape_basis_to_orig_comp_index_map[coeff_iter:coeff_iter+num_basis] = orig_comp_ind shape_basis_to_comp_type_map[coeff_iter:coeff_iter+num_basis] = CompTypes.SHAPE_CURVE.value shape_basis_param_index[coeff_iter:coeff_iter+num_basis] = np.arange(num_basis) coeff_iter += num_basis for comp_ind, orig_comp_ind in zip(cropped_comp_counter.shape_list_inds, cropped_comp_counter.orig_shape_list_inds): num_basis = cropped_comp_counter.num_shape_coeffs[comp_ind] shape_basis_to_orig_comp_index_map[coeff_iter:coeff_iter+num_basis] = orig_comp_ind shape_basis_to_comp_type_map[coeff_iter:coeff_iter+num_basis] = CompTypes.SHAPE_LIST.value shape_basis_param_index[coeff_iter:coeff_iter+num_basis] = np.arange(num_basis) coeff_iter += num_basis return shape_basis_to_orig_comp_index_map, shape_basis_to_comp_type_map, shape_basis_param_index
[docs] def map_chunk_shapelets(cropped_comp_counter : Component_Type_Counter, shape_basis_to_orig_comp_index_map : np.ndarray, shape_basis_to_orig_type_map : np.ndarray, shape_basis_param_index : np.ndarray, chunk_ind : int, coeffs_per_chunk : int): """ Maps the shapelet components in a chunk of the sky model to their corresponding indices in the original sky model. This function is used to create a mapping between the shapelet components in the cropped sky model and their corresponding components in the original sky model. This mapping is used to extract the correct shapelet coefficients from the original sky model when creating a chunked sky model. Parameters ----------- cropped_comp_counter : Component_Type_Counter A Component_Type_Counter object containing information about the components in the cropped sky model. shape_basis_to_orig_comp_index_map : np.ndarray An array mapping the indices of the shapelet basis functions in the cropped sky model to their corresponding indices in the original sky model. shape_basis_to_orig_type_map : np.ndarray An array mapping the indices of the shapelet basis functions in the cropped sky model to their corresponding component types in the original sky model. shape_basis_param_index : np.ndarray An array containing the indices of the shapelet basis functions in the original sky model. chunk_ind : int The index of the chunk being mapped. coeffs_per_chunk : int The number of shapelet coefficients in each chunk. Returns -------- None """ ##Upper indexes of components covered in this chunk upper_coeff_ind = (chunk_ind + 1) * coeffs_per_chunk ##These ints are used to do pointer arithmatic to grab the correct portions ##of arrays out of `cropped_src` and into `temp_cropped_src` lower_coeff_ind = chunk_ind * coeffs_per_chunk ##If there are enough coeffs to fill the chunk? if (cropped_comp_counter.total_shape_basis >= upper_coeff_ind): n_shape_coeffs = coeffs_per_chunk else: n_shape_coeffs = cropped_comp_counter.total_shape_basis % coeffs_per_chunk ##the ranges of comp types being sampled depends on which basis function ##coeffs we are sampling, so work out that range from the mapping arrays orig_index_chunk = shape_basis_to_orig_comp_index_map[lower_coeff_ind:upper_coeff_ind] orig_type_chunk = shape_basis_to_orig_type_map[lower_coeff_ind:upper_coeff_ind] shape_basis_param_index_chunk = shape_basis_param_index[lower_coeff_ind:upper_coeff_ind] ##cop that for an annoyingly complicated piece of logic ##this selects the subset of original component indexes that we want power_orig_inds = np.unique(orig_index_chunk[orig_type_chunk == CompTypes.SHAPE_POWER.value]).astype(int) curve_orig_inds = np.unique(orig_index_chunk[orig_type_chunk == CompTypes.SHAPE_CURVE.value]).astype(int) list_orig_inds = np.unique(orig_index_chunk[orig_type_chunk == CompTypes.SHAPE_LIST.value]).astype(int) power_shape_orig_inds = orig_index_chunk[orig_type_chunk == CompTypes.SHAPE_POWER.value] curve_shape_orig_inds = orig_index_chunk[orig_type_chunk == CompTypes.SHAPE_CURVE.value] list_shape_orig_inds = orig_index_chunk[orig_type_chunk == CompTypes.SHAPE_LIST.value] power_shape_basis_inds = shape_basis_param_index_chunk[orig_type_chunk == CompTypes.SHAPE_POWER.value] curve_shape_basis_inds = shape_basis_param_index_chunk[orig_type_chunk == CompTypes.SHAPE_CURVE.value] list_shape_basis_inds = shape_basis_param_index_chunk[orig_type_chunk == CompTypes.SHAPE_LIST.value] num_chunk_power = len(power_orig_inds) num_chunk_curve = len(curve_orig_inds) num_chunk_list = len(list_orig_inds) chunk_map = Skymodel_Chunk_Map(n_shape_powers = num_chunk_power, n_shape_curves = num_chunk_curve, n_shape_lists = num_chunk_list, n_shape_coeffs = n_shape_coeffs) ##shorthand so we're not typing as many things components = chunk_map.shape_components ##TODO need some way to know what shapelet basis function indexes we ##want; similar to orig_comp_ind but for the basis functions components.power_shape_orig_inds = power_shape_orig_inds components.curve_shape_orig_inds = curve_shape_orig_inds components.list_shape_orig_inds = list_shape_orig_inds components.power_shape_basis_inds = power_shape_basis_inds components.curve_shape_basis_inds = curve_shape_basis_inds components.list_shape_basis_inds = list_shape_basis_inds ##Indexes of the shapelet components in the original sky model components.power_orig_inds = power_orig_inds components.curve_orig_inds = curve_orig_inds components.list_orig_inds = list_orig_inds ##how many shapelet coeffs we have components.total_shape_coeffs = n_shape_coeffs ##these are the indexes of each included component, within the cropped ##sky model itself cropped_power_inds = np.where(np.isin(cropped_comp_counter.orig_comp_indexes, power_orig_inds) == True)[0] cropped_curve_inds = np.where(np.isin(cropped_comp_counter.orig_comp_indexes, curve_orig_inds) == True)[0] cropped_list_inds = np.where(np.isin(cropped_comp_counter.orig_comp_indexes, list_orig_inds) == True)[0] ##if we have list type fluxes, count have many entries in total there are if num_chunk_list > 0: ##how many flux list entries in total are shared by these components components.total_num_flux_entires = np.sum(cropped_comp_counter.num_list_fluxes[cropped_list_inds]) chunk_map.make_all_orig_inds_array() return chunk_map
[docs] def create_skymodel_chunk_map(comp_counter : Component_Type_Counter, max_num_visibilities : int, num_baselines : int, num_freqs : int, num_time_steps : int, text_file=False) -> list: """ Given all the information in `comp_counter`, make a map of how to split the whole sky model up into manageable chunks to fit in memory. The purpose of this function is to record what to 'malloc' in each `Components_t` and `Source_t` ctype class before we lazy-load all the values into them directly from the skymodel. Parameters ---------- comp_counter: Component_Type_Counter object that contains information about the number of components of each type in the sky model. max_num_visibilities: int The maximum number of visibilities that can be loaded into memory at once. num_baselines: int The number of baselines in the observation. num_freqs: int The number of frequency channels in the observation. num_time_steps: int The number of time steps in the observation. text_file: Boolean A boolean flag indicating whether to we are reading in from text file or not (default False) Returns ------- list: A list of dictionaries containing information about the chunked sky model. """ ##The number of components per chunk is set by how many visibilities ##we have comps_per_chunk = int(np.floor(max_num_visibilities / (num_baselines * num_freqs * num_time_steps))) ##pray this never happens, probably means we're going to run out of ##GPU memory TODO don't pray, submit a warning? if comps_per_chunk < 1: comps_per_chunk = 1 ##chunks numbers for each type of component num_point_chunks = int(np.ceil(comp_counter.total_point_comps / comps_per_chunk)) num_gauss_chunks = int(np.ceil(comp_counter.total_gauss_comps / comps_per_chunk)) ##we split SHAPELET by the basis components (number of coeffs) num_coeff_chunks = int(np.ceil(comp_counter.total_shape_basis / comps_per_chunk)) ##total number of chunks the sky model is splitted into num_chunks = num_point_chunks + num_gauss_chunks + num_coeff_chunks ##TODO maybe more efficient to set an array and shove in ##elements rather than appending? chunked_skymodel_maps = [] ##Go through the point sources and add chunked maps for chunk_ind in range(num_point_chunks): chunk_map = map_chunk_pointgauss(comp_counter, chunk_ind, comps_per_chunk, point_source = True) chunked_skymodel_maps.append(chunk_map) ##Go through the gaussian sources and add chunked maps for chunk_ind in range(num_gauss_chunks): chunk_map = map_chunk_pointgauss(comp_counter, chunk_ind, comps_per_chunk, gaussian_source = True) chunked_skymodel_maps.append(chunk_map) ##need some extra mapping arrays to be able to grab the SHAPELET component ##that matches each basis function shape_basis_to_orig_comp_index_map, shape_basis_to_comp_type_map, shape_basis_param_index = create_shape_basis_maps(comp_counter) for chunk_ind in range(num_coeff_chunks): chunk_map = map_chunk_shapelets(comp_counter, shape_basis_to_orig_comp_index_map, shape_basis_to_comp_type_map, shape_basis_param_index, chunk_ind, comps_per_chunk) chunked_skymodel_maps.append(chunk_map) print(f"After chunking there are {len(chunked_skymodel_maps)} chunks") return chunked_skymodel_maps