Artem Goncharov 0f870d5a5f Import IntString from openstack_sdk
We are moving IntString, NumString, BoolString from cli to sdk so that we can
also use them in tui.

Change-Id: Ib2dcdd2f54481feacb7848037fb26fae4ef5d738
2025-04-03 14:59:11 +02:00

835 lines
30 KiB
Python

# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
import logging
from pathlib import Path
import re
import subprocess
from typing import Type, Any
from codegenerator.base import BaseGenerator
from codegenerator import common
from codegenerator import model
from codegenerator.common import BaseCompoundType
from codegenerator.common import BaseCombinedType
from codegenerator.common import BasePrimitiveType
from codegenerator.common import rust as common_rust
from codegenerator.rust_sdk import TypeManager as SdkTypeManager
from codegenerator import rust_sdk
BASIC_FIELDS = [
"name",
"title",
"created_at",
"updated_at",
"state",
"status",
"operating_status",
]
class String(common_rust.String):
type_hint: str = "String"
def get_sdk_setter(
self, source_var_name: str, sdk_mod_path: str, into: bool = False
) -> str:
if into:
return f"{source_var_name}.into()"
else:
return f"{source_var_name}.clone()"
class IntString(common.BasePrimitiveType):
"""TUI Integer or String"""
imports: set[str] = {"openstack_sdk::types::IntString"}
type_hint: str = "IntString"
clap_macros: set[str] = set()
class NumString(common.BasePrimitiveType):
"""TUI Number or String"""
imports: set[str] = {"openstack_sdk::types::NumString"}
type_hint: str = "NumString"
clap_macros: set[str] = set()
class BoolString(common.BasePrimitiveType):
"""TUI Boolean or String"""
imports: set[str] = {"openstack_sdk::types::BoolString"}
type_hint: str = "BoolString"
clap_macros: set[str] = set()
class ArrayInput(common_rust.Array):
original_data_type: (
common_rust.BaseCompoundType
| common_rust.BaseCombinedType
| common_rust.BasePrimitiveType
| None
) = None
def get_sdk_setter(
self,
source_var_name: str,
sdk_mod_path: str,
into: bool = False,
ord_num: int = 0,
) -> str:
ord_num += 1
result: str = source_var_name
if isinstance(self.item_type, common_rust.BaseCompoundType):
result += f".iter().flat_map(|x| TryFrom::try_from(x)).collect::<Vec<{'::'.join(sdk_mod_path)}::{self.item_type.name}>>()"
elif isinstance(self.item_type, ArrayInput) and isinstance(
self.item_type.item_type, common_rust.BasePrimitiveType
):
result += f".iter().cloned()"
elif isinstance(self.item_type, common_rust.BaseCombinedType):
if into:
result += ".iter()"
else:
result += ".iter().cloned()"
result += (
f".map(|x{ord_num}| "
+ self.item_type.get_sdk_setter(
f"x{ord_num}", sdk_mod_path, into=True
)
+ ").collect::<Vec<_>>()"
)
else:
if into:
result += ".into_iter()"
else:
result += ".iter().cloned()"
result += f".map(Into::into).collect::<Vec<_>>()"
return result
class StructField(rust_sdk.StructField):
def get_sdk_setter(
self,
sdk_field: rust_sdk.StructField,
source_var: str,
dest_var: str,
sdk_mod_path: str,
into: bool = False,
) -> str:
result: str = ""
source = "val" if self.is_optional else f"value.{self.local_name}"
if self.is_optional:
result += f"if let Some(val) = &{source_var}.{self.local_name} {{"
if isinstance(sdk_field.data_type, rust_sdk.Struct):
if not self.is_optional and not into:
source = f"&{source}"
result += (
f"{dest_var}.{sdk_field.local_name}(TryInto::<{'::'.join(sdk_mod_path)}::{sdk_field.data_type.name}>::try_into("
+ source
+ ")?);"
)
else:
result += (
f"{dest_var}.{sdk_field.local_name}("
+ self.data_type.get_sdk_setter(
source, sdk_mod_path, into=into
)
+ ");"
)
if self.is_optional:
result += "}\n"
return result
class Struct(rust_sdk.Struct):
field_type_class_: Type[StructField] | StructField = StructField
original_data_type: BaseCompoundType | BaseCompoundType | None = None
is_required: bool = False
@property
def static_lifetime(self):
"""Return Rust `<'lc>` lifetimes representation"""
return f"<{', '.join(self.lifetimes)}>" if self.lifetimes else ""
def get_sdk_builder_try_from(
self, sdk_struct: rust_sdk.Struct, sdk_mod_path: list[str]
) -> str:
result: str = f"impl TryFrom<&{self.name}> for {'::'.join(sdk_mod_path)}::{sdk_struct.name}Builder{sdk_struct.static_lifetime_anonymous} {{"
result += "type Error = Report;\n"
result += (
f"fn try_from(value: &{self.name}) -> Result<Self, Self::Error> {{"
)
result += "let mut ep_builder = Self::default();\n"
result += self.get_set_sdk_struct_fields(
sdk_struct, "value", "ep_builder", sdk_mod_path
)
result += "Ok(ep_builder)"
result += "}\n"
result += "}"
return result
def get_set_sdk_struct_fields(
self,
sdk_struct: rust_sdk.Struct,
source_var: str,
dest_var: str,
sdk_mod_path: list[str],
) -> str:
result: str = ""
for (field, field_data), (_, sdk_field_data) in zip(
self.fields.items(), sdk_struct.fields.items()
):
result += field_data.get_sdk_setter(
sdk_field_data, source_var, dest_var, sdk_mod_path, into=False
)
return result
def get_sdk_type_try_from(
self, sdk_struct: rust_sdk.Struct, sdk_mod_path: list[str]
) -> str:
result: str = f"impl TryFrom<&{self.name}> for {'::'.join(sdk_mod_path)}::{sdk_struct.name}{sdk_struct.static_lifetime_anonymous} {{"
result += "type Error = Report;\n"
result += f"fn try_from(value: &{self.name}) -> Result<Self, Self::Error> {{\n"
result += f"let ep_builder: {'::'.join(sdk_mod_path)}::{sdk_struct.name}Builder = TryFrom::try_from(value)?;\n"
result += f'ep_builder.build().wrap_err("cannot prepare request element `{self.name}`")'
result += "}\n"
result += "}"
return result
class StructFieldResponse(common_rust.StructField):
"""Response Structure Field"""
@property
def type_hint(self):
typ_hint = self.data_type.type_hint
if self.is_optional and not typ_hint.startswith("Option<"):
typ_hint = f"Option<{typ_hint}>"
return typ_hint
@property
def serde_macros(self):
macros = set()
if self.local_name != self.remote_name:
macros.add(f'rename="{self.remote_name}"')
if self.is_optional or self.data_type.type_hint.startswith("Option<"):
macros.add("default")
return f"#[serde({', '.join(sorted(macros))})]"
def get_structable_macros(
self,
struct: "StructResponse",
service_name: str,
resource_name: str,
operation_type: str,
):
macros = set()
if self.is_optional or self.data_type.type_hint.startswith("Option<"):
macros.add("optional")
macros.add(f'title="{self.remote_name.upper()}"')
# Fully Qualified Attribute Name
fqan: str = ".".join(
[service_name, resource_name, self.remote_name]
).lower()
# Check the known alias of the field by FQAN
alias = common.FQAN_ALIAS_MAP.get(fqan)
if operation_type in ["list", "list_from_struct"]:
if (
"id" in struct.fields.keys()
and not (
self.local_name in BASIC_FIELDS or alias in BASIC_FIELDS
)
) or (
"id" not in struct.fields.keys()
and (self.local_name not in list(struct.fields.keys())[-10:])
and not (
self.local_name in BASIC_FIELDS or alias in BASIC_FIELDS
)
):
# Only add "wide" flag if field is not in the basic fields AND
# there is at least "id" field existing in the struct OR the
# field is not in the first 10
macros.add("wide")
if (
self.local_name == "state"
and "status" not in struct.fields.keys()
):
macros.add("status")
elif (
self.local_name == "operating_status"
and "status" not in struct.fields.keys()
):
macros.add("status")
return f"#[structable({', '.join(sorted(macros))})]"
class StructResponse(common_rust.Struct):
field_type_class_: Type[common_rust.StructField] = StructFieldResponse
@property
def imports(self):
imports: set[str] = {"serde::Deserialize"}
for field in self.fields.values():
imports.update(field.data_type.imports)
return imports
class TypeManager(common_rust.TypeManager):
"""Rust SDK type manager
The class is responsible for converting ADT models into types suitable
for Rust (SDK).
"""
primitive_type_mapping: dict[Type[model.PrimitiveType], Type[Any]] = {
model.PrimitiveString: String,
model.ConstraintString: String,
}
data_type_mapping = {
model.Struct: Struct,
model.Array: ArrayInput,
model.CommaSeparatedList: ArrayInput,
}
request_parameter_class: Type[common_rust.RequestParameter] = (
common_rust.RequestParameter
)
sdk_type_manager: SdkTypeManager | None = None
def get_local_attribute_name(self, name: str) -> str:
"""Get localized attribute name"""
name = name.replace(".", "_")
attr_name = "_".join(
x.lower() for x in re.split(common.SPLIT_NAME_RE, name)
)
if attr_name in ["type", "self", "enum", "ref", "default"]:
attr_name = f"_{attr_name}"
return attr_name
def get_remote_attribute_name(self, name: str) -> str:
"""Get the attribute name on the SDK side"""
return self.get_local_attribute_name(name)
def link_sdk_type_manager(self, sdk_type_manager: SdkTypeManager) -> None:
self.sdk_type_manager = sdk_type_manager
def get_subtypes_with_sdk(self):
"""Get all subtypes excluding TLA"""
for k, v in self.refs.items():
if self.sdk_type_manager:
if k.name == "Body":
sdk_type = self.sdk_type_manager.get_root_data_type()
else:
sdk_type = self.sdk_type_manager.refs[k]
else:
sdk_type = None
if (
k
and isinstance(
v, (common_rust.Enum, Struct, common_rust.StringEnum)
)
and k.name != "Body"
):
yield (v, sdk_type)
elif (
k
and k.name != "Body"
and isinstance(v, self.option_type_class)
):
if isinstance(v.item_type, common_rust.Enum):
yield (v.item_type, sdk_type)
class ResponseTypeManager(common_rust.TypeManager):
primitive_type_mapping: dict[
Type[model.PrimitiveType], Type[BasePrimitiveType]
] = {
model.PrimitiveString: common_rust.String,
model.ConstraintString: common_rust.String,
}
data_type_mapping = {
model.Struct: StructResponse,
model.Array: common_rust.JsonValue,
model.Dictionary: common_rust.JsonValue,
}
def get_model_name(self, model_ref: model.Reference | None) -> str:
"""Get the localized model type name
In order to avoid collision between structures in request and
response we prefix all types with `Response`
:returns str: Type name
"""
if not model_ref:
return "Response"
return "Response" + "".join(
x.capitalize()
for x in re.split(common.SPLIT_NAME_RE, model_ref.name)
)
def _simplify_oneof_combinations(self, type_model, kinds):
"""Simplify certain known oneOf combinations"""
kinds_classes = [x["class"] for x in kinds]
if (
common_rust.String in kinds_classes
and common_rust.Number in kinds_classes
):
# oneOf [string, number] => NumString
kinds.clear()
kinds.append({"local": NumString(), "class": NumString})
elif (
common_rust.String in kinds_classes
and common_rust.Integer in kinds_classes
):
# oneOf [string, integer] => NumString
kinds.clear()
kinds.append({"local": IntString(), "class": IntString})
elif (
common_rust.String in kinds_classes
and common_rust.Boolean in kinds_classes
):
# oneOf [string, boolean] => String
kinds.clear()
kinds.append({"local": BoolString(), "class": BoolString})
super()._simplify_oneof_combinations(type_model, kinds)
def _get_struct_type(self, type_model: model.Struct) -> common_rust.Struct:
"""Convert model.Struct into Rust `Struct`"""
struct_class = self.data_type_mapping[model.Struct]
mod = struct_class(
name=self.get_model_name(type_model.reference),
description=common_rust.sanitize_rust_docstrings(
type_model.description
),
)
field_class = mod.field_type_class_
for field_name, field in type_model.fields.items():
is_nullable: bool = False
field_data_type = self.convert_model(field.data_type)
if isinstance(field_data_type, self.option_type_class):
# Unwrap Option into "is_nullable"
# NOTE: but perhaps Option<Option> is better (not set vs set
# explicitly to None)
is_nullable = True
if isinstance(field_data_type.item_type, common_rust.Array):
# Unwrap Option<Option<Vec...>>
field_data_type = field_data_type.item_type
elif not isinstance(
field_data_type.item_type, BasePrimitiveType
):
# Everything more complex than a primitive goes to Value
field_data_type = common_rust.JsonValue(
**field_data_type.model_dump()
)
self.ignored_models.append(field.data_type)
elif not isinstance(field_data_type, BasePrimitiveType):
field_data_type = common_rust.JsonValue(
**field_data_type.model_dump()
)
self.ignored_models.append(field.data_type)
f = field_class(
local_name=self.get_local_attribute_name(field_name),
remote_name=self.get_remote_attribute_name(field_name),
description=common_rust.sanitize_rust_docstrings(
field.description
),
data_type=field_data_type,
is_optional=not field.is_required,
is_nullable=is_nullable,
)
mod.fields[field_name] = f
if type_model.additional_fields:
definition = type_model.additional_fields
# Structure allows additional fields
if isinstance(definition, bool):
mod.additional_fields_type = self.primitive_type_mapping[
model.PrimitiveAny
]
else:
mod.additional_fields_type = self.convert_model(definition)
return mod
def get_subtypes(self):
"""Get all subtypes excluding TLA"""
emited_data: set[str] = set()
for k, v in self.refs.items():
if (
k
and isinstance(
v,
(
common_rust.Enum,
common_rust.Struct,
common_rust.StringEnum,
common_rust.Dictionary,
common_rust.Array,
),
)
and k.name != "Body"
):
key = v.base_type + v.type_hint
if key not in emited_data:
emited_data.add(key)
yield v
def get_imports(self):
"""Get complete set of additional imports required by all models in scope"""
imports: set[str] = super().get_imports()
imports.discard("crate::common::parse_json")
return imports
class RustTuiGenerator(BaseGenerator):
def __init__(self):
super().__init__()
def _format_code(self, *args):
"""Format code using Rustfmt
:param *args: Path to the code to format
"""
for path in args:
subprocess.run(["rustfmt", "--edition", "2021", path])
def _render_command(
self, context: dict, impl_template: str, impl_dest: Path
):
"""Render command code"""
self._render(impl_template, context, impl_dest.parent, impl_dest.name)
def generate(
self, res, target_dir, openapi_spec=None, operation_id=None, args=None
):
"""Generate code for the Rust openstack_tui"""
logging.debug(
"Generating Rust TUI code for %s in %s [%s]",
operation_id,
target_dir,
args,
)
if not openapi_spec:
openapi_spec = common.get_openapi_spec(args.openapi_yaml_spec)
if not operation_id:
operation_id = args.openapi_operation_id
(path, method, spec) = common.find_openapi_operation(
openapi_spec, operation_id
)
# srv_name, resource_name = res.split(".") if res else (None, None)
path_resources = common.get_resource_names_from_url(path)
resource_name = path_resources[-1]
mime_type = None
openapi_parser = model.OpenAPISchemaParser()
operation_params: list[model.RequestParameter] = []
type_manager: TypeManager | None = None
sdk_type_manager: SdkTypeManager | None = None
is_json_patch: bool = False
# Collect all operation parameters
for param in openapi_spec["paths"][path].get(
"parameters", []
) + spec.get("parameters", []):
if (
("{" + param["name"] + "}") in path and param["in"] == "path"
) or param["in"] != "path":
# Respect path params that appear in path and not path params
param_ = openapi_parser.parse_parameter(param)
if param_.name in [
f"{resource_name}_id",
f"{resource_name.replace('_', '')}_id",
]:
path = path.replace(param_.name, "id")
# for i.e. routers/{router_id} we want local_name to be `id` and not `router_id`
param_.name = "id"
operation_params.append(param_)
# Process body information
# List of operation variants (based on the body)
operation_variants = common.get_operation_variants(
spec, args.operation_name
)
api_ver_matches: re.Match | None = None
path_elements = path.lstrip("/").split("/")
api_ver: dict[str, int] = {}
ver_prefix: str | None = None
is_list_paginated = False
if path_elements:
api_ver_matches = re.match(common.VERSION_RE, path_elements[0])
if api_ver_matches and api_ver_matches.groups():
# Remember the version prefix to discard it in the template
ver_prefix = path_elements[0]
for operation_variant in operation_variants:
logging.debug(f"Processing variant {operation_variant}")
# TODO(gtema): if we are in MV variants filter out unsupported query
# parameters
# TODO(gtema): previously we were ensuring `router_id` path param
# is renamed to `id`
additional_imports = set()
if api_ver_matches:
api_ver = {
"major": api_ver_matches.group(1),
"minor": api_ver_matches.group(3) or 0,
}
else:
api_ver = {}
service_name = common.get_rust_service_type_from_str(
args.service_type
)
operation_name: str = (
args.operation_type
if args.operation_type != "action"
else args.module_name
)
operation_name = "".join(
x.title()
for x in re.split(r"[-_]", operation_name.replace("os-", ""))
)
class_name = f"{service_name}{''.join(x.title() for x in path_resources)}{operation_name}".replace(
"_", ""
)
response_class_name = f"{service_name}{''.join(x.title() for x in path_resources)}".replace(
"_", ""
)
operation_body = operation_variant.get("body")
type_manager = TypeManager()
sdk_type_manager = SdkTypeManager()
type_manager.set_parameters(operation_params)
response_type_manager: common_rust.TypeManager = (
ResponseTypeManager()
)
sdk_type_manager.set_parameters(operation_params)
mod_name = "_".join(
x.lower()
for x in re.split(
common.SPLIT_NAME_RE,
(
args.module_name
or args.operation_name
or args.operation_type.value
or method
),
)
)
if operation_body:
min_ver = operation_body.get("x-openstack", {}).get("min-ver")
if min_ver:
mod_name += "_" + min_ver.replace(".", "")
v = min_ver.split(".")
if not len(v) == 2:
raise RuntimeError(
"Version information is not in format MAJOR.MINOR"
)
api_ver = {"major": v[0], "minor": v[1]}
# There is request body. Get the ADT from jsonschema
# if args.operation_type != "action":
(_, all_types) = openapi_parser.parse(
operation_body, ignore_read_only=True
)
# and feed them into the TypeManager
type_manager.set_models(all_types)
sdk_type_manager.set_models(all_types)
# else:
# logging.warn("Ignoring response type of action")
type_manager.link_sdk_type_manager(sdk_type_manager)
if method == "patch":
# There might be multiple supported mime types. We only select ones we are aware of
mime_type = operation_variant.get("mime_type")
if not mime_type:
raise RuntimeError(
"No supported mime types for patch operation found"
)
if mime_type != "application/json":
is_json_patch = True
mod_path = common.get_rust_sdk_mod_path(
args.service_type,
args.api_version,
args.alternative_module_path or path,
)
response_key: str | None = None
result_def: dict = {}
response_def: dict | None = {}
resource_header_metadata: dict = {}
# Get basic information about response
if args.operation_type == "list":
response = common.find_response_schema(
spec["responses"],
args.response_key or resource_name,
(
args.operation_name
if args.operation_type == "action"
else None
),
)
if response:
if args.response_key:
response_key = (
args.response_key
if args.response_key != "null"
else None
)
else:
response_key = resource_name
response_def, _ = common.find_resource_schema(
response, None, response_key
)
if response_def:
if response_def.get("type", "object") == "object" or (
# BS metadata is defined with type: ["object",
# "null"]
isinstance(response_def.get("type"), list)
and "object" in response_def["type"]
):
(root, response_types) = openapi_parser.parse(
response_def
)
response_type_manager.set_models(response_types)
additional_imports.add("serde_json::Value")
sdk_mod_path_base = [
"openstack_sdk",
"api",
] + common.get_rust_sdk_mod_path(
args.service_type, args.api_version, args.module_path or path
)
sdk_mod_path: list[str] = sdk_mod_path_base.copy()
mod_suffix: str = ""
sdk_mod_path.append((args.sdk_mod_name or mod_name) + mod_suffix)
additional_imports.add(
"::".join(sdk_mod_path) + "::RequestBuilder"
)
additional_imports.add(
"openstack_sdk::{AsyncOpenStack, api::QueryAsync}"
)
if args.operation_type == "list":
if "limit" in [
k for (k, _) in type_manager.get_parameters("query")
]:
is_list_paginated = True
additional_imports.add(
"openstack_sdk::api::{paged, Pagination}"
)
additional_imports.add("structable_derive::StructTable")
additional_imports.add("crate::utils::StructTable")
additional_imports.add("crate::utils::OutputConfig")
elif args.operation_type == "delete":
additional_imports.add("openstack_sdk::api::ignore")
additional_imports.add(
"crate::cloud_worker::ConfirmableRequest"
)
additional_imports.update(response_type_manager.get_imports())
# Deserialize is already in template since it is uncoditionally required
additional_imports.discard("serde::Deserialize")
additional_imports.discard("serde::Serialize")
context = {
"additional_imports": additional_imports,
"operation_id": operation_id,
"operation_type": spec.get(
"x-openstack-operation-type", args.operation_type
),
"command_description": common_rust.sanitize_rust_docstrings(
common.make_ascii_string(spec.get("description"))
),
"class_name": class_name,
"response_class_name": response_class_name,
"sdk_service_name": service_name,
"resource_name": resource_name,
"response_type_manager": response_type_manager,
"url": path.lstrip("/").lstrip(ver_prefix).lstrip("/"),
"method": method,
"type_manager": type_manager,
"sdk_type_manager": sdk_type_manager,
"sdk_mod_path": sdk_mod_path,
"response_key": response_key,
"response_list_item_key": args.response_list_item_key,
"mime_type": mime_type,
"is_json_patch": is_json_patch,
"api_ver": api_ver,
"is_list_paginated": is_list_paginated,
}
work_dir = Path(target_dir, "rust", "openstack_tui", "src")
impl_path = Path(
work_dir, "cloud_worker", "/".join(mod_path), f"{mod_name}.rs"
)
# Generate methods for the GET resource command
self._render_command(context, "rust_tui/impl.rs.j2", impl_path)
self._format_code(impl_path)
yield (mod_path, mod_name, path, class_name)
def generate_mod(
self, target_dir, mod_path, mod_list, url, resource_name, service_name
):
"""Generate collection module (include individual modules)"""
work_dir = Path(target_dir, "rust", "openstack_tui", "src")
impl_path = Path(
work_dir,
"cloud_worker",
"/".join(mod_path[0:-1]),
f"{mod_path[-1]}.rs",
)
service_name = "".join(x.title() for x in service_name.split("_"))
new_mod_list: dict[str, dict[str, str]] = {}
for mod_name, class_name in mod_list.items():
name = "".join(x.title() for x in mod_name.split("_"))
full_name = "".join(x.title() for x in mod_path[2:]) + name
if not class_name:
class_name = f"{service_name}{full_name}ApiRequest"
new_mod_list[mod_name] = {"name": name, "class_name": class_name}
context = {
"mod_list": new_mod_list,
"mod_path": mod_path,
"url": url,
"resource_name": resource_name,
"service_name": service_name,
}
# Generate methods for the GET resource command
self._render_command(context, "rust_tui/mod.rs.j2", impl_path)
self._format_code(impl_path)