Skip to content

Commit cd80c22

Browse files
committed
gnd: Resolve entity references to the referenced entity's actual ID type
Entity reference fields (e.g., `token: Token!`) were hardcoded to use String for getters/setters regardless of the referenced entity's actual ID type. Now looks up the referenced entity's IdFieldKind so that e.g. a reference to an entity with `id: Bytes!` correctly generates `toBytes()`/`fromBytes()` instead of `toString()`/`fromString()`. Also fixes derived field getters on entities with Bytes IDs to use `toBytes().toHexString()` for the loader constructor argument.
1 parent 88a5c0d commit cd80c22

1 file changed

Lines changed: 202 additions & 11 deletions

File tree

gnd/src/codegen/schema.rs

Lines changed: 202 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,8 @@ struct FieldInfo {
192192
pub struct SchemaCodeGenerator {
193193
entities: Vec<EntityInfo>,
194194
entity_names: std::collections::HashSet<String>,
195+
/// Maps entity name to its ID field kind, for resolving entity reference types.
196+
entity_id_kinds: std::collections::HashMap<String, IdFieldKind>,
195197
}
196198

197199
impl SchemaCodeGenerator {
@@ -202,12 +204,18 @@ impl SchemaCodeGenerator {
202204
pub fn new(document: &Document<'_, String>) -> Result<Self> {
203205
let mut entities = Vec::new();
204206
let mut entity_names = std::collections::HashSet::new();
207+
let mut entity_id_kinds = std::collections::HashMap::new();
205208

206-
// First pass: collect entity names
209+
// First pass: collect entity names and their ID types
207210
for def in &document.definitions {
208211
if let Definition::TypeDefinition(TypeDefinition::Object(obj)) = def {
209212
if is_entity_type(obj) {
210213
entity_names.insert(obj.name.clone());
214+
let id_field = obj.fields.iter().find(|f| f.name == "id");
215+
let id_kind = id_field
216+
.map(|f| IdFieldKind::from_type_name(&get_base_type_name(&f.field_type)))
217+
.unwrap_or(IdFieldKind::String);
218+
entity_id_kinds.insert(obj.name.clone(), id_kind);
211219
}
212220
}
213221
}
@@ -265,6 +273,7 @@ impl SchemaCodeGenerator {
265273
Ok(Self {
266274
entities,
267275
entity_names,
276+
entity_id_kinds,
268277
})
269278
}
270279

@@ -332,7 +341,7 @@ impl SchemaCodeGenerator {
332341

333342
// Generate field getters and setters
334343
for field in &entity.fields {
335-
if let Some(getter) = self.generate_field_getter(&entity.name, field) {
344+
if let Some(getter) = self.generate_field_getter(&entity.name, field, &entity.id_kind) {
336345
klass.add_method(getter);
337346
}
338347
if let Some(setter) = self.generate_field_setter(field) {
@@ -468,12 +477,17 @@ impl SchemaCodeGenerator {
468477
]
469478
}
470479

471-
fn generate_field_getter(&self, entity_name: &str, field: &FieldInfo) -> Option<Method> {
480+
fn generate_field_getter(
481+
&self,
482+
entity_name: &str,
483+
field: &FieldInfo,
484+
id_kind: &IdFieldKind,
485+
) -> Option<Method> {
472486
let safe_name = handle_reserved_word(&field.name);
473487

474488
// Handle derived fields
475489
if field.is_derived {
476-
return self.generate_derived_field_getter(entity_name, field, &safe_name);
490+
return self.generate_derived_field_getter(entity_name, field, &safe_name, id_kind);
477491
}
478492

479493
let value_type = self.value_type_from_field(field);
@@ -529,17 +543,23 @@ impl SchemaCodeGenerator {
529543
entity_name: &str,
530544
field: &FieldInfo,
531545
safe_name: &str,
546+
id_kind: &IdFieldKind,
532547
) -> Option<Method> {
533548
let loader_name = format!("{}Loader", field.base_type);
534549

550+
let id_conversion = match id_kind {
551+
IdFieldKind::Bytes => "this.get('id')!.toBytes().toHexString()",
552+
_ => "this.get('id')!.toString()",
553+
};
554+
535555
Some(Method::new(
536556
format!("get {}", safe_name),
537557
vec![],
538558
Some(NamedType::new(&loader_name).into()),
539559
format!(
540560
r#"
541-
return new {}('{}', this.get('id')!.toString(), '{}')"#,
542-
loader_name, entity_name, field.name
561+
return new {}('{}', {}, '{}')"#,
562+
loader_name, entity_name, id_conversion, field.name
543563
),
544564
))
545565
}
@@ -636,10 +656,10 @@ impl SchemaCodeGenerator {
636656
/// - Scalars: `String`, `Int`, `BigInt`, etc.
637657
/// - Arrays: `[String]`, `[Int]`, etc.
638658
/// - Nested arrays: `[[String]]`, `[[Int]]`, etc.
639-
/// - Entity references are converted to `String` (their ID type)
659+
/// - Entity references are converted to the referenced entity's ID type
640660
fn value_type_from_field(&self, field: &FieldInfo) -> String {
641-
let base = if self.entity_names.contains(&field.base_type) {
642-
"String".to_string() // Entity references are stored as string IDs
661+
let base = if let Some(id_kind) = self.entity_id_kinds.get(&field.base_type) {
662+
id_kind.gql_type_name().to_string()
643663
} else {
644664
field.base_type.clone()
645665
};
@@ -659,8 +679,8 @@ impl SchemaCodeGenerator {
659679
/// - Arrays: `Array<string>`, `Array<i32>`, etc.
660680
/// - Nested arrays: `Array<Array<string>>`, etc.
661681
fn type_from_field(&self, field: &FieldInfo) -> TypeExpr {
662-
let type_name = if self.entity_names.contains(&field.base_type) {
663-
"string" // Entity references are stored as string IDs
682+
let type_name = if let Some(id_kind) = self.entity_id_kinds.get(&field.base_type) {
683+
id_kind.type_name()
664684
} else {
665685
asc_type_for_value(&field.base_type)
666686
};
@@ -1092,4 +1112,175 @@ mod tests {
10921112
assert_eq!(gen.value_type_from_field(array_field), "[String]");
10931113
assert_eq!(gen.value_type_from_field(matrix_field), "[[String]]");
10941114
}
1115+
1116+
#[test]
1117+
fn test_bytes_id_entity_reference() {
1118+
let schema = r#"
1119+
type Token @entity {
1120+
id: Bytes!
1121+
name: String!
1122+
}
1123+
type Balance @entity {
1124+
id: ID!
1125+
token: Token!
1126+
amount: BigInt!
1127+
}
1128+
"#;
1129+
let doc = parse_schema::<String>(schema).unwrap();
1130+
let gen = SchemaCodeGenerator::new(&doc).unwrap();
1131+
1132+
let classes = gen.generate_types(true);
1133+
let balance = classes.iter().find(|c| c.name == "Balance").unwrap();
1134+
let output = balance.to_string();
1135+
1136+
// Getter should return Bytes and use toBytes()
1137+
assert!(
1138+
output.contains("value.toBytes()"),
1139+
"Bytes-ID entity reference getter should use toBytes(), got: {}",
1140+
output
1141+
);
1142+
1143+
// Setter should use Value.fromBytes()
1144+
assert!(
1145+
output.contains("Value.fromBytes("),
1146+
"Bytes-ID entity reference setter should use Value.fromBytes(), got: {}",
1147+
output
1148+
);
1149+
1150+
// Return type should be Bytes, not string
1151+
assert!(
1152+
output.contains("get token(): Bytes"),
1153+
"Bytes-ID entity reference getter should return Bytes, got: {}",
1154+
output
1155+
);
1156+
}
1157+
1158+
#[test]
1159+
fn test_int8_id_entity_reference() {
1160+
let schema = r#"
1161+
type Counter @entity {
1162+
id: Int8!
1163+
value: BigInt!
1164+
}
1165+
type Snapshot @entity {
1166+
id: ID!
1167+
counter: Counter!
1168+
}
1169+
"#;
1170+
let doc = parse_schema::<String>(schema).unwrap();
1171+
let gen = SchemaCodeGenerator::new(&doc).unwrap();
1172+
1173+
let classes = gen.generate_types(true);
1174+
let snapshot = classes.iter().find(|c| c.name == "Snapshot").unwrap();
1175+
let output = snapshot.to_string();
1176+
1177+
// Getter should return i64 and use toI64()
1178+
assert!(
1179+
output.contains("value.toI64()"),
1180+
"Int8-ID entity reference getter should use toI64(), got: {}",
1181+
output
1182+
);
1183+
1184+
// Setter should use Value.fromI64()
1185+
assert!(
1186+
output.contains("Value.fromI64("),
1187+
"Int8-ID entity reference setter should use Value.fromI64(), got: {}",
1188+
output
1189+
);
1190+
1191+
// Return type should be i64
1192+
assert!(
1193+
output.contains("get counter(): i64"),
1194+
"Int8-ID entity reference getter should return i64, got: {}",
1195+
output
1196+
);
1197+
}
1198+
1199+
#[test]
1200+
fn test_mixed_id_entity_references() {
1201+
let schema = r#"
1202+
type User @entity {
1203+
id: ID!
1204+
name: String!
1205+
}
1206+
type Token @entity {
1207+
id: Bytes!
1208+
owner: User!
1209+
}
1210+
"#;
1211+
let doc = parse_schema::<String>(schema).unwrap();
1212+
let gen = SchemaCodeGenerator::new(&doc).unwrap();
1213+
1214+
let classes = gen.generate_types(true);
1215+
let token = classes.iter().find(|c| c.name == "Token").unwrap();
1216+
let output = token.to_string();
1217+
1218+
// Token.owner references User which has String ID
1219+
assert!(
1220+
output.contains("get owner(): string"),
1221+
"Reference to String-ID entity should use string type, got: {}",
1222+
output
1223+
);
1224+
assert!(
1225+
output.contains("value.toString()"),
1226+
"Reference to String-ID entity should use toString(), got: {}",
1227+
output
1228+
);
1229+
}
1230+
1231+
#[test]
1232+
fn test_nullable_bytes_id_entity_reference() {
1233+
let schema = r#"
1234+
type Token @entity {
1235+
id: Bytes!
1236+
name: String!
1237+
}
1238+
type Balance @entity {
1239+
id: ID!
1240+
token: Token
1241+
amount: BigInt!
1242+
}
1243+
"#;
1244+
let doc = parse_schema::<String>(schema).unwrap();
1245+
let gen = SchemaCodeGenerator::new(&doc).unwrap();
1246+
1247+
let classes = gen.generate_types(true);
1248+
let balance = classes.iter().find(|c| c.name == "Balance").unwrap();
1249+
let output = balance.to_string();
1250+
1251+
// Nullable Bytes reference should be `Bytes | null`
1252+
assert!(
1253+
output.contains("get token(): Bytes | null"),
1254+
"Nullable Bytes-ID reference should return Bytes | null, got: {}",
1255+
output
1256+
);
1257+
}
1258+
1259+
#[test]
1260+
fn test_derived_field_with_bytes_id_parent() {
1261+
let schema = r#"
1262+
type Token @entity {
1263+
id: Bytes!
1264+
balances: [Balance!]! @derivedFrom(field: "token")
1265+
}
1266+
type Balance @entity {
1267+
id: ID!
1268+
token: Token!
1269+
amount: BigInt!
1270+
}
1271+
"#;
1272+
let doc = parse_schema::<String>(schema).unwrap();
1273+
let gen = SchemaCodeGenerator::new(&doc).unwrap();
1274+
1275+
let classes = gen.generate_types(true);
1276+
let token = classes.iter().find(|c| c.name == "Token").unwrap();
1277+
let output = token.to_string();
1278+
1279+
// Derived field getter on Bytes-ID entity should use toBytes().toHexString()
1280+
assert!(
1281+
output.contains("this.get('id')!.toBytes().toHexString()"),
1282+
"Derived field on Bytes-ID entity should use toBytes().toHexString(), got: {}",
1283+
output
1284+
);
1285+
}
10951286
}

0 commit comments

Comments
 (0)