TrajStat icon indicating copy to clipboard operation
TrajStat copied to clipboard

Improve the speed of the analysis of cwt and pscf by spatial indexing techniques

Open maerzhahaha opened this issue 9 months ago • 2 comments

I studied your code recently and found the anlysis of cwt and pscf are so slow when calculating large amount of data. I used spatial indexing techniques to reduced time consumption. The original project that took 30 minutes to calculate Nij was reduced to 30s with the introduction of spatial indexing. A python version of the code that mimics your cwt analysis is as follows:

import geopandas as gpd
from shapely.geometry import Point
from rtree import index  # 用于空间索引
from tqdm import tqdm  # 用于进度条
import numpy as np

def point_in_polygon(polygon, point):
    """
    判断点是否在多边形内部(包括边界)。
    基于射线交叉算法实现。

    参数:
    - polygon: 多边形的顶点列表,格式为 [(x1, y1), (x2, y2), ...]
    - point: 点的坐标,格式为 (x, y)

    返回值:
    - True: 点在多边形内部或边界上
    - False: 点在多边形外部
    """
    x, y = point
    n = len(polygon)
    if n < 3:
        return False  # 少于 3 个顶点,不是多边形

    inside = False
    x_old, y_old = polygon[-1]  # 最后一个顶点

    for i in range(n):
        x_new, y_new = polygon[i]
        if x_new > x_old:
            x1, x2 = x_old, x_new
            y1, y2 = y_old, y_new
        else:
            x1, x2 = x_new, x_old
            y1, y2 = y_new, y_old

        # 检查点是否在边的左侧
        if (x_new < x) == (x <= x_old) and (y - y1) * (x2 - x1) < (y2 - y1) * (x - x1):
            inside = not inside

        x_old, y_old = x_new, y_new

    return inside

def calculate_nij_cwt_wcwt(pscf_path, traj_paths, field_name, a_type, a_criterion, a_n_data, end_points, reduce_ratios, traj_num_end_points, traj_num_reduce_ratios,output_path):
    """
    计算 Nij、N_Traj、CWT 和 WCWT 的 Python 实现

    参数:
    - pscf_path: PSCF 图层(网格单元)的 Shapefile 路径
    - traj_paths: 轨迹图层 Shapefile 路径列表
    - field_name: 轨迹数据中用于筛选的字段名(如污染物浓度)
    - a_type: 类型标识(如 "Mij")
    - a_criterion: 筛选阈值
    - a_n_data: 无效数据标记(如 -9999.0)
    - end_points: Nij 的分段点列表(如 [10, 20, 30])
    - reduce_ratios: 每段的权重列表(如 [1.0, 0.8, 0.5])
    - traj_num_end_points: N_Traj 的分段点列表(如 [5, 10, 15])
    - traj_num_reduce_ratios: 每段的权重列表(如 [1.0, 0.8, 0.5])
    """
    # 1. 读取 PSCF 网格单元
    pscf_gdf = gpd.read_file(pscf_path)
    pscf_gdf = pscf_gdf.to_crs("EPSG:4326")  # 确保使用 WGS84 坐标系
    print(f"PSCF 网格单元数: {len(pscf_gdf):,}")

    # 2. 初始化 Nij、N_Traj 和 CWT 数组
    nijs = np.zeros(len(pscf_gdf), dtype=int)
    n_trajs = np.zeros(len(pscf_gdf), dtype=int)
    cwts = np.zeros(len(pscf_gdf), dtype=float)

    # 3. 为 PSCF 网格创建空间索引
    print("构建空间索引...")
    spatial_idx = index.Index()
    for idx, row in pscf_gdf.iterrows():
        spatial_idx.insert(idx, row.geometry.bounds)

    # 4. 处理每个轨迹图层
    for traj_path in traj_paths:
        traj_gdf = gpd.read_file(traj_path)
        traj_gdf = traj_gdf.to_crs(pscf_gdf.crs)  # 统一坐标系
        
        # 遍历轨迹图层中的每条线
        for _, traj_row in tqdm(traj_gdf.iterrows(), total=len(traj_gdf), desc=f"处理 {traj_path}"):
            # 跳过无效数据
            value = traj_row[field_name]
            if value == a_n_data:
                continue
            
            # 根据类型筛选数据(如 Mij 需满足阈值)
            if a_type == "Mij" and value < a_criterion:
                continue
            
            # 提取轨迹线的所有点
            line = traj_row.geometry
            if line.is_empty or line.geom_type != "LineString":
                continue
            
            visited_cells = set()  # 记录当前轨迹线访问过的网格
            
            # 遍历轨迹线上的每个点
            for point in line.coords:
                p = (point[0], point[1])  # 提取二维坐标 (x, y)
                
                # 通过空间索引快速筛选可能包含点的网格
                possible_cells = list(spatial_idx.intersection((p[0], p[1], p[0], p[1])))
                for cell_idx in possible_cells:
                    polygon = list(pscf_gdf.iloc[cell_idx].geometry.exterior.coords)
                    # 使用射线交叉算法判断点是否在网格内
                    if point_in_polygon(polygon, p):
                        nijs[cell_idx] += 1
                        cwts[cell_idx] += value
                        if cell_idx not in visited_cells:
                            n_trajs[cell_idx] += 1
                            visited_cells.add(cell_idx)
                        break  # 一个点只属于一个网格

    # 5. 计算 CWT 值
    for i in range(len(pscf_gdf)):
        if nijs[i] > 0:
            cwts[i] = cwts[i] / nijs[i]

    # 6. 计算 WCWT 初步值
    wcwts = np.zeros(len(pscf_gdf), dtype=float)
    for i in range(len(pscf_gdf)):
        nij = nijs[i]
        cwt = cwts[i]
        wcwt = cwt  # 默认值
        for j in range(len(end_points)):
            if j == len(end_points) - 1:
                if nij <= end_points[j]:
                    wcwt = cwt * reduce_ratios[j]
            else:
                if nij <= end_points[j] and nij > end_points[j + 1]:
                    wcwt = cwt * reduce_ratios[j]
        wcwts[i] = wcwt

    # 7. 计算 WCWT 二次更新值
    for i in range(len(pscf_gdf)):
        t_nij = n_trajs[i]
        wcwt = wcwts[i]  # 使用初步计算的 WCWT 值
        for j in range(len(traj_num_end_points)):
            if j == len(traj_num_end_points) - 1:
                if t_nij <= traj_num_end_points[j]:
                    wcwt = wcwt * traj_num_reduce_ratios[j]
            else:
                if t_nij <= traj_num_end_points[j] and t_nij > traj_num_end_points[j + 1]:
                    wcwt = wcwt * traj_num_reduce_ratios[j]
        wcwts[i] = wcwt

    # 8. 将结果保存到 PSCF 图层
    pscf_gdf["Nij"] = nijs
    pscf_gdf["N_Traj"] = n_trajs
    pscf_gdf["CWT"] = cwts
    pscf_gdf["WCWT"] = wcwts

    pscf_gdf.to_file(output_path)
    print(f"结果已保存至: {output_path}")

# 示例调用
if __name__ == "__main__":
    # 示例参数
    end_points = [80, 20, 10]  # Nij 的分段点
    reduce_ratios = [0.7, 0.42, 0.05]  # 每段的权重
    traj_num_end_points = [3,1]  # N_Traj 的分段点
    traj_num_reduce_ratios = [0.5, 0.2]  # 每段的权重

    # *********************************************修改参数********** 
    years = [2017,2018,2019,2020,2021,2022, 2023]  # 要计算的年份,注意修改
    year = 2018  # 要计算的年份,注意修改
    pscf_path = r"C:\Users\GIGA RTX\Desktop\SWY\HYSPLIT\O3\O3T_2018.shp"  # PSCF 网格文件路径
    
    field_names = ["SO2","O3","CO","PM2d5","NO2"] #字段名+
    f"C:/Users/GIGA RTX/Desktop/SWY/HYSPLIT/O3/grid2020.shp"
    # *********************************************修改结束**********  
    
    years 

    for i in range(len(field_names)):
        calculate_nij_cwt_wcwt(
            pscf_path=pscf_path,  # PSCF 网格文件路径
            traj_paths = [r"C:\Users\GIGA RTX\Desktop\SWY\HYSPLIT\calculate_tra\calculate_tra\2020\2020_2_5.shp"],  # 这是轨迹*********************************************修改参数********** 
            field_name=field_names[i],  # 轨迹数据中用于筛选的字段
            a_type="Nij",  # 类型标识
            a_criterion=0,  # 筛选阈值
            a_n_data=-9999.0,  # 无效数据标记
            end_points=end_points,  # Nij 的分段点
            reduce_ratios=reduce_ratios,  # 每段的权重
            traj_num_end_points=traj_num_end_points,  # N_Traj 的分段点
            traj_num_reduce_ratios=traj_num_reduce_ratios,  # 每段的权重
            output_path="tt25_WCWT.shp"  # 结果保存路径
        )


maerzhahaha avatar Mar 11 '25 06:03 maerzhahaha

Great job! Is it possible for you to write the similar code using Java and submit a PR to TrajStat for better performance?

Yaqiang avatar Mar 11 '25 08:03 Yaqiang

Thanks for the recognition, your open source work has greatly facilitated our research. However, I am not familiar with java, so if I have some spare time, I will provide a java version

maerzhahaha avatar Mar 11 '25 09:03 maerzhahaha